diff --git a/.ai/active/SPRINT_PACKET.md b/.ai/active/SPRINT_PACKET.md new file mode 100644 index 0000000..445936f --- /dev/null +++ b/.ai/active/SPRINT_PACKET.md @@ -0,0 +1,123 @@ +# SPRINT_PACKET.md + +## Sprint Title + +Sprint 5A: Task Workspace Records and Provisioning + +## Sprint Type + +feature + +## Sprint Reason + +Milestone 5 should start at the workspace boundary, not at document ingestion or connectors. The repo now has the task and execution substrate needed to add one deterministic, user-scoped task workspace seam without expanding product scope. + +## Sprint Intent + +Begin Milestone 5 by adding user-scoped task workspace records plus deterministic local workspace provisioning, so later artifact handling, document ingestion, and read-only connectors have a governed workspace boundary to build on. + +## Git Instructions + +- Branch Name: `codex/sprint-5a-task-workspaces` +- Base Branch: `main` +- PR Strategy: one sprint branch, one PR, no stacked PRs unless Control Tower explicitly opens a follow-up sprint on top of this branch +- Merge Policy: squash merge only after reviewer `PASS`; if review fails, repair on the same branch until pass or explicit abandonment + +## Why This Sprint + +- Sprint 4S is implemented and passed: approvals and executions now both use explicit task-step linkage, so the Milestone 4 lifecycle substrate is in place. +- The roadmap says workspace and artifact boundaries should land before document-heavy or connector-heavy flows rely on them. +- The narrowest safe Milestone 5 entry slice is workspace provisioning only, not artifact indexing, document ingestion, or connectors. +- This keeps sequencing boring and maintainable by establishing the workspace boundary first. + +## In Scope + +- Add schema and migration support for: + - `task_workspaces` +- Define typed contracts for: + - workspace create responses + - workspace list responses + - workspace detail responses +- Implement a minimal workspace seam that: + - provisions one deterministic local workspace path for a visible task + - persists one user-scoped workspace record linked to that task + - validates the workspace path is rooted under one configured workspace base directory + - prevents duplicate active workspace creation for the same task + - exposes deterministic list and detail reads +- Implement the minimal API or service paths needed for: + - creating a workspace for a task + - listing workspaces + - reading one workspace by id +- Add unit and integration tests for: + - workspace creation + - deterministic path generation + - duplicate-create rejection for the same task + - per-user isolation + - stable response shape + +## Out of Scope + +- No artifact inventory or artifact metadata table yet. +- No document ingestion. +- No chunking, embeddings, or document retrieval. +- No Gmail or Calendar connector scope. +- No runner-style orchestration. +- No new proxy handlers or broader side-effect expansion. + +## Required Deliverables + +- Migration for `task_workspaces`. +- Stable workspace create/list/detail contracts. +- Minimal deterministic task-workspace provisioning and persistence path. +- Unit and integration coverage for provisioning, path safety, duplicate protection, and isolation. +- Updated `BUILD_REPORT.md` with exact verification results and explicit deferred scope. + +## Acceptance Criteria + +- A client can provision one user-scoped workspace for a visible task. +- Every workspace record stores a deterministic local path under the configured workspace root. +- Duplicate active workspace creation for the same task is rejected deterministically. +- Workspace list and detail reads are deterministic and user-scoped. +- `./.venv/bin/python -m pytest tests/unit` passes. +- `./.venv/bin/python -m pytest tests/integration` passes. +- No artifact indexing, document ingestion, connector, runner, handler-expansion, or broader side-effect scope enters the sprint. + +## Implementation Constraints + +- Keep the workspace seam narrow and boring. +- Provision only local workspace boundaries; do not invent remote storage abstractions in this sprint. +- Keep workspace paths deterministic, explicit, and rooted under one configured base directory. +- Reuse existing task ownership and isolation seams rather than creating a parallel authorization path. +- Do not add artifact scanning, file sync, or document parsing in the same sprint. + +## Suggested Work Breakdown + +1. Add `task_workspaces` schema and migration. +2. Define workspace create/list/detail contracts. +3. Implement deterministic workspace path generation rooted under the configured base directory. +4. Implement workspace create, list, and detail behavior with duplicate protection. +5. Add unit and integration tests. +6. Update `BUILD_REPORT.md` with executed verification. + +## Build Report Requirements + +`BUILD_REPORT.md` must include: +- the exact workspace schema and contract changes introduced +- the configured workspace root and path-generation rule used +- exact commands run +- unit and integration test results +- one example workspace create response +- one example workspace detail response +- what remains intentionally deferred to later milestones + +## Review Focus + +`REVIEW_REPORT.md` should verify: +- the sprint stayed limited to task workspace records and provisioning +- workspace paths are deterministic, rooted safely, and user-scoped +- duplicate protection, ordering, and isolation are test-backed +- no hidden artifact indexing, document ingestion, connector, runner, handler-expansion, or broader side-effect scope entered the sprint + +## Exit Condition + +This sprint is complete when the repo can provision deterministic user-scoped task workspace records under a configured local workspace root, expose stable workspace reads, and verify the full path with Postgres-backed tests, while still deferring artifact handling, document ingestion, and connector work. diff --git a/.ai/agents/.gitkeep b/.ai/agents/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/.ai/agents/.gitkeep @@ -0,0 +1 @@ + diff --git a/.ai/handoff/CURRENT_STATE.md b/.ai/handoff/CURRENT_STATE.md new file mode 100644 index 0000000..be3a187 --- /dev/null +++ b/.ai/handoff/CURRENT_STATE.md @@ -0,0 +1,53 @@ +# Current State + +## What Exists Today + +- Canonical project docs now describe the shipped repo state through Sprint 4O. +- `apps/api` implements the accepted backend seams for continuity, tracing, context compilation, governed memory, memory review, embeddings, semantic retrieval, entities, policies, tools, approvals, approved proxy execution, execution budgets, execution review, tasks, task steps, and explicit manual continuation lineage. +- The live schema now includes continuity tables, trace tables, memory tables, embedding tables, entity tables, governance tables, plus `tasks` and `task_steps`. +- `apps/web` and `workers` remain starter scaffolds only; no workspace UI, runner, or background-job orchestration is shipped. + +## Stable / Trusted Areas + +- Immutable event log and persisted trace model with per-user isolation. +- Deterministic context compilation and deterministic prompt assembly over durable sources. +- Governed memory admission, narrow deterministic explicit-preference extraction, explicit embedding storage, semantic retrieval, and deterministic hybrid memory merge during compile. +- Deterministic policy evaluation, tool allowlist evaluation, tool routing, approval persistence, approval resolution, approved-only `proxy.echo` execution, durable execution review, and execution-budget enforcement. +- Durable task and task-step reads, deterministic task-step sequencing, explicit task-step transitions, and explicit manual continuation with lineage validated against the parent step outcome. +- Sprint 4O review verification: + - `./.venv/bin/python -m pytest tests/unit` -> `284 passed` + - `./.venv/bin/python -m pytest tests/integration` -> `95 passed` + +## Incomplete / At-Risk Areas + +- Auth beyond DB user context is still unimplemented. +- Memory extraction and retrieval quality remain major ship-gating risks. +- Document ingestion, scoped task workspaces, artifact handling, and read-only connectors have not started in code. +- The current multi-step boundary is still narrow: approval-resolution and execution-synchronization helpers continue to target `task_steps.sequence_no = 1`, even though manual continuation is now implemented for later steps. + +## Current Milestone Position + +- The repo has completed the implementation planned through Milestone 4. +- Milestone 5 has not started in shipped code. +- The project is at a truth-sync checkpoint before Milestone 5 entry. + +## Latest State Summary + +- Local runtime assets exist for Docker Compose, Postgres bootstrap, API startup, migrations, and backend tests. +- `POST /v0/approvals/requests` now creates one durable task plus one initial task step for each routed governed request, with task and task-step lifecycle traces. +- `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `GET /v0/tasks/{task_id}/steps`, and `GET /v0/task-steps/{task_step_id}` expose durable task/task-step review reads with deterministic ordering. +- `POST /v0/tasks/{task_id}/steps` now appends exactly one manual continuation step when the latest step is appendable and explicit lineage points to that latest visible parent step. +- `POST /v0/task-steps/{task_step_id}/transition` now advances only the latest visible step through the explicit status graph and keeps the parent task status synchronized. +- Task-step lineage is trace-visible through `task.step.continuation.request`, `task.step.continuation.lineage`, and `task.step.continuation.summary` events. + +## Critical Constraints + +- Do not treat planned workspace, connector, runner, or broader side-effect work as implemented. +- Do not bypass approval boundaries for consequential actions. +- Do not replace compiled durable context with raw transcript stuffing. +- Appended task steps must carry explicit lineage; do not infer provenance heuristically from task history. +- Keep the current multi-step boundary explicit until the first-step lifecycle helpers are removed or constrained. + +## Immediate Next Move + +- Take the smallest follow-up sprint that removes or explicitly constrains the remaining `task_steps.sequence_no = 1` approval/execution synchronization assumptions before any runner, workspace, or connector work begins. diff --git a/.ai/templates/.gitkeep b/.ai/templates/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/.ai/templates/.gitkeep @@ -0,0 +1 @@ + diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..923dfc8 --- /dev/null +++ b/.env.example @@ -0,0 +1,12 @@ +APP_ENV=development +APP_HOST=127.0.0.1 +APP_PORT=8000 +DATABASE_URL=postgresql://alicebot_app:alicebot_app@localhost:5432/alicebot +DATABASE_ADMIN_URL=postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot +REDIS_URL=redis://localhost:6379/0 +S3_ENDPOINT_URL=http://localhost:9000 +S3_ACCESS_KEY=alicebot +S3_SECRET_KEY=alicebot-secret +S3_BUCKET=alicebot-local +HEALTHCHECK_TIMEOUT_SECONDS=2 +TASK_WORKSPACE_ROOT=/tmp/alicebot/task-workspaces diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6acc4a7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +.env +.pytest_cache/ +.venv/ +*.egg-info/ +__pycache__/ +*.pyc +apps/web/.next/ +apps/web/node_modules/ diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..9cefafb --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,207 @@ +# Architecture + +## Current Implemented Slice + +AliceBot now implements the accepted repo slice through Sprint 5A. The shipped backend includes: + +- foundation continuity storage over `users`, `threads`, `sessions`, and append-only `events` +- deterministic tracing and context compilation over durable continuity, memory, entity, and entity-edge records +- governed memory admission, explicit-preference extraction, memory review labels, review queue reads, evaluation summary reads, explicit embedding config and memory-embedding storage, direct semantic retrieval, and deterministic hybrid compile-path memory merge +- deterministic prompt assembly and one no-tools response path that persists assistant replies as immutable continuity events +- user-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, tool routing, approval request persistence, approval resolution, approved-only proxy execution through the in-process `proxy.echo` handler, durable execution review, and execution-budget lifecycle plus enforcement +- durable `tasks`, `task_steps`, and `task_workspaces`, deterministic task-step sequencing, explicit task-step transitions, explicit manual continuation with lineage through `parent_step_id`, `source_approval_id`, and `source_execution_id`, explicit `tool_executions.task_step_id` linkage for execution synchronization, and deterministic rooted local task-workspace provisioning + +The current multi-step boundary is narrow and explicit. Manual continuation is implemented and review-passed. Approval resolution and proxy execution now both use explicit task-step linkage rather than first-step inference. Task workspaces are now implemented only as deterministic rooted local boundaries. Broader runner-style orchestration, automatic multi-step progression, artifact indexing, document ingestion, connectors, and new side-effect surfaces are still planned later and must not be described as live behavior. + +## Implemented Now + +### Runtime + +- `docker-compose.yml` starts local Postgres with `pgvector`, Redis, and MinIO. +- `scripts/dev_up.sh`, `scripts/migrate.sh`, and `scripts/api_dev.sh` provide the local startup path, with readiness gating before migrations. +- `apps/api` exposes FastAPI endpoints for: + - health and compile: `/healthz`, `POST /v0/context/compile`, `POST /v0/responses` + - memory and retrieval: `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `GET /v0/memories`, `GET /v0/memories/review-queue`, `GET /v0/memories/evaluation-summary`, `POST /v0/memories/semantic-retrieval`, `GET /v0/memories/{memory_id}`, `GET /v0/memories/{memory_id}/revisions`, `POST /v0/memories/{memory_id}/labels`, `GET /v0/memories/{memory_id}/labels` + - embeddings and graph seams: `POST /v0/embedding-configs`, `GET /v0/embedding-configs`, `POST /v0/memory-embeddings`, `GET /v0/memories/{memory_id}/embeddings`, `GET /v0/memory-embeddings/{memory_embedding_id}`, `POST /v0/entities`, `GET /v0/entities`, `GET /v0/entities/{entity_id}`, `POST /v0/entity-edges`, `GET /v0/entities/{entity_id}/edges` + - governance: `POST /v0/consents`, `GET /v0/consents`, `POST /v0/policies`, `GET /v0/policies`, `GET /v0/policies/{policy_id}`, `POST /v0/policies/evaluate`, `POST /v0/tools`, `GET /v0/tools`, `GET /v0/tools/{tool_id}`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, `POST /v0/approvals/{approval_id}/execute` + - task and execution review: `GET /v0/tasks`, `GET /v0/tasks/{task_id}`, `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, `GET /v0/task-workspaces/{task_workspace_id}`, `GET /v0/tasks/{task_id}/steps`, `GET /v0/task-steps/{task_step_id}`, `POST /v0/tasks/{task_id}/steps`, `POST /v0/task-steps/{task_step_id}/transition`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, `GET /v0/tool-executions/{execution_id}` +- `apps/web` and `workers` remain starter shells only. + +### Data Foundation + +- Postgres is the current system of record. +- Alembic manages schema changes through `apps/api/alembic`. +- The live schema includes: + - continuity tables: `users`, `threads`, `sessions`, `events` + - trace tables: `traces`, `trace_events` + - memory and retrieval tables: `memories`, `memory_revisions`, `memory_review_labels`, `embedding_configs`, `memory_embeddings` + - graph tables: `entities`, `entity_edges` + - governance tables: `consents`, `policies`, `tools`, `approvals`, `tool_executions`, `execution_budgets` + - task lifecycle tables: `tasks`, `task_steps`, `task_workspaces` +- `events`, `trace_events`, and `memory_revisions` are append-only by application contract and database enforcement. +- `memory_review_labels` are append-only by database enforcement. +- `tasks` are explicit user-scoped lifecycle records keyed to one thread and one tool, with durable request/tool snapshots, status in `pending_approval | approved | executed | denied | blocked`, and latest approval/execution pointers for the current narrow lifecycle seam. +- `task_steps` are explicit user-scoped ordered lifecycle records keyed by `(user_id, task_id, sequence_no)`, with `kind = 'governed_request'`, status in `created | approved | executed | blocked | denied`, durable request/outcome snapshots, and one trace reference describing the latest mutation. +- Sprint 4O added lineage columns on `task_steps`: + - `parent_step_id` + - `source_approval_id` + - `source_execution_id` +- Lineage fields are guarded by composite user-scoped foreign keys and a self-reference check so a step cannot cite itself as its parent. +- `tool_executions` now persist an explicit `task_step_id` linked by a composite foreign key to `task_steps(id, user_id)`. +- `task_workspaces` persist one active workspace record per visible task and user, store a deterministic `local_path`, and enforce that active uniqueness through a partial unique index on `(user_id, task_id)`. +- `execution_budgets` enforce at most one active budget per `(user_id, tool_key, domain_hint)` selector scope through a partial unique index. +- Per-request user context is set in the database through `app.current_user_id()`. +- `TASK_WORKSPACE_ROOT` defines the only allowed base directory for workspace provisioning, and the live path rule is `resolved_root / user_id / task_id`. + +### Repo Boundaries In This Slice + +- `apps/api`: implemented API, store, contracts, service logic, and migrations for continuity, tracing, memory, embeddings, entities, policies, tools, approvals, proxy execution, execution budgets, tasks, task steps, and task workspaces. +- `apps/web`: minimal shell only; no shipped workflow UI. +- `workers`: scaffold only; no background jobs or runner logic are implemented. +- `infra`: local development bootstrap assets only. +- `tests`: unit and Postgres-backed integration coverage for the shipped seams above, including Sprint 4O task-step lineage/manual continuation, Sprint 4S step-linked execution synchronization, and Sprint 5A task-workspace provisioning. + +## Core Flows Implemented Now + +### Deterministic Context Compilation + +1. Accept a user-scoped `POST /v0/context/compile` request. +2. Read durable continuity records in deterministic order. +3. Merge in active memories, entities, and entity edges through the currently shipped symbolic and optional semantic retrieval paths. +4. Persist a `context.compile` trace plus explicit inclusion and exclusion events. +5. Return one deterministic `context_pack` describing scope, limits, selected context, and trace metadata. + +### Governed Memory And Retrieval + +1. Accept explicit memory candidates through `POST /v0/memories/admit`. +2. Require cited source events, default to `NOOP`, and persist `memory_revisions` only for evidence-backed non-`NOOP` mutations. +3. Support a narrow deterministic explicit-preference extractor over stored `message.user` events. +4. Persist user-scoped embedding configs and memory embeddings explicitly. +5. Support direct semantic retrieval over active memories for a caller-selected embedding config. +6. Merge symbolic and semantic memory results deterministically into the compile path with trace-visible source provenance. +7. Expose review reads, unlabeled review queue reads, evaluation summary reads, and append-only memory-review labels. + +### Policy, Tool, Approval, And Execution Governance + +1. Evaluate policies deterministically over active user-scoped policy and consent state. +2. Evaluate tool allowlists against active tool metadata plus policy decisions. +3. Route one requested invocation deterministically to `ready`, `denied`, or `approval_required`. +4. Persist durable approval rows only for `approval_required` outcomes. +5. Resolve approvals explicitly through approve and reject endpoints. +6. Execute approved requests only through the registered proxy-handler map. +7. In the current repo, only `proxy.echo` is enabled, and it performs no external I/O. +8. Persist one durable `tool_executions` row for every approved execution attempt, including budget-blocked attempts. +9. Enforce narrow execution budgets by selector scope and optional rolling window before approved dispatch. + +### Task Lifecycle Creation + +1. `POST /v0/approvals/requests` always creates one durable `tasks` row and one initial `task_steps` row, even when no approval row is persisted. +2. The initial task and task step reflect the routing decision: + - `approval_required` creates `task.status = pending_approval` and `task_step.status = created` + - `ready` creates `task.status = approved` and `task_step.status = approved` + - `denied` creates `task.status = denied` and `task_step.status = denied` +3. The initial task step is always `sequence_no = 1`. +4. Approval-request traces include task lifecycle and task-step lifecycle events alongside the approval request events. + +### Approval Resolution And Proxy Execution Synchronization + +1. Approval resolution reuses the existing task seam and updates the durable task plus the explicitly linked task step from `approvals.task_step_id`. +2. Approval resolution rejects missing, invisible, cross-task, and inconsistent approval-to-step linkage deterministically. +3. Approved proxy execution validates the approval’s linked task step before dispatch and persists `tool_executions.task_step_id` on every durable execution row. +4. Execution synchronization now reuses `tool_executions.task_step_id` and updates the explicitly linked step by id rather than inferring `sequence_no = 1`. +5. Execution synchronization rejects missing, invisible, cross-task, and inconsistent execution-to-step linkage deterministically before mutating task or task-step state. + +### Task-Step Manual Continuation + +1. Accept a user-scoped `POST /v0/tasks/{task_id}/steps` request to append exactly one next step to an existing task. +2. Lock the task-step sequence before allocating the next `sequence_no`. +3. Require the task to already have visible steps. +4. Allow append only when the latest visible step is in `executed`, `blocked`, or `denied`. +5. Require explicit lineage: + - `lineage.parent_step_id` must be present + - the parent step must belong to the same visible task + - the parent step must be the latest visible task step +6. Optionally allow `lineage.source_approval_id` and `lineage.source_execution_id`, but only when: + - the referenced records are visible in the current user scope + - the referenced records already appear on the parent step outcome +7. Persist the new `task_steps` row with the lineage fields and incremented `sequence_no`. +8. Update the parent `tasks` row to the task status implied by the appended step status. +9. Persist one `task.step.continuation` trace plus request, lineage, summary, task lifecycle, and task-step lifecycle events. +10. Return the updated task, the appended step, deterministic sequencing metadata, and trace summary. + +### Task-Step Transition + +1. Accept a user-scoped `POST /v0/task-steps/{task_step_id}/transition` request. +2. Require the referenced step to be the latest visible step on its task. +3. Enforce the explicit status graph: + - `created -> approved | denied` + - `approved -> executed | blocked` + - terminal states have no further transitions +4. Require approval linkage when the step must reflect approval state and execution linkage when the step must reflect execution state. +5. Update the target step in place with a new trace reference and outcome snapshot. +6. Update the parent task status and latest approval/execution pointers consistently. +7. Persist one `task.step.transition` trace plus request, state, summary, task lifecycle, and task-step lifecycle events. + +### Task And Task-Step Reads + +1. `GET /v0/tasks` lists durable task rows in deterministic `created_at ASC, id ASC` order. +2. `GET /v0/tasks/{task_id}` returns one user-visible task detail record. +3. `GET /v0/tasks/{task_id}/steps` returns task steps in deterministic `sequence_no ASC, created_at ASC, id ASC` order plus sequencing summary metadata. +4. `GET /v0/task-steps/{task_step_id}` returns one user-visible task-step detail record. +5. Task-step list and detail reads expose lineage fields directly. + +### Task Workspace Provisioning + +1. Accept a user-scoped `POST /v0/tasks/{task_id}/workspace` request for one visible task. +2. Resolve the configured `TASK_WORKSPACE_ROOT`. +3. Build the deterministic local path as `resolved_root / user_id / task_id`. +4. Reject provisioning if the resolved workspace path escapes the resolved workspace root. +5. Lock workspace creation for the target task before checking for an existing active workspace. +6. Reject duplicate active workspace creation for the same visible task deterministically. +7. Create the local directory boundary and persist one `task_workspaces` row with `status = active` and the rooted `local_path`. +8. `GET /v0/task-workspaces` lists visible workspaces in deterministic `created_at ASC, id ASC` order. +9. `GET /v0/task-workspaces/{task_workspace_id}` returns one user-visible workspace detail record. + +## Security Model Implemented Now + +- User-owned continuity, trace, memory, embedding, entity, governance, task, task-step, and task-workspace tables enforce row-level security. +- The runtime role is limited to the narrow `SELECT` / `INSERT` / `UPDATE` permissions required by the shipped seams; there is no broad DDL or unrestricted table access at runtime. +- Cross-user references are constrained through composite foreign keys on `(id, user_id)` where the schema needs ownership-linked joins. +- Approval, execution, memory, entity, task/task-step, and task-workspace reads all operate only inside the current user scope. +- Task-step manual continuation adds both schema-level and service-level lineage protection: + - schema-level: user-scoped foreign keys and parent-not-self check + - service-level: same-task, latest-step, visible-approval, visible-execution, and parent-outcome-match validation +- In-place updates and deletes remain blocked for append-only continuity and trace records. + +## Testing Coverage Implemented Now + +- Unit and integration tests cover continuity, compiler, response generation, memory admission, review labels, review queue, embeddings, semantic retrieval, entities, policies, tools, approvals, proxy execution, execution budgets, and execution review. +- Sprint 4O, Sprint 4S, and Sprint 5A added explicit task lifecycle coverage: + - migrations for `tasks`, `task_steps`, and task-step lineage + - staged/backfilled migration coverage for `tool_executions.task_step_id` + - task and task-step store contracts + - task list/detail and task-step list/detail reads + - deterministic sequencing summaries + - manual continuation success paths + - task-step transition success paths + - explicit later-step execution synchronization by linked `task_step_id` + - deterministic task-workspace path generation and rooted-path enforcement + - workspace create/list/detail response shape + - duplicate active workspace rejection + - task-workspace per-user isolation + - trace visibility for continuation and transition events + - user isolation for task and task-step reads and mutations + - adversarial lineage validation for cross-task, cross-user, and parent-step mismatch cases + +## Planned Later + +The following areas remain planned later and must not be described as implemented: + +- runner-style orchestration and automatic multi-step progression beyond the current explicit manual continuation seam +- artifact storage, artifact indexing, and document ingestion beyond the current rooted local workspace boundary +- read-only Gmail and Calendar connectors +- broader tool proxying and real-world side effects beyond the current no-I/O `proxy.echo` handler +- model-driven extraction, reranking, and broader memory review automation +- production deployment automation beyond the local developer stack + +Future docs and code should continue to distinguish the implemented seams above from these later milestones. diff --git a/ARCHIVE_RECOMMENDATIONS.md b/ARCHIVE_RECOMMENDATIONS.md new file mode 100644 index 0000000..a7c178e --- /dev/null +++ b/ARCHIVE_RECOMMENDATIONS.md @@ -0,0 +1,19 @@ +# Archive Recommendations + +## Archive Instead Of Keeping In Live Agent Memory + +- Investor framing, executive rhetoric, and narrative persuasion. +- Long strategy memos once their decisions have been distilled into product, architecture, roadmap, and rules. +- Raw brainstorms, option dumps, and redundant scope alternatives. +- Verbose roadmap history and schedule speculation. +- Meeting notes, implementation diaries, and retrospective prose. +- Detailed vendor pricing snapshots and model-cost assumptions that will drift quickly. +- Duplicate sprint plans or decomposition notes once a current sprint packet exists. +- Example-heavy explanatory text that does not change the operating rules. + +## Keep As Archived Reference Only + +- Original source plans and memos that fed the bootstrap. +- Full schema sketches and endpoint catalogs before implementation-specific docs are created. +- Older roadmap versions after milestone sequencing changes. +- Historical task and review notes that may help reconstruct decision context later. diff --git a/BUILD_REPORT.md b/BUILD_REPORT.md new file mode 100644 index 0000000..c525cd9 --- /dev/null +++ b/BUILD_REPORT.md @@ -0,0 +1,181 @@ +# BUILD_REPORT + +## sprint objective + +Implement Sprint 5A: Task Workspace Records and Provisioning by adding user-scoped `task_workspaces`, deterministic local workspace provisioning under one configured root, duplicate-active protection per task, and stable workspace create/list/detail reads. + +## completed work + +- Added workspace schema and migration: + - new migration `apps/api/alembic/versions/20260313_0022_task_workspaces.py` + - new table `task_workspaces` with `id`, `user_id`, `task_id`, `status`, `local_path`, `created_at`, and `updated_at` + - user/task foreign key `(task_id, user_id) -> tasks(id, user_id)` + - partial unique index enforcing one active workspace per task and user + - RLS policy plus runtime grants limited to `SELECT, INSERT` +- Added workspace configuration and deterministic pathing: + - new setting `TASK_WORKSPACE_ROOT` + - default workspace root: `/tmp/alicebot/task-workspaces` + - path-generation rule: `//` + - workspace provisioning validates the resolved path stays rooted under the resolved workspace root before creating the directory +- Added typed contracts and service behavior: + - `TaskWorkspaceStatus` + - `TaskWorkspaceCreateInput` + - `TaskWorkspaceRecord` + - `TaskWorkspaceCreateResponse` + - `TaskWorkspaceListResponse` + - `TaskWorkspaceDetailResponse` + - new workspace service in `apps/api/src/alicebot_api/workspaces.py` + - duplicate active workspace creation for the same visible task now raises a deterministic conflict +- Added minimal API paths: + - `POST /v0/tasks/{task_id}/workspace` + - `GET /v0/task-workspaces` + - `GET /v0/task-workspaces/{task_workspace_id}` +- Added coverage for: + - deterministic path generation + - rooted path safety validation + - workspace creation + - duplicate-create rejection + - per-user isolation + - stable response shape + - migration upgrade/downgrade expectations including the new table, RLS, and privileges + +## exact workspace schema and contract changes introduced + +- Schema: + - `task_workspaces.id uuid PRIMARY KEY DEFAULT gen_random_uuid()` + - `task_workspaces.user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE` + - `task_workspaces.task_id uuid NOT NULL` + - `task_workspaces.status text NOT NULL CHECK (status IN ('active'))` + - `task_workspaces.local_path text NOT NULL CHECK (length(local_path) > 0)` + - `task_workspaces.created_at timestamptz NOT NULL DEFAULT now()` + - `task_workspaces.updated_at timestamptz NOT NULL DEFAULT now()` + - `CONSTRAINT task_workspaces_task_user_fk FOREIGN KEY (task_id, user_id) REFERENCES tasks(id, user_id) ON DELETE CASCADE` + - `CREATE UNIQUE INDEX task_workspaces_active_task_idx ON task_workspaces (user_id, task_id) WHERE status = 'active'` +- Store layer: + - `TaskWorkspaceRow` + - `ContinuityStore.lock_task_workspaces(...)` + - `ContinuityStore.create_task_workspace(...)` + - `ContinuityStore.get_task_workspace_optional(...)` + - `ContinuityStore.get_active_task_workspace_for_task_optional(...)` + - `ContinuityStore.list_task_workspaces(...)` +- Contracts: + - `TaskWorkspaceStatus = Literal["active"]` + - `TaskWorkspaceCreateInput.task_id` + - `TaskWorkspaceCreateInput.status` + - `TaskWorkspaceRecord.id` + - `TaskWorkspaceRecord.task_id` + - `TaskWorkspaceRecord.status` + - `TaskWorkspaceRecord.local_path` + - `TaskWorkspaceRecord.created_at` + - `TaskWorkspaceRecord.updated_at` + - `TaskWorkspaceCreateResponse.workspace` + - `TaskWorkspaceListResponse.items` + - `TaskWorkspaceListResponse.summary` + - `TaskWorkspaceDetailResponse.workspace` + +## configured workspace root and path-generation rule used + +- Default configured workspace root: `/tmp/alicebot/task-workspaces` +- Test override root: per-test temp directory via `Settings(task_workspace_root=...)` +- Deterministic path rule: `resolved_root / str(user_id) / str(task_id)` +- Safety rule: the resolved workspace path must remain under the resolved configured root or provisioning fails before persistence + +## incomplete work + +- None inside Sprint 5A scope. + +## files changed + +- `apps/api/alembic/versions/20260313_0022_task_workspaces.py` +- `apps/api/src/alicebot_api/config.py` +- `apps/api/src/alicebot_api/contracts.py` +- `apps/api/src/alicebot_api/main.py` +- `apps/api/src/alicebot_api/store.py` +- `apps/api/src/alicebot_api/workspaces.py` +- `tests/integration/test_migrations.py` +- `tests/integration/test_task_workspaces_api.py` +- `tests/unit/test_20260313_0022_task_workspaces.py` +- `tests/unit/test_config.py` +- `tests/unit/test_main.py` +- `tests/unit/test_task_workspace_store.py` +- `tests/unit/test_workspaces.py` +- `tests/unit/test_workspaces_main.py` +- `BUILD_REPORT.md` + +## exact commands run + +- `./.venv/bin/python -m pytest tests/unit/test_workspaces.py tests/unit/test_workspaces_main.py tests/unit/test_task_workspace_store.py tests/unit/test_20260313_0022_task_workspaces.py tests/unit/test_config.py tests/unit/test_main.py` +- `./.venv/bin/python -m pytest tests/integration/test_task_workspaces_api.py tests/integration/test_migrations.py` + - initial sandbox run failed because sandboxed localhost Postgres access was blocked +- `./.venv/bin/python -m pytest tests/unit` +- `./.venv/bin/python -m pytest tests/integration` + +## tests run + +- `./.venv/bin/python -m pytest tests/unit/test_workspaces.py tests/unit/test_workspaces_main.py tests/unit/test_task_workspace_store.py tests/unit/test_20260313_0022_task_workspaces.py tests/unit/test_config.py tests/unit/test_main.py` + - passed: `56 passed in 0.50s` +- `./.venv/bin/python -m pytest tests/integration/test_task_workspaces_api.py tests/integration/test_migrations.py` + - sandboxed run failed before test execution could start against Postgres: `3 errors in 0.21s` +- `./.venv/bin/python -m pytest tests/unit` + - passed: `315 passed in 0.57s` +- `./.venv/bin/python -m pytest tests/integration` + - passed outside the sandbox: `99 passed in 28.56s` + +## unit and integration test results + +- Unit suite: + - green + - covers config loading, migration statement order, store queries, workspace service behavior, rooted path safety, duplicate rejection, route registration, and endpoint error mapping +- Integration suite: + - green + - covers migration upgrade/downgrade expectations, workspace API provisioning, duplicate rejection, deterministic list/detail responses, and per-user isolation against Postgres + +## one example workspace create response + +```json +{ + "workspace": { + "id": "11111111-1111-1111-1111-111111111111", + "task_id": "22222222-2222-2222-2222-222222222222", + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/33333333-3333-3333-3333-333333333333/22222222-2222-2222-2222-222222222222", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00" + } +} +``` + +## one example workspace detail response + +```json +{ + "workspace": { + "id": "11111111-1111-1111-1111-111111111111", + "task_id": "22222222-2222-2222-2222-222222222222", + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/33333333-3333-3333-3333-333333333333/22222222-2222-2222-2222-222222222222", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00" + } +} +``` + +## blockers/issues + +- No implementation blocker remains. +- Verification note: + - Postgres-backed integration tests required unsandboxed access to `localhost:5432`; the initial sandboxed focused integration run failed with connection-permission errors before being rerun successfully outside the sandbox. + +## what remains intentionally deferred to later milestones + +- Artifact inventory and artifact metadata tables +- Document ingestion +- Chunking, embeddings, or document retrieval tied to workspaces +- Gmail or Calendar connector scope +- Runner-style orchestration +- New proxy handlers or broader side-effect expansion +- Any remote storage abstraction beyond the local deterministic workspace boundary added here + +## recommended next step + +Build the next workspace-dependent milestone slice on top of this boundary without widening the seam: artifact or document work should consume `task_workspaces` records and the configured rooted local path instead of inventing a parallel storage contract. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3fd0cc0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,27 @@ +# Changelog + +## 2026-03-11 + +- Redacted embedded Redis credentials from `/healthz` so the endpoint no longer echoes `REDIS_URL` secrets back to callers. +- Added readiness gating to `./scripts/dev_up.sh` so bootstrap waits for Postgres and `alicebot_app` role initialization before running migrations. +- Bound local Postgres, Redis, and MinIO ports to `127.0.0.1` by default and removed the unnecessary runtime-role `CONNECT` grant on the shared `postgres` database. +- Removed the redundant `(thread_id, sequence_no)` events index from the base continuity migration because the unique constraint already provides that index. +- Tightened architecture, roadmap, handoff, and builder-report wording so exposed routes and environment-specific verification claims stay accurate. +- Tightened the runtime Postgres role so the continuity tables are insert/select-only in the migration chain and for upgraded databases. +- Stopped the base migration downgrade from dropping shared `pgcrypto` and `vector` extensions. +- Made the local helper scripts prefer `.venv/bin/python` when the project virtualenv exists, falling back to `python3` otherwise. +- Corrected `/healthz` so only Postgres is reported as live-checked, while Redis and MinIO are surfaced as configured but `not_checked`. +- Fixed Alembic runtime URL handling so migrations use the installed `psycopg` SQLAlchemy driver instead of the missing `psycopg2` default. +- Fixed concurrent event append sequencing by acquiring the per-thread advisory lock before reading the next `sequence_no`. +- Verified the local foundation runtime with `docker compose up -d`, `./scripts/migrate.sh`, `./.venv/bin/python -m pytest tests/unit tests/integration`, and a live `GET /healthz`. + +## 2026-03-10 + +- Bootstrapped the canonical project operating files. +- Created the initial AI handoff snapshot and first sprint packet. +- Added the recommended repo scaffolding directories for implementation work. +- Added local Docker Compose infrastructure for Postgres with `pgvector`, Redis, and MinIO. +- Added the FastAPI foundation scaffold, configuration loading, `/healthz`, and Alembic migration plumbing. +- Added continuity tables for `users`, `threads`, `sessions`, and append-only `events` with RLS and isolation tests. +- Fixed the local quick-start path so repo scripts source `.env`, use `python3`, and keep migrations pointed at the `alicebot` database. +- Serialized same-thread event appends before sequence allocation and added an integration test for concurrent event numbering. diff --git a/PRODUCT_BRIEF.md b/PRODUCT_BRIEF.md new file mode 100644 index 0000000..1735c23 --- /dev/null +++ b/PRODUCT_BRIEF.md @@ -0,0 +1,77 @@ +# Product Brief + +## Product Summary + +AliceBot is a private, permissioned personal AI operating system for a single primary user. It is designed to preserve durable personal context, retrieve the right context at the right time, and move safely from conversation to action without hiding why it acted. + +## Problem + +General-purpose assistants forget preferences, prior decisions, and relationships across sessions. They also make it difficult to audit why they answered a certain way or whether a tool action was properly governed. The result is low trust, repeated user effort, and unsafe action handling. + +## Target Users + +- Primary v1 user: one power user with recurring life and work workflows. +- Delivery model: a human lead working with AI builders and reviewers. +- Architectural assumption: v1 UX is single-user, but the data model must support strict per-user isolation from day one. + +## Core Value Proposition + +- Durable memory for preferences, relationships, prior decisions, and recurring tasks. +- Deterministic context compilation instead of ad hoc prompt stuffing. +- Safe action orchestration with policy checks, approvals, and budgets. +- Clear explainability through traces, memory evidence, and tool history. + +## V1 Scope + +- Web-based chat and task orchestration. +- Immutable thread and session continuity. +- Structured memory with admission controls, revision history, and user review. +- Entity and relationship tracking for people, merchants, products, projects, and routines. +- Hybrid retrieval across memories, entities, relationships, and documents. +- Policy engine, tool proxy, approval workflows, and task budgets. +- Scoped task workspaces and artifact storage. +- Read-only document ingestion plus read-only Gmail and Calendar connectors. +- Hot consolidation for immediate truth updates and cold consolidation for cleanup and summarization. +- Explain-why views for important responses and actions. + +## Non-Goals + +- Autonomous side effects without user approval. +- Multi-user collaboration UX in v1. +- Mobile-first delivery. +- Dedicated graph or vector infrastructure in v1. +- Browser automation, write-capable connectors, proactive automations, and voice at launch. + +## Key User Journeys + +1. Ask a question that depends on prior preferences, purchases, or relationships and get a context-aware answer without restating history. +2. Correct a preference or fact and have the next turn reflect the new truth immediately. +3. Inspect why the system answered or proposed an action by reviewing memories, retrieval choices, and tool traces. +4. Run a repeat-purchase workflow that gathers prior context, proposes the order, pauses for approval, and records the outcome. +5. Retrieve relevant context from documents, Gmail, or Calendar without granting write access. + +## Constraints + +- Single-user product experience, multi-tenant-safe architecture. +- Web-first v1. +- Explicit approval for consequential actions. +- Operational simplicity beats platform sprawl in v1. +- Memory quality, retrieval quality, and explainability are ship-gating concerns. + +## Success Criteria + +- The system recalls relevant preferences, past purchases, relationships, and prior decisions without repeated user restatement. +- The repeat magnesium reorder workflow succeeds end to end with approval gating and memory write-back. +- Every consequential action is explainable through trace, memory, rule, and tool evidence. +- Purchases, emails, bookings, and other side effects never occur without explicit approval. +- Standard retrieval-plus-response interactions reach p95 latency under 5 seconds. +- Prompt and cache reuse exceeds 70% on repeated patterns. +- Memory extraction precision exceeds 80% at ship. + +## Product Non-Negotiables + +- The user stays in control of consequential actions. +- Durable context must come from governed storage, not raw transcript stuffing. +- Explainability is a product requirement, not a debugging feature. +- Preference contradictions must be reflected immediately. +- The repeat magnesium reorder scenario is the canonical v1 ship gate. diff --git a/README.md b/README.md index 55496eb..a16d490 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,36 @@ # AliceBot -AliceBot is a private, permissioned personal AI operating system. This repository currently holds the canonical product, architecture, roadmap, and AI handoff documents that future implementation work should follow. +AliceBot is a private, permissioned personal AI operating system. The repository now includes the runnable foundation slice plus the first tracing/context-compilation seam, the first governed memory/admissions-and-embeddings slice, the first deterministic response-generation seam, the first governance routing seam for non-executing tool requests, the first durable approval-request persistence seam for `approval_required` routing outcomes, the explicit approval-resolution seam, the first minimal approved-only proxy-execution seam, the first durable execution-review seam over that proxy path, the narrow execution-budget lifecycle seam over approved proxy execution, and the first deterministic task-workspace provisioning seam: local infrastructure, an API scaffold, migration tooling, continuity primitives, persisted traces, a deterministic continuity-only compiler, explicit memory admission, a narrow deterministic explicit-preference extraction path, explicit embedding-config and memory-embedding storage paths, a direct semantic memory retrieval primitive, deterministic hybrid compile-path memory merge, a no-tools model invocation path over deterministically assembled prompts, deterministic policy and tool-governance seams, a narrow no-side-effect proxy handler path, durable `tool_executions` records, durable `execution_budgets` records, durable `task_workspaces` records, execution-budget create/list/detail reads, budget deactivate/supersede lifecycle operations, active-only budget enforcement, budget-blocked execution persistence, task-workspace create/list/detail reads, and backend verification coverage. ## Status -- Planning has been distilled into durable operating docs. -- Application code has not been scaffolded yet. -- The first execution target is the foundation sprint in [.ai/active/SPRINT_PACKET.md](/Users/samirusani/Desktop/Codex/AliceBot/.ai/active/SPRINT_PACKET.md). +- Local Docker Compose infrastructure is defined for Postgres with `pgvector`, Redis, and MinIO. +- `apps/api` contains FastAPI health, compile, response-generation, memory-admission, explicit-preference extraction, semantic-memory-retrieval, policy, tool-registry, tool-allowlist, tool-routing, approval-request, approval-resolution, proxy-execution, execution-budget, execution-review, task, and task-workspace endpoints, configuration loading, Alembic migrations, continuity storage primitives, the Sprint 2A trace/compiler path, the Sprint 3A memory-admission path, the Sprint 3I deterministic extraction path, the Sprint 3K embedding substrate, the Sprint 3L semantic retrieval primitive, the Sprint 3M compile-path semantic retrieval adoption, the Sprint 3N deterministic hybrid memory merge, the Sprint 4A deterministic prompt-assembly and no-tools response path, the Sprint 4D deterministic non-executing tool-routing seam, the Sprint 4E durable approval-request persistence seam, the Sprint 4F approval-resolution seam, the Sprint 4G minimal approved-only proxy-execution seam, the Sprint 4H durable execution-review seam, the Sprint 4I execution-budget guard seam, the Sprint 4J execution-budget lifecycle seam, the Sprint 4K time-windowed execution-budget seam, the Sprint 4S explicit execution-to-task-step linkage seam, and the Sprint 5A task-workspace provisioning seam. +- `apps/web` and `workers` contain minimal starter scaffolds for later milestone work. +- The active sprint is documented in [.ai/active/SPRINT_PACKET.md](/Users/samirusani/Desktop/Codex/AliceBot/.ai/active/SPRINT_PACKET.md). -## Quick Start Assumptions +## Quick Start -- Assumption: local development will use Docker Compose for Postgres, Redis, and S3-compatible storage. -- Assumption: backend work will use Python 3.12 and FastAPI. -- Assumption: frontend work will use Node.js 20, `pnpm`, and Next.js. -- Secrets must stay out of the repo; use `.env` files locally and a secret manager in deployed environments. +1. Create a local env file: `cp .env.example .env` +2. Start required infrastructure with one command: `docker compose up -d` +3. Create a project virtualenv and install Python dependencies: `python3 -m venv .venv && ./.venv/bin/python -m pip install -e '.[dev]'` +4. Run database migrations: `./scripts/migrate.sh` +5. Start the API locally: `./scripts/api_dev.sh` + +The health endpoint is exposed at [http://127.0.0.1:8000/healthz](http://127.0.0.1:8000/healthz). +The minimal context-compilation API path is `POST /v0/context/compile`. +The minimal response-generation API path is `POST /v0/responses`. +The minimal memory-admission API path is `POST /v0/memories/admit`. +The explicit-preference extraction API path is `POST /v0/memories/extract-explicit-preferences`. +The minimal non-executing tool-routing API path is `POST /v0/tools/route`. +The minimal approval API paths are `POST /v0/approvals/requests`, `GET /v0/approvals`, `GET /v0/approvals/{approval_id}`, `POST /v0/approvals/{approval_id}/approve`, `POST /v0/approvals/{approval_id}/reject`, and `POST /v0/approvals/{approval_id}/execute`. +The execution-budget API paths are `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `GET /v0/execution-budgets/{execution_budget_id}`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, and `POST /v0/execution-budgets/{execution_budget_id}/supersede`. +The execution-review API paths are `GET /v0/tool-executions` and `GET /v0/tool-executions/{execution_id}`. +The task-workspace API paths are `POST /v0/tasks/{task_id}/workspace`, `GET /v0/task-workspaces`, and `GET /v0/task-workspaces/{task_workspace_id}`. +The helper scripts load the repo-root `.env` automatically and prefer `.venv/bin/python` when that virtualenv exists, falling back to `python3` otherwise. The default migration/admin URL targets the same local `alicebot` database as the app runtime. +`/healthz` currently performs a live Postgres check only. Redis and MinIO are reported as configured endpoints with `not_checked` status. +`TASK_WORKSPACE_ROOT` controls the single rooted base directory used for deterministic local task-workspace provisioning. By default it is `/tmp/alicebot/task-workspaces`, and each workspace path is created as `//`. +The current backend path has been verified in a local developer environment with `docker compose up -d`, `./scripts/migrate.sh`, `./.venv/bin/python -m pytest tests/unit tests/integration`, a live `GET /healthz`, and the Postgres-backed `POST /v0/context/compile`, `POST /v0/responses`, `POST /v0/memories/admit`, `POST /v0/memories/extract-explicit-preferences`, `POST /v0/memories/semantic-retrieval`, `POST /v0/tools/allowlist/evaluate`, `POST /v0/tools/route`, `POST /v0/approvals/requests`, `POST /v0/approvals/{approval_id}/execute`, `POST /v0/execution-budgets`, `GET /v0/execution-budgets`, `POST /v0/execution-budgets/{execution_budget_id}/deactivate`, `POST /v0/execution-budgets/{execution_budget_id}/supersede`, `GET /v0/tool-executions`, and `GET /v0/tool-executions/{execution_id}` integration paths, including compile requests that explicitly enable the hybrid memory merge, response requests that persist assistant events and response traces, deterministic non-executing tool-routing requests that persist `tool.route.*` traces, approval-request persistence requests that persist `approval.request.*` traces plus durable approval rows only for `approval_required` outcomes, approved proxy execution that persists `tool.proxy.execute.*` traces plus durable `tool_executions` rows for approved execution attempts, deterministic budget-management requests over durable `execution_budgets` rows, lifecycle requests that persist `execution_budget.lifecycle.*` traces and change budget status deterministically, budget-prechecked proxy execution that emits `tool.proxy.execute.budget` trace events against active budgets only, and execution-review reads over those durable records including budget-blocked attempts. ## Repo Structure @@ -23,23 +40,34 @@ AliceBot is a private, permissioned personal AI operating system. This repositor - [RULES.md](/Users/samirusani/Desktop/Codex/AliceBot/RULES.md): durable engineering and scope rules. - [.ai/handoff/CURRENT_STATE.md](/Users/samirusani/Desktop/Codex/AliceBot/.ai/handoff/CURRENT_STATE.md): fresh-thread recovery snapshot. - [.ai/active/SPRINT_PACKET.md](/Users/samirusani/Desktop/Codex/AliceBot/.ai/active/SPRINT_PACKET.md): current builder sprint. -- `docs/adr/`: architecture decision records. -- `docs/runbooks/`: operational procedures. -- `docs/archive/`: source material and retired planning docs. -- `apps/api/`, `apps/web/`, `workers/`, `tests/`, `scripts/`: planned implementation areas. +- `docker-compose.yml`: local Postgres, Redis, and MinIO stack. +- `infra/postgres/init/`: Postgres bootstrap SQL, including the non-superuser app role. +- `apps/api/`: FastAPI app, config, continuity store, and Alembic migrations. +- `apps/web/`: minimal Next.js shell for later dashboard work. +- `workers/`: placeholder Python worker package for future background jobs. +- `tests/`: unit and Postgres-backed integration tests for the foundation slice. +- `scripts/`: local development and migration entrypoints. ## Essential Commands -- `docker compose up -d`: expected local infra start command once the foundation sprint lands. -- `alembic upgrade head`: expected database migration command once the API scaffold exists. -- `pytest`: expected backend and integration test entrypoint. -- `pnpm test`: expected frontend test entrypoint. -- `pnpm lint`: expected frontend lint entrypoint. +- `docker compose up -d`: start Postgres, Redis, and MinIO on `127.0.0.1`. +- `./scripts/dev_up.sh`: start local infrastructure, wait for Postgres and role bootstrap readiness, and apply Alembic migrations. +- `./scripts/migrate.sh`: apply Alembic migrations with the admin database URL from `.env` or the built-in defaults. +- `./scripts/api_dev.sh`: run the FastAPI service with auto-reload. +- `./.venv/bin/python -m pytest tests/unit tests/integration`: run backend tests from the project virtualenv. +- `pnpm --dir apps/web dev`: start the web shell after frontend dependencies are installed. ## Environment Notes -- Postgres is the planned system of record and must support `pgvector`. -- Redis is planned for queues, locks, and short-lived cache data. -- Object storage is planned for documents and task artifacts. -- Authentication, row-level security, and approval boundaries are first-class requirements from the start. -# AliceBot +- Postgres is the system of record and the live schema now includes continuity tables, trace tables, policy-governance tables including `approvals`, `tool_executions`, and `execution_budgets`, task lifecycle tables including `tasks`, `task_steps`, and `task_workspaces`, memory tables, entity tables, and the embedding substrate tables `embedding_configs` and `memory_embeddings`. +- Sprint 2A adds persisted `traces` and `trace_events` plus a deterministic continuity-only context compiler over existing durable continuity records. +- Sprint 3A adds governed `memories` and append-only `memory_revisions` plus an explicit `NOOP`-first admission path over cited source events. +- The app and migration defaults both target the local `alicebot` database to keep quick-start behavior deterministic. +- `TASK_WORKSPACE_ROOT` defaults to `/tmp/alicebot/task-workspaces` and defines the only allowed root for deterministic local task-workspace provisioning. +- Local service ports are bound to `127.0.0.1` by default to avoid exposing fixed development credentials on non-loopback interfaces. +- Redis is reserved for future queue, lock, and cache work; no retrieval or orchestration features are enabled in this sprint. +- MinIO provides the local S3-compatible endpoint for future document and artifact storage. +- Continuity tables enforce row-level security from the start and `events` are append-only by application contract plus database trigger, with concurrent appends serialized per thread. +- Trace tables follow the same per-user isolation model, with append-only `trace_events` for compiler explainability. +- Memory admission remains explicit and evidence-backed, automatic extraction is currently limited to a narrow deterministic explicit-preference path over stored user messages, and the repo now includes explicit versioned embedding-config storage, direct memory-embedding persistence, a direct semantic retrieval API over active durable memories, compile-path hybrid memory merge into one `context_pack["memories"]` section with `memory_summary.hybrid_retrieval` metadata, one deterministic no-tools response path that assembles prompts from durable compiled context and persists assistant replies plus response traces, one deterministic approval-request persistence path over `approval_required` tool-routing outcomes, explicit approval resolution, one minimal approved-only proxy execution path through the no-side-effect `proxy.echo` handler, durable execution-review records plus list/detail reads for approved execution attempts, one narrow deterministic execution-budget seam that can activate, deactivate, supersede, and enforce both lifetime and rolling-window limits using durable `tool_executions` history while keeping blocked attempts reviewable, and one narrow deterministic task-workspace seam that provisions rooted local workspace directories and persists durable `task_workspaces` rows. Broader extraction, reranking, external-connector tool execution, artifact indexing, document ingestion, orchestration, and review UI remain deferred. +- The runtime database role is limited to `SELECT`/`INSERT` on continuity and trace tables, `SELECT`/`INSERT` on `memory_revisions`, `memory_review_labels`, `embedding_configs`, `entities`, and `entity_edges`, plus `SELECT`/`INSERT`/`UPDATE` on `consents`, `memories`, `memory_embeddings`, and `execution_budgets`. diff --git a/RECOMMENDED_ADRS.md b/RECOMMENDED_ADRS.md new file mode 100644 index 0000000..0fc8050 --- /dev/null +++ b/RECOMMENDED_ADRS.md @@ -0,0 +1,61 @@ +# Recommended ADRs + +## ADR-001: Modular Monolith for V1 + +- Why it deserves an ADR: service boundaries, deployment complexity, team workflow, and failure modes all depend on this choice. +- Proposed status: Proposed + +## ADR-002: Postgres + `pgvector` as V1 System of Record and Retrieval Store + +- Why it deserves an ADR: it sets the data platform, query model, operational burden, and later migration path. +- Proposed status: Proposed + +## ADR-003: Append-Only Continuity Model for Threads, Sessions, and Events + +- Why it deserves an ADR: this decision defines auditability, replay behavior, and how memory derives from source truth. +- Proposed status: Proposed + +## ADR-004: Memory as a Derived, Revisioned Projection + +- Why it deserves an ADR: it governs data integrity, contradiction handling, consolidation, and user trust. +- Proposed status: Proposed + +## ADR-005: Deterministic Context Compiler Contract + +- Why it deserves an ADR: it affects explainability, cache reuse, testing strategy, and model portability. +- Proposed status: Proposed + +## ADR-006: Auth and Per-User Isolation Model + +- Why it deserves an ADR: username/password plus TOTP, database user context, and RLS policy shape are hard security boundaries. +- Proposed status: Proposed + +## ADR-007: Policy Engine + Tool Proxy + Approval Boundary + +- Why it deserves an ADR: this is the core safety architecture for any external action or sensitive data access. +- Proposed status: Proposed + +## ADR-008: Relational Entity and Relationship Storage in V1 + +- Why it deserves an ADR: choosing relational storage over a graph database affects schema design, query strategy, and scale assumptions. +- Proposed status: Proposed + +## ADR-009: Object Storage and Scoped Task Workspace Strategy + +- Why it deserves an ADR: artifact handling, document ingestion, and task isolation depend on this storage boundary. +- Proposed status: Proposed + +## ADR-010: Read-Only Connector Strategy for Gmail and Calendar + +- Why it deserves an ADR: connector permission scope has major product, security, and delivery consequences. +- Proposed status: Proposed + +## ADR-011: Trace-First Observability and Audit Logging Model + +- Why it deserves an ADR: explainability, incident review, and ship-gate validation depend on what is logged and retained. +- Proposed status: Proposed + +## ADR-012: Deployment Architecture for V1 + +- Why it deserves an ADR: VPS versus managed container hosting, secret handling, backup posture, and runtime topology affect both cost and risk. +- Proposed status: Proposed diff --git a/REVIEW_REPORT.md b/REVIEW_REPORT.md new file mode 100644 index 0000000..0fc8efe --- /dev/null +++ b/REVIEW_REPORT.md @@ -0,0 +1,52 @@ +# REVIEW_REPORT + +## verdict + +PASS + +## criteria met + +- The sprint stayed inside the Sprint 5A boundary. I found no artifact indexing, document ingestion, connector work, runner orchestration, new proxy handlers, or broader side-effect expansion. +- `apps/api/alembic/versions/20260313_0022_task_workspaces.py` adds the required `task_workspaces` schema with user ownership, task linkage through `(task_id, user_id)`, row-level security, and a partial unique index enforcing one active workspace per task and user. +- The workspace seam in `apps/api/src/alicebot_api/workspaces.py` is narrow and deterministic: it resolves one configured root, builds the path as `resolved_root / user_id / task_id`, rejects rooted-path escapes before provisioning, and persists a single active workspace row. +- Stable create/list/detail contracts and the minimal API surface are present for the required endpoints: + - `POST /v0/tasks/{task_id}/workspace` + - `GET /v0/task-workspaces` + - `GET /v0/task-workspaces/{task_workspace_id}` +- Duplicate active workspace creation is rejected deterministically through the advisory lock plus active-workspace lookup, with the database unique index providing backstop enforcement. +- User isolation, deterministic ordering, and stable response shape are test-backed in both unit and Postgres-backed integration coverage, including `tests/integration/test_task_workspaces_api.py`. +- `BUILD_REPORT.md` accurately describes the schema change, contract change, rooted path rule, exact commands, sample responses, and deferred scope. +- Independent verification passed: + - `./.venv/bin/python -m pytest tests/unit` -> `315 passed in 0.62s` + - `./.venv/bin/python -m pytest tests/integration` -> `99 passed in 28.66s` + +## criteria missed + +- None. + +## quality issues + +- Non-blocking: `create_task_workspace_record()` provisions the directory before the insert is durably committed and uses `mkdir(..., exist_ok=True)`. If the insert or transaction commit fails after directory creation, the code can leave behind an orphaned directory that a later successful create would silently reuse. + +## regression risks + +- Runtime regression risk is low because both acceptance suites passed and the workspace behavior is covered at service, route, migration, and integration boundaries. +- Operational note: Postgres-backed integration tests require unsandboxed localhost access. The sandboxed run fails with `Operation not permitted` against `localhost:5432`, which matches the note in `BUILD_REPORT.md`. +- The main residual behavior risk is filesystem/database drift if provisioning fails after directory creation. + +## docs issues + +- None. `README.md`, `ARCHITECTURE.md`, and `.env.example` all reflect the Sprint 5A workspace seam and deferred scope accurately. + +## should anything be added to RULES.md? + +- No. The current rules already cover sprint scope control, doc accuracy, and schema/test expectations for this slice. + +## should anything update ARCHITECTURE.md? + +- No further update is needed for Sprint 5A. + +## recommended next action + +- Accept Sprint 5A. +- In the next workspace-dependent sprint, tighten provisioning hygiene so filesystem creation cannot drift from durable row persistence on failure. diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..ed7c2ba --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,97 @@ +# Roadmap + +## Current State + +- The repo has shipped the implementation slices originally planned as Milestones 1 through 4. +- Sprint 4O added the latest accepted backend seam: durable `tasks` and `task_steps` with explicit manual continuation lineage and deterministic task-step transitions. +- The project is no longer at Foundation. The current repo state is a post-Milestone-4 checkpoint, and this sprint is synchronizing project-truth docs before Milestone 5 work begins. +- No task runner, workspace/artifact layer, document ingestion, read-only connector, or broader side-effect surface has landed yet. + +## Completed Milestones + +### Milestone 1: Foundation + +- Repo scaffold, local Docker Compose infra, FastAPI app shell, config loading, migration tooling, and backend test harness. +- Postgres continuity primitives: `users`, `threads`, `sessions`, and append-only `events`. +- Row-level-security foundation and concurrent event sequencing hardening. + +Status on March 13, 2026: +- Complete. + +### Milestone 2: Context Compiler and Tracing + +- Deterministic context compilation over durable continuity records. +- Persisted `traces` and append-only `trace_events`. +- Trace-visible inclusion and exclusion reasoning for compiled context. + +Status on March 13, 2026: +- Complete. + +### Milestone 3: Memory and Retrieval + +- Governed memory admission with append-only revisions. +- Narrow deterministic explicit-preference extraction from stored user events. +- Memory review labels, review queue reads, and evaluation summary reads. +- Explicit entities and temporal entity edges backed by cited memories. +- Versioned embedding configs, durable memory embeddings, direct semantic retrieval, and deterministic hybrid compile-path memory merge. + +Status on March 13, 2026: +- Complete. + +### Milestone 4: Governance and Safe Action + +- Deterministic response generation over compiled context. +- User-scoped consents, policies, policy evaluation, tool registry, allowlist evaluation, and tool routing. +- Durable approval requests and explicit approval resolution. +- Approved-only proxy execution through the in-process `proxy.echo` handler. +- Durable execution review, execution-budget enforcement, lifecycle mutations, and optional rolling-window limits. +- Durable `tasks` and `task_steps`, deterministic task-step reads, explicit task-step transitions, and explicit manual continuation with lineage. + +Status on March 13, 2026: +- Complete through Sprint 4O. + +## Current Milestone Position + +- The repo is at the boundary after Milestone 4. +- Milestone 5 has not started in shipped code yet. +- The immediate work is documentation synchronization and narrow lifecycle-boundary hardening so Milestone 5 planning and review start from truthful artifacts. + +## Next Milestones + +### Immediate Next Narrow Boundary + +- Preserve the current manual-continuation seam as the only shipped multi-step task path. +- Remove or explicitly constrain the remaining approval/execution helpers that still synchronize against `task_steps.sequence_no = 1` before starting runner-style orchestration or workspace-heavy task flows. + +### Milestone 5: Documents, Workspaces, and Read-Only Connectors + +- Add document ingestion and chunk retrieval. +- Add scoped task workspaces and artifact handling. +- Add read-only Gmail and Calendar sync. +- Keep connector scope read-only and approval-aware. + +### Sequencing After Milestone 5 + +- Generalize task lifecycle handling beyond the current manual continuation seam. +- Introduce runner-style orchestration only after the first-step lifecycle assumption is removed. +- Expand tool execution breadth only after the governance and task seams stay deterministic under multi-step flows. + +## Dependencies + +- Truth artifacts must stay synchronized before milestone planning and review work can be trusted. +- The current first-step lifecycle assumption must be resolved before broader runner or workspace work can safely depend on `tasks` / `task_steps`. +- Scoped workspace and artifact boundaries should land before document-heavy or connector-heavy flows rely on them. +- Connector scope should remain deferred until the core memory, governance, and task seams stay stable under the shipped workload. + +## Blockers and Risks + +- Memory extraction and retrieval quality remain the biggest product risk. +- Auth beyond DB user context is still unimplemented. +- The remaining first-step approval/execution synchronization helpers are a forward-compatibility risk for broader multi-step orchestration. +- Workspace or connector work could create hidden scope drift if it starts before the current task-lifecycle boundary is hardened. + +## Recently Completed + +- Durable approval, execution review, and execution-budget seams over the approved proxy path. +- Durable `tasks` and `task_steps` with deterministic reads and status transitions. +- Explicit task-step lineage and manual continuation, including adversarial validation for cross-task, cross-user, and parent-step mismatch cases. diff --git a/RULES.md b/RULES.md new file mode 100644 index 0000000..f6ac44b --- /dev/null +++ b/RULES.md @@ -0,0 +1,52 @@ +# Rules + +## Product / Scope Rules + +- The active sprint packet is the top priority scope boundary for implementation work and overrides broader roadmap intent when they conflict. +- Never represent planned architecture as implemented behavior in docs, handoffs, or build reports. +- Never execute a consequential external action without explicit user approval. +- Always treat explainability as a product feature, not an internal debugging aid. +- Treat the repeat magnesium reorder as the v1 ship-gate scenario. +- Never expand v1 scope with proactive automation, write-capable connectors, voice, or browser automation without an explicit roadmap change. +- Do not start runner, workspace/artifact, document-ingestion, or connector work unless the active sprint explicitly opens that boundary. + +## Architecture Rules + +- Treat the immutable event store as ground truth; memories, tasks, and summaries are derived or governed views over durable records. +- Always compile context per invocation from durable sources. +- Keep prompt prefixes, tool schemas, and serialized context ordering deterministic. +- Treat Postgres as the v1 system of record unless measured constraints justify a platform split. +- Appended task steps must carry explicit lineage to a prior visible task step. Do not relink approvals or executions heuristically from broader task history. +- Manual continuation is the current multi-step boundary. Until the older first-step lifecycle helpers are removed or constrained, do not describe broader automatic multi-step orchestration as implemented. + +## Coding Rules + +- Always build against typed contracts and migration-backed schemas first. +- Never mutate tool schemas mid-session; enforce access through policy and proxy layers. +- Keep changes small, module-scoped, and test-backed. +- Stop long-running tasks with a clear progress summary when budgets or circuit breakers trip. +- Sprint-scoped docs must clearly separate what exists now from what is only planned later. + +## Data / Schema Rules + +- Enforce row-level security on every user-owned table from the start. +- Default memory admission to `NOOP`; promote only evidence-backed changes. +- Always keep memory revision history for non-`NOOP` changes. +- Task-step lineage references must stay inside the current user scope and must validate against the intended parent step and its recorded outcome. +- Apply domain and sensitivity filters before semantic retrieval. + +## Deployment / Ops Rules + +- Keep v1 operations simple: one modular monolith, one primary database, one cache, one object store. +- Never store secrets in source control, committed config, or logs. +- Any repo-advertised bootstrap script that starts dependencies and then runs dependent commands must wait for service readiness before proceeding. +- When external side effects are introduced, route them through approval-aware tool execution paths. +- Backups and object versioning are required before production use. + +## Testing Rules + +- Schema changes are not complete without forward and rollback coverage. +- Every module needs unit tests and at least one integration boundary test. +- Approval boundaries, RLS isolation, and audit logging require adversarial tests. +- Lineage changes require adversarial tests for cross-task, cross-user, and parent-step mismatch cases. +- Memory quality and retrieval quality need labeled evaluations before release claims. diff --git a/apps/api/.gitkeep b/apps/api/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/apps/api/.gitkeep @@ -0,0 +1 @@ + diff --git a/apps/api/alembic.ini b/apps/api/alembic.ini new file mode 100644 index 0000000..2ca852a --- /dev/null +++ b/apps/api/alembic.ini @@ -0,0 +1,37 @@ +[alembic] +script_location = apps/api/alembic +prepend_sys_path = apps/api/src +path_separator = os +sqlalchemy.url = postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = console +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s diff --git a/apps/api/alembic/env.py b/apps/api/alembic/env.py new file mode 100644 index 0000000..b8880aa --- /dev/null +++ b/apps/api/alembic/env.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from logging.config import fileConfig +import os + +from alembic import context +from sqlalchemy import engine_from_config, pool + + +config = context.config + +target_metadata = None + + +def normalize_sqlalchemy_url(database_url: str) -> str: + if database_url.startswith("postgresql://"): + return database_url.replace("postgresql://", "postgresql+psycopg://", 1) + return database_url + + +def get_url() -> str: + database_url = ( + os.getenv("DATABASE_ADMIN_URL") + or os.getenv("DATABASE_URL") + or config.get_main_option("sqlalchemy.url") + ) + return normalize_sqlalchemy_url(database_url) + + +def configure_logging() -> None: + if config.config_file_name is not None: + fileConfig(config.config_file_name) + + +def run_migrations_offline() -> None: + context.configure( + url=get_url(), + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + configuration = config.get_section(config.config_ini_section, {}) + configuration["sqlalchemy.url"] = get_url() + connectable = engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations() -> None: + configure_logging() + if context.is_offline_mode(): + run_migrations_offline() + else: + run_migrations_online() + + +run_migrations() diff --git a/apps/api/alembic/versions/20260310_0001_foundation_continuity.py b/apps/api/alembic/versions/20260310_0001_foundation_continuity.py new file mode 100644 index 0000000..eeb1d3b --- /dev/null +++ b/apps/api/alembic/versions/20260310_0001_foundation_continuity.py @@ -0,0 +1,167 @@ +"""Create continuity foundation tables with RLS and append-only events.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260310_0001" +down_revision = None +branch_labels = None +depends_on = None + +_RLS_TABLES = ("users", "threads", "sessions", "events") + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + "CREATE EXTENSION IF NOT EXISTS pgcrypto", + "CREATE EXTENSION IF NOT EXISTS vector", + "CREATE SCHEMA IF NOT EXISTS app", + """ + CREATE OR REPLACE FUNCTION app.current_user_id() + RETURNS uuid + LANGUAGE sql + STABLE + AS $$ + SELECT NULLIF(current_setting('app.current_user_id', true), '')::uuid + $$; + """, + """ + CREATE OR REPLACE FUNCTION app.reject_event_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'events are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE users ( + id uuid PRIMARY KEY, + email text NOT NULL UNIQUE, + display_name text, + created_at timestamptz NOT NULL DEFAULT now() + ); + + CREATE TABLE threads ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id) + ); + + CREATE TABLE sessions ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + thread_id uuid NOT NULL, + status text NOT NULL DEFAULT 'active', + started_at timestamptz NOT NULL DEFAULT now(), + ended_at timestamptz, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE + ); + + CREATE TABLE events ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + thread_id uuid NOT NULL, + session_id uuid, + sequence_no bigint NOT NULL, + kind text NOT NULL, + payload jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (thread_id, sequence_no), + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + FOREIGN KEY (session_id, user_id) + REFERENCES sessions(id, user_id) + ON DELETE CASCADE + ); + + CREATE INDEX sessions_thread_created_idx + ON sessions (thread_id, created_at); + CREATE INDEX threads_user_created_idx + ON threads (user_id, created_at); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER events_append_only + BEFORE UPDATE OR DELETE ON events + FOR EACH ROW + EXECUTE FUNCTION app.reject_event_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT USAGE ON SCHEMA public TO alicebot_app", + "GRANT USAGE ON SCHEMA app TO alicebot_app", + "GRANT SELECT, INSERT ON users TO alicebot_app", + "GRANT SELECT, INSERT ON threads TO alicebot_app", + "GRANT SELECT, INSERT ON sessions TO alicebot_app", + "GRANT SELECT, INSERT ON events TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY users_is_owner ON users + USING (id = app.current_user_id()) + WITH CHECK (id = app.current_user_id()); + + CREATE POLICY threads_is_owner ON threads + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY sessions_is_owner ON sessions + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY events_read_own ON events + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY events_insert_own ON events + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS events_append_only ON events", + "DROP TABLE IF EXISTS events", + "DROP TABLE IF EXISTS sessions", + "DROP TABLE IF EXISTS threads", + "DROP TABLE IF EXISTS users", + "DROP FUNCTION IF EXISTS app.reject_event_mutation()", + "DROP FUNCTION IF EXISTS app.current_user_id()", + "DROP SCHEMA IF EXISTS app", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260311_0002_tighten_runtime_privileges.py b/apps/api/alembic/versions/20260311_0002_tighten_runtime_privileges.py new file mode 100644 index 0000000..5935399 --- /dev/null +++ b/apps/api/alembic/versions/20260311_0002_tighten_runtime_privileges.py @@ -0,0 +1,39 @@ +"""Tighten the runtime role to insert/select-only continuity access.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260311_0002" +down_revision = "20260310_0001" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + "REVOKE UPDATE ON users FROM alicebot_app", + "REVOKE UPDATE ON threads FROM alicebot_app", + "REVOKE UPDATE ON sessions FROM alicebot_app", +) + +# Revision 20260310_0001 already leaves the runtime role with no UPDATE grants +# on these tables. Downgrading back to that revision should therefore preserve +# the same privilege floor explicitly rather than re-introducing broader access. +_DOWNGRADE_STATEMENTS = ( + "REVOKE UPDATE ON users FROM alicebot_app", + "REVOKE UPDATE ON threads FROM alicebot_app", + "REVOKE UPDATE ON sessions FROM alicebot_app", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260311_0003_trace_backbone.py b/apps/api/alembic/versions/20260311_0003_trace_backbone.py new file mode 100644 index 0000000..6028ff4 --- /dev/null +++ b/apps/api/alembic/versions/20260311_0003_trace_backbone.py @@ -0,0 +1,117 @@ +"""Add persisted traces and trace events for context compilation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260311_0003" +down_revision = "20260311_0002" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("traces", "trace_events") + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + """ + CREATE OR REPLACE FUNCTION app.reject_trace_event_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'trace events are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE traces ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + thread_id uuid NOT NULL, + kind text NOT NULL, + compiler_version text NOT NULL, + status text NOT NULL, + limits jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE + ); + + CREATE TABLE trace_events ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + trace_id uuid NOT NULL, + sequence_no bigint NOT NULL, + kind text NOT NULL, + payload jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (trace_id, sequence_no), + FOREIGN KEY (trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE CASCADE + ); + + CREATE INDEX traces_thread_created_idx + ON traces (thread_id, created_at); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER trace_events_append_only + BEFORE UPDATE OR DELETE ON trace_events + FOR EACH ROW + EXECUTE FUNCTION app.reject_trace_event_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON traces TO alicebot_app", + "GRANT SELECT, INSERT ON trace_events TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY traces_is_owner ON traces + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY trace_events_read_own ON trace_events + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY trace_events_insert_own ON trace_events + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS trace_events_append_only ON trace_events", + "DROP TABLE IF EXISTS trace_events", + "DROP TABLE IF EXISTS traces", + "DROP FUNCTION IF EXISTS app.reject_trace_event_mutation()", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260311_0004_memory_admission.py b/apps/api/alembic/versions/20260311_0004_memory_admission.py new file mode 100644 index 0000000..c782d3b --- /dev/null +++ b/apps/api/alembic/versions/20260311_0004_memory_admission.py @@ -0,0 +1,123 @@ +"""Add governed memory tables and append-only memory revisions.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260311_0004" +down_revision = "20260311_0003" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("memories", "memory_revisions") + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + """ + CREATE OR REPLACE FUNCTION app.reject_memory_revision_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'memory revisions are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE memories ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + memory_key text NOT NULL, + value jsonb NOT NULL, + status text NOT NULL, + source_event_ids jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + deleted_at timestamptz, + UNIQUE (id, user_id), + UNIQUE (user_id, memory_key) + ); + + CREATE TABLE memory_revisions ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + memory_id uuid NOT NULL, + sequence_no bigint NOT NULL, + action text NOT NULL, + memory_key text NOT NULL, + previous_value jsonb, + new_value jsonb, + source_event_ids jsonb NOT NULL, + candidate jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (memory_id, sequence_no), + FOREIGN KEY (memory_id, user_id) + REFERENCES memories(id, user_id) + ON DELETE CASCADE + ); + + CREATE INDEX memories_user_status_updated_idx + ON memories (user_id, status, updated_at); + CREATE INDEX memory_revisions_memory_created_idx + ON memory_revisions (memory_id, created_at); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER memory_revisions_append_only + BEFORE UPDATE OR DELETE ON memory_revisions + FOR EACH ROW + EXECUTE FUNCTION app.reject_memory_revision_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON memories TO alicebot_app", + "GRANT SELECT, INSERT ON memory_revisions TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY memories_is_owner ON memories + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY memory_revisions_read_own ON memory_revisions + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY memory_revisions_insert_own ON memory_revisions + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS memory_revisions_append_only ON memory_revisions", + "DROP TABLE IF EXISTS memory_revisions", + "DROP TABLE IF EXISTS memories", + "DROP FUNCTION IF EXISTS app.reject_memory_revision_mutation()", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0005_memory_review_labels.py b/apps/api/alembic/versions/20260312_0005_memory_review_labels.py new file mode 100644 index 0000000..2b7ede5 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0005_memory_review_labels.py @@ -0,0 +1,99 @@ +"""Add append-only memory review labels for human evaluation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0005" +down_revision = "20260311_0004" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("memory_review_labels",) + +_UPGRADE_BOOTSTRAP_STATEMENTS = ( + """ + CREATE OR REPLACE FUNCTION app.reject_memory_review_label_mutation() + RETURNS trigger + LANGUAGE plpgsql + AS $$ + BEGIN + RAISE EXCEPTION 'memory review labels are append-only'; + END; + $$; + """, +) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE memory_review_labels ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL, + memory_id uuid NOT NULL, + label text NOT NULL, + note text, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + FOREIGN KEY (memory_id, user_id) + REFERENCES memories(id, user_id) + ON DELETE CASCADE, + CONSTRAINT memory_review_labels_label_check + CHECK (label IN ('correct', 'incorrect', 'outdated', 'insufficient_evidence')), + CONSTRAINT memory_review_labels_note_length_check + CHECK (note IS NULL OR char_length(note) <= 280) + ); + + CREATE INDEX memory_review_labels_memory_created_idx + ON memory_review_labels (memory_id, created_at, id); + """ + +_UPGRADE_TRIGGER_STATEMENT = """ + CREATE TRIGGER memory_review_labels_append_only + BEFORE UPDATE OR DELETE ON memory_review_labels + FOR EACH ROW + EXECUTE FUNCTION app.reject_memory_review_label_mutation(); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON memory_review_labels TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY memory_review_labels_read_own ON memory_review_labels + FOR SELECT + USING (user_id = app.current_user_id()); + + CREATE POLICY memory_review_labels_insert_own ON memory_review_labels + FOR INSERT + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TRIGGER IF EXISTS memory_review_labels_append_only ON memory_review_labels", + "DROP TABLE IF EXISTS memory_review_labels", + "DROP FUNCTION IF EXISTS app.reject_memory_review_label_mutation()", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + _execute_statements(_UPGRADE_BOOTSTRAP_STATEMENTS) + op.execute(_UPGRADE_SCHEMA_STATEMENT) + op.execute(_UPGRADE_TRIGGER_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0006_entities_backbone.py b/apps/api/alembic/versions/20260312_0006_entities_backbone.py new file mode 100644 index 0000000..a1d3bcb --- /dev/null +++ b/apps/api/alembic/versions/20260312_0006_entities_backbone.py @@ -0,0 +1,72 @@ +"""Add explicit user-scoped entities backed by durable source memories.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0006" +down_revision = "20260312_0005" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("entities",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE entities ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + entity_type text NOT NULL, + name text NOT NULL, + source_memory_ids jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT entities_type_check + CHECK (entity_type IN ('person', 'merchant', 'product', 'project', 'routine')), + CONSTRAINT entities_name_length_check + CHECK (char_length(name) BETWEEN 1 AND 200), + CONSTRAINT entities_source_memory_ids_array_check + CHECK (jsonb_typeof(source_memory_ids) = 'array'), + CONSTRAINT entities_source_memory_ids_nonempty_check + CHECK (jsonb_array_length(source_memory_ids) > 0) + ); + + CREATE INDEX entities_user_created_idx + ON entities (user_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON entities TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY entities_is_owner ON entities + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS entities", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0007_entity_edges.py b/apps/api/alembic/versions/20260312_0007_entity_edges.py new file mode 100644 index 0000000..fa08bda --- /dev/null +++ b/apps/api/alembic/versions/20260312_0007_entity_edges.py @@ -0,0 +1,83 @@ +"""Add explicit user-scoped entity edges with simple temporal metadata.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0007" +down_revision = "20260312_0006" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("entity_edges",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE entity_edges ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + from_entity_id uuid NOT NULL, + to_entity_id uuid NOT NULL, + relationship_type text NOT NULL, + valid_from timestamptz NULL, + valid_to timestamptz NULL, + source_memory_ids jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT entity_edges_from_entity_fkey + FOREIGN KEY (from_entity_id, user_id) REFERENCES entities(id, user_id) ON DELETE CASCADE, + CONSTRAINT entity_edges_to_entity_fkey + FOREIGN KEY (to_entity_id, user_id) REFERENCES entities(id, user_id) ON DELETE CASCADE, + CONSTRAINT entity_edges_relationship_type_length_check + CHECK (char_length(relationship_type) BETWEEN 1 AND 100), + CONSTRAINT entity_edges_source_memory_ids_array_check + CHECK (jsonb_typeof(source_memory_ids) = 'array'), + CONSTRAINT entity_edges_source_memory_ids_nonempty_check + CHECK (jsonb_array_length(source_memory_ids) > 0), + CONSTRAINT entity_edges_valid_range_check + CHECK (valid_from IS NULL OR valid_to IS NULL OR valid_to >= valid_from) + ); + + CREATE INDEX entity_edges_user_created_idx + ON entity_edges (user_id, created_at, id); + CREATE INDEX entity_edges_user_from_created_idx + ON entity_edges (user_id, from_entity_id, created_at, id); + CREATE INDEX entity_edges_user_to_created_idx + ON entity_edges (user_id, to_entity_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON entity_edges TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY entity_edges_is_owner ON entity_edges + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS entity_edges", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0008_embedding_substrate.py b/apps/api/alembic/versions/20260312_0008_embedding_substrate.py new file mode 100644 index 0000000..d83551e --- /dev/null +++ b/apps/api/alembic/versions/20260312_0008_embedding_substrate.py @@ -0,0 +1,115 @@ +"""Add versioned embedding configs and user-scoped memory embeddings.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0008" +down_revision = "20260312_0007" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("embedding_configs", "memory_embeddings") + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE embedding_configs ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider text NOT NULL, + model text NOT NULL, + version text NOT NULL, + dimensions integer NOT NULL, + status text NOT NULL, + metadata jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, provider, model, version), + CONSTRAINT embedding_configs_provider_length_check + CHECK (char_length(provider) BETWEEN 1 AND 100), + CONSTRAINT embedding_configs_model_length_check + CHECK (char_length(model) BETWEEN 1 AND 200), + CONSTRAINT embedding_configs_version_length_check + CHECK (char_length(version) BETWEEN 1 AND 100), + CONSTRAINT embedding_configs_dimensions_check + CHECK (dimensions > 0), + CONSTRAINT embedding_configs_status_check + CHECK (status IN ('active', 'deprecated', 'disabled')), + CONSTRAINT embedding_configs_metadata_object_check + CHECK (jsonb_typeof(metadata) = 'object') + ); + + CREATE INDEX embedding_configs_user_created_idx + ON embedding_configs (user_id, created_at, id); + + CREATE TABLE memory_embeddings ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + memory_id uuid NOT NULL, + embedding_config_id uuid NOT NULL, + dimensions integer NOT NULL, + vector jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, memory_id, embedding_config_id), + CONSTRAINT memory_embeddings_memory_fkey + FOREIGN KEY (memory_id, user_id) REFERENCES memories(id, user_id) ON DELETE CASCADE, + CONSTRAINT memory_embeddings_embedding_config_fkey + FOREIGN KEY (embedding_config_id, user_id) + REFERENCES embedding_configs(id, user_id) ON DELETE CASCADE, + CONSTRAINT memory_embeddings_dimensions_check + CHECK (dimensions > 0), + CONSTRAINT memory_embeddings_vector_array_check + CHECK (jsonb_typeof(vector) = 'array'), + CONSTRAINT memory_embeddings_vector_nonempty_check + CHECK (jsonb_array_length(vector) > 0), + CONSTRAINT memory_embeddings_vector_dimensions_match_check + CHECK (jsonb_array_length(vector) = dimensions) + ); + + CREATE INDEX memory_embeddings_user_memory_created_idx + ON memory_embeddings (user_id, memory_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON embedding_configs TO alicebot_app", + "GRANT SELECT, INSERT, UPDATE ON memory_embeddings TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY embedding_configs_is_owner ON embedding_configs + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY memory_embeddings_is_owner ON memory_embeddings + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS memory_embeddings", + "DROP TABLE IF EXISTS embedding_configs", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0009_policy_and_consent_core.py b/apps/api/alembic/versions/20260312_0009_policy_and_consent_core.py new file mode 100644 index 0000000..25fcf20 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0009_policy_and_consent_core.py @@ -0,0 +1,111 @@ +"""Add user-scoped consents and deterministic policy storage.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0009" +down_revision = "20260312_0008" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("consents", "policies") + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE consents ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + consent_key text NOT NULL, + status text NOT NULL, + metadata jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, consent_key), + CONSTRAINT consents_key_length_check + CHECK (char_length(consent_key) BETWEEN 1 AND 200), + CONSTRAINT consents_status_check + CHECK (status IN ('granted', 'revoked')), + CONSTRAINT consents_metadata_object_check + CHECK (jsonb_typeof(metadata) = 'object') + ); + + CREATE INDEX consents_user_key_created_idx + ON consents (user_id, consent_key, created_at, id); + + CREATE TABLE policies ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + name text NOT NULL, + action text NOT NULL, + scope text NOT NULL, + effect text NOT NULL, + priority integer NOT NULL, + active boolean NOT NULL DEFAULT TRUE, + conditions jsonb NOT NULL DEFAULT '{}'::jsonb, + required_consents jsonb NOT NULL DEFAULT '[]'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT policies_name_length_check + CHECK (char_length(name) BETWEEN 1 AND 200), + CONSTRAINT policies_action_length_check + CHECK (char_length(action) BETWEEN 1 AND 100), + CONSTRAINT policies_scope_length_check + CHECK (char_length(scope) BETWEEN 1 AND 200), + CONSTRAINT policies_effect_check + CHECK (effect IN ('allow', 'deny', 'require_approval')), + CONSTRAINT policies_priority_check + CHECK (priority >= 0), + CONSTRAINT policies_conditions_object_check + CHECK (jsonb_typeof(conditions) = 'object'), + CONSTRAINT policies_required_consents_array_check + CHECK (jsonb_typeof(required_consents) = 'array') + ); + + CREATE INDEX policies_user_active_priority_created_idx + ON policies (user_id, active, priority, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON consents TO alicebot_app", + "GRANT SELECT, INSERT ON policies TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY consents_is_owner ON consents + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + + CREATE POLICY policies_is_owner ON policies + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS policies", + "DROP TABLE IF EXISTS consents", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0010_tools_registry_and_allowlist.py b/apps/api/alembic/versions/20260312_0010_tools_registry_and_allowlist.py new file mode 100644 index 0000000..6d58470 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0010_tools_registry_and_allowlist.py @@ -0,0 +1,96 @@ +"""Add stable tool registry storage for deterministic allowlist evaluation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0010" +down_revision = "20260312_0009" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("tools",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE tools ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + tool_key text NOT NULL, + name text NOT NULL, + description text NOT NULL, + version text NOT NULL, + metadata_version text NOT NULL, + active boolean NOT NULL DEFAULT TRUE, + tags jsonb NOT NULL DEFAULT '[]'::jsonb, + action_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + scope_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + domain_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + risk_hints jsonb NOT NULL DEFAULT '[]'::jsonb, + metadata jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + UNIQUE (user_id, tool_key, version), + CONSTRAINT tools_key_length_check + CHECK (char_length(tool_key) BETWEEN 1 AND 200), + CONSTRAINT tools_name_length_check + CHECK (char_length(name) BETWEEN 1 AND 200), + CONSTRAINT tools_description_length_check + CHECK (char_length(description) BETWEEN 1 AND 500), + CONSTRAINT tools_version_length_check + CHECK (char_length(version) BETWEEN 1 AND 100), + CONSTRAINT tools_metadata_version_check + CHECK (metadata_version = 'tool_metadata_v0'), + CONSTRAINT tools_tags_array_check + CHECK (jsonb_typeof(tags) = 'array'), + CONSTRAINT tools_action_hints_array_check + CHECK (jsonb_typeof(action_hints) = 'array'), + CONSTRAINT tools_scope_hints_array_check + CHECK (jsonb_typeof(scope_hints) = 'array'), + CONSTRAINT tools_domain_hints_array_check + CHECK (jsonb_typeof(domain_hints) = 'array'), + CONSTRAINT tools_risk_hints_array_check + CHECK (jsonb_typeof(risk_hints) = 'array'), + CONSTRAINT tools_metadata_object_check + CHECK (jsonb_typeof(metadata) = 'object') + ); + + CREATE INDEX tools_user_active_key_version_created_idx + ON tools (user_id, active, tool_key, version, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON tools TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY tools_is_owner ON tools + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS tools", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0011_approval_request_records.py b/apps/api/alembic/versions/20260312_0011_approval_request_records.py new file mode 100644 index 0000000..49aff5c --- /dev/null +++ b/apps/api/alembic/versions/20260312_0011_approval_request_records.py @@ -0,0 +1,88 @@ +"""Add durable approval request records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0011" +down_revision = "20260312_0010" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("approvals",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE approvals ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + thread_id uuid NOT NULL, + tool_id uuid NOT NULL, + status text NOT NULL DEFAULT 'pending', + request jsonb NOT NULL, + tool jsonb NOT NULL, + routing jsonb NOT NULL, + routing_trace_id uuid NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT approvals_thread_user_fk + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + CONSTRAINT approvals_tool_user_fk + FOREIGN KEY (tool_id, user_id) + REFERENCES tools(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT approvals_routing_trace_user_fk + FOREIGN KEY (routing_trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT approvals_status_check + CHECK (status = 'pending'), + CONSTRAINT approvals_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT approvals_tool_object_check + CHECK (jsonb_typeof(tool) = 'object'), + CONSTRAINT approvals_routing_object_check + CHECK (jsonb_typeof(routing) = 'object') + ); + + CREATE INDEX approvals_user_created_idx + ON approvals (user_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON approvals TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY approvals_is_owner ON approvals + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS approvals", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260312_0012_approval_resolution.py b/apps/api/alembic/versions/20260312_0012_approval_resolution.py new file mode 100644 index 0000000..7ef2907 --- /dev/null +++ b/apps/api/alembic/versions/20260312_0012_approval_resolution.py @@ -0,0 +1,63 @@ +"""Add approval resolution state and runtime update access.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260312_0012" +down_revision = "20260312_0011" +branch_labels = None +depends_on = None + +_UPGRADE_SCHEMA_STATEMENT = """ + ALTER TABLE approvals + DROP CONSTRAINT approvals_status_check, + ADD COLUMN resolved_at timestamptz, + ADD COLUMN resolved_by_user_id uuid REFERENCES users(id) ON DELETE RESTRICT, + ADD CONSTRAINT approvals_status_check + CHECK (status IN ('pending', 'approved', 'rejected')), + ADD CONSTRAINT approvals_resolution_consistency_check + CHECK ( + (status = 'pending' AND resolved_at IS NULL AND resolved_by_user_id IS NULL) + OR ( + status IN ('approved', 'rejected') + AND resolved_at IS NOT NULL + AND resolved_by_user_id IS NOT NULL + ) + ), + ADD CONSTRAINT approvals_resolved_by_owner_check + CHECK (resolved_by_user_id IS NULL OR resolved_by_user_id = user_id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT UPDATE ON approvals TO alicebot_app", +) + +_DOWNGRADE_STATEMENTS = ( + "REVOKE UPDATE ON approvals FROM alicebot_app", + """ + ALTER TABLE approvals + DROP CONSTRAINT approvals_resolved_by_owner_check, + DROP CONSTRAINT approvals_resolution_consistency_check, + DROP CONSTRAINT approvals_status_check, + DROP COLUMN resolved_by_user_id, + DROP COLUMN resolved_at, + ADD CONSTRAINT approvals_status_check + CHECK (status = 'pending'); + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0013_tool_executions.py b/apps/api/alembic/versions/20260313_0013_tool_executions.py new file mode 100644 index 0000000..9bcfdfe --- /dev/null +++ b/apps/api/alembic/versions/20260313_0013_tool_executions.py @@ -0,0 +1,118 @@ +"""Add durable tool execution review records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0013" +down_revision = "20260312_0012" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("tool_executions",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE tool_executions ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + approval_id uuid NOT NULL, + thread_id uuid NOT NULL, + tool_id uuid NOT NULL, + trace_id uuid NOT NULL, + request_event_id uuid, + result_event_id uuid, + status text NOT NULL, + handler_key text, + request jsonb NOT NULL, + tool jsonb NOT NULL, + result jsonb NOT NULL, + executed_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT tool_executions_approval_user_fk + FOREIGN KEY (approval_id, user_id) + REFERENCES approvals(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_thread_user_fk + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + CONSTRAINT tool_executions_tool_user_fk + FOREIGN KEY (tool_id, user_id) + REFERENCES tools(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_trace_user_fk + FOREIGN KEY (trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_request_event_user_fk + FOREIGN KEY (request_event_id, user_id) + REFERENCES events(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_result_event_user_fk + FOREIGN KEY (result_event_id, user_id) + REFERENCES events(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tool_executions_status_check + CHECK (status IN ('completed', 'blocked')), + CONSTRAINT tool_executions_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT tool_executions_tool_object_check + CHECK (jsonb_typeof(tool) = 'object'), + CONSTRAINT tool_executions_result_object_check + CHECK (jsonb_typeof(result) = 'object'), + CONSTRAINT tool_executions_status_event_consistency_check + CHECK ( + ( + status = 'completed' + AND handler_key IS NOT NULL + AND request_event_id IS NOT NULL + AND result_event_id IS NOT NULL + ) + OR ( + status = 'blocked' + AND request_event_id IS NULL + AND result_event_id IS NULL + ) + ) + ); + + CREATE INDEX tool_executions_user_executed_idx + ON tool_executions (user_id, executed_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON tool_executions TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY tool_executions_is_owner ON tool_executions + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS tool_executions", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0014_execution_budgets.py b/apps/api/alembic/versions/20260313_0014_execution_budgets.py new file mode 100644 index 0000000..f6c3519 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0014_execution_budgets.py @@ -0,0 +1,71 @@ +"""Add deterministic execution budget records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0014" +down_revision = "20260313_0013" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("execution_budgets",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE execution_budgets ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + tool_key text, + domain_hint text, + max_completed_executions integer NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT execution_budgets_selector_check + CHECK (tool_key IS NOT NULL OR domain_hint IS NOT NULL), + CONSTRAINT execution_budgets_max_completed_executions_check + CHECK (max_completed_executions > 0) + ); + + CREATE INDEX execution_budgets_user_created_idx + ON execution_budgets (user_id, created_at, id); + + CREATE INDEX execution_budgets_user_match_idx + ON execution_budgets (user_id, tool_key, domain_hint, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON execution_budgets TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY execution_budgets_is_owner ON execution_budgets + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS execution_budgets", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0015_execution_budget_lifecycle.py b/apps/api/alembic/versions/20260313_0015_execution_budget_lifecycle.py new file mode 100644 index 0000000..ffacd8c --- /dev/null +++ b/apps/api/alembic/versions/20260313_0015_execution_budget_lifecycle.py @@ -0,0 +1,80 @@ +"""Add execution budget lifecycle controls.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0015" +down_revision = "20260313_0014" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + """ + ALTER TABLE execution_budgets + ADD COLUMN status text NOT NULL DEFAULT 'active', + ADD COLUMN deactivated_at timestamptz, + ADD COLUMN superseded_by_budget_id uuid REFERENCES execution_budgets(id) ON DELETE SET NULL DEFERRABLE INITIALLY DEFERRED, + ADD COLUMN supersedes_budget_id uuid REFERENCES execution_budgets(id) ON DELETE SET NULL DEFERRABLE INITIALLY DEFERRED; + """, + """ + ALTER TABLE execution_budgets + ADD CONSTRAINT execution_budgets_status_check + CHECK (status IN ('active', 'inactive', 'superseded')), + ADD CONSTRAINT execution_budgets_lifecycle_state_check + CHECK ( + (status = 'active' AND deactivated_at IS NULL AND superseded_by_budget_id IS NULL) + OR (status = 'inactive' AND deactivated_at IS NOT NULL AND superseded_by_budget_id IS NULL) + OR (status = 'superseded' AND deactivated_at IS NOT NULL AND superseded_by_budget_id IS NOT NULL) + ), + ADD CONSTRAINT execution_budgets_supersedes_budget_unique + UNIQUE (supersedes_budget_id); + """, + """ + CREATE INDEX execution_budgets_user_status_created_idx + ON execution_budgets (user_id, status, created_at, id); + """, + """ + CREATE UNIQUE INDEX execution_budgets_one_active_scope_idx + ON execution_budgets ( + user_id, + COALESCE(tool_key, ''), + COALESCE(domain_hint, '') + ) + WHERE status = 'active'; + """, + "GRANT SELECT, INSERT, UPDATE ON execution_budgets TO alicebot_app", +) + +_DOWNGRADE_STATEMENTS = ( + "REVOKE UPDATE ON execution_budgets FROM alicebot_app", + "DROP INDEX IF EXISTS execution_budgets_one_active_scope_idx", + "DROP INDEX IF EXISTS execution_budgets_user_status_created_idx", + """ + ALTER TABLE execution_budgets + DROP CONSTRAINT IF EXISTS execution_budgets_supersedes_budget_unique, + DROP CONSTRAINT IF EXISTS execution_budgets_lifecycle_state_check, + DROP CONSTRAINT IF EXISTS execution_budgets_status_check; + """, + """ + ALTER TABLE execution_budgets + DROP COLUMN IF EXISTS supersedes_budget_id, + DROP COLUMN IF EXISTS superseded_by_budget_id, + DROP COLUMN IF EXISTS deactivated_at, + DROP COLUMN IF EXISTS status; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0016_execution_budget_rolling_window.py b/apps/api/alembic/versions/20260313_0016_execution_budget_rolling_window.py new file mode 100644 index 0000000..31a842c --- /dev/null +++ b/apps/api/alembic/versions/20260313_0016_execution_budget_rolling_window.py @@ -0,0 +1,47 @@ +"""Add optional rolling-window execution budget support.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0016" +down_revision = "20260313_0015" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + """ + ALTER TABLE execution_budgets + ADD COLUMN rolling_window_seconds integer; + """, + """ + ALTER TABLE execution_budgets + ADD CONSTRAINT execution_budgets_rolling_window_seconds_check + CHECK (rolling_window_seconds IS NULL OR rolling_window_seconds > 0); + """, +) + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE execution_budgets + DROP CONSTRAINT IF EXISTS execution_budgets_rolling_window_seconds_check; + """, + """ + ALTER TABLE execution_budgets + DROP COLUMN IF EXISTS rolling_window_seconds; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0017_tasks_lifecycle_records.py b/apps/api/alembic/versions/20260313_0017_tasks_lifecycle_records.py new file mode 100644 index 0000000..f00f07c --- /dev/null +++ b/apps/api/alembic/versions/20260313_0017_tasks_lifecycle_records.py @@ -0,0 +1,112 @@ +"""Add durable task records with deterministic lifecycle status.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0017" +down_revision = "20260313_0016" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("tasks",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE tasks ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + thread_id uuid NOT NULL, + tool_id uuid NOT NULL, + status text NOT NULL, + request jsonb NOT NULL, + tool jsonb NOT NULL, + latest_approval_id uuid, + latest_execution_id uuid, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT tasks_thread_user_fk + FOREIGN KEY (thread_id, user_id) + REFERENCES threads(id, user_id) + ON DELETE CASCADE, + CONSTRAINT tasks_tool_user_fk + FOREIGN KEY (tool_id, user_id) + REFERENCES tools(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tasks_latest_approval_user_fk + FOREIGN KEY (latest_approval_id, user_id) + REFERENCES approvals(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tasks_latest_execution_user_fk + FOREIGN KEY (latest_execution_id, user_id) + REFERENCES tool_executions(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT tasks_status_check + CHECK (status IN ('pending_approval', 'approved', 'executed', 'denied', 'blocked')), + CONSTRAINT tasks_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT tasks_tool_object_check + CHECK (jsonb_typeof(tool) = 'object'), + CONSTRAINT tasks_pending_approval_link_check + CHECK (status <> 'pending_approval' OR latest_approval_id IS NOT NULL), + CONSTRAINT tasks_execution_link_check + CHECK ( + ( + status IN ('executed', 'blocked') + AND latest_execution_id IS NOT NULL + ) + OR ( + status NOT IN ('executed', 'blocked') + AND latest_execution_id IS NULL + ) + ) + ); + + CREATE INDEX tasks_user_created_idx + ON tasks (user_id, created_at, id); + + CREATE UNIQUE INDEX tasks_latest_approval_unique_idx + ON tasks (user_id, latest_approval_id) + WHERE latest_approval_id IS NOT NULL; + + CREATE UNIQUE INDEX tasks_latest_execution_unique_idx + ON tasks (user_id, latest_execution_id) + WHERE latest_execution_id IS NOT NULL; + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON tasks TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY tasks_is_owner ON tasks + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS tasks", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0018_task_steps.py b/apps/api/alembic/versions/20260313_0018_task_steps.py new file mode 100644 index 0000000..9467472 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0018_task_steps.py @@ -0,0 +1,93 @@ +"""Add durable task-step review records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0018" +down_revision = "20260313_0017" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("task_steps",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE task_steps ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + task_id uuid NOT NULL, + sequence_no integer NOT NULL, + kind text NOT NULL, + status text NOT NULL, + request jsonb NOT NULL, + outcome jsonb NOT NULL, + trace_id uuid NOT NULL, + trace_kind text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT task_steps_task_user_fk + FOREIGN KEY (task_id, user_id) + REFERENCES tasks(id, user_id) + ON DELETE CASCADE, + CONSTRAINT task_steps_trace_user_fk + FOREIGN KEY (trace_id, user_id) + REFERENCES traces(id, user_id) + ON DELETE RESTRICT, + CONSTRAINT task_steps_sequence_no_check + CHECK (sequence_no > 0), + CONSTRAINT task_steps_kind_check + CHECK (kind IN ('governed_request')), + CONSTRAINT task_steps_status_check + CHECK (status IN ('created', 'approved', 'executed', 'blocked', 'denied')), + CONSTRAINT task_steps_request_object_check + CHECK (jsonb_typeof(request) = 'object'), + CONSTRAINT task_steps_outcome_object_check + CHECK (jsonb_typeof(outcome) = 'object'), + CONSTRAINT task_steps_trace_kind_nonempty_check + CHECK (length(trace_kind) > 0) + ); + + CREATE UNIQUE INDEX task_steps_task_sequence_idx + ON task_steps (user_id, task_id, sequence_no); + + CREATE INDEX task_steps_user_created_idx + ON task_steps (user_id, created_at, id); + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT, UPDATE ON task_steps TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY task_steps_is_owner ON task_steps + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS task_steps", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0019_task_step_lineage.py b/apps/api/alembic/versions/20260313_0019_task_step_lineage.py new file mode 100644 index 0000000..b0d98a5 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0019_task_step_lineage.py @@ -0,0 +1,58 @@ +"""Add explicit lineage fields for manual task-step continuation.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0019" +down_revision = "20260313_0018" +branch_labels = None +depends_on = None + +_UPGRADE_SCHEMA_STATEMENT = """ + ALTER TABLE task_steps + ADD COLUMN parent_step_id uuid, + ADD COLUMN source_approval_id uuid, + ADD COLUMN source_execution_id uuid, + ADD CONSTRAINT task_steps_parent_step_user_fk + FOREIGN KEY (parent_step_id, user_id) + REFERENCES task_steps(id, user_id) + ON DELETE RESTRICT, + ADD CONSTRAINT task_steps_source_approval_user_fk + FOREIGN KEY (source_approval_id, user_id) + REFERENCES approvals(id, user_id) + ON DELETE RESTRICT, + ADD CONSTRAINT task_steps_source_execution_user_fk + FOREIGN KEY (source_execution_id, user_id) + REFERENCES tool_executions(id, user_id) + ON DELETE RESTRICT, + ADD CONSTRAINT task_steps_parent_step_not_self_check + CHECK (parent_step_id IS NULL OR parent_step_id <> id); + """ + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE task_steps + DROP CONSTRAINT task_steps_parent_step_not_self_check, + DROP CONSTRAINT task_steps_source_execution_user_fk, + DROP CONSTRAINT task_steps_source_approval_user_fk, + DROP CONSTRAINT task_steps_parent_step_user_fk, + DROP COLUMN source_execution_id, + DROP COLUMN source_approval_id, + DROP COLUMN parent_step_id; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0020_approval_task_step_linkage.py b/apps/api/alembic/versions/20260313_0020_approval_task_step_linkage.py new file mode 100644 index 0000000..fe8a270 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0020_approval_task_step_linkage.py @@ -0,0 +1,41 @@ +"""Link approvals directly to their durable task step.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0020" +down_revision = "20260313_0019" +branch_labels = None +depends_on = None + +_UPGRADE_SCHEMA_STATEMENT = """ + ALTER TABLE approvals + ADD COLUMN task_step_id uuid, + ADD CONSTRAINT approvals_task_step_user_fk + FOREIGN KEY (task_step_id, user_id) + REFERENCES task_steps(id, user_id) + ON DELETE RESTRICT; + """ + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE approvals + DROP CONSTRAINT approvals_task_step_user_fk, + DROP COLUMN task_step_id; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0021_tool_execution_task_step_linkage.py b/apps/api/alembic/versions/20260313_0021_tool_execution_task_step_linkage.py new file mode 100644 index 0000000..6a26d41 --- /dev/null +++ b/apps/api/alembic/versions/20260313_0021_tool_execution_task_step_linkage.py @@ -0,0 +1,81 @@ +"""Link tool executions directly to their durable task step.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0021" +down_revision = "20260313_0020" +branch_labels = None +depends_on = None + +_UPGRADE_STATEMENTS = ( + """ + ALTER TABLE tool_executions + ADD COLUMN task_step_id uuid; + """, + """ + UPDATE tool_executions AS executions + SET task_step_id = COALESCE( + approvals.task_step_id, + ( + SELECT task_steps.id + FROM task_steps + WHERE task_steps.user_id = executions.user_id + AND task_steps.outcome ->> 'approval_id' = approvals.id::text + ORDER BY task_steps.created_at ASC, task_steps.id ASC + LIMIT 1 + ) + ) + FROM approvals + WHERE approvals.id = executions.approval_id + AND approvals.user_id = executions.user_id; + """, + """ + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 + FROM tool_executions + WHERE task_step_id IS NULL + ) THEN + RAISE EXCEPTION + 'tool_executions.task_step_id backfill failed for existing rows'; + END IF; + END; + $$; + """, + """ + ALTER TABLE tool_executions + ADD CONSTRAINT tool_executions_task_step_user_fk + FOREIGN KEY (task_step_id, user_id) + REFERENCES task_steps(id, user_id) + ON DELETE RESTRICT; + """, + """ + ALTER TABLE tool_executions + ALTER COLUMN task_step_id SET NOT NULL; + """, +) + +_DOWNGRADE_STATEMENTS = ( + """ + ALTER TABLE tool_executions + DROP CONSTRAINT tool_executions_task_step_user_fk, + DROP COLUMN task_step_id; + """, +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def upgrade() -> None: + _execute_statements(_UPGRADE_STATEMENTS) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/alembic/versions/20260313_0022_task_workspaces.py b/apps/api/alembic/versions/20260313_0022_task_workspaces.py new file mode 100644 index 0000000..626224f --- /dev/null +++ b/apps/api/alembic/versions/20260313_0022_task_workspaces.py @@ -0,0 +1,77 @@ +"""Add user-scoped task workspace records.""" + +from __future__ import annotations + +from alembic import op + + +revision = "20260313_0022" +down_revision = "20260313_0021" +branch_labels = None +depends_on = None + +_RLS_TABLES = ("task_workspaces",) + +_UPGRADE_SCHEMA_STATEMENT = """ + CREATE TABLE task_workspaces ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + task_id uuid NOT NULL, + status text NOT NULL, + local_path text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + UNIQUE (id, user_id), + CONSTRAINT task_workspaces_task_user_fk + FOREIGN KEY (task_id, user_id) + REFERENCES tasks(id, user_id) + ON DELETE CASCADE, + CONSTRAINT task_workspaces_status_check + CHECK (status IN ('active')), + CONSTRAINT task_workspaces_local_path_nonempty_check + CHECK (length(local_path) > 0) + ); + + CREATE INDEX task_workspaces_user_created_idx + ON task_workspaces (user_id, created_at, id); + + CREATE UNIQUE INDEX task_workspaces_active_task_idx + ON task_workspaces (user_id, task_id) + WHERE status = 'active'; + """ + +_UPGRADE_GRANT_STATEMENTS = ( + "GRANT SELECT, INSERT ON task_workspaces TO alicebot_app", +) + +_UPGRADE_POLICY_STATEMENT = """ + CREATE POLICY task_workspaces_is_owner ON task_workspaces + USING (user_id = app.current_user_id()) + WITH CHECK (user_id = app.current_user_id()); + """ + +_DOWNGRADE_STATEMENTS = ( + "DROP TABLE IF EXISTS task_workspaces", +) + + +def _execute_statements(statements: tuple[str, ...]) -> None: + for statement in statements: + op.execute(statement) + + +def _enable_row_level_security() -> None: + for table_name in _RLS_TABLES: + op.execute(f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table_name} FORCE ROW LEVEL SECURITY") + + +def upgrade() -> None: + op.execute(_UPGRADE_SCHEMA_STATEMENT) + _execute_statements(_UPGRADE_GRANT_STATEMENTS) + _enable_row_level_security() + op.execute(_UPGRADE_POLICY_STATEMENT) + + +def downgrade() -> None: + _execute_statements(_DOWNGRADE_STATEMENTS) diff --git a/apps/api/src/alicebot_api/__init__.py b/apps/api/src/alicebot_api/__init__.py new file mode 100644 index 0000000..39a8a6c --- /dev/null +++ b/apps/api/src/alicebot_api/__init__.py @@ -0,0 +1,2 @@ +"""AliceBot foundation API package.""" + diff --git a/apps/api/src/alicebot_api/approvals.py b/apps/api/src/alicebot_api/approvals.py new file mode 100644 index 0000000..7ef1827 --- /dev/null +++ b/apps/api/src/alicebot_api/approvals.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +from typing import cast +from uuid import UUID + +from alicebot_api.contracts import ( + APPROVAL_LIST_ORDER, + APPROVAL_REQUEST_VERSION_V0, + APPROVAL_RESOLUTION_VERSION_V0, + TRACE_KIND_APPROVAL_REQUEST, + TRACE_KIND_APPROVAL_RESOLUTION, + ApprovalApproveInput, + ApprovalDetailResponse, + ApprovalListResponse, + ApprovalListSummary, + ApprovalRecord, + ApprovalRejectInput, + ApprovalResolutionAction, + ApprovalResolutionOutcome, + ApprovalResolutionRecord, + ApprovalResolutionRequestTracePayload, + ApprovalResolutionResponse, + ApprovalResolutionStateTracePayload, + ApprovalResolutionSummaryTracePayload, + ApprovalRequestCreateInput, + ApprovalRequestCreateResponse, + ApprovalRequestTraceSummary, + ApprovalRoutingRecord, + TaskCreateInput, + TaskStepCreateInput, + ToolRoutingRequestInput, +) +from alicebot_api.store import ApprovalRow, ContinuityStore +from alicebot_api.tasks import ( + DEFAULT_TASK_STEP_KIND, + DEFAULT_TASK_STEP_SEQUENCE_NO, + create_task_step_for_governed_request, + create_task_for_governed_request, + sync_task_step_with_approval, + task_step_lifecycle_trace_events, + task_step_outcome_snapshot, + task_step_status_for_routing_decision, + sync_task_with_approval, + task_lifecycle_trace_events, + task_status_for_routing_decision, + validate_linked_task_step_for_approval, +) +from alicebot_api.tools import route_tool_invocation + + +class ApprovalNotFoundError(LookupError): + """Raised when an approval record is not visible inside the current user scope.""" + + +class ApprovalResolutionConflictError(RuntimeError): + """Raised when a visible approval record is no longer pending.""" + + +def _serialize_resolution(row: ApprovalRow) -> ApprovalResolutionRecord | None: + if row["resolved_at"] is None or row["resolved_by_user_id"] is None: + return None + return { + "resolved_at": row["resolved_at"].isoformat(), + "resolved_by_user_id": str(row["resolved_by_user_id"]), + } + + +def serialize_approval_row(row: ApprovalRow) -> ApprovalRecord: + return { + "id": str(row["id"]), + "thread_id": str(row["thread_id"]), + "task_step_id": None if row["task_step_id"] is None else str(row["task_step_id"]), + "status": cast(str, row["status"]), + "request": cast(dict[str, object], row["request"]), + "tool": cast(dict[str, object], row["tool"]), + "routing": cast(ApprovalRoutingRecord, row["routing"]), + "created_at": row["created_at"].isoformat(), + "resolution": _serialize_resolution(row), + } + + +_serialize_approval = serialize_approval_row + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _resolution_outcome( + *, + requested_action: ApprovalResolutionAction, + current_status: str, +) -> ApprovalResolutionOutcome: + if ( + requested_action == "approve" + and current_status == "approved" + ) or ( + requested_action == "reject" + and current_status == "rejected" + ): + return "duplicate_rejected" + return "conflict_rejected" + + +def _resolution_error( + approval_id: UUID, + *, + requested_action: ApprovalResolutionAction, + current_status: str, +) -> ApprovalResolutionConflictError: + if ( + requested_action == "approve" + and current_status == "approved" + ) or ( + requested_action == "reject" + and current_status == "rejected" + ): + return ApprovalResolutionConflictError(f"approval {approval_id} was already {current_status}") + + requested_status = "approved" if requested_action == "approve" else "rejected" + return ApprovalResolutionConflictError( + f"approval {approval_id} was already {current_status} and cannot be {requested_status}" + ) + + +def _resolve_approval( + store: ContinuityStore, + *, + user_id: UUID, + approval_id: UUID, + requested_action: ApprovalResolutionAction, + resolved_status: str, +) -> ApprovalResolutionResponse: + del user_id + + approval = store.get_approval_optional(approval_id) + if approval is None: + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + validate_linked_task_step_for_approval( + store, + approval_id=approval_id, + task_step_id=cast(UUID | None, approval["task_step_id"]), + ) + + previous_status = cast(str, approval["status"]) + current = approval + outcome: ApprovalResolutionOutcome + + if approval["status"] == "pending": + resolved = store.resolve_approval_optional( + approval_id=approval_id, + status=resolved_status, + ) + if resolved is None: + current = store.get_approval_optional(approval_id) + if current is None: + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + outcome = _resolution_outcome( + requested_action=requested_action, + current_status=cast(str, current["status"]), + ) + else: + current = resolved + outcome = "resolved" + else: + outcome = _resolution_outcome( + requested_action=requested_action, + current_status=previous_status, + ) + + trace = store.create_trace( + user_id=current["user_id"], + thread_id=current["thread_id"], + kind=TRACE_KIND_APPROVAL_RESOLUTION, + compiler_version=APPROVAL_RESOLUTION_VERSION_V0, + status="completed", + limits={ + "order": list(APPROVAL_LIST_ORDER), + "requested_action": requested_action, + "outcome": outcome, + }, + ) + + resolution = _serialize_resolution(current) + linked_task_step_id = None if current["task_step_id"] is None else str(current["task_step_id"]) + request_payload: ApprovalResolutionRequestTracePayload = { + "approval_id": str(approval_id), + "task_step_id": linked_task_step_id, + "requested_action": requested_action, + } + state_payload: ApprovalResolutionStateTracePayload = { + "approval_id": str(current["id"]), + "task_step_id": linked_task_step_id, + "requested_action": requested_action, + "previous_status": previous_status, + "outcome": outcome, + "current_status": cast(str, current["status"]), + "resolved_at": None if resolution is None else resolution["resolved_at"], + "resolved_by_user_id": None if resolution is None else resolution["resolved_by_user_id"], + } + summary_payload: ApprovalResolutionSummaryTracePayload = { + "approval_id": str(current["id"]), + "task_step_id": linked_task_step_id, + "requested_action": requested_action, + "outcome": outcome, + "final_status": cast(str, current["status"]), + } + task_transition = sync_task_with_approval( + store, + approval_id=current["id"], + approval_status=cast(str, current["status"]), + ) + task_step_transition = sync_task_step_with_approval( + store, + approval_id=current["id"], + task_step_id=cast(UUID | None, current["task_step_id"]), + approval_status=cast(str, current["status"]), + trace_id=trace["id"], + trace_kind=TRACE_KIND_APPROVAL_RESOLUTION, + ) + trace_events: list[tuple[str, dict[str, object]]] = [ + ("approval.resolution.request", cast(dict[str, object], request_payload)), + ("approval.resolution.state", cast(dict[str, object], state_payload)), + ("approval.resolution.summary", cast(dict[str, object], summary_payload)), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="approval_resolution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="approval_resolution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + if outcome != "resolved": + raise _resolution_error( + approval_id, + requested_action=requested_action, + current_status=cast(str, current["status"]), + ) + + return { + "approval": _serialize_approval(current), + "trace": { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + }, + } + + +def submit_approval_request( + store: ContinuityStore, + *, + user_id: UUID, + request: ApprovalRequestCreateInput, +) -> ApprovalRequestCreateResponse: + routing = route_tool_invocation( + store, + user_id=user_id, + request=ToolRoutingRequestInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise RuntimeError("validated thread disappeared before approval request trace creation") + + approval_persist_requested = routing["decision"] == "approval_required" + approval = None + approval_created = False + if routing["decision"] == "approval_required": + approval_row = store.create_approval( + thread_id=request.thread_id, + tool_id=request.tool_id, + task_step_id=None, + status="pending", + request=routing["request"], + tool=routing["tool"], + routing={ + "decision": routing["decision"], + "reasons": routing["reasons"], + "trace": routing["trace"], + }, + routing_trace_id=UUID(routing["trace"]["trace_id"]), + ) + approval = _serialize_approval(approval_row) + approval_created = True + + task = create_task_for_governed_request( + store, + request=TaskCreateInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + status=task_status_for_routing_decision(routing["decision"]), + request=routing["request"], + tool=routing["tool"], + latest_approval_id=None if approval is None else UUID(approval["id"]), + ), + )["task"] + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_APPROVAL_REQUEST, + compiler_version=APPROVAL_REQUEST_VERSION_V0, + status="completed", + limits={ + "order": list(APPROVAL_LIST_ORDER), + "persisted": approval_persist_requested, + }, + ) + task_step = create_task_step_for_governed_request( + store, + request=TaskStepCreateInput( + task_id=UUID(task["id"]), + sequence_no=DEFAULT_TASK_STEP_SEQUENCE_NO, + kind=DEFAULT_TASK_STEP_KIND, + status=task_step_status_for_routing_decision(routing["decision"]), + request=routing["request"], + outcome=task_step_outcome_snapshot( + routing_decision=routing["decision"], + approval_id=None if approval is None else approval["id"], + approval_status=None if approval is None else approval["status"], + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=trace["id"], + trace_kind=TRACE_KIND_APPROVAL_REQUEST, + ), + )["task_step"] + if approval is not None: + updated_approval = store.update_approval_task_step_optional( + approval_id=UUID(approval["id"]), + task_step_id=UUID(task_step["id"]), + ) + if updated_approval is None: + raise RuntimeError("approval disappeared while linking it to its originating task step") + approval = _serialize_approval(updated_approval) + + trace_events: list[tuple[str, dict[str, object]]] = [ + ("approval.request.request", request.as_payload()), + ( + "approval.request.routing", + { + "decision": routing["decision"], + "tool_id": routing["tool"]["id"], + "tool_key": routing["tool"]["tool_key"], + "tool_version": routing["tool"]["version"], + "routing_trace_id": routing["trace"]["trace_id"], + "routing_trace_event_count": routing["trace"]["trace_event_count"], + "reasons": routing["reasons"], + }, + ), + ( + "approval.request.persisted" if approval_created else "approval.request.skipped", + { + "approval_id": None if approval is None else approval["id"], + "task_step_id": None if approval is None else approval["task_step_id"], + "decision": routing["decision"], + "persisted": approval_created, + }, + ), + ( + "approval.request.summary", + { + "decision": routing["decision"], + "persisted": approval_created, + "approval_id": None if approval is None else approval["id"], + "task_step_id": None if approval is None else approval["task_step_id"], + }, + ), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task, + previous_status=None, + source="approval_request", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step, + previous_status=None, + source="approval_request", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + trace_summary: ApprovalRequestTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "request": routing["request"], + "decision": routing["decision"], + "tool": routing["tool"], + "reasons": routing["reasons"], + "task": task, + "approval": approval, + "routing_trace": routing["trace"], + "trace": trace_summary, + } + + +def approve_approval_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ApprovalApproveInput, +) -> ApprovalResolutionResponse: + return _resolve_approval( + store, + user_id=user_id, + approval_id=request.approval_id, + requested_action="approve", + resolved_status="approved", + ) + + +def reject_approval_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ApprovalRejectInput, +) -> ApprovalResolutionResponse: + return _resolve_approval( + store, + user_id=user_id, + approval_id=request.approval_id, + requested_action="reject", + resolved_status="rejected", + ) + + +def list_approval_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ApprovalListResponse: + del user_id + + items = [_serialize_approval(row) for row in store.list_approvals()] + summary: ApprovalListSummary = { + "total_count": len(items), + "order": list(APPROVAL_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_approval_record( + store: ContinuityStore, + *, + user_id: UUID, + approval_id: UUID, +) -> ApprovalDetailResponse: + del user_id + + approval = store.get_approval_optional(approval_id) + if approval is None: + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + return {"approval": _serialize_approval(approval)} diff --git a/apps/api/src/alicebot_api/compiler.py b/apps/api/src/alicebot_api/compiler.py new file mode 100644 index 0000000..3626eed --- /dev/null +++ b/apps/api/src/alicebot_api/compiler.py @@ -0,0 +1,832 @@ +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from alicebot_api.contracts import ( + COMPILER_VERSION_V0, + CompilerDecision, + CompileContextSemanticRetrievalInput, + CompilerRunResult, + CompiledContextPack, + ContextCompilerLimits, + ContextPackHybridMemorySummary, + ContextPackMemory, + ContextPackMemorySummary, + HybridMemoryDecisionTracePayload, + MemorySelectionSource, + SEMANTIC_MEMORY_RETRIEVAL_ORDER, + SemanticMemoryRetrievalRequestInput, + TRACE_KIND_CONTEXT_COMPILE, + TraceEventRecord, + isoformat_or_none, +) +from alicebot_api.semantic_retrieval import validate_semantic_memory_retrieval_request +from alicebot_api.store import ( + ContinuityStore, + EntityEdgeRow, + EntityRow, + EventRow, + MemoryRow, + SemanticMemoryRetrievalRow, + SessionRow, + ThreadRow, + UserRow, +) + +SUMMARY_TRACE_EVENT_KIND = "context.summary" +_UNBOUNDED_SEMANTIC_RETRIEVAL_LIMIT = 2_147_483_647 +HYBRID_MEMORY_SOURCE_PRECEDENCE: list[MemorySelectionSource] = ["symbolic", "semantic"] +HYBRID_SYMBOLIC_ORDER = ["updated_at_asc", "created_at_asc", "id_asc"] + + +@dataclass(frozen=True, slots=True) +class CompiledTraceRun: + trace_id: str + context_pack: CompiledContextPack + trace_event_count: int + + +@dataclass(frozen=True, slots=True) +class CompiledMemorySection: + items: list[ContextPackMemory] + summary: ContextPackMemorySummary + decisions: list[CompilerDecision] + + +@dataclass(slots=True) +class HybridMemoryCandidate: + memory: MemoryRow + sources: list[MemorySelectionSource] + semantic_score: float | None = None + + +def _session_sort_key( + session: SessionRow, + latest_session_sequence: dict[UUID, int], +) -> tuple[int, str, str, str]: + latest_sequence = latest_session_sequence.get(session["id"], -1) + started_at = isoformat_or_none(session["started_at"]) or "" + created_at = session["created_at"].isoformat() + return (latest_sequence, started_at, created_at, str(session["id"])) + + +def _serialize_user(user: UserRow) -> dict[str, str | None]: + return { + "id": str(user["id"]), + "email": user["email"], + "display_name": user["display_name"], + "created_at": user["created_at"].isoformat(), + } + + +def _serialize_thread(thread: ThreadRow) -> dict[str, str]: + return { + "id": str(thread["id"]), + "title": thread["title"], + "created_at": thread["created_at"].isoformat(), + "updated_at": thread["updated_at"].isoformat(), + } + + +def _serialize_session(session: SessionRow) -> dict[str, str | None]: + return { + "id": str(session["id"]), + "status": session["status"], + "started_at": isoformat_or_none(session["started_at"]), + "ended_at": isoformat_or_none(session["ended_at"]), + "created_at": session["created_at"].isoformat(), + } + + +def _serialize_event(event: EventRow) -> dict[str, object]: + return { + "id": str(event["id"]), + "session_id": None if event["session_id"] is None else str(event["session_id"]), + "sequence_no": event["sequence_no"], + "kind": event["kind"], + "payload": event["payload"], + "created_at": event["created_at"].isoformat(), + } + + +def _memory_sort_key(memory: MemoryRow) -> tuple[str, str, str]: + return ( + memory["updated_at"].isoformat(), + memory["created_at"].isoformat(), + str(memory["id"]), + ) + + +def _serialize_memory(memory: MemoryRow) -> dict[str, object]: + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "source_provenance": { + "sources": ["symbolic"], + "semantic_score": None, + }, + } + + +def _entity_sort_key(entity: EntityRow) -> tuple[str, str]: + return (entity["created_at"].isoformat(), str(entity["id"])) + + +def _serialize_entity(entity: EntityRow) -> dict[str, object]: + return { + "id": str(entity["id"]), + "entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + "created_at": entity["created_at"].isoformat(), + } + + +def _entity_edge_sort_key(edge: EntityEdgeRow) -> tuple[str, str]: + return (edge["created_at"].isoformat(), str(edge["id"])) + + +def _serialize_entity_edge(edge: EntityEdgeRow) -> dict[str, object]: + return { + "id": str(edge["id"]), + "from_entity_id": str(edge["from_entity_id"]), + "to_entity_id": str(edge["to_entity_id"]), + "relationship_type": edge["relationship_type"], + "valid_from": isoformat_or_none(edge["valid_from"]), + "valid_to": isoformat_or_none(edge["valid_to"]), + "source_memory_ids": edge["source_memory_ids"], + "created_at": edge["created_at"].isoformat(), + } + + +def _semantic_memory_sort_key(memory: SemanticMemoryRetrievalRow) -> tuple[float, str, str]: + return (-float(memory["score"]), memory["created_at"].isoformat(), str(memory["id"])) + + +def _semantic_deleted_memory_sort_key(memory: MemoryRow) -> tuple[str, str, str]: + return ( + memory["updated_at"].isoformat(), + memory["created_at"].isoformat(), + str(memory["id"]), + ) + + +def _empty_hybrid_memory_summary() -> ContextPackHybridMemorySummary: + return { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 0, + "semantic_selected_count": 0, + "merged_candidate_count": 0, + "deduplicated_count": 0, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": list(HYBRID_MEMORY_SOURCE_PRECEDENCE), + "symbolic_order": list(HYBRID_SYMBOLIC_ORDER), + "semantic_order": list(SEMANTIC_MEMORY_RETRIEVAL_ORDER), + } + + +def _hybrid_memory_decision_metadata( + *, + embedding_config_id: UUID | None, + memory_key: str, + status: str, + source_event_ids: list[str], + selected_sources: list[MemorySelectionSource], + semantic_score: float | None, +) -> HybridMemoryDecisionTracePayload: + return { + "embedding_config_id": None if embedding_config_id is None else str(embedding_config_id), + "memory_key": memory_key, + "status": status, + "source_event_ids": source_event_ids, + "selected_sources": list(selected_sources), + "semantic_score": semantic_score, + } + + +def _serialize_hybrid_memory(candidate: HybridMemoryCandidate) -> ContextPackMemory: + memory = candidate.memory + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "source_provenance": { + "sources": list(candidate.sources), + "semantic_score": candidate.semantic_score, + }, + } + + +def _build_symbolic_memory_section( + *, + memories: list[MemoryRow], + limits: ContextCompilerLimits, +) -> CompiledMemorySection: + ordered_memories = sorted(memories, key=_memory_sort_key) + active_memories = [memory for memory in ordered_memories if memory["status"] == "active"] + deleted_memories = [memory for memory in ordered_memories if memory["status"] != "active"] + symbolic_candidates = active_memories[-limits.max_memories :] if limits.max_memories > 0 else [] + memory_candidates = [ + HybridMemoryCandidate(memory=memory, sources=["symbolic"]) + for memory in symbolic_candidates + ] + decisions: list[CompilerDecision] = [] + + for position, candidate in enumerate(memory_candidates, start=1): + decisions.append( + CompilerDecision( + "included", + "memory", + candidate.memory["id"], + "within_hybrid_memory_limit", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=None, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=None, + ), + ) + ) + + for position, memory in enumerate(deleted_memories, start=1): + decisions.append( + CompilerDecision( + "excluded", + "memory", + memory["id"], + "hybrid_memory_deleted", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=None, + memory_key=memory["memory_key"], + status=memory["status"], + source_event_ids=memory["source_event_ids"], + selected_sources=["symbolic"], + semantic_score=None, + ), + ) + ) + + included_items = [_serialize_hybrid_memory(candidate) for candidate in memory_candidates] + return CompiledMemorySection( + items=included_items, + summary={ + "candidate_count": len(memory_candidates) + len(deleted_memories), + "included_count": len(included_items), + "excluded_deleted_count": len(deleted_memories), + "excluded_limit_count": 0, + "hybrid_retrieval": { + **_empty_hybrid_memory_summary(), + "symbolic_selected_count": len(memory_candidates), + "merged_candidate_count": len(memory_candidates), + "included_symbolic_only_count": len(included_items), + }, + }, + decisions=decisions, + ) + + +def _compile_memory_section( + store: ContinuityStore, + *, + memories: list[MemoryRow], + limits: ContextCompilerLimits, + semantic_retrieval: CompileContextSemanticRetrievalInput | None, +) -> CompiledMemorySection: + if semantic_retrieval is None: + return _build_symbolic_memory_section(memories=memories, limits=limits) + + ordered_memories = sorted(memories, key=_memory_sort_key) + active_memories = [memory for memory in ordered_memories if memory["status"] == "active"] + deleted_memories = [memory for memory in ordered_memories if memory["status"] != "active"] + symbolic_candidates = active_memories[-limits.max_memories :] if limits.max_memories > 0 else [] + active_memories_by_id = {memory["id"]: memory for memory in active_memories} + + request = SemanticMemoryRetrievalRequestInput( + embedding_config_id=semantic_retrieval.embedding_config_id, + query_vector=semantic_retrieval.query_vector, + limit=semantic_retrieval.limit, + ) + _config, query_vector = validate_semantic_memory_retrieval_request(store, request=request) + ordered_semantic_candidates = sorted( + store.retrieve_semantic_memory_matches( + embedding_config_id=semantic_retrieval.embedding_config_id, + query_vector=query_vector, + limit=_UNBOUNDED_SEMANTIC_RETRIEVAL_LIMIT, + ), + key=_semantic_memory_sort_key, + ) + selected_semantic_candidates = ordered_semantic_candidates[: semantic_retrieval.limit] + + merged_candidates: list[HybridMemoryCandidate] = [ + HybridMemoryCandidate(memory=memory, sources=["symbolic"]) + for memory in symbolic_candidates + ] + merged_candidate_ids = {candidate.memory["id"] for candidate in merged_candidates} + deduplication_decisions: list[CompilerDecision] = [] + deduplicated_count = 0 + + for position, semantic_candidate in enumerate(selected_semantic_candidates, start=1): + memory = active_memories_by_id.get(semantic_candidate["id"], semantic_candidate) + if semantic_candidate["id"] in merged_candidate_ids: + deduplicated_count += 1 + for candidate in merged_candidates: + if candidate.memory["id"] != semantic_candidate["id"]: + continue + if "semantic" not in candidate.sources: + candidate.sources.append("semantic") + candidate.semantic_score = float(semantic_candidate["score"]) + deduplication_decisions.append( + CompilerDecision( + "included", + "memory", + semantic_candidate["id"], + "hybrid_memory_deduplicated", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=candidate.semantic_score, + ), + ) + ) + break + continue + + merged_candidate_ids.add(semantic_candidate["id"]) + merged_candidates.append( + HybridMemoryCandidate( + memory=memory, + sources=["semantic"], + semantic_score=float(semantic_candidate["score"]), + ) + ) + + deleted_candidates = [ + HybridMemoryCandidate( + memory=memory, + sources=["symbolic"], + ) + for memory in sorted(deleted_memories, key=_semantic_deleted_memory_sort_key) + ] + + decisions = list(deduplication_decisions) + included_candidates = merged_candidates[: limits.max_memories] if limits.max_memories > 0 else [] + excluded_candidates = merged_candidates[limits.max_memories :] if limits.max_memories > 0 else merged_candidates + included_symbolic_only_count = 0 + included_semantic_only_count = 0 + included_dual_source_count = 0 + + for position, candidate in enumerate(merged_candidates, start=1): + if position <= limits.max_memories and limits.max_memories > 0: + if candidate.sources == ["symbolic"]: + included_symbolic_only_count += 1 + elif candidate.sources == ["semantic"]: + included_semantic_only_count += 1 + else: + included_dual_source_count += 1 + decisions.append( + CompilerDecision( + "included", + "memory", + candidate.memory["id"], + "within_hybrid_memory_limit", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=candidate.semantic_score, + ), + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "memory", + candidate.memory["id"], + "hybrid_memory_limit_exceeded", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=candidate.semantic_score, + ), + ) + ) + + for position, candidate in enumerate(deleted_candidates, start=1): + decisions.append( + CompilerDecision( + "excluded", + "memory", + candidate.memory["id"], + "hybrid_memory_deleted", + position, + metadata=_hybrid_memory_decision_metadata( + embedding_config_id=semantic_retrieval.embedding_config_id, + memory_key=candidate.memory["memory_key"], + status=candidate.memory["status"], + source_event_ids=candidate.memory["source_event_ids"], + selected_sources=candidate.sources, + semantic_score=None, + ), + ) + ) + + return CompiledMemorySection( + items=[_serialize_hybrid_memory(candidate) for candidate in included_candidates], + summary={ + "candidate_count": len(merged_candidates) + len(deleted_candidates), + "included_count": len(included_candidates), + "excluded_deleted_count": len(deleted_candidates), + "excluded_limit_count": len(excluded_candidates), + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(semantic_retrieval.embedding_config_id), + "query_vector_dimensions": len(query_vector), + "semantic_limit": semantic_retrieval.limit, + "symbolic_selected_count": len(symbolic_candidates), + "semantic_selected_count": len(selected_semantic_candidates), + "merged_candidate_count": len(merged_candidates), + "deduplicated_count": deduplicated_count, + "included_symbolic_only_count": included_symbolic_only_count, + "included_semantic_only_count": included_semantic_only_count, + "included_dual_source_count": included_dual_source_count, + "similarity_metric": "cosine_similarity", + "source_precedence": list(HYBRID_MEMORY_SOURCE_PRECEDENCE), + "symbolic_order": list(HYBRID_SYMBOLIC_ORDER), + "semantic_order": list(SEMANTIC_MEMORY_RETRIEVAL_ORDER), + }, + }, + decisions=decisions, + ) + + +def compile_continuity_context( + *, + user: UserRow, + thread: ThreadRow, + sessions: list[SessionRow], + events: list[EventRow], + memories: list[MemoryRow], + entities: list[EntityRow], + entity_edges: list[EntityEdgeRow], + limits: ContextCompilerLimits, + memory_section: CompiledMemorySection | None = None, +) -> CompilerRunResult: + latest_session_sequence: dict[UUID, int] = {} + for event in events: + session_id = event["session_id"] + if session_id is None: + continue + latest_session_sequence[session_id] = max( + latest_session_sequence.get(session_id, -1), + event["sequence_no"], + ) + + ordered_sessions = sorted( + sessions, + key=lambda session: _session_sort_key(session, latest_session_sequence), + ) + included_sessions = ordered_sessions[-limits.max_sessions :] if limits.max_sessions > 0 else [] + included_session_ids = {session["id"] for session in included_sessions} + + decisions: list[CompilerDecision] = [ + CompilerDecision("included", "user", user["id"], "scope_user", 1), + CompilerDecision("included", "thread", thread["id"], "scope_thread", 1), + ] + + for position, session in enumerate(included_sessions, start=1): + decisions.append( + CompilerDecision( + "included", + "session", + session["id"], + "within_session_limit", + position, + ) + ) + + excluded_sessions = ordered_sessions[: max(len(ordered_sessions) - len(included_sessions), 0)] + for position, session in enumerate(excluded_sessions, start=1): + decisions.append( + CompilerDecision( + "excluded", + "session", + session["id"], + "session_limit_exceeded", + position, + ) + ) + + eligible_events: list[EventRow] = [] + for event in events: + if event["session_id"] is not None and event["session_id"] not in included_session_ids: + decisions.append( + CompilerDecision( + "excluded", + "event", + event["id"], + "session_not_included", + event["sequence_no"], + ) + ) + continue + eligible_events.append(event) + + included_events = eligible_events[-limits.max_events :] if limits.max_events > 0 else [] + included_event_ids = {event["id"] for event in included_events} + + for event in eligible_events: + if event["id"] in included_event_ids: + decisions.append( + CompilerDecision( + "included", + "event", + event["id"], + "within_event_limit", + event["sequence_no"], + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "event", + event["id"], + "event_limit_exceeded", + event["sequence_no"], + ) + ) + + resolved_memory_section = memory_section or _build_symbolic_memory_section( + memories=memories, + limits=limits, + ) + decisions.extend(resolved_memory_section.decisions) + ordered_entities = sorted(entities, key=_entity_sort_key) + included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] + included_entity_ids = {entity["id"] for entity in included_entities} + excluded_entity_limit_count = max(len(ordered_entities) - len(included_entities), 0) + + for position, entity in enumerate(ordered_entities, start=1): + if entity["id"] in included_entity_ids: + decisions.append( + CompilerDecision( + "included", + "entity", + entity["id"], + "within_entity_limit", + position, + metadata={ + "record_entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + }, + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "entity", + entity["id"], + "entity_limit_exceeded", + position, + metadata={ + "record_entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + }, + ) + ) + + ordered_candidate_entity_edges = sorted( + [ + edge + for edge in entity_edges + if edge["from_entity_id"] in included_entity_ids + or edge["to_entity_id"] in included_entity_ids + ], + key=_entity_edge_sort_key, + ) + included_entity_edges = ( + ordered_candidate_entity_edges[-limits.max_entity_edges :] + if limits.max_entity_edges > 0 + else [] + ) + included_entity_edge_ids = {edge["id"] for edge in included_entity_edges} + excluded_entity_edge_limit_count = max( + len(ordered_candidate_entity_edges) - len(included_entity_edges), + 0, + ) + + for position, edge in enumerate(ordered_candidate_entity_edges, start=1): + attached_included_entity_ids = [ + str(entity_id) + for entity_id in (edge["from_entity_id"], edge["to_entity_id"]) + if entity_id in included_entity_ids + ] + metadata = { + "from_entity_id": str(edge["from_entity_id"]), + "to_entity_id": str(edge["to_entity_id"]), + "relationship_type": edge["relationship_type"], + "valid_from": isoformat_or_none(edge["valid_from"]), + "valid_to": isoformat_or_none(edge["valid_to"]), + "source_memory_ids": edge["source_memory_ids"], + "attached_included_entity_ids": attached_included_entity_ids, + } + if edge["id"] in included_entity_edge_ids: + decisions.append( + CompilerDecision( + "included", + "entity_edge", + edge["id"], + "within_entity_edge_limit", + position, + metadata=metadata, + ) + ) + continue + + decisions.append( + CompilerDecision( + "excluded", + "entity_edge", + edge["id"], + "entity_edge_limit_exceeded", + position, + metadata=metadata, + ) + ) + + trace_events = [decision.to_trace_event() for decision in decisions] + trace_events.append( + TraceEventRecord( + kind=SUMMARY_TRACE_EVENT_KIND, + payload={ + "included_session_count": len(included_sessions), + "excluded_session_count": len(excluded_sessions), + "included_event_count": len(included_events), + "excluded_event_count": len(events) - len(included_events), + "included_memory_count": resolved_memory_section.summary["included_count"], + "excluded_memory_count": ( + resolved_memory_section.summary["excluded_deleted_count"] + + resolved_memory_section.summary["excluded_limit_count"] + ), + "excluded_deleted_memory_count": resolved_memory_section.summary[ + "excluded_deleted_count" + ], + "excluded_memory_limit_count": resolved_memory_section.summary[ + "excluded_limit_count" + ], + "hybrid_memory_requested": resolved_memory_section.summary["hybrid_retrieval"][ + "requested" + ], + "hybrid_memory_candidate_count": resolved_memory_section.summary["candidate_count"], + "hybrid_memory_merged_candidate_count": resolved_memory_section.summary[ + "hybrid_retrieval" + ]["merged_candidate_count"], + "hybrid_memory_deduplicated_count": resolved_memory_section.summary[ + "hybrid_retrieval" + ]["deduplicated_count"], + "included_dual_source_memory_count": resolved_memory_section.summary[ + "hybrid_retrieval" + ]["included_dual_source_count"], + "included_entity_count": len(included_entities), + "excluded_entity_count": excluded_entity_limit_count, + "excluded_entity_limit_count": excluded_entity_limit_count, + "included_entity_edge_count": len(included_entity_edges), + "excluded_entity_edge_count": excluded_entity_edge_limit_count, + "excluded_entity_edge_limit_count": excluded_entity_edge_limit_count, + "compiler_version": COMPILER_VERSION_V0, + }, + ) + ) + + return CompilerRunResult( + context_pack={ + "compiler_version": COMPILER_VERSION_V0, + "scope": { + "user_id": str(user["id"]), + "thread_id": str(thread["id"]), + }, + "limits": { + "max_sessions": limits.max_sessions, + "max_events": limits.max_events, + "max_memories": limits.max_memories, + "max_entities": limits.max_entities, + "max_entity_edges": limits.max_entity_edges, + }, + "user": _serialize_user(user), + "thread": _serialize_thread(thread), + "sessions": [_serialize_session(session) for session in included_sessions], + "events": [_serialize_event(event) for event in included_events], + "memories": list(resolved_memory_section.items), + "memory_summary": resolved_memory_section.summary, + "entities": [_serialize_entity(entity) for entity in included_entities], + "entity_summary": { + "candidate_count": len(ordered_entities), + "included_count": len(included_entities), + "excluded_limit_count": excluded_entity_limit_count, + }, + "entity_edges": [_serialize_entity_edge(edge) for edge in included_entity_edges], + "entity_edge_summary": { + "anchor_entity_count": len(included_entities), + "candidate_count": len(ordered_candidate_entity_edges), + "included_count": len(included_entity_edges), + "excluded_limit_count": excluded_entity_edge_limit_count, + }, + }, + trace_events=trace_events, + ) + + +def compile_and_persist_trace( + store: ContinuityStore, + *, + user_id: UUID, + thread_id: UUID, + limits: ContextCompilerLimits, + semantic_retrieval: CompileContextSemanticRetrievalInput | None = None, +) -> CompiledTraceRun: + user = store.get_user(user_id) + thread = store.get_thread(thread_id) + sessions = store.list_thread_sessions(thread_id) + events = store.list_thread_events(thread_id) + memories = store.list_context_memories() + memory_section = _compile_memory_section( + store, + memories=memories, + limits=limits, + semantic_retrieval=semantic_retrieval, + ) + entities = store.list_entities() + ordered_entities = sorted(entities, key=_entity_sort_key) + included_entities = ordered_entities[-limits.max_entities :] if limits.max_entities > 0 else [] + entity_edges = store.list_entity_edges_for_entities([entity["id"] for entity in included_entities]) + compiler_run = compile_continuity_context( + user=user, + thread=thread, + sessions=sessions, + events=events, + memories=memories, + entities=entities, + entity_edges=entity_edges, + limits=limits, + memory_section=memory_section, + ) + trace = store.create_trace( + user_id=user_id, + thread_id=thread_id, + kind=TRACE_KIND_CONTEXT_COMPILE, + compiler_version=COMPILER_VERSION_V0, + status="completed", + limits=limits.as_payload(), + ) + + for sequence_no, trace_event in enumerate(compiler_run.trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=trace_event.kind, + payload=trace_event.payload, + ) + + return CompiledTraceRun( + trace_id=str(trace["id"]), + context_pack=compiler_run.context_pack, + trace_event_count=len(compiler_run.trace_events), + ) diff --git a/apps/api/src/alicebot_api/config.py b/apps/api/src/alicebot_api/config.py new file mode 100644 index 0000000..3f41bd7 --- /dev/null +++ b/apps/api/src/alicebot_api/config.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from functools import lru_cache +import os + +DEFAULT_APP_ENV = "development" +DEFAULT_APP_HOST = "127.0.0.1" +DEFAULT_APP_PORT = 8000 +DEFAULT_DATABASE_NAME = "alicebot" +DEFAULT_DATABASE_HOST = "localhost" +DEFAULT_DATABASE_PORT = 5432 +DEFAULT_DATABASE_URL = ( + f"postgresql://alicebot_app:alicebot_app@{DEFAULT_DATABASE_HOST}:" + f"{DEFAULT_DATABASE_PORT}/{DEFAULT_DATABASE_NAME}" +) +DEFAULT_DATABASE_ADMIN_URL = ( + f"postgresql://alicebot_admin:alicebot_admin@{DEFAULT_DATABASE_HOST}:" + f"{DEFAULT_DATABASE_PORT}/{DEFAULT_DATABASE_NAME}" +) +DEFAULT_REDIS_URL = f"redis://{DEFAULT_DATABASE_HOST}:6379/0" +DEFAULT_S3_ENDPOINT_URL = "http://localhost:9000" +DEFAULT_S3_ACCESS_KEY = "alicebot" +DEFAULT_S3_SECRET_KEY = "alicebot-secret" +DEFAULT_S3_BUCKET = "alicebot-local" +DEFAULT_HEALTHCHECK_TIMEOUT_SECONDS = 2 +DEFAULT_MODEL_PROVIDER = "openai_responses" +DEFAULT_MODEL_BASE_URL = "https://api.openai.com/v1" +DEFAULT_MODEL_NAME = "gpt-5-mini" +DEFAULT_MODEL_API_KEY = "" +DEFAULT_MODEL_TIMEOUT_SECONDS = 30 +DEFAULT_TASK_WORKSPACE_ROOT = "/tmp/alicebot/task-workspaces" + +Environment = Mapping[str, str] + + +def _get_env_value(env: Environment, key: str, default: str) -> str: + return env.get(key, default) + + +def _get_env_int(env: Environment, key: str, default: int) -> int: + raw_value = env.get(key) + if raw_value is None: + return default + + try: + return int(raw_value) + except ValueError as exc: + raise ValueError(f"{key} must be an integer") from exc + + +@dataclass(frozen=True) +class Settings: + app_env: str = DEFAULT_APP_ENV + app_host: str = DEFAULT_APP_HOST + app_port: int = DEFAULT_APP_PORT + database_url: str = DEFAULT_DATABASE_URL + database_admin_url: str = DEFAULT_DATABASE_ADMIN_URL + redis_url: str = DEFAULT_REDIS_URL + s3_endpoint_url: str = DEFAULT_S3_ENDPOINT_URL + s3_access_key: str = DEFAULT_S3_ACCESS_KEY + s3_secret_key: str = DEFAULT_S3_SECRET_KEY + s3_bucket: str = DEFAULT_S3_BUCKET + healthcheck_timeout_seconds: int = DEFAULT_HEALTHCHECK_TIMEOUT_SECONDS + model_provider: str = DEFAULT_MODEL_PROVIDER + model_base_url: str = DEFAULT_MODEL_BASE_URL + model_name: str = DEFAULT_MODEL_NAME + model_api_key: str = DEFAULT_MODEL_API_KEY + model_timeout_seconds: int = DEFAULT_MODEL_TIMEOUT_SECONDS + task_workspace_root: str = DEFAULT_TASK_WORKSPACE_ROOT + + @classmethod + def from_env(cls, env: Environment | None = None) -> "Settings": + current_env = os.environ if env is None else env + return cls( + app_env=_get_env_value(current_env, "APP_ENV", cls.app_env), + app_host=_get_env_value(current_env, "APP_HOST", cls.app_host), + app_port=_get_env_int(current_env, "APP_PORT", cls.app_port), + database_url=_get_env_value(current_env, "DATABASE_URL", cls.database_url), + database_admin_url=_get_env_value( + current_env, + "DATABASE_ADMIN_URL", + cls.database_admin_url, + ), + redis_url=_get_env_value(current_env, "REDIS_URL", cls.redis_url), + s3_endpoint_url=_get_env_value( + current_env, + "S3_ENDPOINT_URL", + cls.s3_endpoint_url, + ), + s3_access_key=_get_env_value(current_env, "S3_ACCESS_KEY", cls.s3_access_key), + s3_secret_key=_get_env_value(current_env, "S3_SECRET_KEY", cls.s3_secret_key), + s3_bucket=_get_env_value(current_env, "S3_BUCKET", cls.s3_bucket), + healthcheck_timeout_seconds=_get_env_int( + current_env, + "HEALTHCHECK_TIMEOUT_SECONDS", + cls.healthcheck_timeout_seconds, + ), + model_provider=_get_env_value(current_env, "MODEL_PROVIDER", cls.model_provider), + model_base_url=_get_env_value(current_env, "MODEL_BASE_URL", cls.model_base_url), + model_name=_get_env_value(current_env, "MODEL_NAME", cls.model_name), + model_api_key=_get_env_value(current_env, "MODEL_API_KEY", cls.model_api_key), + model_timeout_seconds=_get_env_int( + current_env, + "MODEL_TIMEOUT_SECONDS", + cls.model_timeout_seconds, + ), + task_workspace_root=_get_env_value( + current_env, + "TASK_WORKSPACE_ROOT", + cls.task_workspace_root, + ), + ) + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + return Settings.from_env() diff --git a/apps/api/src/alicebot_api/contracts.py b/apps/api/src/alicebot_api/contracts.py new file mode 100644 index 0000000..fc794c2 --- /dev/null +++ b/apps/api/src/alicebot_api/contracts.py @@ -0,0 +1,2080 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Literal, NotRequired, TypedDict +from uuid import UUID + +from alicebot_api.store import JsonObject, JsonValue + +DecisionKind = Literal["included", "excluded"] +AdmissionAction = Literal["NOOP", "ADD", "UPDATE", "DELETE"] +MemoryStatus = Literal["active", "deleted"] +MemoryReviewStatusFilter = Literal["active", "deleted", "all"] +MemoryReviewLabelValue = Literal["correct", "incorrect", "outdated", "insufficient_evidence"] +EntityType = Literal["person", "merchant", "product", "project", "routine"] +EmbeddingConfigStatus = Literal["active", "deprecated", "disabled"] +ConsentStatus = Literal["granted", "revoked"] +ApprovalStatus = Literal["pending", "approved", "rejected"] +ApprovalResolutionAction = Literal["approve", "reject"] +ApprovalResolutionOutcome = Literal["resolved", "duplicate_rejected", "conflict_rejected"] +TaskStatus = Literal["pending_approval", "approved", "executed", "denied", "blocked"] +TaskWorkspaceStatus = Literal["active"] +TaskLifecycleSource = Literal[ + "approval_request", + "approval_resolution", + "proxy_execution", + "task_step_continuation", + "task_step_sequence", + "task_step_transition", +] +TaskStepKind = Literal["governed_request"] +TaskStepStatus = Literal["created", "approved", "executed", "blocked", "denied"] +ProxyExecutionStatus = Literal["completed", "blocked"] +ExecutionBudgetStatus = Literal["active", "inactive", "superseded"] +ExecutionBudgetDecision = Literal["allow", "block"] +ExecutionBudgetDecisionReason = Literal["no_matching_budget", "within_budget", "budget_exceeded"] +ExecutionBudgetCountScope = Literal["lifetime", "rolling_window"] +ExecutionBudgetLifecycleAction = Literal["deactivate", "supersede"] +ExecutionBudgetLifecycleOutcome = Literal["deactivated", "superseded", "rejected"] +PolicyEffect = Literal["allow", "deny", "require_approval"] +PolicyEvaluationReasonCode = Literal[ + "matched_policy", + "policy_effect_allow", + "policy_effect_deny", + "policy_effect_require_approval", + "consent_missing", + "consent_revoked", + "no_matching_policy", +] +ToolMetadataVersion = Literal["tool_metadata_v0"] +ToolAllowlistReasonCode = Literal[ + "tool_metadata_matched", + "tool_action_unsupported", + "tool_scope_unsupported", + "tool_domain_mismatch", + "tool_risk_mismatch", + "matched_policy", + "policy_effect_allow", + "policy_effect_deny", + "policy_effect_require_approval", + "consent_missing", + "consent_revoked", + "no_matching_policy", +] +ToolAllowlistDecision = Literal["allowed", "denied", "approval_required"] +ToolRoutingDecision = Literal["ready", "denied", "approval_required"] +PromptSectionName = Literal["system", "developer", "context", "conversation"] +ModelProvider = Literal["openai_responses"] +ModelFinishReason = Literal["completed", "incomplete"] +ExplicitPreferencePattern = Literal[ + "i_like", + "i_dont_like", + "i_prefer", + "remember_that_i_like", + "remember_that_i_dont_like", + "remember_that_i_prefer", +] +MemorySelectionSource = Literal["symbolic", "semantic"] + +DEFAULT_MAX_SESSIONS = 3 +DEFAULT_MAX_EVENTS = 8 +DEFAULT_MAX_MEMORIES = 5 +DEFAULT_MAX_ENTITIES = 5 +DEFAULT_MAX_ENTITY_EDGES = 10 +DEFAULT_MEMORY_REVIEW_LIMIT = 20 +MAX_MEMORY_REVIEW_LIMIT = 100 +DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT = 5 +MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT = 50 +COMPILER_VERSION_V0 = "continuity_v0" +PROMPT_ASSEMBLY_VERSION_V0 = "prompt_assembly_v0" +RESPONSE_GENERATION_VERSION_V0 = "response_generation_v0" +TRACE_KIND_CONTEXT_COMPILE = "context.compile" +TRACE_KIND_RESPONSE_GENERATE = "response.generate" +MEMORY_REVIEW_ORDER = ["updated_at_desc", "created_at_desc", "id_desc"] +MEMORY_REVIEW_QUEUE_ORDER = ["updated_at_desc", "created_at_desc", "id_desc"] +MEMORY_REVISION_REVIEW_ORDER = ["sequence_no_asc"] +MEMORY_REVIEW_LABEL_VALUES = [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", +] +MEMORY_REVIEW_LABEL_ORDER = ["created_at_asc", "id_asc"] +ENTITY_TYPES = [ + "person", + "merchant", + "product", + "project", + "routine", +] +ENTITY_LIST_ORDER = ["created_at_asc", "id_asc"] +ENTITY_EDGE_LIST_ORDER = ["created_at_asc", "id_asc"] +EMBEDDING_CONFIG_LIST_ORDER = ["created_at_asc", "id_asc"] +MEMORY_EMBEDDING_LIST_ORDER = ["created_at_asc", "id_asc"] +SEMANTIC_MEMORY_RETRIEVAL_ORDER = ["score_desc", "created_at_asc", "id_asc"] +EMBEDDING_CONFIG_STATUSES = ["active", "deprecated", "disabled"] +CONSENT_STATUSES = ["granted", "revoked"] +CONSENT_LIST_ORDER = ["consent_key_asc", "created_at_asc", "id_asc"] +POLICY_EFFECTS = ["allow", "deny", "require_approval"] +POLICY_LIST_ORDER = ["priority_asc", "created_at_asc", "id_asc"] +POLICY_EVALUATION_VERSION_V0 = "policy_evaluation_v0" +TRACE_KIND_POLICY_EVALUATE = "policy.evaluate" +TOOL_METADATA_VERSION_V0 = "tool_metadata_v0" +TOOL_LIST_ORDER = ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"] +TOOL_ALLOWLIST_EVALUATION_VERSION_V0 = "tool_allowlist_evaluation_v0" +TRACE_KIND_TOOL_ALLOWLIST_EVALUATE = "tool.allowlist.evaluate" +TOOL_ROUTING_VERSION_V0 = "tool_routing_v0" +TRACE_KIND_TOOL_ROUTE = "tool.route" +APPROVAL_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_WORKSPACE_LIST_ORDER = ["created_at_asc", "id_asc"] +TASK_STEP_LIST_ORDER = ["sequence_no_asc", "created_at_asc", "id_asc"] +TOOL_EXECUTION_LIST_ORDER = ["executed_at_asc", "id_asc"] +EXECUTION_BUDGET_LIST_ORDER = ["created_at_asc", "id_asc"] +EXECUTION_BUDGET_MATCH_ORDER = ["specificity_desc", "created_at_asc", "id_asc"] +EXECUTION_BUDGET_STATUSES = ["active", "inactive", "superseded"] +TASK_STATUSES = ["pending_approval", "approved", "executed", "denied", "blocked"] +TASK_WORKSPACE_STATUSES = ["active"] +TASK_STEP_KINDS = ["governed_request"] +TASK_STEP_STATUSES = ["created", "approved", "executed", "blocked", "denied"] +APPROVAL_REQUEST_VERSION_V0 = "approval_request_v0" +TRACE_KIND_APPROVAL_REQUEST = "approval.request" +APPROVAL_RESOLUTION_VERSION_V0 = "approval_resolution_v0" +TRACE_KIND_APPROVAL_RESOLUTION = "approval.resolve" +TRACE_KIND_APPROVAL_RESOLVE = TRACE_KIND_APPROVAL_RESOLUTION +PROXY_EXECUTION_VERSION_V0 = "proxy_execution_v0" +TRACE_KIND_PROXY_EXECUTE = "tool.proxy.execute" +TASK_STEP_SEQUENCE_VERSION_V0 = "task_step_sequence_v0" +TRACE_KIND_TASK_STEP_SEQUENCE = "task.step.sequence" +TASK_STEP_CONTINUATION_VERSION_V0 = "task_step_continuation_v0" +TRACE_KIND_TASK_STEP_CONTINUATION = "task.step.continuation" +TASK_STEP_TRANSITION_VERSION_V0 = "task_step_transition_v0" +TRACE_KIND_TASK_STEP_TRANSITION = "task.step.transition" +EXECUTION_BUDGET_LIFECYCLE_VERSION_V0 = "execution_budget_lifecycle_v0" +TRACE_KIND_EXECUTION_BUDGET_LIFECYCLE = "execution_budget.lifecycle" + + +@dataclass(frozen=True, slots=True) +class ContextCompilerLimits: + max_sessions: int = DEFAULT_MAX_SESSIONS + max_events: int = DEFAULT_MAX_EVENTS + max_memories: int = DEFAULT_MAX_MEMORIES + max_entities: int = DEFAULT_MAX_ENTITIES + max_entity_edges: int = DEFAULT_MAX_ENTITY_EDGES + + def as_payload(self) -> JsonObject: + return { + "max_sessions": self.max_sessions, + "max_events": self.max_events, + "max_memories": self.max_memories, + "max_entities": self.max_entities, + "max_entity_edges": self.max_entity_edges, + } + + +@dataclass(frozen=True, slots=True) +class CompileContextSemanticRetrievalInput: + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + +@dataclass(frozen=True, slots=True) +class TraceCreate: + user_id: UUID + thread_id: UUID + kind: str + compiler_version: str + status: str + limits: ContextCompilerLimits + + +@dataclass(frozen=True, slots=True) +class TraceEventRecord: + kind: str + payload: JsonObject + + +@dataclass(frozen=True, slots=True) +class CompilerDecision: + kind: DecisionKind + entity_type: str + entity_id: UUID + reason: str + position: int + metadata: JsonObject | None = None + + def to_trace_event(self) -> TraceEventRecord: + payload: JsonObject = { + "entity_type": self.entity_type, + "entity_id": str(self.entity_id), + "reason": self.reason, + "position": self.position, + } + if self.metadata is not None: + payload.update(self.metadata) + return TraceEventRecord(kind=f"context.{self.kind}", payload=payload) + + +class ContextPackScope(TypedDict): + user_id: str + thread_id: str + + +class ContextPackLimits(TypedDict): + max_sessions: int + max_events: int + max_memories: int + max_entities: int + max_entity_edges: int + + +class ContextPackUser(TypedDict): + id: str + email: str + display_name: str | None + created_at: str + + +class ContextPackThread(TypedDict): + id: str + title: str + created_at: str + updated_at: str + + +class ContextPackSession(TypedDict): + id: str + status: str + started_at: str | None + ended_at: str | None + created_at: str + + +class ContextPackEvent(TypedDict): + id: str + session_id: str | None + sequence_no: int + kind: str + payload: JsonObject + created_at: str + + +class ContextPackMemory(TypedDict): + id: str + memory_key: str + value: JsonValue + status: MemoryStatus + source_event_ids: list[str] + created_at: str + updated_at: str + source_provenance: "ContextPackMemorySourceProvenance" + + +class ContextPackMemorySourceProvenance(TypedDict): + sources: list[MemorySelectionSource] + semantic_score: float | None + + +class ContextPackHybridMemorySummary(TypedDict): + requested: bool + embedding_config_id: str | None + query_vector_dimensions: int + semantic_limit: int + symbolic_selected_count: int + semantic_selected_count: int + merged_candidate_count: int + deduplicated_count: int + included_symbolic_only_count: int + included_semantic_only_count: int + included_dual_source_count: int + similarity_metric: Literal["cosine_similarity"] | None + source_precedence: list[MemorySelectionSource] + symbolic_order: list[str] + semantic_order: list[str] + + +class ContextPackMemorySummary(TypedDict): + candidate_count: int + included_count: int + excluded_deleted_count: int + excluded_limit_count: int + hybrid_retrieval: ContextPackHybridMemorySummary + + +class HybridMemoryDecisionTracePayload(TypedDict): + embedding_config_id: str | None + memory_key: str + status: MemoryStatus + source_event_ids: list[str] + selected_sources: list[MemorySelectionSource] + semantic_score: float | None + + +class ContextPackEntity(TypedDict): + id: str + entity_type: EntityType + name: str + source_memory_ids: list[str] + created_at: str + + +class ContextPackEntitySummary(TypedDict): + candidate_count: int + included_count: int + excluded_limit_count: int + + +class EntityDecisionTracePayload(TypedDict): + entity_type: str + entity_id: str + reason: str + position: int + record_entity_type: EntityType + name: str + source_memory_ids: list[str] + + +class ContextPackEntityEdge(TypedDict): + id: str + from_entity_id: str + to_entity_id: str + relationship_type: str + valid_from: str | None + valid_to: str | None + source_memory_ids: list[str] + created_at: str + + +class ContextPackEntityEdgeSummary(TypedDict): + anchor_entity_count: int + candidate_count: int + included_count: int + excluded_limit_count: int + + +class EntityEdgeDecisionTracePayload(TypedDict): + entity_type: str + entity_id: str + reason: str + position: int + from_entity_id: str + to_entity_id: str + relationship_type: str + valid_from: str | None + valid_to: str | None + source_memory_ids: list[str] + attached_included_entity_ids: list[str] + + +class CompiledContextPack(TypedDict): + compiler_version: str + scope: ContextPackScope + limits: ContextPackLimits + user: ContextPackUser + thread: ContextPackThread + sessions: list[ContextPackSession] + events: list[ContextPackEvent] + memories: list[ContextPackMemory] + memory_summary: ContextPackMemorySummary + entities: list[ContextPackEntity] + entity_summary: ContextPackEntitySummary + entity_edges: list[ContextPackEntityEdge] + entity_edge_summary: ContextPackEntityEdgeSummary + + +@dataclass(frozen=True, slots=True) +class CompilerRunResult: + context_pack: CompiledContextPack + trace_events: list[TraceEventRecord] + + +@dataclass(frozen=True, slots=True) +class PromptAssemblyInput: + context_pack: CompiledContextPack + system_instruction: str + developer_instruction: str + + +@dataclass(frozen=True, slots=True) +class PromptSection: + name: PromptSectionName + content: str + + +class PromptAssemblyTracePayload(TypedDict): + version: str + compile_trace_id: str + compiler_version: str + prompt_sha256: str + prompt_char_count: int + section_order: list[PromptSectionName] + section_characters: dict[PromptSectionName, int] + included_session_count: int + included_event_count: int + included_memory_count: int + included_entity_count: int + included_entity_edge_count: int + + +@dataclass(frozen=True, slots=True) +class PromptAssemblyResult: + sections: tuple[PromptSection, ...] + prompt_text: str + prompt_sha256: str + trace_payload: PromptAssemblyTracePayload + + +class ModelInvocationRequestPayload(TypedDict): + provider: ModelProvider + model: str + tool_choice: Literal["none"] + tools: list[JsonObject] + store: bool + sections: list[PromptSectionName] + prompt: str + + +@dataclass(frozen=True, slots=True) +class ModelInvocationRequest: + provider: ModelProvider + model: str + prompt: PromptAssemblyResult + tool_choice: Literal["none"] = "none" + store: bool = False + + def as_payload(self) -> ModelInvocationRequestPayload: + return { + "provider": self.provider, + "model": self.model, + "tool_choice": self.tool_choice, + "tools": [], + "store": self.store, + "sections": [section.name for section in self.prompt.sections], + "prompt": self.prompt.prompt_text, + } + + +class ModelUsagePayload(TypedDict): + input_tokens: int | None + output_tokens: int | None + total_tokens: int | None + + +class ModelInvocationTracePayload(TypedDict): + provider: ModelProvider + model: str + tool_choice: Literal["none"] + tools_enabled: Literal[False] + response_id: str | None + finish_reason: ModelFinishReason + output_text_char_count: int + usage: ModelUsagePayload + error_message: str | None + + +@dataclass(frozen=True, slots=True) +class ModelInvocationResponse: + provider: ModelProvider + model: str + response_id: str | None + finish_reason: ModelFinishReason + output_text: str + usage: ModelUsagePayload + + def to_trace_payload(self, *, error_message: str | None = None) -> ModelInvocationTracePayload: + return { + "provider": self.provider, + "model": self.model, + "tool_choice": "none", + "tools_enabled": False, + "response_id": self.response_id, + "finish_reason": self.finish_reason, + "output_text_char_count": len(self.output_text), + "usage": self.usage, + "error_message": error_message, + } + + +class AssistantResponseModelRecord(TypedDict): + provider: ModelProvider + model: str + response_id: str | None + finish_reason: ModelFinishReason + usage: ModelUsagePayload + + +class AssistantResponsePromptRecord(TypedDict): + assembly_version: str + prompt_sha256: str + section_order: list[PromptSectionName] + + +class AssistantResponseEventPayload(TypedDict): + text: str + model: AssistantResponseModelRecord + prompt: AssistantResponsePromptRecord + + +class GeneratedAssistantRecord(TypedDict): + event_id: str + sequence_no: int + text: str + model_provider: ModelProvider + model: str + + +class ResponseTraceSummary(TypedDict): + compile_trace_id: str + compile_trace_event_count: int + response_trace_id: str + response_trace_event_count: int + + +class GenerateResponseSuccess(TypedDict): + assistant: GeneratedAssistantRecord + trace: ResponseTraceSummary + + +@dataclass(frozen=True, slots=True) +class MemoryCandidateInput: + memory_key: str + value: JsonValue | None + source_event_ids: tuple[UUID, ...] + delete_requested: bool = False + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "memory_key": self.memory_key, + "source_event_ids": [str(source_event_id) for source_event_id in self.source_event_ids], + "delete_requested": self.delete_requested, + } + payload["value"] = self.value + return payload + + +@dataclass(frozen=True, slots=True) +class ExplicitPreferenceExtractionRequestInput: + source_event_id: UUID + + def as_payload(self) -> JsonObject: + return { + "source_event_id": str(self.source_event_id), + } + + +class ExtractedPreferenceCandidateRecord(TypedDict): + memory_key: str + value: JsonValue + source_event_ids: list[str] + delete_requested: bool + pattern: ExplicitPreferencePattern + subject_text: str + + +@dataclass(frozen=True, slots=True) +class EntityCreateInput: + entity_type: EntityType + name: str + source_memory_ids: tuple[UUID, ...] + + def as_payload(self) -> JsonObject: + return { + "entity_type": self.entity_type, + "name": self.name, + "source_memory_ids": [str(source_memory_id) for source_memory_id in self.source_memory_ids], + } + + +@dataclass(frozen=True, slots=True) +class EntityEdgeCreateInput: + from_entity_id: UUID + to_entity_id: UUID + relationship_type: str + valid_from: datetime | None + valid_to: datetime | None + source_memory_ids: tuple[UUID, ...] + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "from_entity_id": str(self.from_entity_id), + "to_entity_id": str(self.to_entity_id), + "relationship_type": self.relationship_type, + "source_memory_ids": [str(source_memory_id) for source_memory_id in self.source_memory_ids], + } + payload["valid_from"] = isoformat_or_none(self.valid_from) + payload["valid_to"] = isoformat_or_none(self.valid_to) + return payload + + +@dataclass(frozen=True, slots=True) +class EmbeddingConfigCreateInput: + provider: str + model: str + version: str + dimensions: int + status: EmbeddingConfigStatus + metadata: JsonObject + + def as_payload(self) -> JsonObject: + return { + "provider": self.provider, + "model": self.model, + "version": self.version, + "dimensions": self.dimensions, + "status": self.status, + "metadata": self.metadata, + } + + +@dataclass(frozen=True, slots=True) +class MemoryEmbeddingUpsertInput: + memory_id: UUID + embedding_config_id: UUID + vector: tuple[float, ...] + + def as_payload(self) -> JsonObject: + return { + "memory_id": str(self.memory_id), + "embedding_config_id": str(self.embedding_config_id), + "vector": [float(value) for value in self.vector], + } + + +@dataclass(frozen=True, slots=True) +class SemanticMemoryRetrievalRequestInput: + embedding_config_id: UUID + query_vector: tuple[float, ...] + limit: int = DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT + + def as_payload(self) -> JsonObject: + return { + "embedding_config_id": str(self.embedding_config_id), + "query_vector": [float(value) for value in self.query_vector], + "limit": self.limit, + } + + +@dataclass(frozen=True, slots=True) +class ConsentUpsertInput: + consent_key: str + status: ConsentStatus + metadata: JsonObject + + def as_payload(self) -> JsonObject: + return { + "consent_key": self.consent_key, + "status": self.status, + "metadata": self.metadata, + } + + +@dataclass(frozen=True, slots=True) +class PolicyCreateInput: + name: str + action: str + scope: str + effect: PolicyEffect + priority: int + active: bool + conditions: JsonObject + required_consents: tuple[str, ...] + + def as_payload(self) -> JsonObject: + return { + "name": self.name, + "action": self.action, + "scope": self.scope, + "effect": self.effect, + "priority": self.priority, + "active": self.active, + "conditions": self.conditions, + "required_consents": list(self.required_consents), + } + + +@dataclass(frozen=True, slots=True) +class PolicyEvaluationRequestInput: + thread_id: UUID + action: str + scope: str + attributes: JsonObject + + def as_payload(self) -> JsonObject: + return { + "thread_id": str(self.thread_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + + +@dataclass(frozen=True, slots=True) +class ToolCreateInput: + tool_key: str + name: str + description: str + version: str + metadata_version: ToolMetadataVersion = TOOL_METADATA_VERSION_V0 + active: bool = True + tags: tuple[str, ...] = field(default_factory=tuple) + action_hints: tuple[str, ...] = field(default_factory=tuple) + scope_hints: tuple[str, ...] = field(default_factory=tuple) + domain_hints: tuple[str, ...] = field(default_factory=tuple) + risk_hints: tuple[str, ...] = field(default_factory=tuple) + metadata: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + return { + "tool_key": self.tool_key, + "name": self.name, + "description": self.description, + "version": self.version, + "metadata_version": self.metadata_version, + "active": self.active, + "tags": list(self.tags), + "action_hints": list(self.action_hints), + "scope_hints": list(self.scope_hints), + "domain_hints": list(self.domain_hints), + "risk_hints": list(self.risk_hints), + "metadata": self.metadata, + } + + +@dataclass(frozen=True, slots=True) +class ToolAllowlistEvaluationRequestInput: + thread_id: UUID + action: str + scope: str + domain_hint: str | None = None + risk_hint: str | None = None + attributes: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "thread_id": str(self.thread_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + payload["domain_hint"] = self.domain_hint + payload["risk_hint"] = self.risk_hint + return payload + + +@dataclass(frozen=True, slots=True) +class ToolRoutingRequestInput: + thread_id: UUID + tool_id: UUID + action: str + scope: str + domain_hint: str | None = None + risk_hint: str | None = None + attributes: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "thread_id": str(self.thread_id), + "tool_id": str(self.tool_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + payload["domain_hint"] = self.domain_hint + payload["risk_hint"] = self.risk_hint + return payload + + +@dataclass(frozen=True, slots=True) +class ApprovalRequestCreateInput: + thread_id: UUID + tool_id: UUID + action: str + scope: str + domain_hint: str | None = None + risk_hint: str | None = None + attributes: JsonObject = field(default_factory=dict) + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "thread_id": str(self.thread_id), + "tool_id": str(self.tool_id), + "action": self.action, + "scope": self.scope, + "attributes": self.attributes, + } + payload["domain_hint"] = self.domain_hint + payload["risk_hint"] = self.risk_hint + return payload + + +@dataclass(frozen=True, slots=True) +class ApprovalApproveInput: + approval_id: UUID + + def as_payload(self) -> JsonObject: + return { + "approval_id": str(self.approval_id), + "requested_action": "approve", + } + + +@dataclass(frozen=True, slots=True) +class ApprovalRejectInput: + approval_id: UUID + + def as_payload(self) -> JsonObject: + return { + "approval_id": str(self.approval_id), + "requested_action": "reject", + } + + +@dataclass(frozen=True, slots=True) +class ProxyExecutionRequestInput: + approval_id: UUID + + def as_payload(self) -> JsonObject: + return { + "approval_id": str(self.approval_id), + } + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetCreateInput: + max_completed_executions: int + tool_key: str | None = None + domain_hint: str | None = None + rolling_window_seconds: int | None = None + + def as_payload(self) -> JsonObject: + payload: JsonObject = { + "max_completed_executions": self.max_completed_executions, + } + payload["tool_key"] = self.tool_key + payload["domain_hint"] = self.domain_hint + payload["rolling_window_seconds"] = self.rolling_window_seconds + return payload + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetDeactivateInput: + thread_id: UUID + execution_budget_id: UUID + + def as_payload(self) -> JsonObject: + return { + "thread_id": str(self.thread_id), + "execution_budget_id": str(self.execution_budget_id), + "requested_action": "deactivate", + } + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetSupersedeInput: + thread_id: UUID + execution_budget_id: UUID + max_completed_executions: int + + def as_payload(self) -> JsonObject: + return { + "thread_id": str(self.thread_id), + "execution_budget_id": str(self.execution_budget_id), + "requested_action": "supersede", + "max_completed_executions": self.max_completed_executions, + } + + +class PersistedMemoryRecord(TypedDict): + id: str + user_id: str + memory_key: str + value: JsonValue + status: MemoryStatus + source_event_ids: list[str] + created_at: str + updated_at: str + deleted_at: str | None + + +class PersistedMemoryRevisionRecord(TypedDict): + id: str + user_id: str + memory_id: str + sequence_no: int + action: AdmissionAction + memory_key: str + previous_value: JsonValue | None + new_value: JsonValue | None + source_event_ids: list[str] + candidate: JsonObject + created_at: str + + +@dataclass(frozen=True, slots=True) +class AdmissionDecisionOutput: + action: AdmissionAction + reason: str + memory: PersistedMemoryRecord | None + revision: PersistedMemoryRevisionRecord | None + + +class ExplicitPreferenceAdmissionRecord(TypedDict): + decision: AdmissionAction + reason: str + memory: PersistedMemoryRecord | None + revision: PersistedMemoryRevisionRecord | None + + +class ExplicitPreferenceExtractionSummary(TypedDict): + source_event_id: str + source_event_kind: str + candidate_count: int + admission_count: int + persisted_change_count: int + noop_count: int + + +class ExplicitPreferenceExtractionResponse(TypedDict): + candidates: list[ExtractedPreferenceCandidateRecord] + admissions: list[ExplicitPreferenceAdmissionRecord] + summary: ExplicitPreferenceExtractionSummary + + +class MemoryReviewRecord(TypedDict): + id: str + memory_key: str + value: JsonValue + status: MemoryStatus + source_event_ids: list[str] + created_at: str + updated_at: str + deleted_at: str | None + + +class MemoryReviewListSummary(TypedDict): + status: MemoryReviewStatusFilter + limit: int + returned_count: int + total_count: int + has_more: bool + order: list[str] + + +class MemoryReviewListResponse(TypedDict): + items: list[MemoryReviewRecord] + summary: MemoryReviewListSummary + + +class MemoryReviewDetailResponse(TypedDict): + memory: MemoryReviewRecord + + +class MemoryRevisionReviewRecord(TypedDict): + id: str + memory_id: str + sequence_no: int + action: AdmissionAction + memory_key: str + previous_value: JsonValue | None + new_value: JsonValue | None + source_event_ids: list[str] + created_at: str + + +class MemoryRevisionReviewListSummary(TypedDict): + memory_id: str + limit: int + returned_count: int + total_count: int + has_more: bool + order: list[str] + + +class MemoryRevisionReviewListResponse(TypedDict): + items: list[MemoryRevisionReviewRecord] + summary: MemoryRevisionReviewListSummary + + +class MemoryReviewLabelCounts(TypedDict): + correct: int + incorrect: int + outdated: int + insufficient_evidence: int + + +class MemoryReviewLabelRecord(TypedDict): + id: str + memory_id: str + reviewer_user_id: str + label: MemoryReviewLabelValue + note: str | None + created_at: str + + +class MemoryReviewLabelSummary(TypedDict): + memory_id: str + total_count: int + counts_by_label: MemoryReviewLabelCounts + order: list[str] + + +class MemoryReviewLabelCreateResponse(TypedDict): + label: MemoryReviewLabelRecord + summary: MemoryReviewLabelSummary + + +class MemoryReviewLabelListResponse(TypedDict): + items: list[MemoryReviewLabelRecord] + summary: MemoryReviewLabelSummary + + +class MemoryReviewQueueItem(TypedDict): + id: str + memory_key: str + value: JsonValue + status: Literal["active"] + source_event_ids: list[str] + created_at: str + updated_at: str + + +class MemoryReviewQueueSummary(TypedDict): + memory_status: Literal["active"] + review_state: Literal["unlabeled"] + limit: int + returned_count: int + total_count: int + has_more: bool + order: list[str] + + +class MemoryReviewQueueResponse(TypedDict): + items: list[MemoryReviewQueueItem] + summary: MemoryReviewQueueSummary + + +class MemoryEvaluationSummary(TypedDict): + total_memory_count: int + active_memory_count: int + deleted_memory_count: int + labeled_memory_count: int + unlabeled_memory_count: int + total_label_row_count: int + label_row_counts_by_value: MemoryReviewLabelCounts + label_value_order: list[MemoryReviewLabelValue] + + +class MemoryEvaluationSummaryResponse(TypedDict): + summary: MemoryEvaluationSummary + + +class EntityRecord(TypedDict): + id: str + entity_type: EntityType + name: str + source_memory_ids: list[str] + created_at: str + + +class EntityCreateResponse(TypedDict): + entity: EntityRecord + + +class EntityListSummary(TypedDict): + total_count: int + order: list[str] + + +class EntityListResponse(TypedDict): + items: list[EntityRecord] + summary: EntityListSummary + + +class EntityDetailResponse(TypedDict): + entity: EntityRecord + + +class EntityEdgeRecord(ContextPackEntityEdge): + pass + + +class EntityEdgeCreateResponse(TypedDict): + edge: EntityEdgeRecord + + +class EntityEdgeListSummary(TypedDict): + entity_id: str + total_count: int + order: list[str] + + +class EntityEdgeListResponse(TypedDict): + items: list[EntityEdgeRecord] + summary: EntityEdgeListSummary + + +class EmbeddingConfigRecord(TypedDict): + id: str + provider: str + model: str + version: str + dimensions: int + status: EmbeddingConfigStatus + metadata: JsonObject + created_at: str + + +class EmbeddingConfigCreateResponse(TypedDict): + embedding_config: EmbeddingConfigRecord + + +class EmbeddingConfigListSummary(TypedDict): + total_count: int + order: list[str] + + +class EmbeddingConfigListResponse(TypedDict): + items: list[EmbeddingConfigRecord] + summary: EmbeddingConfigListSummary + + +class MemoryEmbeddingRecord(TypedDict): + id: str + memory_id: str + embedding_config_id: str + dimensions: int + vector: list[float] + created_at: str + updated_at: str + + +class MemoryEmbeddingUpsertResponse(TypedDict): + embedding: MemoryEmbeddingRecord + write_mode: Literal["created", "updated"] + + +class MemoryEmbeddingDetailResponse(TypedDict): + embedding: MemoryEmbeddingRecord + + +class MemoryEmbeddingListSummary(TypedDict): + memory_id: str + total_count: int + order: list[str] + + +class MemoryEmbeddingListResponse(TypedDict): + items: list[MemoryEmbeddingRecord] + summary: MemoryEmbeddingListSummary + + +class SemanticMemoryRetrievalResultItem(TypedDict): + memory_id: str + memory_key: str + value: JsonValue + source_event_ids: list[str] + created_at: str + updated_at: str + score: float + + +class SemanticMemoryRetrievalSummary(TypedDict): + embedding_config_id: str + limit: int + returned_count: int + similarity_metric: Literal["cosine_similarity"] + order: list[str] + + +class SemanticMemoryRetrievalResponse(TypedDict): + items: list[SemanticMemoryRetrievalResultItem] + summary: SemanticMemoryRetrievalSummary + + +class ConsentRecord(TypedDict): + id: str + consent_key: str + status: ConsentStatus + metadata: JsonObject + created_at: str + updated_at: str + + +class ConsentUpsertResponse(TypedDict): + consent: ConsentRecord + write_mode: Literal["created", "updated"] + + +class ConsentListSummary(TypedDict): + total_count: int + order: list[str] + + +class ConsentListResponse(TypedDict): + items: list[ConsentRecord] + summary: ConsentListSummary + + +class PolicyRecord(TypedDict): + id: str + name: str + action: str + scope: str + effect: PolicyEffect + priority: int + active: bool + conditions: JsonObject + required_consents: list[str] + created_at: str + updated_at: str + + +class PolicyCreateResponse(TypedDict): + policy: PolicyRecord + + +class PolicyListSummary(TypedDict): + total_count: int + order: list[str] + + +class PolicyListResponse(TypedDict): + items: list[PolicyRecord] + summary: PolicyListSummary + + +class PolicyDetailResponse(TypedDict): + policy: PolicyRecord + + +class PolicyEvaluationReason(TypedDict): + code: PolicyEvaluationReasonCode + source: Literal["policy", "consent", "system"] + message: str + policy_id: str | None + consent_key: str | None + + +class PolicyEvaluationSummary(TypedDict): + action: str + scope: str + evaluated_policy_count: int + matched_policy_id: str | None + order: list[str] + + +class PolicyEvaluationTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class PolicyEvaluationResponse(TypedDict): + decision: PolicyEffect + matched_policy: PolicyRecord | None + reasons: list[PolicyEvaluationReason] + evaluation: PolicyEvaluationSummary + trace: PolicyEvaluationTraceSummary + + +class ToolRecord(TypedDict): + id: str + tool_key: str + name: str + description: str + version: str + metadata_version: ToolMetadataVersion + active: bool + tags: list[str] + action_hints: list[str] + scope_hints: list[str] + domain_hints: list[str] + risk_hints: list[str] + metadata: JsonObject + created_at: str + + +class ToolCreateResponse(TypedDict): + tool: ToolRecord + + +class ToolListSummary(TypedDict): + total_count: int + order: list[str] + + +class ToolListResponse(TypedDict): + items: list[ToolRecord] + summary: ToolListSummary + + +class ToolDetailResponse(TypedDict): + tool: ToolRecord + + +class ToolAllowlistReason(TypedDict): + code: ToolAllowlistReasonCode + source: Literal["tool", "policy", "consent", "system"] + message: str + tool_id: str | None + policy_id: str | None + consent_key: str | None + + +class ToolAllowlistDecisionRecord(TypedDict): + decision: ToolAllowlistDecision + tool: ToolRecord + reasons: list[ToolAllowlistReason] + + +class ToolAllowlistEvaluationSummary(TypedDict): + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + evaluated_tool_count: int + allowed_count: int + denied_count: int + approval_required_count: int + order: list[str] + + +class ToolAllowlistTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ToolAllowlistEvaluationResponse(TypedDict): + allowed: list[ToolAllowlistDecisionRecord] + denied: list[ToolAllowlistDecisionRecord] + approval_required: list[ToolAllowlistDecisionRecord] + summary: ToolAllowlistEvaluationSummary + trace: ToolAllowlistTraceSummary + + +class ToolRoutingRequestRecord(TypedDict): + thread_id: str + tool_id: str + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + attributes: JsonObject + + +class ToolRoutingRequestTracePayload(TypedDict): + thread_id: str + tool_id: str + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + attributes: JsonObject + + +class ToolRoutingDecisionTracePayload(TypedDict): + tool_id: str + tool_key: str + tool_version: str + allowlist_decision: ToolAllowlistDecision + routing_decision: ToolRoutingDecision + matched_policy_id: str | None + reasons: list[ToolAllowlistReason] + + +class ToolRoutingSummaryTracePayload(TypedDict): + decision: ToolRoutingDecision + evaluated_tool_count: int + active_policy_count: int + consent_count: int + + +class ToolRoutingSummary(TypedDict): + thread_id: str + tool_id: str + action: str + scope: str + domain_hint: str | None + risk_hint: str | None + decision: ToolRoutingDecision + evaluated_tool_count: int + active_policy_count: int + consent_count: int + order: list[str] + + +class ToolRoutingTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ToolRoutingResponse(TypedDict): + request: ToolRoutingRequestRecord + decision: ToolRoutingDecision + tool: ToolRecord + reasons: list[ToolAllowlistReason] + summary: ToolRoutingSummary + trace: ToolRoutingTraceSummary + + +class ApprovalRoutingRecord(TypedDict): + decision: ToolRoutingDecision + reasons: list[ToolAllowlistReason] + trace: ToolRoutingTraceSummary + + +class ApprovalResolutionRecord(TypedDict): + resolved_at: str + resolved_by_user_id: str + + +class ApprovalRecord(TypedDict): + id: str + thread_id: str + task_step_id: str | None + status: ApprovalStatus + request: ToolRoutingRequestRecord + tool: ToolRecord + routing: ApprovalRoutingRecord + created_at: str + resolution: ApprovalResolutionRecord | None + + +class ApprovalRequestTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ApprovalResolutionTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ApprovalResolutionRequestTracePayload(TypedDict): + approval_id: str + task_step_id: str | None + requested_action: ApprovalResolutionAction + + +class ApprovalResolutionStateTracePayload(TypedDict): + approval_id: str + task_step_id: str | None + requested_action: ApprovalResolutionAction + previous_status: ApprovalStatus + outcome: ApprovalResolutionOutcome + current_status: ApprovalStatus + resolved_at: str | None + resolved_by_user_id: str | None + + +class ApprovalResolutionSummaryTracePayload(TypedDict): + approval_id: str + task_step_id: str | None + requested_action: ApprovalResolutionAction + outcome: ApprovalResolutionOutcome + final_status: ApprovalStatus + + +@dataclass(frozen=True, slots=True) +class TaskCreateInput: + thread_id: UUID + tool_id: UUID + status: TaskStatus + request: ToolRoutingRequestRecord + tool: ToolRecord + latest_approval_id: UUID | None = None + latest_execution_id: UUID | None = None + + +class TaskRecord(TypedDict): + id: str + thread_id: str + tool_id: str + status: TaskStatus + request: ToolRoutingRequestRecord + tool: ToolRecord + latest_approval_id: str | None + latest_execution_id: str | None + created_at: str + updated_at: str + + +class TaskCreateResponse(TypedDict): + task: TaskRecord + + +@dataclass(frozen=True, slots=True) +class TaskStepCreateInput: + task_id: UUID + sequence_no: int + kind: TaskStepKind + status: TaskStepStatus + request: ToolRoutingRequestRecord + outcome: "TaskStepOutcomeSnapshot" + trace_id: UUID + trace_kind: str + + +@dataclass(frozen=True, slots=True) +class TaskStepNextCreateInput: + task_id: UUID + kind: TaskStepKind + status: TaskStepStatus + request: ToolRoutingRequestRecord + outcome: "TaskStepOutcomeSnapshot" + lineage: "TaskStepLineageInput" + + +@dataclass(frozen=True, slots=True) +class TaskStepTransitionInput: + task_step_id: UUID + status: TaskStepStatus + outcome: "TaskStepOutcomeSnapshot" + + +@dataclass(frozen=True, slots=True) +class TaskStepLineageInput: + parent_step_id: UUID + source_approval_id: UUID | None = None + source_execution_id: UUID | None = None + + +class TaskListSummary(TypedDict): + total_count: int + order: list[str] + + +class TaskListResponse(TypedDict): + items: list[TaskRecord] + summary: TaskListSummary + + +class TaskDetailResponse(TypedDict): + task: TaskRecord + + +@dataclass(frozen=True, slots=True) +class TaskWorkspaceCreateInput: + task_id: UUID + status: TaskWorkspaceStatus + + +class TaskWorkspaceRecord(TypedDict): + id: str + task_id: str + status: TaskWorkspaceStatus + local_path: str + created_at: str + updated_at: str + + +class TaskWorkspaceCreateResponse(TypedDict): + workspace: TaskWorkspaceRecord + + +class TaskWorkspaceListSummary(TypedDict): + total_count: int + order: list[str] + + +class TaskWorkspaceListResponse(TypedDict): + items: list[TaskWorkspaceRecord] + summary: TaskWorkspaceListSummary + + +class TaskWorkspaceDetailResponse(TypedDict): + workspace: TaskWorkspaceRecord + + +class TaskStepTraceLink(TypedDict): + trace_id: str + trace_kind: str + + +class TaskStepOutcomeSnapshot(TypedDict): + routing_decision: ToolRoutingDecision + approval_id: str | None + approval_status: ApprovalStatus | None + execution_id: str | None + execution_status: ProxyExecutionStatus | None + blocked_reason: str | None + + +class TaskStepLineageRecord(TypedDict): + parent_step_id: str | None + source_approval_id: str | None + source_execution_id: str | None + + +class TaskStepRecord(TypedDict): + id: str + task_id: str + sequence_no: int + kind: TaskStepKind + status: TaskStepStatus + request: ToolRoutingRequestRecord + outcome: TaskStepOutcomeSnapshot + lineage: TaskStepLineageRecord + trace: TaskStepTraceLink + created_at: str + updated_at: str + + +class TaskStepCreateResponse(TypedDict): + task_step: TaskStepRecord + + +class TaskStepSequencingSummary(TypedDict): + task_id: str + total_count: int + latest_sequence_no: int | None + latest_status: TaskStepStatus | None + next_sequence_no: int + append_allowed: bool + order: list[str] + + +class TaskStepListSummary(TaskStepSequencingSummary): + pass + + +class TaskStepListResponse(TypedDict): + items: list[TaskStepRecord] + summary: TaskStepListSummary + + +class TaskStepDetailResponse(TypedDict): + task_step: TaskStepRecord + + +class TaskStepMutationTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class TaskStepNextCreateResponse(TypedDict): + task: TaskRecord + task_step: TaskStepRecord + sequencing: TaskStepSequencingSummary + trace: TaskStepMutationTraceSummary + + +class TaskStepTransitionResponse(TypedDict): + task: TaskRecord + task_step: TaskStepRecord + sequencing: TaskStepSequencingSummary + trace: TaskStepMutationTraceSummary + + +class TaskLifecycleStateTracePayload(TypedDict): + task_id: str + source: TaskLifecycleSource + previous_status: TaskStatus | None + current_status: TaskStatus + latest_approval_id: str | None + latest_execution_id: str | None + + +class TaskLifecycleSummaryTracePayload(TypedDict): + task_id: str + source: TaskLifecycleSource + final_status: TaskStatus + latest_approval_id: str | None + latest_execution_id: str | None + + +class TaskStepLifecycleStateTracePayload(TypedDict): + task_id: str + task_step_id: str + source: TaskLifecycleSource + sequence_no: int + kind: TaskStepKind + previous_status: TaskStepStatus | None + current_status: TaskStepStatus + trace: TaskStepTraceLink + + +class TaskStepLifecycleSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + source: TaskLifecycleSource + sequence_no: int + kind: TaskStepKind + final_status: TaskStepStatus + trace: TaskStepTraceLink + + +class TaskStepSequenceRequestTracePayload(TypedDict): + task_id: str + previous_task_step_id: str + previous_sequence_no: int + previous_status: TaskStepStatus + requested_kind: TaskStepKind + requested_status: TaskStepStatus + + +class TaskStepSequenceStateTracePayload(TypedDict): + task_id: str + previous_task_step_id: str + previous_sequence_no: int + previous_status: TaskStepStatus + task_step_id: str + assigned_sequence_no: int + kind: TaskStepKind + current_status: TaskStepStatus + + +class TaskStepSequenceSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + latest_sequence_no: int + next_sequence_no: int + append_allowed: bool + + +class TaskStepContinuationRequestTracePayload(TypedDict): + task_id: str + parent_task_step_id: str + parent_sequence_no: int + parent_status: TaskStepStatus + requested_kind: TaskStepKind + requested_status: TaskStepStatus + requested_source_approval_id: str | None + requested_source_execution_id: str | None + + +class TaskStepContinuationLineageTracePayload(TypedDict): + task_id: str + parent_task_step_id: str + parent_sequence_no: int + parent_status: TaskStepStatus + source_approval_id: str | None + source_execution_id: str | None + + +class TaskStepContinuationSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + latest_sequence_no: int + next_sequence_no: int + append_allowed: bool + lineage: TaskStepLineageRecord + + +class TaskStepTransitionRequestTracePayload(TypedDict): + task_id: str + task_step_id: str + sequence_no: int + previous_status: TaskStepStatus + requested_status: TaskStepStatus + + +class TaskStepTransitionStateTracePayload(TypedDict): + task_id: str + task_step_id: str + sequence_no: int + previous_status: TaskStepStatus + current_status: TaskStepStatus + allowed_next_statuses: list[TaskStepStatus] + trace: TaskStepTraceLink + + +class TaskStepTransitionSummaryTracePayload(TypedDict): + task_id: str + task_step_id: str + sequence_no: int + final_status: TaskStepStatus + parent_task_status: TaskStatus + trace: TaskStepTraceLink + + +class ApprovalRequestCreateResponse(TypedDict): + request: ToolRoutingRequestRecord + decision: ToolRoutingDecision + tool: ToolRecord + reasons: list[ToolAllowlistReason] + task: TaskRecord + approval: ApprovalRecord | None + routing_trace: ToolRoutingTraceSummary + trace: ApprovalRequestTraceSummary + + +class ApprovalListSummary(TypedDict): + total_count: int + order: list[str] + + +class ApprovalListResponse(TypedDict): + items: list[ApprovalRecord] + summary: ApprovalListSummary + + +class ApprovalDetailResponse(TypedDict): + approval: ApprovalRecord + + +class ApprovalResolutionResponse(TypedDict): + approval: ApprovalRecord + trace: ApprovalResolutionTraceSummary + + +class ExecutionBudgetRecord(TypedDict): + id: str + tool_key: str | None + domain_hint: str | None + max_completed_executions: int + rolling_window_seconds: int | None + status: ExecutionBudgetStatus + deactivated_at: str | None + superseded_by_budget_id: str | None + supersedes_budget_id: str | None + created_at: str + + +class ExecutionBudgetCreateResponse(TypedDict): + execution_budget: ExecutionBudgetRecord + + +class ExecutionBudgetListSummary(TypedDict): + total_count: int + order: list[str] + + +class ExecutionBudgetListResponse(TypedDict): + items: list[ExecutionBudgetRecord] + summary: ExecutionBudgetListSummary + + +class ExecutionBudgetDetailResponse(TypedDict): + execution_budget: ExecutionBudgetRecord + + +class ExecutionBudgetLifecycleTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ExecutionBudgetDeactivateResponse(TypedDict): + execution_budget: ExecutionBudgetRecord + trace: ExecutionBudgetLifecycleTraceSummary + + +class ExecutionBudgetSupersedeResponse(TypedDict): + superseded_budget: ExecutionBudgetRecord + replacement_budget: ExecutionBudgetRecord + trace: ExecutionBudgetLifecycleTraceSummary + + +class ExecutionBudgetDecisionRecord(TypedDict): + matched_budget_id: str | None + tool_key: str + domain_hint: str | None + budget_tool_key: str | None + budget_domain_hint: str | None + max_completed_executions: int | None + rolling_window_seconds: int | None + count_scope: ExecutionBudgetCountScope + window_started_at: str | None + completed_execution_count: int + projected_completed_execution_count: int + decision: ExecutionBudgetDecision + reason: ExecutionBudgetDecisionReason + order: list[str] + history_order: list[str] + + +class ExecutionBudgetLifecycleRequestTracePayload(TypedDict): + thread_id: str + execution_budget_id: str + requested_action: ExecutionBudgetLifecycleAction + replacement_max_completed_executions: int | None + + +class ExecutionBudgetLifecycleStateTracePayload(TypedDict): + execution_budget_id: str + requested_action: ExecutionBudgetLifecycleAction + previous_status: ExecutionBudgetStatus + current_status: ExecutionBudgetStatus + tool_key: str | None + domain_hint: str | None + max_completed_executions: int + rolling_window_seconds: int | None + deactivated_at: str | None + superseded_by_budget_id: str | None + supersedes_budget_id: str | None + replacement_budget_id: str | None + replacement_status: ExecutionBudgetStatus | None + replacement_max_completed_executions: int | None + replacement_rolling_window_seconds: int | None + rejection_reason: str | None + + +class ExecutionBudgetLifecycleSummaryTracePayload(TypedDict): + execution_budget_id: str + requested_action: ExecutionBudgetLifecycleAction + outcome: ExecutionBudgetLifecycleOutcome + replacement_budget_id: str | None + active_budget_id: str | None + + +@dataclass(frozen=True, slots=True) +class ToolExecutionCreateInput: + approval_id: UUID + task_step_id: UUID + thread_id: UUID + tool_id: UUID + trace_id: UUID + request_event_id: UUID | None + result_event_id: UUID | None + status: ProxyExecutionStatus + handler_key: str | None + request: ToolRoutingRequestRecord + tool: ToolRecord + result: "ToolExecutionResultRecord" + + +class ToolExecutionRecord(TypedDict): + id: str + approval_id: str + task_step_id: str + thread_id: str + tool_id: str + trace_id: str + request_event_id: str | None + result_event_id: str | None + status: ProxyExecutionStatus + handler_key: str | None + request: ToolRoutingRequestRecord + tool: ToolRecord + result: "ToolExecutionResultRecord" + executed_at: str + + +class ToolExecutionListSummary(TypedDict): + total_count: int + order: list[str] + + +class ToolExecutionListResponse(TypedDict): + items: list[ToolExecutionRecord] + summary: ToolExecutionListSummary + + +class ToolExecutionDetailResponse(TypedDict): + execution: ToolExecutionRecord + + +class ProxyExecutionRequestRecord(TypedDict): + approval_id: str + task_step_id: str + + +class ProxyExecutionRequestEventPayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + request: ToolRoutingRequestRecord + + +class ProxyExecutionResultRecord(TypedDict): + handler_key: str + status: Literal["completed"] + output: JsonObject + + +class ProxyExecutionResultEventPayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + handler_key: str + status: Literal["completed"] + output: JsonObject + + +class ToolExecutionResultRecord(TypedDict): + handler_key: str | None + status: ProxyExecutionStatus + output: JsonObject | None + reason: str | None + budget_decision: NotRequired[ExecutionBudgetDecisionRecord] + + +class ProxyExecutionEventSummary(TypedDict): + request_event_id: str + request_sequence_no: int + result_event_id: str + result_sequence_no: int + + +class ProxyExecutionTraceSummary(TypedDict): + trace_id: str + trace_event_count: int + + +class ProxyExecutionBudgetPrecheckTracePayload(ExecutionBudgetDecisionRecord): + pass + + +class ProxyExecutionApprovalTracePayload(TypedDict): + approval_id: str + task_step_id: str + approval_status: ApprovalStatus + eligible_for_execution: bool + + +class ProxyExecutionDispatchTracePayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + handler_key: str | None + dispatch_status: Literal["executed", "blocked"] + reason: str | None + result_status: ProxyExecutionStatus | None + output: JsonObject | None + + +class ProxyExecutionSummaryTracePayload(TypedDict): + approval_id: str + task_step_id: str + tool_id: str + tool_key: str + approval_status: ApprovalStatus + execution_status: Literal["completed", "blocked"] + handler_key: str | None + request_event_id: str | None + result_event_id: str | None + + +class ProxyExecutionResponse(TypedDict): + request: ProxyExecutionRequestRecord + approval: ApprovalRecord + tool: ToolRecord + result: ProxyExecutionResultRecord | ToolExecutionResultRecord + events: ProxyExecutionEventSummary | None + trace: ProxyExecutionTraceSummary + + +class ProxyExecutionBudgetBlockedResponse(TypedDict): + request: ProxyExecutionRequestRecord + approval: ApprovalRecord + tool: ToolRecord + result: ToolExecutionResultRecord + events: None + trace: ProxyExecutionTraceSummary + + +def isoformat_or_none(value: datetime | None) -> str | None: + if value is None: + return None + return value.isoformat() diff --git a/apps/api/src/alicebot_api/db.py b/apps/api/src/alicebot_api/db.py new file mode 100644 index 0000000..cc6e87b --- /dev/null +++ b/apps/api/src/alicebot_api/db.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from uuid import UUID + +import psycopg +from psycopg.rows import dict_row + +PING_DATABASE_SQL = "SELECT 1" +SET_CURRENT_USER_SQL = "SELECT set_config('app.current_user_id', %s, true)" +ConnectionRow = dict[str, object] +UserConnection = psycopg.Connection[ConnectionRow] + + +def ping_database(database_url: str, timeout_seconds: int) -> bool: + try: + with psycopg.connect(database_url, connect_timeout=timeout_seconds) as conn: + with conn.cursor() as cur: + cur.execute(PING_DATABASE_SQL) + cur.fetchone() + return True + except psycopg.Error: + return False + + +def set_current_user(conn: psycopg.Connection, user_id: UUID) -> None: + with conn.cursor() as cur: + cur.execute(SET_CURRENT_USER_SQL, (str(user_id),)) + + +@contextmanager +def user_connection(database_url: str, user_id: UUID) -> Iterator[UserConnection]: + with psycopg.connect(database_url, row_factory=dict_row) as conn: + with conn.transaction(): + set_current_user(conn, user_id) + yield conn diff --git a/apps/api/src/alicebot_api/embedding.py b/apps/api/src/alicebot_api/embedding.py new file mode 100644 index 0000000..5248197 --- /dev/null +++ b/apps/api/src/alicebot_api/embedding.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import math +from uuid import UUID + +import psycopg + +from alicebot_api.contracts import ( + EMBEDDING_CONFIG_LIST_ORDER, + MEMORY_EMBEDDING_LIST_ORDER, + EmbeddingConfigCreateInput, + EmbeddingConfigCreateResponse, + EmbeddingConfigListResponse, + EmbeddingConfigListSummary, + EmbeddingConfigRecord, + MemoryEmbeddingDetailResponse, + MemoryEmbeddingListResponse, + MemoryEmbeddingListSummary, + MemoryEmbeddingRecord, + MemoryEmbeddingUpsertInput, + MemoryEmbeddingUpsertResponse, +) +from alicebot_api.store import ContinuityStore, EmbeddingConfigRow, MemoryEmbeddingRow + + +class EmbeddingConfigValidationError(ValueError): + """Raised when an embedding-config request fails explicit validation.""" + + +class MemoryEmbeddingValidationError(ValueError): + """Raised when a memory-embedding request fails explicit validation.""" + + +class MemoryEmbeddingNotFoundError(LookupError): + """Raised when a requested memory embedding is not visible inside the current user scope.""" + + +def _duplicate_embedding_config_message( + *, + provider: str, + model: str, + version: str, +) -> str: + return ( + "embedding config already exists for provider/model/version under the user scope: " + f"{provider}/{model}/{version}" + ) + + +def _serialize_embedding_config(config: EmbeddingConfigRow) -> EmbeddingConfigRecord: + return { + "id": str(config["id"]), + "provider": config["provider"], + "model": config["model"], + "version": config["version"], + "dimensions": config["dimensions"], + "status": config["status"], + "metadata": config["metadata"], + "created_at": config["created_at"].isoformat(), + } + + +def _serialize_memory_embedding(embedding: MemoryEmbeddingRow) -> MemoryEmbeddingRecord: + return { + "id": str(embedding["id"]), + "memory_id": str(embedding["memory_id"]), + "embedding_config_id": str(embedding["embedding_config_id"]), + "dimensions": embedding["dimensions"], + "vector": [float(value) for value in embedding["vector"]], + "created_at": embedding["created_at"].isoformat(), + "updated_at": embedding["updated_at"].isoformat(), + } + + +def _validate_vector(vector: tuple[float, ...]) -> list[float]: + if not vector: + raise MemoryEmbeddingValidationError("vector must include at least one numeric value") + + normalized: list[float] = [] + for value in vector: + normalized_value = float(value) + if not math.isfinite(normalized_value): + raise MemoryEmbeddingValidationError("vector must contain only finite numeric values") + normalized.append(normalized_value) + + return normalized + + +def create_embedding_config_record( + store: ContinuityStore, + *, + user_id: UUID, + config: EmbeddingConfigCreateInput, +) -> EmbeddingConfigCreateResponse: + del user_id + + existing = store.get_embedding_config_by_identity_optional( + provider=config.provider, + model=config.model, + version=config.version, + ) + if existing is not None: + raise EmbeddingConfigValidationError( + _duplicate_embedding_config_message( + provider=config.provider, + model=config.model, + version=config.version, + ) + ) + + try: + created = store.create_embedding_config( + provider=config.provider, + model=config.model, + version=config.version, + dimensions=config.dimensions, + status=config.status, + metadata=config.metadata, + ) + except psycopg.errors.UniqueViolation as exc: + raise EmbeddingConfigValidationError( + _duplicate_embedding_config_message( + provider=config.provider, + model=config.model, + version=config.version, + ) + ) from exc + return {"embedding_config": _serialize_embedding_config(created)} + + +def list_embedding_config_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> EmbeddingConfigListResponse: + del user_id + + configs = store.list_embedding_configs() + items = [_serialize_embedding_config(config) for config in configs] + summary: EmbeddingConfigListSummary = { + "total_count": len(items), + "order": list(EMBEDDING_CONFIG_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def upsert_memory_embedding_record( + store: ContinuityStore, + *, + user_id: UUID, + request: MemoryEmbeddingUpsertInput, +) -> MemoryEmbeddingUpsertResponse: + del user_id + + memory = store.get_memory_optional(request.memory_id) + if memory is None: + raise MemoryEmbeddingValidationError( + f"memory_id must reference an existing memory owned by the user: {request.memory_id}" + ) + + config = store.get_embedding_config_optional(request.embedding_config_id) + if config is None: + raise MemoryEmbeddingValidationError( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{request.embedding_config_id}" + ) + + vector = _validate_vector(request.vector) + if len(vector) != config["dimensions"]: + raise MemoryEmbeddingValidationError( + "vector length must match embedding config dimensions " + f"({config['dimensions']}): {len(vector)}" + ) + + existing = store.get_memory_embedding_by_memory_and_config_optional( + memory_id=request.memory_id, + embedding_config_id=request.embedding_config_id, + ) + if existing is None: + created = store.create_memory_embedding( + memory_id=request.memory_id, + embedding_config_id=request.embedding_config_id, + dimensions=config["dimensions"], + vector=vector, + ) + return { + "embedding": _serialize_memory_embedding(created), + "write_mode": "created", + } + + updated = store.update_memory_embedding( + memory_embedding_id=existing["id"], + dimensions=config["dimensions"], + vector=vector, + ) + return { + "embedding": _serialize_memory_embedding(updated), + "write_mode": "updated", + } + + +def get_memory_embedding_record( + store: ContinuityStore, + *, + user_id: UUID, + memory_embedding_id: UUID, +) -> MemoryEmbeddingDetailResponse: + del user_id + + embedding = store.get_memory_embedding_optional(memory_embedding_id) + if embedding is None: + raise MemoryEmbeddingNotFoundError(f"memory embedding {memory_embedding_id} was not found") + + return {"embedding": _serialize_memory_embedding(embedding)} + + +def list_memory_embedding_records( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, +) -> MemoryEmbeddingListResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryEmbeddingNotFoundError(f"memory {memory_id} was not found") + + embeddings = store.list_memory_embeddings_for_memory(memory_id) + items = [_serialize_memory_embedding(embedding) for embedding in embeddings] + summary: MemoryEmbeddingListSummary = { + "memory_id": str(memory_id), + "total_count": len(items), + "order": list(MEMORY_EMBEDDING_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } diff --git a/apps/api/src/alicebot_api/entity.py b/apps/api/src/alicebot_api/entity.py new file mode 100644 index 0000000..8e811eb --- /dev/null +++ b/apps/api/src/alicebot_api/entity.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from uuid import UUID + +from alicebot_api.contracts import ( + ENTITY_LIST_ORDER, + EntityCreateInput, + EntityCreateResponse, + EntityDetailResponse, + EntityListResponse, + EntityListSummary, + EntityRecord, +) +from alicebot_api.store import ContinuityStore, EntityRow + + +class EntityValidationError(ValueError): + """Raised when an entity create request fails explicit validation.""" + + +class EntityNotFoundError(LookupError): + """Raised when a requested entity is not visible inside the current user scope.""" + + +def _serialize_entity(entity: EntityRow) -> EntityRecord: + return { + "id": str(entity["id"]), + "entity_type": entity["entity_type"], + "name": entity["name"], + "source_memory_ids": entity["source_memory_ids"], + "created_at": entity["created_at"].isoformat(), + } + + +def _dedupe_source_memory_ids(source_memory_ids: tuple[UUID, ...]) -> tuple[UUID, ...]: + deduped: list[UUID] = [] + seen: set[UUID] = set() + for source_memory_id in source_memory_ids: + if source_memory_id in seen: + continue + seen.add(source_memory_id) + deduped.append(source_memory_id) + return tuple(deduped) + + +def _validate_source_memories(store: ContinuityStore, source_memory_ids: tuple[UUID, ...]) -> list[str]: + normalized_memory_ids = _dedupe_source_memory_ids(source_memory_ids) + if not normalized_memory_ids: + raise EntityValidationError( + "source_memory_ids must include at least one existing memory owned by the user" + ) + + source_memories = store.list_memories_by_ids(list(normalized_memory_ids)) + found_memory_ids = {memory["id"] for memory in source_memories} + missing_memory_ids = [ + str(source_memory_id) + for source_memory_id in normalized_memory_ids + if source_memory_id not in found_memory_ids + ] + if missing_memory_ids: + raise EntityValidationError( + "source_memory_ids must all reference existing memories owned by the user: " + + ", ".join(missing_memory_ids) + ) + + return [str(source_memory_id) for source_memory_id in normalized_memory_ids] + + +def create_entity_record( + store: ContinuityStore, + *, + user_id: UUID, + entity: EntityCreateInput, +) -> EntityCreateResponse: + del user_id + + source_memory_ids = _validate_source_memories(store, entity.source_memory_ids) + created = store.create_entity( + entity_type=entity.entity_type, + name=entity.name, + source_memory_ids=source_memory_ids, + ) + return {"entity": _serialize_entity(created)} + + +def list_entity_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> EntityListResponse: + del user_id + + entities = store.list_entities() + items = [_serialize_entity(entity) for entity in entities] + summary: EntityListSummary = { + "total_count": len(items), + "order": list(ENTITY_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_entity_record( + store: ContinuityStore, + *, + user_id: UUID, + entity_id: UUID, +) -> EntityDetailResponse: + del user_id + + entity = store.get_entity_optional(entity_id) + if entity is None: + raise EntityNotFoundError(f"entity {entity_id} was not found") + + return {"entity": _serialize_entity(entity)} diff --git a/apps/api/src/alicebot_api/entity_edge.py b/apps/api/src/alicebot_api/entity_edge.py new file mode 100644 index 0000000..84731a2 --- /dev/null +++ b/apps/api/src/alicebot_api/entity_edge.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from datetime import datetime +from uuid import UUID + +from alicebot_api.contracts import ( + ENTITY_EDGE_LIST_ORDER, + EntityEdgeCreateInput, + EntityEdgeCreateResponse, + EntityEdgeListResponse, + EntityEdgeListSummary, + EntityEdgeRecord, + isoformat_or_none, +) +from alicebot_api.entity import EntityNotFoundError +from alicebot_api.store import ContinuityStore, EntityEdgeRow + + +class EntityEdgeValidationError(ValueError): + """Raised when an entity-edge request fails explicit validation.""" + + +def _serialize_entity_edge(edge: EntityEdgeRow) -> EntityEdgeRecord: + return { + "id": str(edge["id"]), + "from_entity_id": str(edge["from_entity_id"]), + "to_entity_id": str(edge["to_entity_id"]), + "relationship_type": edge["relationship_type"], + "valid_from": isoformat_or_none(edge["valid_from"]), + "valid_to": isoformat_or_none(edge["valid_to"]), + "source_memory_ids": edge["source_memory_ids"], + "created_at": edge["created_at"].isoformat(), + } + + +def _dedupe_source_memory_ids(source_memory_ids: tuple[UUID, ...]) -> tuple[UUID, ...]: + deduped: list[UUID] = [] + seen: set[UUID] = set() + for source_memory_id in source_memory_ids: + if source_memory_id in seen: + continue + seen.add(source_memory_id) + deduped.append(source_memory_id) + return tuple(deduped) + + +def _validate_source_memories(store: ContinuityStore, source_memory_ids: tuple[UUID, ...]) -> list[str]: + normalized_memory_ids = _dedupe_source_memory_ids(source_memory_ids) + if not normalized_memory_ids: + raise EntityEdgeValidationError( + "source_memory_ids must include at least one existing memory owned by the user" + ) + + source_memories = store.list_memories_by_ids(list(normalized_memory_ids)) + found_memory_ids = {memory["id"] for memory in source_memories} + missing_memory_ids = [ + str(source_memory_id) + for source_memory_id in normalized_memory_ids + if source_memory_id not in found_memory_ids + ] + if missing_memory_ids: + raise EntityEdgeValidationError( + "source_memory_ids must all reference existing memories owned by the user: " + + ", ".join(missing_memory_ids) + ) + + return [str(source_memory_id) for source_memory_id in normalized_memory_ids] + + +def _validate_entity_exists( + store: ContinuityStore, + *, + field_name: str, + entity_id: UUID, +) -> None: + entity = store.get_entity_optional(entity_id) + if entity is None: + raise EntityEdgeValidationError( + f"{field_name} must reference an existing entity owned by the user: {entity_id}" + ) + + +def _validate_temporal_range(valid_from: datetime | None, valid_to: datetime | None) -> None: + if valid_from is not None and valid_to is not None and valid_to < valid_from: + raise EntityEdgeValidationError("valid_to must be greater than or equal to valid_from") + + +def create_entity_edge_record( + store: ContinuityStore, + *, + user_id: UUID, + edge: EntityEdgeCreateInput, +) -> EntityEdgeCreateResponse: + del user_id + + _validate_entity_exists(store, field_name="from_entity_id", entity_id=edge.from_entity_id) + _validate_entity_exists(store, field_name="to_entity_id", entity_id=edge.to_entity_id) + _validate_temporal_range(edge.valid_from, edge.valid_to) + source_memory_ids = _validate_source_memories(store, edge.source_memory_ids) + + created = store.create_entity_edge( + from_entity_id=edge.from_entity_id, + to_entity_id=edge.to_entity_id, + relationship_type=edge.relationship_type, + valid_from=edge.valid_from, + valid_to=edge.valid_to, + source_memory_ids=source_memory_ids, + ) + return {"edge": _serialize_entity_edge(created)} + + +def list_entity_edge_records( + store: ContinuityStore, + *, + user_id: UUID, + entity_id: UUID, +) -> EntityEdgeListResponse: + del user_id + + entity = store.get_entity_optional(entity_id) + if entity is None: + raise EntityNotFoundError(f"entity {entity_id} was not found") + + edges = store.list_entity_edges_for_entity(entity_id) + items = [_serialize_entity_edge(edge) for edge in edges] + summary: EntityEdgeListSummary = { + "entity_id": str(entity["id"]), + "total_count": len(items), + "order": list(ENTITY_EDGE_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } diff --git a/apps/api/src/alicebot_api/execution_budgets.py b/apps/api/src/alicebot_api/execution_budgets.py new file mode 100644 index 0000000..870bd13 --- /dev/null +++ b/apps/api/src/alicebot_api/execution_budgets.py @@ -0,0 +1,818 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import cast +from uuid import UUID, uuid4 + +import psycopg + +from alicebot_api.contracts import ( + EXECUTION_BUDGET_LIFECYCLE_VERSION_V0, + EXECUTION_BUDGET_LIST_ORDER, + EXECUTION_BUDGET_MATCH_ORDER, + EXECUTION_BUDGET_STATUSES, + TOOL_EXECUTION_LIST_ORDER, + TRACE_KIND_EXECUTION_BUDGET_LIFECYCLE, + ExecutionBudgetCreateInput, + ExecutionBudgetCreateResponse, + ExecutionBudgetDeactivateInput, + ExecutionBudgetDeactivateResponse, + ExecutionBudgetDecisionRecord, + ExecutionBudgetDetailResponse, + ExecutionBudgetLifecycleAction, + ExecutionBudgetLifecycleOutcome, + ExecutionBudgetLifecycleRequestTracePayload, + ExecutionBudgetLifecycleStateTracePayload, + ExecutionBudgetLifecycleSummaryTracePayload, + ExecutionBudgetListResponse, + ExecutionBudgetListSummary, + ExecutionBudgetRecord, + ExecutionBudgetSupersedeInput, + ExecutionBudgetSupersedeResponse, + ToolExecutionResultRecord, + ToolRecord, + ToolRoutingRequestRecord, +) +from alicebot_api.store import ContinuityStore, ExecutionBudgetRow, ToolExecutionRow + + +class ExecutionBudgetValidationError(ValueError): + """Raised when an execution-budget request fails explicit validation.""" + + +class ExecutionBudgetNotFoundError(LookupError): + """Raised when an execution budget is not visible inside the current user scope.""" + + +class ExecutionBudgetLifecycleError(RuntimeError): + """Raised when an execution budget lifecycle transition is invalid.""" + + +@dataclass(frozen=True, slots=True) +class ExecutionBudgetDecision: + record: ExecutionBudgetDecisionRecord + blocked_result: ToolExecutionResultRecord | None + + +def serialize_execution_budget_row(row: ExecutionBudgetRow) -> ExecutionBudgetRecord: + return { + "id": str(row["id"]), + "tool_key": row["tool_key"], + "domain_hint": row["domain_hint"], + "max_completed_executions": row["max_completed_executions"], + "rolling_window_seconds": row["rolling_window_seconds"], + "status": cast(str, row["status"]), + "deactivated_at": None if row["deactivated_at"] is None else row["deactivated_at"].isoformat(), + "superseded_by_budget_id": ( + None if row["superseded_by_budget_id"] is None else str(row["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if row["supersedes_budget_id"] is None else str(row["supersedes_budget_id"]) + ), + "created_at": row["created_at"].isoformat(), + } + + +def _validate_budget_scope(*, tool_key: str | None, domain_hint: str | None) -> None: + if tool_key is None and domain_hint is None: + raise ExecutionBudgetValidationError( + "execution budget requires at least one selector: tool_key or domain_hint" + ) + + +def _validate_rolling_window_seconds(rolling_window_seconds: int | None) -> None: + if rolling_window_seconds is not None and rolling_window_seconds <= 0: + raise ExecutionBudgetValidationError( + "rolling_window_seconds must be greater than 0 when provided" + ) + + +def _validate_lifecycle_thread(store: ContinuityStore, *, thread_id: UUID) -> dict[str, object]: + thread = store.get_thread_optional(thread_id) + if thread is None: + raise ExecutionBudgetValidationError( + "thread_id must reference an existing thread owned by the user" + ) + return cast(dict[str, object], thread) + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _trace_summary(trace_id: UUID, trace_events: list[tuple[str, dict[str, object]]]) -> dict[str, object]: + return { + "trace_id": str(trace_id), + "trace_event_count": len(trace_events), + } + + +def _active_budget_rows_for_scope( + store: ContinuityStore, + *, + tool_key: str | None, + domain_hint: str | None, +) -> list[ExecutionBudgetRow]: + rows = [ + row + for row in store.list_execution_budgets() + if row["tool_key"] == tool_key + and row["domain_hint"] == domain_hint + and cast(str, row["status"]) == "active" + ] + return sorted(rows, key=lambda row: (row["created_at"], row["id"])) + + +def _scope_label(*, tool_key: str | None, domain_hint: str | None) -> str: + return f"tool_key={tool_key!r}, domain_hint={domain_hint!r}" + + +def _duplicate_active_scope_message(*, tool_key: str | None, domain_hint: str | None) -> str: + return ( + "active execution budget already exists for selector scope " + f"{_scope_label(tool_key=tool_key, domain_hint=domain_hint)}" + ) + + +def _is_active_scope_uniqueness_error(exc: psycopg.Error) -> bool: + diag = getattr(exc, "diag", None) + return getattr(diag, "constraint_name", None) == "execution_budgets_one_active_scope_idx" + + +def _invalid_transition_error( + *, + row: ExecutionBudgetRow, + requested_action: ExecutionBudgetLifecycleAction, +) -> ExecutionBudgetLifecycleError: + return ExecutionBudgetLifecycleError( + f"execution budget {row['id']} is {row['status']} and cannot be {requested_action}d" + ) + + +def _record_lifecycle_trace( + store: ContinuityStore, + *, + thread: dict[str, object], + request_payload: ExecutionBudgetLifecycleRequestTracePayload, + state_payload: ExecutionBudgetLifecycleStateTracePayload, + summary_payload: ExecutionBudgetLifecycleSummaryTracePayload, + requested_action: ExecutionBudgetLifecycleAction, + outcome: ExecutionBudgetLifecycleOutcome, +) -> dict[str, object]: + trace = store.create_trace( + user_id=cast(UUID, thread["user_id"]), + thread_id=cast(UUID, thread["id"]), + kind=TRACE_KIND_EXECUTION_BUDGET_LIFECYCLE, + compiler_version=EXECUTION_BUDGET_LIFECYCLE_VERSION_V0, + status="completed", + limits={ + "order": list(EXECUTION_BUDGET_LIST_ORDER), + "match_order": list(EXECUTION_BUDGET_MATCH_ORDER), + "statuses": list(EXECUTION_BUDGET_STATUSES), + "requested_action": requested_action, + "outcome": outcome, + }, + ) + trace_events: list[tuple[str, dict[str, object]]] = [ + ("execution_budget.lifecycle.request", cast(dict[str, object], request_payload)), + ("execution_budget.lifecycle.state", cast(dict[str, object], state_payload)), + ("execution_budget.lifecycle.summary", cast(dict[str, object], summary_payload)), + ] + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + return _trace_summary(trace["id"], trace_events) + + +def create_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ExecutionBudgetCreateInput, +) -> ExecutionBudgetCreateResponse: + del user_id + + _validate_budget_scope(tool_key=request.tool_key, domain_hint=request.domain_hint) + _validate_rolling_window_seconds(request.rolling_window_seconds) + if _active_budget_rows_for_scope( + store, + tool_key=request.tool_key, + domain_hint=request.domain_hint, + ): + raise ExecutionBudgetValidationError( + _duplicate_active_scope_message( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + ) + ) + try: + row = store.create_execution_budget( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + max_completed_executions=request.max_completed_executions, + rolling_window_seconds=request.rolling_window_seconds, + ) + except psycopg.IntegrityError as exc: + if _is_active_scope_uniqueness_error(exc): + raise ExecutionBudgetValidationError( + _duplicate_active_scope_message( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + ) + ) from exc + raise + return {"execution_budget": serialize_execution_budget_row(row)} + + +def list_execution_budget_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ExecutionBudgetListResponse: + del user_id + + items = [serialize_execution_budget_row(row) for row in store.list_execution_budgets()] + summary: ExecutionBudgetListSummary = { + "total_count": len(items), + "order": list(EXECUTION_BUDGET_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + execution_budget_id: UUID, +) -> ExecutionBudgetDetailResponse: + del user_id + + row = store.get_execution_budget_optional(execution_budget_id) + if row is None: + raise ExecutionBudgetNotFoundError(f"execution budget {execution_budget_id} was not found") + return {"execution_budget": serialize_execution_budget_row(row)} + + +def deactivate_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ExecutionBudgetDeactivateInput, +) -> ExecutionBudgetDeactivateResponse: + del user_id + + thread = _validate_lifecycle_thread(store, thread_id=request.thread_id) + row = store.get_execution_budget_optional(request.execution_budget_id) + if row is None: + raise ExecutionBudgetNotFoundError( + f"execution budget {request.execution_budget_id} was not found" + ) + + request_payload: ExecutionBudgetLifecycleRequestTracePayload = { + "thread_id": str(request.thread_id), + "execution_budget_id": str(request.execution_budget_id), + "requested_action": "deactivate", + "replacement_max_completed_executions": None, + } + + if cast(str, row["status"]) != "active": + error = _invalid_transition_error(row=row, requested_action="deactivate") + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(row["id"]), + "requested_action": "deactivate", + "previous_status": cast(str, row["status"]), + "current_status": cast(str, row["status"]), + "tool_key": row["tool_key"], + "domain_hint": row["domain_hint"], + "max_completed_executions": row["max_completed_executions"], + "rolling_window_seconds": row["rolling_window_seconds"], + "deactivated_at": ( + None if row["deactivated_at"] is None else row["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if row["superseded_by_budget_id"] is None else str(row["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if row["supersedes_budget_id"] is None else str(row["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": None, + "replacement_rolling_window_seconds": None, + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(row["id"]), + "requested_action": "deactivate", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": None, + }, + requested_action="deactivate", + outcome="rejected", + ) + del trace + raise error + + updated = store.deactivate_execution_budget_optional(request.execution_budget_id) + if updated is None: + raise ExecutionBudgetLifecycleError( + f"execution budget {request.execution_budget_id} could not be deactivated" + ) + + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(updated["id"]), + "requested_action": "deactivate", + "previous_status": "active", + "current_status": cast(str, updated["status"]), + "tool_key": updated["tool_key"], + "domain_hint": updated["domain_hint"], + "max_completed_executions": updated["max_completed_executions"], + "rolling_window_seconds": updated["rolling_window_seconds"], + "deactivated_at": ( + None if updated["deactivated_at"] is None else updated["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if updated["superseded_by_budget_id"] is None else str(updated["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if updated["supersedes_budget_id"] is None else str(updated["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": None, + "replacement_rolling_window_seconds": None, + "rejection_reason": None, + }, + summary_payload={ + "execution_budget_id": str(updated["id"]), + "requested_action": "deactivate", + "outcome": "deactivated", + "replacement_budget_id": None, + "active_budget_id": None, + }, + requested_action="deactivate", + outcome="deactivated", + ) + return { + "execution_budget": serialize_execution_budget_row(updated), + "trace": cast(dict[str, object], trace), + } + + +def supersede_execution_budget_record( + store: ContinuityStore, + *, + user_id: UUID, + request: ExecutionBudgetSupersedeInput, +) -> ExecutionBudgetSupersedeResponse: + del user_id + + thread = _validate_lifecycle_thread(store, thread_id=request.thread_id) + current = store.get_execution_budget_optional(request.execution_budget_id) + if current is None: + raise ExecutionBudgetNotFoundError( + f"execution budget {request.execution_budget_id} was not found" + ) + + request_payload: ExecutionBudgetLifecycleRequestTracePayload = { + "thread_id": str(request.thread_id), + "execution_budget_id": str(request.execution_budget_id), + "requested_action": "supersede", + "replacement_max_completed_executions": request.max_completed_executions, + } + + if cast(str, current["status"]) != "active": + error = _invalid_transition_error(row=current, requested_action="supersede") + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "previous_status": cast(str, current["status"]), + "current_status": cast(str, current["status"]), + "tool_key": current["tool_key"], + "domain_hint": current["domain_hint"], + "max_completed_executions": current["max_completed_executions"], + "rolling_window_seconds": current["rolling_window_seconds"], + "deactivated_at": ( + None if current["deactivated_at"] is None else current["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if current["superseded_by_budget_id"] is None else str(current["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if current["supersedes_budget_id"] is None else str(current["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": request.max_completed_executions, + "replacement_rolling_window_seconds": current["rolling_window_seconds"], + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": str(current["id"]) if cast(str, current["status"]) == "active" else None, + }, + requested_action="supersede", + outcome="rejected", + ) + del trace + raise error + + active_scope_rows = _active_budget_rows_for_scope( + store, + tool_key=current["tool_key"], + domain_hint=current["domain_hint"], + ) + if [row["id"] for row in active_scope_rows] != [current["id"]]: + error = ExecutionBudgetLifecycleError( + "execution budget selector scope must have exactly one active budget to supersede: " + f"{_scope_label(tool_key=current['tool_key'], domain_hint=current['domain_hint'])}" + ) + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "previous_status": "active", + "current_status": "active", + "tool_key": current["tool_key"], + "domain_hint": current["domain_hint"], + "max_completed_executions": current["max_completed_executions"], + "rolling_window_seconds": current["rolling_window_seconds"], + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": ( + None if current["supersedes_budget_id"] is None else str(current["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": request.max_completed_executions, + "replacement_rolling_window_seconds": current["rolling_window_seconds"], + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(current["id"]), + "requested_action": "supersede", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": str(current["id"]), + }, + requested_action="supersede", + outcome="rejected", + ) + del trace + raise error + + replacement_budget_id = uuid4() + try: + with store.conn.transaction(): + superseded = store.supersede_execution_budget_optional( + execution_budget_id=request.execution_budget_id, + superseded_by_budget_id=replacement_budget_id, + ) + if superseded is None: + raise ExecutionBudgetLifecycleError( + f"execution budget {request.execution_budget_id} could not be superseded" + ) + replacement = store.create_execution_budget( + budget_id=replacement_budget_id, + tool_key=current["tool_key"], + domain_hint=current["domain_hint"], + max_completed_executions=request.max_completed_executions, + rolling_window_seconds=current["rolling_window_seconds"], + supersedes_budget_id=current["id"], + ) + except psycopg.IntegrityError as exc: + if _is_active_scope_uniqueness_error(exc): + error = ExecutionBudgetLifecycleError( + _duplicate_active_scope_message( + tool_key=current["tool_key"], + domain_hint=current["domain_hint"], + ) + ) + else: + raise + except ExecutionBudgetLifecycleError as exc: + error = exc + else: + error = None + + if error is not None: + current_state = store.get_execution_budget_optional(request.execution_budget_id) + if current_state is None: + raise ExecutionBudgetNotFoundError( + f"execution budget {request.execution_budget_id} was not found" + ) + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(current_state["id"]), + "requested_action": "supersede", + "previous_status": cast(str, current["status"]), + "current_status": cast(str, current_state["status"]), + "tool_key": current_state["tool_key"], + "domain_hint": current_state["domain_hint"], + "max_completed_executions": current_state["max_completed_executions"], + "rolling_window_seconds": current_state["rolling_window_seconds"], + "deactivated_at": ( + None + if current_state["deactivated_at"] is None + else current_state["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None + if current_state["superseded_by_budget_id"] is None + else str(current_state["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None + if current_state["supersedes_budget_id"] is None + else str(current_state["supersedes_budget_id"]) + ), + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": request.max_completed_executions, + "replacement_rolling_window_seconds": current["rolling_window_seconds"], + "rejection_reason": str(error), + }, + summary_payload={ + "execution_budget_id": str(current_state["id"]), + "requested_action": "supersede", + "outcome": "rejected", + "replacement_budget_id": None, + "active_budget_id": ( + str(current_state["id"]) + if cast(str, current_state["status"]) == "active" + else None + ), + }, + requested_action="supersede", + outcome="rejected", + ) + del trace + raise error + + trace = _record_lifecycle_trace( + store, + thread=thread, + request_payload=request_payload, + state_payload={ + "execution_budget_id": str(superseded["id"]), + "requested_action": "supersede", + "previous_status": "active", + "current_status": cast(str, superseded["status"]), + "tool_key": superseded["tool_key"], + "domain_hint": superseded["domain_hint"], + "max_completed_executions": superseded["max_completed_executions"], + "rolling_window_seconds": superseded["rolling_window_seconds"], + "deactivated_at": ( + None if superseded["deactivated_at"] is None else superseded["deactivated_at"].isoformat() + ), + "superseded_by_budget_id": ( + None if superseded["superseded_by_budget_id"] is None else str(superseded["superseded_by_budget_id"]) + ), + "supersedes_budget_id": ( + None if superseded["supersedes_budget_id"] is None else str(superseded["supersedes_budget_id"]) + ), + "replacement_budget_id": str(replacement["id"]), + "replacement_status": cast(str, replacement["status"]), + "replacement_max_completed_executions": replacement["max_completed_executions"], + "replacement_rolling_window_seconds": replacement["rolling_window_seconds"], + "rejection_reason": None, + }, + summary_payload={ + "execution_budget_id": str(superseded["id"]), + "requested_action": "supersede", + "outcome": "superseded", + "replacement_budget_id": str(replacement["id"]), + "active_budget_id": str(replacement["id"]), + }, + requested_action="supersede", + outcome="superseded", + ) + return { + "superseded_budget": serialize_execution_budget_row(superseded), + "replacement_budget": serialize_execution_budget_row(replacement), + "trace": cast(dict[str, object], trace), + } + + +def _budget_specificity(row: ExecutionBudgetRow) -> int: + return int(row["tool_key"] is not None) + int(row["domain_hint"] is not None) + + +def _matches_budget( + row: ExecutionBudgetRow, + *, + tool_key: str, + domain_hint: str | None, +) -> bool: + if row["tool_key"] is not None and row["tool_key"] != tool_key: + return False + if row["domain_hint"] is not None and row["domain_hint"] != domain_hint: + return False + return True + + +def _matching_budget_rows( + store: ContinuityStore, + *, + tool_key: str, + domain_hint: str | None, +) -> list[ExecutionBudgetRow]: + rows = [ + row + for row in store.list_execution_budgets() + if cast(str, row["status"]) == "active" + and _matches_budget(row, tool_key=tool_key, domain_hint=domain_hint) + ] + return sorted( + rows, + key=lambda row: (-_budget_specificity(row), row["created_at"], row["id"]), + ) + + +def _execution_matches_budget(row: ToolExecutionRow, budget: ExecutionBudgetRow) -> bool: + if cast(str, row["status"]) != "completed": + return False + + tool = cast(dict[str, object], row["tool"]) + request = cast(dict[str, object], row["request"]) + + if budget["tool_key"] is not None and tool.get("tool_key") != budget["tool_key"]: + return False + if budget["domain_hint"] is not None and request.get("domain_hint") != budget["domain_hint"]: + return False + return True + + +def _current_time(store: ContinuityStore) -> datetime: + current_time = getattr(store, "current_time", None) + if callable(current_time): + value = current_time() + if isinstance(value, datetime): + return value + return datetime.now(UTC) + + +def _window_started_at( + *, + evaluation_time: datetime, + rolling_window_seconds: int | None, +) -> datetime | None: + if rolling_window_seconds is None: + return None + return evaluation_time - timedelta(seconds=rolling_window_seconds) + + +def _counted_completed_execution_rows( + store: ContinuityStore, + *, + matched_budget: ExecutionBudgetRow, + evaluation_time: datetime, +) -> list[ToolExecutionRow]: + window_started_at = _window_started_at( + evaluation_time=evaluation_time, + rolling_window_seconds=matched_budget["rolling_window_seconds"], + ) + counted_rows: list[ToolExecutionRow] = [] + for row in store.list_tool_executions(): + execution_row = cast(ToolExecutionRow, row) + if not _execution_matches_budget(execution_row, matched_budget): + continue + if window_started_at is not None and execution_row["executed_at"] < window_started_at: + continue + counted_rows.append(execution_row) + return counted_rows + + +def _blocked_result( + decision: ExecutionBudgetDecisionRecord, +) -> ToolExecutionResultRecord: + matched_budget_id = decision["matched_budget_id"] + max_completed_executions = decision["max_completed_executions"] + projected_completed_execution_count = decision["projected_completed_execution_count"] + rolling_window_seconds = decision["rolling_window_seconds"] + if rolling_window_seconds is None: + reason = ( + f"execution budget {matched_budget_id} blocks execution: projected completed executions " + f"{projected_completed_execution_count} would exceed limit {max_completed_executions}" + ) + else: + reason = ( + f"execution budget {matched_budget_id} blocks execution: projected completed executions " + f"{projected_completed_execution_count} within rolling window {rolling_window_seconds} " + f"seconds would exceed limit {max_completed_executions}" + ) + return { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": reason, + "budget_decision": decision, + } + + +def evaluate_execution_budget( + store: ContinuityStore, + *, + tool: ToolRecord, + request: ToolRoutingRequestRecord, +) -> ExecutionBudgetDecision: + matching_budgets = _matching_budget_rows( + store, + tool_key=tool["tool_key"], + domain_hint=request["domain_hint"], + ) + matched_budget = matching_budgets[0] if matching_budgets else None + evaluation_time = _current_time(store) + window_started_at = ( + None + if matched_budget is None + else _window_started_at( + evaluation_time=evaluation_time, + rolling_window_seconds=matched_budget["rolling_window_seconds"], + ) + ) + completed_execution_count = 0 + projected_completed_execution_count = 1 + + if matched_budget is not None: + completed_execution_count = len( + _counted_completed_execution_rows( + store, + matched_budget=matched_budget, + evaluation_time=evaluation_time, + ) + ) + projected_completed_execution_count = completed_execution_count + 1 + + record: ExecutionBudgetDecisionRecord = { + "matched_budget_id": None if matched_budget is None else str(matched_budget["id"]), + "tool_key": tool["tool_key"], + "domain_hint": request["domain_hint"], + "budget_tool_key": None if matched_budget is None else matched_budget["tool_key"], + "budget_domain_hint": None if matched_budget is None else matched_budget["domain_hint"], + "max_completed_executions": ( + None if matched_budget is None else matched_budget["max_completed_executions"] + ), + "rolling_window_seconds": ( + None if matched_budget is None else matched_budget["rolling_window_seconds"] + ), + "count_scope": ( + "lifetime" + if matched_budget is None or matched_budget["rolling_window_seconds"] is None + else "rolling_window" + ), + "window_started_at": None if window_started_at is None else window_started_at.isoformat(), + "completed_execution_count": completed_execution_count, + "projected_completed_execution_count": projected_completed_execution_count, + "decision": "allow", + "reason": "no_matching_budget", + "order": list(EXECUTION_BUDGET_MATCH_ORDER), + "history_order": list(TOOL_EXECUTION_LIST_ORDER), + } + + if matched_budget is None: + return ExecutionBudgetDecision(record=record, blocked_result=None) + + if projected_completed_execution_count <= matched_budget["max_completed_executions"]: + record["reason"] = "within_budget" + return ExecutionBudgetDecision(record=record, blocked_result=None) + + record["decision"] = "block" + record["reason"] = "budget_exceeded" + blocked_result = _blocked_result(record) + return ExecutionBudgetDecision(record=record, blocked_result=blocked_result) diff --git a/apps/api/src/alicebot_api/executions.py b/apps/api/src/alicebot_api/executions.py new file mode 100644 index 0000000..5bb0740 --- /dev/null +++ b/apps/api/src/alicebot_api/executions.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import cast +from uuid import UUID + +from alicebot_api.contracts import ( + TOOL_EXECUTION_LIST_ORDER, + ToolExecutionDetailResponse, + ToolExecutionListResponse, + ToolExecutionListSummary, + ToolExecutionRecord, +) +from alicebot_api.store import ContinuityStore, ToolExecutionRow + + +class ToolExecutionNotFoundError(LookupError): + """Raised when an execution record is not visible inside the current user scope.""" + + +def serialize_tool_execution_row(row: ToolExecutionRow) -> ToolExecutionRecord: + return { + "id": str(row["id"]), + "approval_id": str(row["approval_id"]), + "task_step_id": str(row["task_step_id"]), + "thread_id": str(row["thread_id"]), + "tool_id": str(row["tool_id"]), + "trace_id": str(row["trace_id"]), + "request_event_id": None if row["request_event_id"] is None else str(row["request_event_id"]), + "result_event_id": None if row["result_event_id"] is None else str(row["result_event_id"]), + "status": cast(str, row["status"]), + "handler_key": row["handler_key"], + "request": cast(dict[str, object], row["request"]), + "tool": cast(dict[str, object], row["tool"]), + "result": cast(dict[str, object], row["result"]), + "executed_at": row["executed_at"].isoformat(), + } + + +def list_tool_execution_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ToolExecutionListResponse: + del user_id + + items = [serialize_tool_execution_row(row) for row in store.list_tool_executions()] + summary: ToolExecutionListSummary = { + "total_count": len(items), + "order": list(TOOL_EXECUTION_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_tool_execution_record( + store: ContinuityStore, + *, + user_id: UUID, + execution_id: UUID, +) -> ToolExecutionDetailResponse: + del user_id + + execution = store.get_tool_execution_optional(execution_id) + if execution is None: + raise ToolExecutionNotFoundError(f"tool execution {execution_id} was not found") + return {"execution": serialize_tool_execution_row(execution)} diff --git a/apps/api/src/alicebot_api/explicit_preferences.py b/apps/api/src/alicebot_api/explicit_preferences.py new file mode 100644 index 0000000..f451426 --- /dev/null +++ b/apps/api/src/alicebot_api/explicit_preferences.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import hashlib +import re +from collections.abc import Sequence +from typing import Literal +from uuid import UUID + +from alicebot_api.contracts import ( + AdmissionDecisionOutput, + ExplicitPreferenceAdmissionRecord, + ExplicitPreferenceExtractionRequestInput, + ExplicitPreferenceExtractionResponse, + ExplicitPreferenceExtractionSummary, + ExplicitPreferencePattern, + ExtractedPreferenceCandidateRecord, + MemoryCandidateInput, +) +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore, EventRow, JsonObject + +PreferenceKind = Literal["like", "dislike", "prefer"] +_DIRECT_PATTERNS: tuple[tuple[ExplicitPreferencePattern, PreferenceKind, re.Pattern[str]], ...] = ( + ("i_like", "like", re.compile(r"^i like (?P.+)$", re.IGNORECASE)), + ("i_dont_like", "dislike", re.compile(r"^i don't like (?P.+)$", re.IGNORECASE)), + ("i_prefer", "prefer", re.compile(r"^i prefer (?P.+)$", re.IGNORECASE)), +) +_REMEMBER_PREFIX = "remember that " +_TRAILING_PUNCTUATION = ".!?" +_MEMORY_KEY_PREFIX = "user.preference." +_MAX_MEMORY_KEY_LENGTH = 200 +_MEMORY_KEY_HASH_LENGTH = 12 +_MAX_SUBJECT_TOKENS = 6 +_ALLOWED_SUBJECT_TOKEN = re.compile(r"^[a-z0-9][a-z0-9+#&./+'-]*$", re.IGNORECASE) +_DISALLOWED_SUBJECT_PREFIX_TOKENS = { + "that", + "to", + "if", + "when", + "because", + "whether", + "we", + "you", + "they", + "he", + "she", + "it", + "there", + "this", +} +_REMEMBER_PATTERN_MAP: dict[ExplicitPreferencePattern, ExplicitPreferencePattern] = { + "i_like": "remember_that_i_like", + "i_dont_like": "remember_that_i_dont_like", + "i_prefer": "remember_that_i_prefer", + "remember_that_i_like": "remember_that_i_like", + "remember_that_i_dont_like": "remember_that_i_dont_like", + "remember_that_i_prefer": "remember_that_i_prefer", +} + + +class ExplicitPreferenceExtractionValidationError(ValueError): + """Raised when an explicit-preference extraction request is invalid.""" + + +def _normalize_whitespace(value: str) -> str: + return re.sub(r"\s+", " ", value).strip() + + +def _normalize_subject(subject: str) -> str: + normalized = _normalize_whitespace(subject) + normalized = normalized.rstrip(_TRAILING_PUNCTUATION).strip() + return normalized + + +def _canonicalize_subject_for_key(subject: str) -> str: + return subject.casefold() + + +def _subject_has_supported_shape(subject: str) -> bool: + tokens = subject.split(" ") + if not tokens or len(tokens) > _MAX_SUBJECT_TOKENS: + return False + + if tokens[0].casefold() in _DISALLOWED_SUBJECT_PREFIX_TOKENS: + return False + + return all(_ALLOWED_SUBJECT_TOKEN.fullmatch(token) is not None for token in tokens) + + +def _slugify_subject(subject: str, *, max_length: int) -> str: + slug = subject.casefold() + slug = slug.replace("'", "") + slug = re.sub(r"[^a-z0-9]+", "_", slug) + slug = slug.strip("_") + if len(slug) > max_length: + slug = slug[:max_length].rstrip("_") + return slug + + +def _build_memory_key(subject: str) -> str: + canonical_subject = _canonicalize_subject_for_key(subject) + digest = hashlib.sha256(canonical_subject.encode("utf-8")).hexdigest()[:_MEMORY_KEY_HASH_LENGTH] + max_slug_length = _MAX_MEMORY_KEY_LENGTH - len(_MEMORY_KEY_PREFIX) - len("__") - len(digest) + slug = _slugify_subject(canonical_subject, max_length=max_slug_length) + if not slug: + return f"{_MEMORY_KEY_PREFIX}{digest}" + return f"{_MEMORY_KEY_PREFIX}{slug}__{digest}" + + +def _build_candidate( + *, + source_event_id: UUID, + pattern: ExplicitPreferencePattern, + preference: PreferenceKind, + subject_text: str, +) -> ExtractedPreferenceCandidateRecord | None: + normalized_subject = _normalize_subject(subject_text) + if not normalized_subject: + return None + + if not _subject_has_supported_shape(normalized_subject): + return None + + value: JsonObject = { + "kind": "explicit_preference", + "preference": preference, + "text": normalized_subject, + } + return { + "memory_key": _build_memory_key(normalized_subject), + "value": value, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + "pattern": pattern, + "subject_text": normalized_subject, + } + + +def extract_explicit_preference_candidates( + *, + source_event_id: UUID, + text: str, +) -> list[ExtractedPreferenceCandidateRecord]: + normalized_text = _normalize_whitespace(text) + if not normalized_text: + return [] + + for pattern_name, preference, pattern in _DIRECT_PATTERNS: + match = pattern.fullmatch(normalized_text) + if match is not None: + candidate = _build_candidate( + source_event_id=source_event_id, + pattern=pattern_name, + preference=preference, + subject_text=match.group("subject"), + ) + return [] if candidate is None else [candidate] + + lowered_text = normalized_text.lower() + if lowered_text.startswith(_REMEMBER_PREFIX): + nested_text = normalized_text[len(_REMEMBER_PREFIX) :] + nested_candidates = extract_explicit_preference_candidates( + source_event_id=source_event_id, + text=nested_text, + ) + if not nested_candidates: + return [] + candidate = dict(nested_candidates[0]) + candidate["pattern"] = _REMEMBER_PATTERN_MAP[candidate["pattern"]] + return [candidate] + + return [] + + +def _get_single_source_event(store: ContinuityStore, source_event_id: UUID) -> EventRow: + events = store.list_events_by_ids([source_event_id]) + if not events: + raise ExplicitPreferenceExtractionValidationError( + "source_event_id must reference an existing message.user event owned by the user" + ) + return events[0] + + +def _extract_text_payload(event: EventRow) -> str: + if event["kind"] != "message.user": + raise ExplicitPreferenceExtractionValidationError( + "source_event_id must reference an existing message.user event owned by the user" + ) + + payload_text = event["payload"].get("text") + if not isinstance(payload_text, str): + raise ExplicitPreferenceExtractionValidationError( + "source_event_id must reference a message.user event with string payload.text" + ) + + return payload_text + + +def _serialize_admission(decision: AdmissionDecisionOutput) -> ExplicitPreferenceAdmissionRecord: + return { + "decision": decision.action, + "reason": decision.reason, + "memory": decision.memory, + "revision": decision.revision, + } + + +def _build_summary( + *, + source_event_id: UUID, + source_event_kind: str, + admissions: Sequence[ExplicitPreferenceAdmissionRecord], + candidates: Sequence[ExtractedPreferenceCandidateRecord], +) -> ExplicitPreferenceExtractionSummary: + noop_count = sum(1 for admission in admissions if admission["decision"] == "NOOP") + return { + "source_event_id": str(source_event_id), + "source_event_kind": source_event_kind, + "candidate_count": len(candidates), + "admission_count": len(admissions), + "persisted_change_count": len(admissions) - noop_count, + "noop_count": noop_count, + } + + +def extract_and_admit_explicit_preferences( + store: ContinuityStore, + *, + user_id: UUID, + request: ExplicitPreferenceExtractionRequestInput, +) -> ExplicitPreferenceExtractionResponse: + source_event = _get_single_source_event(store, request.source_event_id) + payload_text = _extract_text_payload(source_event) + candidates = extract_explicit_preference_candidates( + source_event_id=request.source_event_id, + text=payload_text, + ) + + admissions: list[ExplicitPreferenceAdmissionRecord] = [] + for candidate in candidates: + decision = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key=candidate["memory_key"], + value=candidate["value"], + source_event_ids=(request.source_event_id,), + delete_requested=candidate["delete_requested"], + ), + ) + admissions.append(_serialize_admission(decision)) + + return { + "candidates": list(candidates), + "admissions": admissions, + "summary": _build_summary( + source_event_id=request.source_event_id, + source_event_kind=source_event["kind"], + admissions=admissions, + candidates=candidates, + ), + } diff --git a/apps/api/src/alicebot_api/main.py b/apps/api/src/alicebot_api/main.py new file mode 100644 index 0000000..764812d --- /dev/null +++ b/apps/api/src/alicebot_api/main.py @@ -0,0 +1,1837 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal, TypedDict +from uuid import UUID +from fastapi import FastAPI, Query +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel, Field +from fastapi.responses import JSONResponse +from urllib.parse import urlsplit, urlunsplit + +from alicebot_api.compiler import compile_and_persist_trace +from alicebot_api.config import Settings, get_settings +from alicebot_api.contracts import ( + ApprovalApproveInput, + ApprovalRejectInput, + ApprovalRequestCreateInput, + ConsentStatus, + ConsentUpsertInput, + CompileContextSemanticRetrievalInput, + DEFAULT_MAX_EVENTS, + DEFAULT_MAX_ENTITY_EDGES, + DEFAULT_MAX_ENTITIES, + DEFAULT_MAX_MEMORIES, + DEFAULT_MEMORY_REVIEW_LIMIT, + DEFAULT_MAX_SESSIONS, + DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + MAX_MEMORY_REVIEW_LIMIT, + MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ContextCompilerLimits, + EmbeddingConfigStatus, + EmbeddingConfigCreateInput, + ExecutionBudgetCreateInput, + ExecutionBudgetDeactivateInput, + ExecutionBudgetSupersedeInput, + EntityEdgeCreateInput, + EntityCreateInput, + EntityType, + ExplicitPreferenceExtractionRequestInput, + MemoryCandidateInput, + MemoryEmbeddingUpsertInput, + MemoryReviewLabelValue, + MemoryReviewStatusFilter, + PolicyCreateInput, + PolicyEffect, + PolicyEvaluationRequestInput, + SemanticMemoryRetrievalRequestInput, + TOOL_METADATA_VERSION_V0, + ApprovalStatus, + ProxyExecutionStatus, + ToolAllowlistEvaluationRequestInput, + ProxyExecutionRequestInput, + TaskStepKind, + TaskStepLineageInput, + TaskStepNextCreateInput, + TaskStepStatus, + TaskStepTransitionInput, + TaskWorkspaceCreateInput, + ToolRoutingDecision, + ToolRoutingRequestInput, + ToolCreateInput, +) +from alicebot_api.approvals import ( + ApprovalNotFoundError, + ApprovalResolutionConflictError, + approve_approval_record, + get_approval_record, + list_approval_records, + reject_approval_record, + submit_approval_request, +) +from alicebot_api.db import ping_database, user_connection +from alicebot_api.executions import ( + ToolExecutionNotFoundError, + get_tool_execution_record, + list_tool_execution_records, +) +from alicebot_api.tasks import ( + TaskNotFoundError, + TaskStepApprovalLinkageError, + TaskStepExecutionLinkageError, + TaskStepLifecycleBoundaryError, + TaskStepSequenceError, + TaskStepNotFoundError, + TaskStepTransitionError, + create_next_task_step_record, + get_task_record, + get_task_step_record, + list_task_records, + list_task_step_records, + transition_task_step_record, +) +from alicebot_api.workspaces import ( + TaskWorkspaceAlreadyExistsError, + TaskWorkspaceNotFoundError, + TaskWorkspaceProvisioningError, + create_task_workspace_record, + get_task_workspace_record, + list_task_workspace_records, +) +from alicebot_api.execution_budgets import ( + ExecutionBudgetLifecycleError, + ExecutionBudgetNotFoundError, + ExecutionBudgetValidationError, + create_execution_budget_record, + deactivate_execution_budget_record, + get_execution_budget_record, + list_execution_budget_records, + supersede_execution_budget_record, +) +from alicebot_api.embedding import ( + EmbeddingConfigValidationError, + MemoryEmbeddingNotFoundError, + MemoryEmbeddingValidationError, + create_embedding_config_record, + get_memory_embedding_record, + list_embedding_config_records, + list_memory_embedding_records, + upsert_memory_embedding_record, +) +from alicebot_api.entity import ( + EntityNotFoundError, + EntityValidationError, + create_entity_record, + get_entity_record, + list_entity_records, +) +from alicebot_api.entity_edge import ( + EntityEdgeValidationError, + create_entity_edge_record, + list_entity_edge_records, +) +from alicebot_api.explicit_preferences import ( + ExplicitPreferenceExtractionValidationError, + extract_and_admit_explicit_preferences, +) +from alicebot_api.memory import ( + MemoryAdmissionValidationError, + MemoryReviewNotFoundError, + admit_memory_candidate, + create_memory_review_label_record, + get_memory_evaluation_summary, + get_memory_review_record, + list_memory_review_queue_records, + list_memory_review_label_records, + list_memory_review_records, + list_memory_revision_review_records, +) +from alicebot_api.policy import ( + PolicyEvaluationValidationError, + PolicyNotFoundError, + PolicyValidationError, + create_policy_record, + evaluate_policy_request, + get_policy_record, + list_consent_records, + list_policy_records, + upsert_consent_record, +) +from alicebot_api.tools import ( + ToolAllowlistValidationError, + ToolNotFoundError, + ToolRoutingValidationError, + ToolValidationError, + create_tool_record, + evaluate_tool_allowlist, + get_tool_record, + list_tool_records, + route_tool_invocation, +) +from alicebot_api.semantic_retrieval import ( + SemanticMemoryRetrievalValidationError, + retrieve_semantic_memory_records, +) +from alicebot_api.response_generation import ( + ResponseFailure, + generate_response, +) +from alicebot_api.proxy_execution import ( + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, + execute_approved_proxy_request, +) +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError + + +app = FastAPI(title="AliceBot API", version="0.1.0") +HealthStatus = Literal["ok", "degraded"] +ServiceStatus = Literal["ok", "unreachable", "not_checked"] + + +class DatabaseServicePayload(TypedDict): + status: Literal["ok", "unreachable"] + + +class RedisServicePayload(TypedDict): + status: Literal["not_checked"] + url: str + + +class ObjectStorageServicePayload(TypedDict): + status: Literal["not_checked"] + endpoint_url: str + + +class HealthServicesPayload(TypedDict): + database: DatabaseServicePayload + redis: RedisServicePayload + object_storage: ObjectStorageServicePayload + + +class HealthcheckPayload(TypedDict): + status: HealthStatus + environment: str + services: HealthServicesPayload + + +class CompileContextSemanticRequest(BaseModel): + embedding_config_id: UUID + query_vector: list[float] = Field(min_length=1, max_length=20000) + limit: int = Field( + default=DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ge=1, + le=MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ) + + +class CompileContextRequest(BaseModel): + user_id: UUID + thread_id: UUID + max_sessions: int = Field(default=DEFAULT_MAX_SESSIONS, ge=0, le=25) + max_events: int = Field(default=DEFAULT_MAX_EVENTS, ge=0, le=200) + max_memories: int = Field(default=DEFAULT_MAX_MEMORIES, ge=0, le=50) + max_entities: int = Field(default=DEFAULT_MAX_ENTITIES, ge=0, le=50) + max_entity_edges: int = Field(default=DEFAULT_MAX_ENTITY_EDGES, ge=0, le=100) + semantic: CompileContextSemanticRequest | None = None + + +class GenerateResponseRequest(BaseModel): + user_id: UUID + thread_id: UUID + message: str = Field(min_length=1, max_length=8000) + max_sessions: int = Field(default=DEFAULT_MAX_SESSIONS, ge=0, le=25) + max_events: int = Field(default=DEFAULT_MAX_EVENTS, ge=0, le=200) + max_memories: int = Field(default=DEFAULT_MAX_MEMORIES, ge=0, le=50) + max_entities: int = Field(default=DEFAULT_MAX_ENTITIES, ge=0, le=50) + max_entity_edges: int = Field(default=DEFAULT_MAX_ENTITY_EDGES, ge=0, le=100) + + +class AdmitMemoryRequest(BaseModel): + user_id: UUID + memory_key: str = Field(min_length=1, max_length=200) + value: object | None = None + source_event_ids: list[UUID] = Field(min_length=1) + delete_requested: bool = False + + +class ExtractExplicitPreferencesRequest(BaseModel): + user_id: UUID + source_event_id: UUID + + +class CreateMemoryReviewLabelRequest(BaseModel): + user_id: UUID + label: MemoryReviewLabelValue + note: str | None = Field(default=None, min_length=1, max_length=280) + + +class CreateEntityRequest(BaseModel): + user_id: UUID + entity_type: EntityType + name: str = Field(min_length=1, max_length=200) + source_memory_ids: list[UUID] = Field(min_length=1) + + +class CreateEntityEdgeRequest(BaseModel): + user_id: UUID + from_entity_id: UUID + to_entity_id: UUID + relationship_type: str = Field(min_length=1, max_length=100) + valid_from: datetime | None = None + valid_to: datetime | None = None + source_memory_ids: list[UUID] = Field(min_length=1) + + +class CreateEmbeddingConfigRequest(BaseModel): + user_id: UUID + provider: str = Field(min_length=1, max_length=100) + model: str = Field(min_length=1, max_length=200) + version: str = Field(min_length=1, max_length=100) + dimensions: int = Field(ge=1, le=20000) + status: EmbeddingConfigStatus = "active" + metadata: dict[str, object] = Field(default_factory=dict) + + +class UpsertMemoryEmbeddingRequest(BaseModel): + user_id: UUID + memory_id: UUID + embedding_config_id: UUID + vector: list[float] = Field(min_length=1, max_length=20000) + + +class RetrieveSemanticMemoriesRequest(BaseModel): + user_id: UUID + embedding_config_id: UUID + query_vector: list[float] = Field(min_length=1, max_length=20000) + limit: int = Field( + default=DEFAULT_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ge=1, + le=MAX_SEMANTIC_MEMORY_RETRIEVAL_LIMIT, + ) + + +class UpsertConsentRequest(BaseModel): + user_id: UUID + consent_key: str = Field(min_length=1, max_length=200) + status: ConsentStatus + metadata: dict[str, object] = Field(default_factory=dict) + + +class CreatePolicyRequest(BaseModel): + user_id: UUID + name: str = Field(min_length=1, max_length=200) + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + effect: PolicyEffect + priority: int = Field(ge=0, le=1000000) + active: bool = True + conditions: dict[str, object] = Field(default_factory=dict) + required_consents: list[str] = Field(default_factory=list) + + +class EvaluatePolicyRequest(BaseModel): + user_id: UUID + thread_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + attributes: dict[str, object] = Field(default_factory=dict) + + +class CreateToolRequest(BaseModel): + user_id: UUID + tool_key: str = Field(min_length=1, max_length=200) + name: str = Field(min_length=1, max_length=200) + description: str = Field(min_length=1, max_length=500) + version: str = Field(min_length=1, max_length=100) + metadata_version: str = Field(default=TOOL_METADATA_VERSION_V0, pattern=f"^{TOOL_METADATA_VERSION_V0}$") + active: bool = True + tags: list[str] = Field(default_factory=list) + action_hints: list[str] = Field(default_factory=list, min_length=1) + scope_hints: list[str] = Field(default_factory=list, min_length=1) + domain_hints: list[str] = Field(default_factory=list) + risk_hints: list[str] = Field(default_factory=list) + metadata: dict[str, object] = Field(default_factory=dict) + + +class EvaluateToolAllowlistRequest(BaseModel): + user_id: UUID + thread_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class RouteToolRequest(BaseModel): + user_id: UUID + thread_id: UUID + tool_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class CreateApprovalRequest(BaseModel): + user_id: UUID + thread_id: UUID + tool_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class ResolveApprovalRequest(BaseModel): + user_id: UUID + + +class ExecuteApprovedProxyRequest(BaseModel): + user_id: UUID + + +class CreateTaskWorkspaceRequest(BaseModel): + user_id: UUID + + +class TaskStepRequestSnapshot(BaseModel): + thread_id: UUID + tool_id: UUID + action: str = Field(min_length=1, max_length=100) + scope: str = Field(min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + risk_hint: str | None = Field(default=None, min_length=1, max_length=100) + attributes: dict[str, object] = Field(default_factory=dict) + + +class TaskStepOutcomeRequest(BaseModel): + routing_decision: ToolRoutingDecision + approval_id: UUID | None = None + approval_status: ApprovalStatus | None = None + execution_id: UUID | None = None + execution_status: ProxyExecutionStatus | None = None + blocked_reason: str | None = Field(default=None, min_length=1, max_length=500) + + +class TaskStepLineageRequest(BaseModel): + parent_step_id: UUID + source_approval_id: UUID | None = None + source_execution_id: UUID | None = None + + +class CreateNextTaskStepRequest(BaseModel): + user_id: UUID + kind: TaskStepKind = "governed_request" + status: TaskStepStatus + request: TaskStepRequestSnapshot + outcome: TaskStepOutcomeRequest + lineage: TaskStepLineageRequest + + +class TransitionTaskStepRequest(BaseModel): + user_id: UUID + status: TaskStepStatus + outcome: TaskStepOutcomeRequest + + +class CreateExecutionBudgetRequest(BaseModel): + user_id: UUID + tool_key: str | None = Field(default=None, min_length=1, max_length=200) + domain_hint: str | None = Field(default=None, min_length=1, max_length=200) + max_completed_executions: int = Field(ge=1, le=1000000) + rolling_window_seconds: int | None = Field(default=None, ge=1) + + +class DeactivateExecutionBudgetRequest(BaseModel): + user_id: UUID + thread_id: UUID + + +class SupersedeExecutionBudgetRequest(BaseModel): + user_id: UUID + thread_id: UUID + max_completed_executions: int = Field(ge=1, le=1000000) + + +def redact_url_credentials(raw_url: str) -> str: + parsed = urlsplit(raw_url) + + if parsed.hostname is None or (parsed.username is None and parsed.password is None): + return raw_url + + hostname = parsed.hostname + if ":" in hostname and not hostname.startswith("["): + hostname = f"[{hostname}]" + + netloc = hostname + if parsed.port is not None: + netloc = f"{hostname}:{parsed.port}" + + return urlunsplit((parsed.scheme, netloc, parsed.path, parsed.query, parsed.fragment)) + + +def build_healthcheck_payload(settings: Settings, database_ok: bool) -> HealthcheckPayload: + status: HealthStatus = "ok" if database_ok else "degraded" + database_status: Literal["ok", "unreachable"] = "ok" if database_ok else "unreachable" + + return { + "status": status, + "environment": settings.app_env, + "services": { + "database": { + "status": database_status, + }, + "redis": { + "status": "not_checked", + "url": redact_url_credentials(settings.redis_url), + }, + "object_storage": { + "status": "not_checked", + "endpoint_url": settings.s3_endpoint_url, + }, + }, + } + + +@app.get("/healthz") +def healthcheck() -> JSONResponse: + settings = get_settings() + database_ok = ping_database( + settings.database_url, + settings.healthcheck_timeout_seconds, + ) + payload = build_healthcheck_payload(settings, database_ok) + status_code = 200 if payload["status"] == "ok" else 503 + return JSONResponse( + status_code=status_code, + content=payload, + ) + + +@app.post("/v0/context/compile") +def compile_context(request: CompileContextRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + result = compile_and_persist_trace( + ContinuityStore(conn), + user_id=request.user_id, + thread_id=request.thread_id, + limits=ContextCompilerLimits( + max_sessions=request.max_sessions, + max_events=request.max_events, + max_memories=request.max_memories, + max_entities=request.max_entities, + max_entity_edges=request.max_entity_edges, + ), + semantic_retrieval=( + None + if request.semantic is None + else CompileContextSemanticRetrievalInput( + embedding_config_id=request.semantic.embedding_config_id, + query_vector=tuple(request.semantic.query_vector), + limit=request.semantic.limit, + ) + ), + ) + except SemanticMemoryRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except ContinuityStoreInvariantError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder( + { + "trace_id": result.trace_id, + "trace_event_count": result.trace_event_count, + "context_pack": result.context_pack, + } + ), + ) + + +@app.post("/v0/responses") +def generate_assistant_response(request: GenerateResponseRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + result = generate_response( + store=ContinuityStore(conn), + settings=settings, + user_id=request.user_id, + thread_id=request.thread_id, + message_text=request.message, + limits=ContextCompilerLimits( + max_sessions=request.max_sessions, + max_events=request.max_events, + max_memories=request.max_memories, + max_entities=request.max_entities, + max_entity_edges=request.max_entity_edges, + ), + ) + except ContinuityStoreInvariantError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if isinstance(result, ResponseFailure): + return JSONResponse( + status_code=502, + content=jsonable_encoder( + { + "detail": result.detail, + "trace": result.trace, + } + ), + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(result), + ) + + +@app.post("/v0/memories/admit") +def admit_memory(request: AdmitMemoryRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + decision = admit_memory_candidate( + ContinuityStore(conn), + user_id=request.user_id, + candidate=MemoryCandidateInput( + memory_key=request.memory_key, + value=request.value, + source_event_ids=tuple(request.source_event_ids), + delete_requested=request.delete_requested, + ), + ) + except MemoryAdmissionValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder( + { + "decision": decision.action, + "reason": decision.reason, + "memory": decision.memory, + "revision": decision.revision, + } + ), + ) + + +@app.post("/v0/consents") +def upsert_consent(request: UpsertConsentRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = upsert_consent_record( + ContinuityStore(conn), + user_id=request.user_id, + consent=ConsentUpsertInput( + consent_key=request.consent_key, + status=request.status, + metadata=request.metadata, + ), + ) + except PolicyValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + status_code = 201 if payload["write_mode"] == "created" else 200 + return JSONResponse( + status_code=status_code, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/consents") +def list_consents(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_consent_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/policies") +def create_policy(request: CreatePolicyRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_policy_record( + ContinuityStore(conn), + user_id=request.user_id, + policy=PolicyCreateInput( + name=request.name, + action=request.action, + scope=request.scope, + effect=request.effect, + priority=request.priority, + active=request.active, + conditions=request.conditions, + required_consents=tuple(request.required_consents), + ), + ) + except PolicyValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/policies") +def list_policies(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_policy_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/policies/{policy_id}") +def get_policy(policy_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_policy_record( + ContinuityStore(conn), + user_id=user_id, + policy_id=policy_id, + ) + except PolicyNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/policies/evaluate") +def evaluate_policy(request: EvaluatePolicyRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = evaluate_policy_request( + ContinuityStore(conn), + user_id=request.user_id, + request=PolicyEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + attributes=request.attributes, + ), + ) + except PolicyEvaluationValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tools") +def create_tool(request: CreateToolRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_tool_record( + ContinuityStore(conn), + user_id=request.user_id, + tool=ToolCreateInput( + tool_key=request.tool_key, + name=request.name, + description=request.description, + version=request.version, + metadata_version=request.metadata_version, + active=request.active, + tags=tuple(request.tags), + action_hints=tuple(request.action_hints), + scope_hints=tuple(request.scope_hints), + domain_hints=tuple(request.domain_hints), + risk_hints=tuple(request.risk_hints), + metadata=request.metadata, + ), + ) + except ToolValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tools") +def list_tools(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_tool_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tools/allowlist/evaluate") +def evaluate_tools_allowlist(request: EvaluateToolAllowlistRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = evaluate_tool_allowlist( + ContinuityStore(conn), + user_id=request.user_id, + request=ToolAllowlistEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + except ToolAllowlistValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tools/route") +def route_tool(request: RouteToolRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = route_tool_invocation( + ContinuityStore(conn), + user_id=request.user_id, + request=ToolRoutingRequestInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + except ToolRoutingValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/requests") +def create_approval_request(request: CreateApprovalRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = submit_approval_request( + ContinuityStore(conn), + user_id=request.user_id, + request=ApprovalRequestCreateInput( + thread_id=request.thread_id, + tool_id=request.tool_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ), + ) + except ToolRoutingValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/approvals") +def list_approvals(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_approval_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/approvals/{approval_id}") +def get_approval(approval_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_approval_record( + ContinuityStore(conn), + user_id=user_id, + approval_id=approval_id, + ) + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/{approval_id}/approve") +def approve_approval(approval_id: UUID, request: ResolveApprovalRequest) -> JSONResponse: + settings = get_settings() + resolution_error: ( + ApprovalResolutionConflictError | TaskStepApprovalLinkageError | TaskStepLifecycleBoundaryError | None + ) = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = approve_approval_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ApprovalApproveInput(approval_id=approval_id), + ) + except ( + ApprovalResolutionConflictError, + TaskStepApprovalLinkageError, + TaskStepLifecycleBoundaryError, + ) as exc: + resolution_error = exc + payload = None + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if resolution_error is not None: + return JSONResponse(status_code=409, content={"detail": str(resolution_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/{approval_id}/reject") +def reject_approval(approval_id: UUID, request: ResolveApprovalRequest) -> JSONResponse: + settings = get_settings() + resolution_error: ( + ApprovalResolutionConflictError | TaskStepApprovalLinkageError | TaskStepLifecycleBoundaryError | None + ) = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = reject_approval_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ApprovalRejectInput(approval_id=approval_id), + ) + except ( + ApprovalResolutionConflictError, + TaskStepApprovalLinkageError, + TaskStepLifecycleBoundaryError, + ) as exc: + resolution_error = exc + payload = None + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if resolution_error is not None: + return JSONResponse(status_code=409, content={"detail": str(resolution_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/approvals/{approval_id}/execute") +def execute_approved_proxy(approval_id: UUID, request: ExecuteApprovedProxyRequest) -> JSONResponse: + settings = get_settings() + execution_error: ( + ProxyExecutionApprovalStateError + | ProxyExecutionHandlerNotFoundError + | TaskStepApprovalLinkageError + | TaskStepExecutionLinkageError + | None + ) = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = execute_approved_proxy_request( + ContinuityStore(conn), + user_id=request.user_id, + request=ProxyExecutionRequestInput(approval_id=approval_id), + ) + except ( + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, + TaskStepApprovalLinkageError, + TaskStepExecutionLinkageError, + ) as exc: + execution_error = exc + payload = None + except ApprovalNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if execution_error is not None: + return JSONResponse(status_code=409, content={"detail": str(execution_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tasks") +def list_tasks(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tasks/{task_id}") +def get_task(task_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_record( + ContinuityStore(conn), + user_id=user_id, + task_id=task_id, + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tasks/{task_id}/workspace") +def create_task_workspace(task_id: UUID, request: CreateTaskWorkspaceRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_task_workspace_record( + ContinuityStore(conn), + settings=settings, + user_id=request.user_id, + request=TaskWorkspaceCreateInput( + task_id=task_id, + status="active", + ), + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except (TaskWorkspaceAlreadyExistsError, TaskWorkspaceProvisioningError) as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-workspaces") +def list_task_workspaces(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_workspace_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-workspaces/{task_workspace_id}") +def get_task_workspace(task_workspace_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_workspace_record( + ContinuityStore(conn), + user_id=user_id, + task_workspace_id=task_workspace_id, + ) + except TaskWorkspaceNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tasks/{task_id}/steps") +def list_task_steps(task_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_task_step_records( + ContinuityStore(conn), + user_id=user_id, + task_id=task_id, + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/task-steps/{task_step_id}") +def get_task_step(task_step_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_task_step_record( + ContinuityStore(conn), + user_id=user_id, + task_step_id=task_step_id, + ) + except TaskStepNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/tasks/{task_id}/steps") +def create_next_task_step(task_id: UUID, request: CreateNextTaskStepRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_next_task_step_record( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskStepNextCreateInput( + task_id=task_id, + kind=request.kind, + status=request.status, + request=request.request.model_dump(mode="json"), + outcome=request.outcome.model_dump(mode="json"), + lineage=TaskStepLineageInput( + parent_step_id=request.lineage.parent_step_id, + source_approval_id=request.lineage.source_approval_id, + source_execution_id=request.lineage.source_execution_id, + ), + ), + ) + except TaskNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskStepSequenceError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/task-steps/{task_step_id}/transition") +def transition_task_step(task_step_id: UUID, request: TransitionTaskStepRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = transition_task_step_record( + ContinuityStore(conn), + user_id=request.user_id, + request=TaskStepTransitionInput( + task_step_id=task_step_id, + status=request.status, + outcome=request.outcome.model_dump(mode="json"), + ), + ) + except TaskStepNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + except TaskStepTransitionError as exc: + return JSONResponse(status_code=409, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/execution-budgets") +def create_execution_budget(request: CreateExecutionBudgetRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_execution_budget_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ExecutionBudgetCreateInput( + tool_key=request.tool_key, + domain_hint=request.domain_hint, + max_completed_executions=request.max_completed_executions, + rolling_window_seconds=request.rolling_window_seconds, + ), + ) + except ExecutionBudgetValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/execution-budgets") +def list_execution_budgets(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_execution_budget_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/execution-budgets/{execution_budget_id}") +def get_execution_budget(execution_budget_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_execution_budget_record( + ContinuityStore(conn), + user_id=user_id, + execution_budget_id=execution_budget_id, + ) + except ExecutionBudgetNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/execution-budgets/{execution_budget_id}/deactivate") +def deactivate_execution_budget( + execution_budget_id: UUID, + request: DeactivateExecutionBudgetRequest, +) -> JSONResponse: + settings = get_settings() + lifecycle_error: ExecutionBudgetLifecycleError | None = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = deactivate_execution_budget_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=request.thread_id, + execution_budget_id=execution_budget_id, + ), + ) + except ExecutionBudgetLifecycleError as exc: + lifecycle_error = exc + payload = None + except ExecutionBudgetValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except ExecutionBudgetNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if lifecycle_error is not None: + return JSONResponse(status_code=409, content={"detail": str(lifecycle_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/execution-budgets/{execution_budget_id}/supersede") +def supersede_execution_budget( + execution_budget_id: UUID, + request: SupersedeExecutionBudgetRequest, +) -> JSONResponse: + settings = get_settings() + lifecycle_error: ExecutionBudgetLifecycleError | None = None + + try: + with user_connection(settings.database_url, request.user_id) as conn: + try: + payload = supersede_execution_budget_record( + ContinuityStore(conn), + user_id=request.user_id, + request=ExecutionBudgetSupersedeInput( + thread_id=request.thread_id, + execution_budget_id=execution_budget_id, + max_completed_executions=request.max_completed_executions, + ), + ) + except ExecutionBudgetLifecycleError as exc: + lifecycle_error = exc + payload = None + except ExecutionBudgetValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except ExecutionBudgetNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + if lifecycle_error is not None: + return JSONResponse(status_code=409, content={"detail": str(lifecycle_error)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tool-executions") +def list_tool_executions(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_tool_execution_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tool-executions/{execution_id}") +def get_tool_execution(execution_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_tool_execution_record( + ContinuityStore(conn), + user_id=user_id, + execution_id=execution_id, + ) + except ToolExecutionNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/tools/{tool_id}") +def get_tool(tool_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_tool_record( + ContinuityStore(conn), + user_id=user_id, + tool_id=tool_id, + ) + except ToolNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memories/extract-explicit-preferences") +def extract_explicit_preferences(request: ExtractExplicitPreferencesRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = extract_and_admit_explicit_preferences( + ContinuityStore(conn), + user_id=request.user_id, + request=ExplicitPreferenceExtractionRequestInput( + source_event_id=request.source_event_id, + ), + ) + except ExplicitPreferenceExtractionValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + except MemoryAdmissionValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories") +def list_memories( + user_id: UUID, + status: MemoryReviewStatusFilter = Query(default="active"), + limit: int = Query(default=DEFAULT_MEMORY_REVIEW_LIMIT, ge=1, le=MAX_MEMORY_REVIEW_LIMIT), +) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_review_records( + ContinuityStore(conn), + user_id=user_id, + status=status, + limit=limit, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/review-queue") +def list_memory_review_queue( + user_id: UUID, + limit: int = Query(default=DEFAULT_MEMORY_REVIEW_LIMIT, ge=1, le=MAX_MEMORY_REVIEW_LIMIT), +) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_review_queue_records( + ContinuityStore(conn), + user_id=user_id, + limit=limit, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/evaluation-summary") +def get_memories_evaluation_summary(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = get_memory_evaluation_summary( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memories/semantic-retrieval") +def retrieve_semantic_memories(request: RetrieveSemanticMemoriesRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = retrieve_semantic_memory_records( + ContinuityStore(conn), + user_id=request.user_id, + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=request.embedding_config_id, + query_vector=tuple(request.query_vector), + limit=request.limit, + ), + ) + except SemanticMemoryRetrievalValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}") +def get_memory( + memory_id: UUID, + user_id: UUID, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_memory_review_record( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}/revisions") +def list_memory_revisions( + memory_id: UUID, + user_id: UUID, + limit: int = Query(default=DEFAULT_MEMORY_REVIEW_LIMIT, ge=1, le=MAX_MEMORY_REVIEW_LIMIT), +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_revision_review_records( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + limit=limit, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memories/{memory_id}/labels") +def create_memory_review_label( + memory_id: UUID, + request: CreateMemoryReviewLabelRequest, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_memory_review_label_record( + ContinuityStore(conn), + user_id=request.user_id, + memory_id=memory_id, + label=request.label, + note=request.note, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}/labels") +def list_memory_review_labels( + memory_id: UUID, + user_id: UUID, +) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_review_label_records( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + ) + except MemoryReviewNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/embedding-configs") +def create_embedding_config(request: CreateEmbeddingConfigRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_embedding_config_record( + ContinuityStore(conn), + user_id=request.user_id, + config=EmbeddingConfigCreateInput( + provider=request.provider, + model=request.model, + version=request.version, + dimensions=request.dimensions, + status=request.status, + metadata=request.metadata, + ), + ) + except EmbeddingConfigValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/embedding-configs") +def list_embedding_configs(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_embedding_config_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/memory-embeddings") +def upsert_memory_embedding(request: UpsertMemoryEmbeddingRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = upsert_memory_embedding_record( + ContinuityStore(conn), + user_id=request.user_id, + request=MemoryEmbeddingUpsertInput( + memory_id=request.memory_id, + embedding_config_id=request.embedding_config_id, + vector=tuple(request.vector), + ), + ) + except MemoryEmbeddingValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memories/{memory_id}/embeddings") +def list_memory_embeddings(memory_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_memory_embedding_records( + ContinuityStore(conn), + user_id=user_id, + memory_id=memory_id, + ) + except MemoryEmbeddingNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/memory-embeddings/{memory_embedding_id}") +def get_memory_embedding(memory_embedding_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_memory_embedding_record( + ContinuityStore(conn), + user_id=user_id, + memory_embedding_id=memory_embedding_id, + ) + except MemoryEmbeddingNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/entities") +def create_entity(request: CreateEntityRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_entity_record( + ContinuityStore(conn), + user_id=request.user_id, + entity=EntityCreateInput( + entity_type=request.entity_type, + name=request.name, + source_memory_ids=tuple(request.source_memory_ids), + ), + ) + except EntityValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.post("/v0/entity-edges") +def create_entity_edge(request: CreateEntityEdgeRequest) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, request.user_id) as conn: + payload = create_entity_edge_record( + ContinuityStore(conn), + user_id=request.user_id, + edge=EntityEdgeCreateInput( + from_entity_id=request.from_entity_id, + to_entity_id=request.to_entity_id, + relationship_type=request.relationship_type, + valid_from=request.valid_from, + valid_to=request.valid_to, + source_memory_ids=tuple(request.source_memory_ids), + ), + ) + except EntityEdgeValidationError as exc: + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + return JSONResponse( + status_code=201, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/entities") +def list_entities(user_id: UUID) -> JSONResponse: + settings = get_settings() + + with user_connection(settings.database_url, user_id) as conn: + payload = list_entity_records( + ContinuityStore(conn), + user_id=user_id, + ) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/entities/{entity_id}/edges") +def list_entity_edges(entity_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = list_entity_edge_records( + ContinuityStore(conn), + user_id=user_id, + entity_id=entity_id, + ) + except EntityNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) + + +@app.get("/v0/entities/{entity_id}") +def get_entity(entity_id: UUID, user_id: UUID) -> JSONResponse: + settings = get_settings() + + try: + with user_connection(settings.database_url, user_id) as conn: + payload = get_entity_record( + ContinuityStore(conn), + user_id=user_id, + entity_id=entity_id, + ) + except EntityNotFoundError as exc: + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + return JSONResponse( + status_code=200, + content=jsonable_encoder(payload), + ) diff --git a/apps/api/src/alicebot_api/memory.py b/apps/api/src/alicebot_api/memory.py new file mode 100644 index 0000000..3c5ebc3 --- /dev/null +++ b/apps/api/src/alicebot_api/memory.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +from uuid import UUID + +from alicebot_api.contracts import ( + AdmissionDecisionOutput, + DEFAULT_MEMORY_REVIEW_LIMIT, + MEMORY_REVIEW_LABEL_ORDER, + MEMORY_REVIEW_LABEL_VALUES, + MEMORY_REVIEW_QUEUE_ORDER, + MEMORY_REVISION_REVIEW_ORDER, + MEMORY_REVIEW_ORDER, + MemoryCandidateInput, + MemoryEvaluationSummary, + MemoryEvaluationSummaryResponse, + MemoryReviewLabelCounts, + MemoryReviewLabelCreateResponse, + MemoryReviewLabelListResponse, + MemoryReviewLabelRecord, + MemoryReviewLabelSummary, + MemoryReviewLabelValue, + MemoryReviewQueueItem, + MemoryReviewQueueResponse, + MemoryReviewQueueSummary, + MemoryRevisionReviewListResponse, + MemoryRevisionReviewListSummary, + MemoryRevisionReviewRecord, + MemoryReviewDetailResponse, + MemoryReviewListResponse, + MemoryReviewListSummary, + MemoryReviewRecord, + MemoryReviewStatusFilter, + PersistedMemoryRecord, + PersistedMemoryRevisionRecord, + isoformat_or_none, +) +from alicebot_api.store import ContinuityStore, JsonObject, LabelCountRow, MemoryReviewLabelRow, MemoryRevisionRow, MemoryRow + + +class MemoryAdmissionValidationError(ValueError): + """Raised when an admission request fails explicit candidate validation.""" + + +class MemoryReviewNotFoundError(LookupError): + """Raised when a requested memory is not visible inside the current user scope.""" + + +def _serialize_memory(memory: MemoryRow) -> PersistedMemoryRecord: + return { + "id": str(memory["id"]), + "user_id": str(memory["user_id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "deleted_at": isoformat_or_none(memory["deleted_at"]), + } + + +def _serialize_memory_revision(revision: MemoryRevisionRow) -> PersistedMemoryRevisionRecord: + return { + "id": str(revision["id"]), + "user_id": str(revision["user_id"]), + "memory_id": str(revision["memory_id"]), + "sequence_no": revision["sequence_no"], + "action": revision["action"], + "memory_key": revision["memory_key"], + "previous_value": revision["previous_value"], + "new_value": revision["new_value"], + "source_event_ids": revision["source_event_ids"], + "candidate": revision["candidate"], + "created_at": revision["created_at"].isoformat(), + } + + +def _serialize_memory_review(memory: MemoryRow) -> MemoryReviewRecord: + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + "deleted_at": isoformat_or_none(memory["deleted_at"]), + } + + +def _serialize_memory_review_queue_item(memory: MemoryRow) -> MemoryReviewQueueItem: + return { + "id": str(memory["id"]), + "memory_key": memory["memory_key"], + "value": memory["value"], + "status": memory["status"], + "source_event_ids": memory["source_event_ids"], + "created_at": memory["created_at"].isoformat(), + "updated_at": memory["updated_at"].isoformat(), + } + + +def _serialize_memory_revision_review(revision: MemoryRevisionRow) -> MemoryRevisionReviewRecord: + return { + "id": str(revision["id"]), + "memory_id": str(revision["memory_id"]), + "sequence_no": revision["sequence_no"], + "action": revision["action"], + "memory_key": revision["memory_key"], + "previous_value": revision["previous_value"], + "new_value": revision["new_value"], + "source_event_ids": revision["source_event_ids"], + "created_at": revision["created_at"].isoformat(), + } + + +def _serialize_memory_review_label(label: MemoryReviewLabelRow) -> MemoryReviewLabelRecord: + return { + "id": str(label["id"]), + "memory_id": str(label["memory_id"]), + "reviewer_user_id": str(label["user_id"]), + "label": label["label"], + "note": label["note"], + "created_at": label["created_at"].isoformat(), + } + + +def _empty_memory_review_label_counts() -> MemoryReviewLabelCounts: + return { + "correct": 0, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + } + + +def _summarize_memory_review_label_counts(rows: list[LabelCountRow]) -> MemoryReviewLabelCounts: + counts = _empty_memory_review_label_counts() + for row in rows: + label = row["label"] + if label in counts: + counts[label] = row["count"] + return counts + + +def _build_memory_review_label_summary( + *, + memory_id: UUID, + counts: MemoryReviewLabelCounts, +) -> MemoryReviewLabelSummary: + return { + "memory_id": str(memory_id), + "total_count": sum(counts.values()), + "counts_by_label": counts, + "order": list(MEMORY_REVIEW_LABEL_ORDER), + } + + +def _normalize_memory_status_filter(status: MemoryReviewStatusFilter) -> str | None: + if status == "all": + return None + return status + + +def list_memory_review_records( + store: ContinuityStore, + *, + user_id: UUID, + status: MemoryReviewStatusFilter = "active", + limit: int = DEFAULT_MEMORY_REVIEW_LIMIT, +) -> MemoryReviewListResponse: + del user_id + + normalized_status = _normalize_memory_status_filter(status) + total_count = store.count_memories(status=normalized_status) + memories = store.list_review_memories(status=normalized_status, limit=limit) + items = [_serialize_memory_review(memory) for memory in memories] + summary: MemoryReviewListSummary = { + "status": status, + "limit": limit, + "returned_count": len(items), + "total_count": total_count, + "has_more": len(items) < total_count, + "order": list(MEMORY_REVIEW_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def list_memory_review_queue_records( + store: ContinuityStore, + *, + user_id: UUID, + limit: int = DEFAULT_MEMORY_REVIEW_LIMIT, +) -> MemoryReviewQueueResponse: + del user_id + + total_count = store.count_unlabeled_review_memories() + memories = store.list_unlabeled_review_memories(limit=limit) + items = [_serialize_memory_review_queue_item(memory) for memory in memories] + summary: MemoryReviewQueueSummary = { + "memory_status": "active", + "review_state": "unlabeled", + "limit": limit, + "returned_count": len(items), + "total_count": total_count, + "has_more": len(items) < total_count, + "order": list(MEMORY_REVIEW_QUEUE_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_memory_review_record( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, +) -> MemoryReviewDetailResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + return { + "memory": _serialize_memory_review(memory), + } + + +def list_memory_revision_review_records( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, + limit: int = DEFAULT_MEMORY_REVIEW_LIMIT, +) -> MemoryRevisionReviewListResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + total_count = store.count_memory_revisions(memory_id) + revisions = store.list_memory_revisions(memory_id, limit=limit) + items = [_serialize_memory_revision_review(revision) for revision in revisions] + summary: MemoryRevisionReviewListSummary = { + "memory_id": str(memory["id"]), + "limit": limit, + "returned_count": len(items), + "total_count": total_count, + "has_more": len(items) < total_count, + "order": list(MEMORY_REVISION_REVIEW_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def create_memory_review_label_record( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, + label: MemoryReviewLabelValue, + note: str | None, +) -> MemoryReviewLabelCreateResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + created_label = store.create_memory_review_label( + memory_id=memory_id, + label=label, + note=note, + ) + counts = _summarize_memory_review_label_counts(store.list_memory_review_label_counts(memory_id)) + return { + "label": _serialize_memory_review_label(created_label), + "summary": _build_memory_review_label_summary(memory_id=memory_id, counts=counts), + } + + +def list_memory_review_label_records( + store: ContinuityStore, + *, + user_id: UUID, + memory_id: UUID, +) -> MemoryReviewLabelListResponse: + del user_id + + memory = store.get_memory_optional(memory_id) + if memory is None: + raise MemoryReviewNotFoundError(f"memory {memory_id} was not found") + + items = [_serialize_memory_review_label(label) for label in store.list_memory_review_labels(memory_id)] + counts = _summarize_memory_review_label_counts(store.list_memory_review_label_counts(memory_id)) + return { + "items": items, + "summary": _build_memory_review_label_summary(memory_id=memory_id, counts=counts), + } + + +def get_memory_evaluation_summary( + store: ContinuityStore, + *, + user_id: UUID, +) -> MemoryEvaluationSummaryResponse: + del user_id + + total_memory_count = store.count_memories() + active_memory_count = store.count_memories(status="active") + deleted_memory_count = store.count_memories(status="deleted") + labeled_memory_count = store.count_labeled_memories() + unlabeled_memory_count = store.count_unlabeled_memories() + label_row_counts = _summarize_memory_review_label_counts(store.list_all_memory_review_label_counts()) + summary: MemoryEvaluationSummary = { + "total_memory_count": total_memory_count, + "active_memory_count": active_memory_count, + "deleted_memory_count": deleted_memory_count, + "labeled_memory_count": labeled_memory_count, + "unlabeled_memory_count": unlabeled_memory_count, + "total_label_row_count": sum(label_row_counts.values()), + "label_row_counts_by_value": label_row_counts, + "label_value_order": list(MEMORY_REVIEW_LABEL_VALUES), + } + return { + "summary": summary, + } + + +def _dedupe_source_event_ids(source_event_ids: tuple[UUID, ...]) -> tuple[UUID, ...]: + deduped: list[UUID] = [] + seen: set[UUID] = set() + for source_event_id in source_event_ids: + if source_event_id in seen: + continue + seen.add(source_event_id) + deduped.append(source_event_id) + return tuple(deduped) + + +def _validate_source_events(store: ContinuityStore, source_event_ids: tuple[UUID, ...]) -> list[str]: + normalized_event_ids = _dedupe_source_event_ids(source_event_ids) + if not normalized_event_ids: + raise MemoryAdmissionValidationError( + "source_event_ids must include at least one existing event owned by the user" + ) + source_events = store.list_events_by_ids(list(normalized_event_ids)) + found_event_ids = {event["id"] for event in source_events} + missing_event_ids = [ + str(source_event_id) + for source_event_id in normalized_event_ids + if source_event_id not in found_event_ids + ] + if missing_event_ids: + raise MemoryAdmissionValidationError( + "source_event_ids must all reference existing events owned by the user: " + + ", ".join(missing_event_ids) + ) + return [str(source_event_id) for source_event_id in normalized_event_ids] + + +def _candidate_payload(candidate: MemoryCandidateInput) -> JsonObject: + return candidate.as_payload() + + +def admit_memory_candidate( + store: ContinuityStore, + *, + user_id: UUID, + candidate: MemoryCandidateInput, +) -> AdmissionDecisionOutput: + del user_id + + source_event_ids = _validate_source_events(store, candidate.source_event_ids) + existing_memory = store.get_memory_by_key(candidate.memory_key) + + noop_decision = AdmissionDecisionOutput( + action="NOOP", + reason="candidate_default_noop", + memory=None, + revision=None, + ) + + if candidate.delete_requested: + if existing_memory is None or existing_memory["status"] == "deleted": + return AdmissionDecisionOutput( + action=noop_decision.action, + reason="memory_not_found_for_delete", + memory=None if existing_memory is None else _serialize_memory(existing_memory), + revision=None, + ) + + memory = store.update_memory( + memory_id=existing_memory["id"], + value=existing_memory["value"], + status="deleted", + source_event_ids=source_event_ids, + ) + revision = store.append_memory_revision( + memory_id=memory["id"], + action="DELETE", + memory_key=memory["memory_key"], + previous_value=existing_memory["value"], + new_value=None, + source_event_ids=source_event_ids, + candidate=_candidate_payload(candidate), + ) + return AdmissionDecisionOutput( + action="DELETE", + reason="source_backed_delete", + memory=_serialize_memory(memory), + revision=_serialize_memory_revision(revision), + ) + + if candidate.value is None: + return AdmissionDecisionOutput( + action=noop_decision.action, + reason="candidate_value_missing", + memory=None if existing_memory is None else _serialize_memory(existing_memory), + revision=None, + ) + + if existing_memory is None: + memory = store.create_memory( + memory_key=candidate.memory_key, + value=candidate.value, + status="active", + source_event_ids=source_event_ids, + ) + revision = store.append_memory_revision( + memory_id=memory["id"], + action="ADD", + memory_key=memory["memory_key"], + previous_value=None, + new_value=candidate.value, + source_event_ids=source_event_ids, + candidate=_candidate_payload(candidate), + ) + return AdmissionDecisionOutput( + action="ADD", + reason="source_backed_add", + memory=_serialize_memory(memory), + revision=_serialize_memory_revision(revision), + ) + + if existing_memory["status"] == "active" and existing_memory["value"] == candidate.value: + return AdmissionDecisionOutput( + action=noop_decision.action, + reason="memory_unchanged", + memory=_serialize_memory(existing_memory), + revision=None, + ) + + memory = store.update_memory( + memory_id=existing_memory["id"], + value=candidate.value, + status="active", + source_event_ids=source_event_ids, + ) + revision = store.append_memory_revision( + memory_id=memory["id"], + action="UPDATE", + memory_key=memory["memory_key"], + previous_value=existing_memory["value"], + new_value=candidate.value, + source_event_ids=source_event_ids, + candidate=_candidate_payload(candidate), + ) + return AdmissionDecisionOutput( + action="UPDATE", + reason="source_backed_update", + memory=_serialize_memory(memory), + revision=_serialize_memory_revision(revision), + ) diff --git a/apps/api/src/alicebot_api/migrations.py b/apps/api/src/alicebot_api/migrations.py new file mode 100644 index 0000000..52a5d15 --- /dev/null +++ b/apps/api/src/alicebot_api/migrations.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from pathlib import Path + +from alembic.config import Config + + +PROJECT_ROOT = Path(__file__).resolve().parents[4] +ALEMBIC_INI_PATH = PROJECT_ROOT / "apps" / "api" / "alembic.ini" + + +def make_alembic_config(database_url: str | None = None) -> Config: + config = Config(str(ALEMBIC_INI_PATH)) + if database_url: + config.set_main_option("sqlalchemy.url", database_url) + return config + diff --git a/apps/api/src/alicebot_api/policy.py b/apps/api/src/alicebot_api/policy.py new file mode 100644 index 0000000..68988c4 --- /dev/null +++ b/apps/api/src/alicebot_api/policy.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from alicebot_api.contracts import ( + CONSENT_LIST_ORDER, + POLICY_EVALUATION_VERSION_V0, + POLICY_LIST_ORDER, + TRACE_KIND_POLICY_EVALUATE, + ConsentListResponse, + ConsentListSummary, + ConsentRecord, + ConsentUpsertInput, + ConsentUpsertResponse, + PolicyCreateInput, + PolicyCreateResponse, + PolicyDetailResponse, + PolicyEvaluationReason, + PolicyEvaluationRequestInput, + PolicyEvaluationResponse, + PolicyEvaluationSummary, + PolicyEvaluationTraceSummary, + PolicyListResponse, + PolicyListSummary, + PolicyRecord, + isoformat_or_none, +) +from alicebot_api.store import ConsentRow, ContinuityStore, PolicyRow + + +class PolicyValidationError(ValueError): + """Raised when a policy or consent request fails explicit validation.""" + + +class PolicyNotFoundError(LookupError): + """Raised when a requested policy is not visible inside the current user scope.""" + + +class PolicyEvaluationValidationError(ValueError): + """Raised when a policy-evaluation request fails explicit validation.""" + + +@dataclass(frozen=True, slots=True) +class PolicyEvaluationContext: + active_policies: tuple[PolicyRow, ...] + consents_by_key: dict[str, ConsentRow] + + +@dataclass(frozen=True, slots=True) +class PolicyEvaluationCoreDecision: + decision: str + matched_policy: PolicyRow | None + reasons: list[PolicyEvaluationReason] + + +def _serialize_consent(consent: ConsentRow) -> ConsentRecord: + return { + "id": str(consent["id"]), + "consent_key": consent["consent_key"], + "status": consent["status"], + "metadata": consent["metadata"], + "created_at": consent["created_at"].isoformat(), + "updated_at": consent["updated_at"].isoformat(), + } + + +def _serialize_policy(policy: PolicyRow) -> PolicyRecord: + return { + "id": str(policy["id"]), + "name": policy["name"], + "action": policy["action"], + "scope": policy["scope"], + "effect": policy["effect"], + "priority": policy["priority"], + "active": policy["active"], + "conditions": policy["conditions"], + "required_consents": policy["required_consents"], + "created_at": policy["created_at"].isoformat(), + "updated_at": policy["updated_at"].isoformat(), + } + + +def _dedupe_required_consents(required_consents: tuple[str, ...]) -> list[str]: + deduped: list[str] = [] + seen: set[str] = set() + for consent_key in required_consents: + if consent_key in seen: + continue + seen.add(consent_key) + deduped.append(consent_key) + return deduped + + +def _policy_matches(policy: PolicyRow, request: PolicyEvaluationRequestInput) -> bool: + if policy["action"] != request.action or policy["scope"] != request.scope: + return False + + conditions = policy["conditions"] + for key, expected_value in conditions.items(): + if key not in request.attributes: + return False + if request.attributes[key] != expected_value: + return False + + return True + + +def _build_reason( + *, + code: str, + source: str, + message: str, + policy_id: UUID | None = None, + consent_key: str | None = None, +) -> PolicyEvaluationReason: + return { + "code": code, + "source": source, + "message": message, + "policy_id": None if policy_id is None else str(policy_id), + "consent_key": consent_key, + } + + +def upsert_consent_record( + store: ContinuityStore, + *, + user_id: UUID, + consent: ConsentUpsertInput, +) -> ConsentUpsertResponse: + del user_id + + existing = store.get_consent_by_key_optional(consent.consent_key) + if existing is None: + created = store.create_consent( + consent_key=consent.consent_key, + status=consent.status, + metadata=consent.metadata, + ) + return { + "consent": _serialize_consent(created), + "write_mode": "created", + } + + updated = store.update_consent( + consent_id=existing["id"], + status=consent.status, + metadata=consent.metadata, + ) + return { + "consent": _serialize_consent(updated), + "write_mode": "updated", + } + + +def list_consent_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ConsentListResponse: + del user_id + + items = [_serialize_consent(consent) for consent in store.list_consents()] + summary: ConsentListSummary = { + "total_count": len(items), + "order": list(CONSENT_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def create_policy_record( + store: ContinuityStore, + *, + user_id: UUID, + policy: PolicyCreateInput, +) -> PolicyCreateResponse: + del user_id + + required_consents = _dedupe_required_consents(policy.required_consents) + created = store.create_policy( + name=policy.name, + action=policy.action, + scope=policy.scope, + effect=policy.effect, + priority=policy.priority, + active=policy.active, + conditions=policy.conditions, + required_consents=required_consents, + ) + return {"policy": _serialize_policy(created)} + + +def list_policy_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> PolicyListResponse: + del user_id + + items = [_serialize_policy(policy) for policy in store.list_policies()] + summary: PolicyListSummary = { + "total_count": len(items), + "order": list(POLICY_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_policy_record( + store: ContinuityStore, + *, + user_id: UUID, + policy_id: UUID, +) -> PolicyDetailResponse: + del user_id + + policy = store.get_policy_optional(policy_id) + if policy is None: + raise PolicyNotFoundError(f"policy {policy_id} was not found") + + return {"policy": _serialize_policy(policy)} + + +def load_policy_evaluation_context(store: ContinuityStore) -> PolicyEvaluationContext: + return PolicyEvaluationContext( + active_policies=tuple(store.list_active_policies()), + consents_by_key={consent["consent_key"]: consent for consent in store.list_consents()}, + ) + + +def evaluate_policy_against_context( + context: PolicyEvaluationContext, + *, + request: PolicyEvaluationRequestInput, +) -> PolicyEvaluationCoreDecision: + matched_policy = next( + (policy for policy in context.active_policies if _policy_matches(policy, request)), + None, + ) + + reasons: list[PolicyEvaluationReason] = [] + decision = "deny" + + if matched_policy is None: + reasons.append( + _build_reason( + code="no_matching_policy", + source="system", + message="No active policy matched the requested action, scope, and attributes.", + ) + ) + return PolicyEvaluationCoreDecision( + decision=decision, + matched_policy=None, + reasons=reasons, + ) + + reasons.append( + _build_reason( + code="matched_policy", + source="policy", + message=f"Matched policy '{matched_policy['name']}' at priority {matched_policy['priority']}.", + policy_id=matched_policy["id"], + ) + ) + + missing_or_revoked = False + for consent_key in matched_policy["required_consents"]: + consent = context.consents_by_key.get(consent_key) + if consent is None: + missing_or_revoked = True + reasons.append( + _build_reason( + code="consent_missing", + source="consent", + message=f"Required consent '{consent_key}' is missing.", + policy_id=matched_policy["id"], + consent_key=consent_key, + ) + ) + continue + if consent["status"] != "granted": + missing_or_revoked = True + reasons.append( + _build_reason( + code="consent_revoked", + source="consent", + message=f"Required consent '{consent_key}' is not granted (status={consent['status']}).", + policy_id=matched_policy["id"], + consent_key=consent_key, + ) + ) + + if not missing_or_revoked: + decision = matched_policy["effect"] + effect_code = { + "allow": "policy_effect_allow", + "deny": "policy_effect_deny", + "require_approval": "policy_effect_require_approval", + }[decision] + reasons.append( + _build_reason( + code=effect_code, + source="policy", + message=f"Policy effect resolved the decision to '{decision}'.", + policy_id=matched_policy["id"], + ) + ) + + return PolicyEvaluationCoreDecision( + decision=decision, + matched_policy=matched_policy, + reasons=reasons, + ) + + +def evaluate_policy_request( + store: ContinuityStore, + *, + user_id: UUID, + request: PolicyEvaluationRequestInput, +) -> PolicyEvaluationResponse: + del user_id + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise PolicyEvaluationValidationError( + "thread_id must reference an existing thread owned by the user" + ) + + context = load_policy_evaluation_context(store) + core_decision = evaluate_policy_against_context( + context, + request=request, + ) + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_POLICY_EVALUATE, + compiler_version=POLICY_EVALUATION_VERSION_V0, + status="completed", + limits={ + "order": list(POLICY_LIST_ORDER), + "active_policy_count": len(context.active_policies), + "consent_count": len(context.consents_by_key), + }, + ) + + trace_events = [ + ( + "policy.evaluate.request", + { + "thread_id": str(request.thread_id), + "action": request.action, + "scope": request.scope, + "attributes": request.attributes, + }, + ), + ( + "policy.evaluate.order", + { + "order": list(POLICY_LIST_ORDER), + "policy_ids": [str(policy["id"]) for policy in context.active_policies], + }, + ), + ( + "policy.evaluate.decision", + { + "decision": core_decision.decision, + "matched_policy_id": ( + None if core_decision.matched_policy is None else str(core_decision.matched_policy["id"]) + ), + "reasons": core_decision.reasons, + "evaluated_policy_count": len(context.active_policies), + "consent_states": { + consent_key: { + "status": consent["status"], + "updated_at": isoformat_or_none(consent["updated_at"]), + } + for consent_key, consent in context.consents_by_key.items() + }, + }, + ), + ] + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + evaluation: PolicyEvaluationSummary = { + "action": request.action, + "scope": request.scope, + "evaluated_policy_count": len(context.active_policies), + "matched_policy_id": ( + None if core_decision.matched_policy is None else str(core_decision.matched_policy["id"]) + ), + "order": list(POLICY_LIST_ORDER), + } + trace_summary: PolicyEvaluationTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "decision": core_decision.decision, + "matched_policy": ( + None if core_decision.matched_policy is None else _serialize_policy(core_decision.matched_policy) + ), + "reasons": core_decision.reasons, + "evaluation": evaluation, + "trace": trace_summary, + } diff --git a/apps/api/src/alicebot_api/proxy_execution.py b/apps/api/src/alicebot_api/proxy_execution.py new file mode 100644 index 0000000..ce0a54c --- /dev/null +++ b/apps/api/src/alicebot_api/proxy_execution.py @@ -0,0 +1,557 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import cast +from uuid import UUID + +from alicebot_api.approvals import ApprovalNotFoundError, serialize_approval_row +from alicebot_api.contracts import ( + PROXY_EXECUTION_VERSION_V0, + EXECUTION_BUDGET_MATCH_ORDER, + TRACE_KIND_PROXY_EXECUTE, + ApprovalRecord, + ProxyExecutionApprovalTracePayload, + ProxyExecutionBudgetPrecheckTracePayload, + ProxyExecutionDispatchTracePayload, + ProxyExecutionEventSummary, + ProxyExecutionRequestEventPayload, + ProxyExecutionRequestInput, + ProxyExecutionResponse, + ProxyExecutionResultEventPayload, + ProxyExecutionResultRecord, + ProxyExecutionStatus, + ProxyExecutionSummaryTracePayload, + ProxyExecutionTraceSummary, + ToolRecord, + ToolExecutionCreateInput, + ToolExecutionResultRecord, + ToolRoutingRequestRecord, +) +from alicebot_api.execution_budgets import evaluate_execution_budget +from alicebot_api.store import ContinuityStore, JsonObject, ToolExecutionRow +from alicebot_api.tasks import ( + validate_linked_task_step_for_approval, + sync_task_step_with_execution, + sync_task_with_execution, + task_lifecycle_trace_events, + task_step_lifecycle_trace_events, +) + +PROXY_EXECUTION_REQUEST_EVENT_KIND = "tool.proxy.execution.request" +PROXY_EXECUTION_RESULT_EVENT_KIND = "tool.proxy.execution.result" + + +class ProxyExecutionApprovalStateError(RuntimeError): + """Raised when an approval is visible but not executable in its current state.""" + + +class ProxyExecutionHandlerNotFoundError(RuntimeError): + """Raised when an approved tool has no registered proxy handler.""" + + +ProxyHandler = Callable[[ToolRoutingRequestRecord, ToolRecord], ProxyExecutionResultRecord] + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _proxy_echo_handler( + request: ToolRoutingRequestRecord, + tool: ToolRecord, +) -> ProxyExecutionResultRecord: + output: JsonObject = { + "mode": "no_side_effect", + "tool_key": tool["tool_key"], + "action": request["action"], + "scope": request["scope"], + "domain_hint": request["domain_hint"], + "risk_hint": request["risk_hint"], + "attributes": request["attributes"], + } + return { + "handler_key": "proxy.echo", + "status": "completed", + "output": output, + } + + +REGISTERED_PROXY_HANDLERS: dict[str, ProxyHandler] = { + "proxy.echo": _proxy_echo_handler, +} + + +def registered_proxy_handler_keys() -> tuple[str, ...]: + return tuple(sorted(REGISTERED_PROXY_HANDLERS)) + + +def _trace_summary(trace_id: UUID, trace_events: list[tuple[str, dict[str, object]]]) -> ProxyExecutionTraceSummary: + return { + "trace_id": str(trace_id), + "trace_event_count": len(trace_events), + } + + +def _blocked_state_error(*, approval: ApprovalRecord) -> ProxyExecutionApprovalStateError: + return ProxyExecutionApprovalStateError( + f"approval {approval['id']} is {approval['status']} and cannot be executed" + ) + + +def _missing_handler_error(*, tool: ToolRecord) -> ProxyExecutionHandlerNotFoundError: + return ProxyExecutionHandlerNotFoundError( + f"tool '{tool['tool_key']}' has no registered proxy handler" + ) + + +def _tool_execution_result( + *, + handler_key: str | None, + status: ProxyExecutionStatus, + output: JsonObject | None, + reason: str | None, + budget_decision: dict[str, object] | None = None, +) -> ToolExecutionResultRecord: + payload: ToolExecutionResultRecord = { + "handler_key": handler_key, + "status": status, + "output": output, + "reason": reason, + } + if budget_decision is not None: + payload["budget_decision"] = cast(dict[str, object], budget_decision) + return payload + + +def _persist_tool_execution( + store: ContinuityStore, + *, + approval_row: dict[str, object], + task_step_id: UUID, + trace_id: UUID, + handler_key: str | None, + request: ToolRoutingRequestRecord, + tool: ToolRecord, + result: ToolExecutionResultRecord, + request_event_id: UUID | None, + result_event_id: UUID | None, +) -> ToolExecutionRow: + execution = ToolExecutionCreateInput( + approval_id=cast(UUID, approval_row["id"]), + task_step_id=task_step_id, + thread_id=cast(UUID, approval_row["thread_id"]), + tool_id=cast(UUID, approval_row["tool_id"]), + trace_id=trace_id, + request_event_id=request_event_id, + result_event_id=result_event_id, + status=result["status"], + handler_key=handler_key, + request=request, + tool=tool, + result=result, + ) + return store.create_tool_execution( + approval_id=execution.approval_id, + task_step_id=execution.task_step_id, + thread_id=execution.thread_id, + tool_id=execution.tool_id, + trace_id=execution.trace_id, + request_event_id=execution.request_event_id, + result_event_id=execution.result_event_id, + status=execution.status, + handler_key=execution.handler_key, + request=cast(JsonObject, execution.request), + tool=cast(JsonObject, execution.tool), + result=cast(JsonObject, execution.result), + ) + + +def execute_approved_proxy_request( + store: ContinuityStore, + *, + user_id: UUID, + request: ProxyExecutionRequestInput, +) -> ProxyExecutionResponse: + del user_id + + approval_row = store.get_approval_optional(request.approval_id) + if approval_row is None: + raise ApprovalNotFoundError(f"approval {request.approval_id} was not found") + _, linked_task_step = validate_linked_task_step_for_approval( + store, + approval_id=request.approval_id, + task_step_id=cast(UUID | None, approval_row["task_step_id"]), + ) + + approval = serialize_approval_row(approval_row) + linked_task_step_id = cast(str, approval["task_step_id"]) + tool = cast(ToolRecord, approval["tool"]) + routed_request = cast(ToolRoutingRequestRecord, approval["request"]) + handler = REGISTERED_PROXY_HANDLERS.get(tool["tool_key"]) + + trace = store.create_trace( + user_id=approval_row["user_id"], + thread_id=approval_row["thread_id"], + kind=TRACE_KIND_PROXY_EXECUTE, + compiler_version=PROXY_EXECUTION_VERSION_V0, + status="completed", + limits={ + "approval_status": approval["status"], + "enabled_handler_keys": list(registered_proxy_handler_keys()), + "budget_match_order": list(EXECUTION_BUDGET_MATCH_ORDER), + }, + ) + + approval_trace_payload: ProxyExecutionApprovalTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "approval_status": approval["status"], + "eligible_for_execution": approval["status"] == "approved", + } + + trace_events: list[tuple[str, dict[str, object]]] = [ + ( + "tool.proxy.execute.request", + { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + }, + ), + ("tool.proxy.execute.approval", cast(dict[str, object], approval_trace_payload)), + ] + + if approval["status"] != "approved": + error = _blocked_state_error(approval=approval) + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": None, + "dispatch_status": "blocked", + "reason": str(error), + "result_status": None, + "output": None, + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "blocked", + "handler_key": None, + "request_event_id": None, + "result_event_id": None, + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + raise error + + budget_decision = evaluate_execution_budget( + store, + tool=tool, + request=routed_request, + ) + budget_trace_payload: ProxyExecutionBudgetPrecheckTracePayload = budget_decision.record + trace_events.append( + ("tool.proxy.execute.budget", cast(dict[str, object], budget_trace_payload)) + ) + + if budget_decision.blocked_result is not None: + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": None, + "dispatch_status": "blocked", + "reason": budget_decision.blocked_result["reason"], + "result_status": budget_decision.blocked_result["status"], + "output": budget_decision.blocked_result["output"], + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "blocked", + "handler_key": None, + "request_event_id": None, + "result_event_id": None, + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + execution = _persist_tool_execution( + store, + approval_row=cast(dict[str, object], approval_row), + task_step_id=cast(UUID, linked_task_step["id"]), + trace_id=trace["id"], + handler_key=None, + request=routed_request, + tool=tool, + result=budget_decision.blocked_result, + request_event_id=None, + result_event_id=None, + ) + task_transition = sync_task_with_execution( + store, + approval_id=cast(UUID, approval_row["id"]), + execution_id=execution["id"], + execution_status=execution["status"], + ) + task_step_transition = sync_task_step_with_execution( + store, + task_id=UUID(task_transition.task["id"]), + execution=execution, + trace_id=trace["id"], + trace_kind=TRACE_KIND_PROXY_EXECUTE, + ) + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="proxy_execution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="proxy_execution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + return { + "request": { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + }, + "approval": approval, + "tool": tool, + "result": budget_decision.blocked_result, + "events": None, + "trace": _trace_summary(trace["id"], trace_events), + } + + if handler is None: + error = _missing_handler_error(tool=tool) + result = _tool_execution_result( + handler_key=None, + status="blocked", + output=None, + reason=str(error), + ) + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": None, + "dispatch_status": "blocked", + "reason": str(error), + "result_status": result["status"], + "output": None, + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "blocked", + "handler_key": None, + "request_event_id": None, + "result_event_id": None, + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + execution = _persist_tool_execution( + store, + approval_row=cast(dict[str, object], approval_row), + task_step_id=cast(UUID, linked_task_step["id"]), + trace_id=trace["id"], + handler_key=None, + request=routed_request, + tool=tool, + result=result, + request_event_id=None, + result_event_id=None, + ) + task_transition = sync_task_with_execution( + store, + approval_id=cast(UUID, approval_row["id"]), + execution_id=execution["id"], + execution_status=execution["status"], + ) + task_step_transition = sync_task_step_with_execution( + store, + task_id=UUID(task_transition.task["id"]), + execution=execution, + trace_id=trace["id"], + trace_kind=TRACE_KIND_PROXY_EXECUTE, + ) + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="proxy_execution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="proxy_execution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + raise error + + request_event_payload: ProxyExecutionRequestEventPayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "request": routed_request, + } + request_event = store.append_event( + approval_row["thread_id"], + None, + PROXY_EXECUTION_REQUEST_EVENT_KIND, + cast(JsonObject, request_event_payload), + ) + + result = handler(routed_request, tool) + result_event_payload: ProxyExecutionResultEventPayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": result["handler_key"], + "status": result["status"], + "output": result["output"], + } + result_event = store.append_event( + approval_row["thread_id"], + None, + PROXY_EXECUTION_RESULT_EVENT_KIND, + cast(JsonObject, result_event_payload), + ) + execution = _persist_tool_execution( + store, + approval_row=cast(dict[str, object], approval_row), + task_step_id=cast(UUID, linked_task_step["id"]), + trace_id=trace["id"], + handler_key=result["handler_key"], + request=routed_request, + tool=tool, + result=_tool_execution_result( + handler_key=result["handler_key"], + status=result["status"], + output=result["output"], + reason=None, + ), + request_event_id=request_event["id"], + result_event_id=result_event["id"], + ) + + events: ProxyExecutionEventSummary = { + "request_event_id": str(request_event["id"]), + "request_sequence_no": request_event["sequence_no"], + "result_event_id": str(result_event["id"]), + "result_sequence_no": result_event["sequence_no"], + } + dispatch_payload: ProxyExecutionDispatchTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "handler_key": result["handler_key"], + "dispatch_status": "executed", + "reason": None, + "result_status": result["status"], + "output": result["output"], + } + summary_payload: ProxyExecutionSummaryTracePayload = { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + "tool_id": tool["id"], + "tool_key": tool["tool_key"], + "approval_status": approval["status"], + "execution_status": "completed", + "handler_key": result["handler_key"], + "request_event_id": events["request_event_id"], + "result_event_id": events["result_event_id"], + } + trace_events.extend( + [ + ("tool.proxy.execute.dispatch", cast(dict[str, object], dispatch_payload)), + ("tool.proxy.execute.summary", cast(dict[str, object], summary_payload)), + ] + ) + task_transition = sync_task_with_execution( + store, + approval_id=cast(UUID, approval_row["id"]), + execution_id=execution["id"], + execution_status=execution["status"], + ) + task_step_transition = sync_task_step_with_execution( + store, + task_id=UUID(task_transition.task["id"]), + execution=execution, + trace_id=trace["id"], + trace_kind=TRACE_KIND_PROXY_EXECUTE, + ) + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="proxy_execution", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step_transition.task_step, + previous_status=task_step_transition.previous_status, + source="proxy_execution", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + return { + "request": { + "approval_id": approval["id"], + "task_step_id": linked_task_step_id, + }, + "approval": approval, + "tool": tool, + "result": result, + "events": events, + "trace": _trace_summary(trace["id"], trace_events), + } diff --git a/apps/api/src/alicebot_api/response_generation.py b/apps/api/src/alicebot_api/response_generation.py new file mode 100644 index 0000000..7652a5d --- /dev/null +++ b/apps/api/src/alicebot_api/response_generation.py @@ -0,0 +1,474 @@ +from __future__ import annotations + +from dataclasses import dataclass +import hashlib +import json +from typing import Any, TypedDict, cast +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen +from uuid import UUID + +from alicebot_api.compiler import compile_and_persist_trace +from alicebot_api.config import Settings +from alicebot_api.contracts import ( + AssistantResponseEventPayload, + CompiledContextPack, + ContextCompilerLimits, + GenerateResponseSuccess, + ModelInvocationRequest, + ModelInvocationResponse, + ModelUsagePayload, + PROMPT_ASSEMBLY_VERSION_V0, + PromptAssemblyInput, + PromptAssemblyResult, + PromptAssemblyTracePayload, + PromptSection, + RESPONSE_GENERATION_VERSION_V0, + ResponseTraceSummary, + TRACE_KIND_RESPONSE_GENERATE, + TraceEventRecord, +) +from alicebot_api.store import ContinuityStore, JsonObject + +PROMPT_TRACE_EVENT_KIND = "response.prompt.assembled" +MODEL_COMPLETED_TRACE_EVENT_KIND = "response.model.completed" +MODEL_FAILED_TRACE_EVENT_KIND = "response.model.failed" +SYSTEM_INSTRUCTION = ( + "You are AliceBot. Reply to the latest user message using the provided durable context. " + "If the context is insufficient, say so briefly instead of inventing facts." +) +DEVELOPER_INSTRUCTION = ( + "Treat the CONTEXT and CONVERSATION sections as authoritative durable state. " + "Do not call tools, do not describe hidden chain-of-thought, and keep the reply concise." +) + + +class ModelInvocationError(RuntimeError): + """Raised when the configured model provider cannot produce a response.""" + + +@dataclass(frozen=True, slots=True) +class ResponseFailure: + detail: str + trace: ResponseTraceSummary + + +class _OpenAIResponseContentItem(TypedDict, total=False): + type: str + text: str + + +class _OpenAIResponseOutputItem(TypedDict, total=False): + type: str + content: list[_OpenAIResponseContentItem] + + +class _OpenAIResponseUsage(TypedDict, total=False): + input_tokens: int | None + output_tokens: int | None + total_tokens: int | None + + +class _OpenAIResponsePayload(TypedDict, total=False): + id: str + status: str + output: list[_OpenAIResponseOutputItem] + usage: _OpenAIResponseUsage + + +def _deterministic_json(value: JsonObject | list[object]) -> str: + return json.dumps(value, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + + +def _context_section_payload(context_pack: CompiledContextPack) -> JsonObject: + return { + "compiler_version": context_pack["compiler_version"], + "scope": context_pack["scope"], + "limits": context_pack["limits"], + "user": context_pack["user"], + "thread": context_pack["thread"], + "sessions": context_pack["sessions"], + "memories": context_pack["memories"], + "memory_summary": context_pack["memory_summary"], + "entities": context_pack["entities"], + "entity_summary": context_pack["entity_summary"], + "entity_edges": context_pack["entity_edges"], + "entity_edge_summary": context_pack["entity_edge_summary"], + } + + +def assemble_prompt( + *, + request: PromptAssemblyInput, + compile_trace_id: str, +) -> PromptAssemblyResult: + sections = ( + PromptSection(name="system", content=request.system_instruction), + PromptSection(name="developer", content=request.developer_instruction), + PromptSection( + name="context", + content=_deterministic_json(_context_section_payload(request.context_pack)), + ), + PromptSection( + name="conversation", + content=_deterministic_json({"events": request.context_pack["events"]}), + ), + ) + prompt_text = "\n\n".join( + f"[{section.name.upper()}]\n{section.content}" for section in sections + ) + prompt_sha256 = hashlib.sha256(prompt_text.encode("utf-8")).hexdigest() + trace_payload: PromptAssemblyTracePayload = { + "version": PROMPT_ASSEMBLY_VERSION_V0, + "compile_trace_id": compile_trace_id, + "compiler_version": request.context_pack["compiler_version"], + "prompt_sha256": prompt_sha256, + "prompt_char_count": len(prompt_text), + "section_order": [section.name for section in sections], + "section_characters": {section.name: len(section.content) for section in sections}, + "included_session_count": len(request.context_pack["sessions"]), + "included_event_count": len(request.context_pack["events"]), + "included_memory_count": len(request.context_pack["memories"]), + "included_entity_count": len(request.context_pack["entities"]), + "included_entity_edge_count": len(request.context_pack["entity_edges"]), + } + return PromptAssemblyResult( + sections=sections, + prompt_text=prompt_text, + prompt_sha256=prompt_sha256, + trace_payload=trace_payload, + ) + + +def _openai_input_message(role: str, content: str) -> JsonObject: + return { + "role": role, + "content": [{"type": "input_text", "text": content}], + } + + +def _build_openai_responses_payload(request: ModelInvocationRequest) -> JsonObject: + sections = {section.name: section.content for section in request.prompt.sections} + return { + "model": request.model, + "store": request.store, + "tool_choice": request.tool_choice, + "tools": [], + "input": [ + _openai_input_message("system", sections["system"]), + _openai_input_message("developer", sections["developer"]), + _openai_input_message("user", f"[CONTEXT]\n{sections['context']}"), + _openai_input_message("user", f"[CONVERSATION]\n{sections['conversation']}"), + ], + "text": {"format": {"type": "text"}}, + } + + +def _extract_output_text(response_payload: _OpenAIResponsePayload) -> str: + output_items = response_payload.get("output", []) + for output_item in output_items: + if output_item.get("type") != "message": + continue + for content_item in output_item.get("content", []): + if content_item.get("type") == "output_text": + text = content_item.get("text") + if isinstance(text, str) and text: + return text + raise ModelInvocationError("model response did not include assistant output text") + + +def _parse_usage(response_payload: _OpenAIResponsePayload) -> ModelUsagePayload: + usage = response_payload.get("usage", {}) + if not isinstance(usage, dict): + return {"input_tokens": None, "output_tokens": None, "total_tokens": None} + return { + "input_tokens": usage.get("input_tokens"), + "output_tokens": usage.get("output_tokens"), + "total_tokens": usage.get("total_tokens"), + } + + +def _parse_openai_response_payload(raw_payload: bytes) -> _OpenAIResponsePayload: + try: + parsed_payload = json.loads(raw_payload) + except json.JSONDecodeError as exc: + raise ModelInvocationError("model provider returned invalid JSON") from exc + + if not isinstance(parsed_payload, dict): + raise ModelInvocationError("model provider returned invalid JSON") + + return cast(_OpenAIResponsePayload, parsed_payload) + + +def _extract_http_error_detail(exc: HTTPError) -> str | None: + raw_body = exc.read().decode("utf-8", errors="replace") + try: + parsed_error = json.loads(raw_body) + except json.JSONDecodeError: + return None + + if not isinstance(parsed_error, dict): + return None + + error = parsed_error.get("error", {}) + if not isinstance(error, dict): + return None + + detail = error.get("message") + if isinstance(detail, str) and detail: + return detail + return None + + +def _build_model_http_request(*, settings: Settings, payload: JsonObject) -> Request: + endpoint = settings.model_base_url.rstrip("/") + "/responses" + return Request( + endpoint, + data=json.dumps(payload).encode("utf-8"), + headers={ + "Authorization": f"Bearer {settings.model_api_key}", + "Content-Type": "application/json", + }, + method="POST", + ) + + +def _model_failure_trace_payload( + *, + request: ModelInvocationRequest, + error_message: str, +) -> JsonObject: + return { + "provider": request.provider, + "model": request.model, + "tool_choice": "none", + "tools_enabled": False, + "response_id": None, + "finish_reason": "incomplete", + "output_text_char_count": 0, + "usage": { + "input_tokens": None, + "output_tokens": None, + "total_tokens": None, + }, + "error_message": error_message, + } + + +def _create_linked_response_trace( + *, + store: ContinuityStore, + user_id: UUID, + thread_id: UUID, + limits: ContextCompilerLimits, + compiled_trace_id: str, + compiled_trace_event_count: int, + status: str, + trace_events: list[TraceEventRecord], +) -> ResponseTraceSummary: + trace = _create_response_trace( + store=store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + status=status, + trace_events=trace_events, + ) + trace["compile_trace_id"] = compiled_trace_id + trace["compile_trace_event_count"] = compiled_trace_event_count + return trace + + +def invoke_model( + *, + settings: Settings, + request: ModelInvocationRequest, +) -> ModelInvocationResponse: + if request.provider != "openai_responses": + raise ModelInvocationError(f"unsupported model provider: {request.provider}") + if not settings.model_api_key: + raise ModelInvocationError("MODEL_API_KEY is not configured") + + payload = _build_openai_responses_payload(request) + http_request = _build_model_http_request(settings=settings, payload=payload) + + try: + with urlopen(http_request, timeout=settings.model_timeout_seconds) as response: + raw_payload = response.read() + except HTTPError as exc: + detail = _extract_http_error_detail(exc) + if detail is not None: + raise ModelInvocationError(detail) from exc + raise ModelInvocationError(f"model provider returned HTTP {exc.code}") from exc + except URLError as exc: + raise ModelInvocationError(f"model provider request failed: {exc.reason}") from exc + + response_payload = _parse_openai_response_payload(raw_payload) + output_text = _extract_output_text(response_payload) + finish_reason = "completed" if response_payload.get("status") == "completed" else "incomplete" + return ModelInvocationResponse( + provider=request.provider, + model=request.model, + response_id=response_payload.get("id"), + finish_reason=finish_reason, + output_text=output_text, + usage=_parse_usage(response_payload), + ) + + +def build_assistant_response_payload( + *, + prompt: PromptAssemblyResult, + model_response: ModelInvocationResponse, +) -> AssistantResponseEventPayload: + return { + "text": model_response.output_text, + "model": { + "provider": model_response.provider, + "model": model_response.model, + "response_id": model_response.response_id, + "finish_reason": model_response.finish_reason, + "usage": model_response.usage, + }, + "prompt": { + "assembly_version": PROMPT_ASSEMBLY_VERSION_V0, + "prompt_sha256": prompt.prompt_sha256, + "section_order": [section.name for section in prompt.sections], + }, + } + + +def _create_response_trace( + *, + store: ContinuityStore, + user_id: UUID, + thread_id: UUID, + limits: ContextCompilerLimits, + status: str, + trace_events: list[TraceEventRecord], +) -> ResponseTraceSummary: + trace = store.create_trace( + user_id=user_id, + thread_id=thread_id, + kind=TRACE_KIND_RESPONSE_GENERATE, + compiler_version=RESPONSE_GENERATION_VERSION_V0, + status=status, + limits=limits.as_payload(), + ) + for sequence_no, trace_event in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=trace_event.kind, + payload=trace_event.payload, + ) + return { + "compile_trace_id": "", + "compile_trace_event_count": 0, + "response_trace_id": str(trace["id"]), + "response_trace_event_count": len(trace_events), + } + + +def generate_response( + *, + store: ContinuityStore, + settings: Settings, + user_id: UUID, + thread_id: UUID, + message_text: str, + limits: ContextCompilerLimits, +) -> GenerateResponseSuccess | ResponseFailure: + store.get_user(user_id) + store.get_thread(thread_id) + + store.append_event( + thread_id, + None, + "message.user", + {"text": message_text}, + ) + compiled_trace = compile_and_persist_trace( + store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + ) + prompt = assemble_prompt( + request=PromptAssemblyInput( + context_pack=compiled_trace.context_pack, + system_instruction=SYSTEM_INSTRUCTION, + developer_instruction=DEVELOPER_INSTRUCTION, + ), + compile_trace_id=compiled_trace.trace_id, + ) + request = ModelInvocationRequest( + provider=settings.model_provider, # type: ignore[arg-type] + model=settings.model_name, + prompt=prompt, + ) + prompt_trace_event = TraceEventRecord( + kind=PROMPT_TRACE_EVENT_KIND, + payload=prompt.trace_payload, + ) + + try: + model_response = invoke_model(settings=settings, request=request) + except ModelInvocationError as exc: + trace = _create_linked_response_trace( + store=store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + compiled_trace_id=compiled_trace.trace_id, + compiled_trace_event_count=compiled_trace.trace_event_count, + status="failed", + trace_events=[ + prompt_trace_event, + TraceEventRecord( + kind=MODEL_FAILED_TRACE_EVENT_KIND, + payload=_model_failure_trace_payload( + request=request, + error_message=str(exc), + ), + ), + ], + ) + return ResponseFailure(detail=str(exc), trace=trace) + + assistant_payload = build_assistant_response_payload( + prompt=prompt, + model_response=model_response, + ) + assistant_event = store.append_event( + thread_id, + None, + "message.assistant", + assistant_payload, + ) + trace = _create_linked_response_trace( + store=store, + user_id=user_id, + thread_id=thread_id, + limits=limits, + compiled_trace_id=compiled_trace.trace_id, + compiled_trace_event_count=compiled_trace.trace_event_count, + status="completed", + trace_events=[ + prompt_trace_event, + TraceEventRecord( + kind=MODEL_COMPLETED_TRACE_EVENT_KIND, + payload=model_response.to_trace_payload(), + ), + ], + ) + return { + "assistant": { + "event_id": str(assistant_event["id"]), + "sequence_no": assistant_event["sequence_no"], + "text": model_response.output_text, + "model_provider": model_response.provider, + "model": model_response.model, + }, + "trace": trace, + } diff --git a/apps/api/src/alicebot_api/semantic_retrieval.py b/apps/api/src/alicebot_api/semantic_retrieval.py new file mode 100644 index 0000000..5384e3d --- /dev/null +++ b/apps/api/src/alicebot_api/semantic_retrieval.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import math +from uuid import UUID + +from alicebot_api.contracts import ( + SEMANTIC_MEMORY_RETRIEVAL_ORDER, + SemanticMemoryRetrievalRequestInput, + SemanticMemoryRetrievalResponse, + SemanticMemoryRetrievalResultItem, + SemanticMemoryRetrievalSummary, +) +from alicebot_api.store import ContinuityStore, SemanticMemoryRetrievalRow + + +class SemanticMemoryRetrievalValidationError(ValueError): + """Raised when semantic memory retrieval fails explicit validation.""" + + +def _validate_query_vector(query_vector: tuple[float, ...]) -> list[float]: + if not query_vector: + raise SemanticMemoryRetrievalValidationError( + "query_vector must include at least one numeric value" + ) + + normalized: list[float] = [] + for value in query_vector: + normalized_value = float(value) + if not math.isfinite(normalized_value): + raise SemanticMemoryRetrievalValidationError( + "query_vector must contain only finite numeric values" + ) + normalized.append(normalized_value) + + return normalized + + +def validate_semantic_memory_retrieval_request( + store: ContinuityStore, + *, + request: SemanticMemoryRetrievalRequestInput, +) -> tuple[dict[str, object], list[float]]: + config = store.get_embedding_config_optional(request.embedding_config_id) + if config is None: + raise SemanticMemoryRetrievalValidationError( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{request.embedding_config_id}" + ) + + query_vector = _validate_query_vector(request.query_vector) + if len(query_vector) != config["dimensions"]: + raise SemanticMemoryRetrievalValidationError( + "query_vector length must match embedding config dimensions " + f"({config['dimensions']}): {len(query_vector)}" + ) + + return config, query_vector + + +def serialize_semantic_memory_result_item( + row: SemanticMemoryRetrievalRow, +) -> SemanticMemoryRetrievalResultItem: + if row["status"] != "active": + raise SemanticMemoryRetrievalValidationError( + f"semantic retrieval only supports active memories: {row['id']}" + ) + + return { + "memory_id": str(row["id"]), + "memory_key": row["memory_key"], + "value": row["value"], + "source_event_ids": row["source_event_ids"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + "score": float(row["score"]), + } + + +def retrieve_semantic_memory_records( + store: ContinuityStore, + *, + user_id: UUID, + request: SemanticMemoryRetrievalRequestInput, +) -> SemanticMemoryRetrievalResponse: + del user_id + + _config, query_vector = validate_semantic_memory_retrieval_request(store, request=request) + + items = [ + serialize_semantic_memory_result_item(row) + for row in store.retrieve_semantic_memory_matches( + embedding_config_id=request.embedding_config_id, + query_vector=query_vector, + limit=request.limit, + ) + ] + summary: SemanticMemoryRetrievalSummary = { + "embedding_config_id": str(request.embedding_config_id), + "limit": request.limit, + "returned_count": len(items), + "similarity_metric": "cosine_similarity", + "order": list(SEMANTIC_MEMORY_RETRIEVAL_ORDER), + } + return { + "items": items, + "summary": summary, + } diff --git a/apps/api/src/alicebot_api/store.py b/apps/api/src/alicebot_api/store.py new file mode 100644 index 0000000..8c3b551 --- /dev/null +++ b/apps/api/src/alicebot_api/store.py @@ -0,0 +1,2713 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, TypedDict, TypeVar, cast +from uuid import UUID + +import psycopg +from psycopg.types.json import Jsonb + +JsonScalar = str | int | float | bool | None +JsonValue = JsonScalar | list["JsonValue"] | dict[str, "JsonValue"] +JsonObject = dict[str, JsonValue] +RowT = TypeVar("RowT") + + +class UserRow(TypedDict): + id: UUID + email: str + display_name: str | None + created_at: datetime + + +class ThreadRow(TypedDict): + id: UUID + user_id: UUID + title: str + created_at: datetime + updated_at: datetime + + +class SessionRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + status: str + started_at: datetime | None + ended_at: datetime | None + created_at: datetime + + +class EventRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + session_id: UUID | None + sequence_no: int + kind: str + payload: JsonObject + created_at: datetime + + +class TraceRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + kind: str + compiler_version: str + status: str + limits: JsonObject + created_at: datetime + + +class TraceEventRow(TypedDict): + id: UUID + user_id: UUID + trace_id: UUID + sequence_no: int + kind: str + payload: JsonObject + created_at: datetime + + +class MemoryRow(TypedDict): + id: UUID + user_id: UUID + memory_key: str + value: JsonValue + status: str + source_event_ids: list[str] + created_at: datetime + updated_at: datetime + deleted_at: datetime | None + + +class MemoryRevisionRow(TypedDict): + id: UUID + user_id: UUID + memory_id: UUID + sequence_no: int + action: str + memory_key: str + previous_value: JsonValue | None + new_value: JsonValue | None + source_event_ids: list[str] + candidate: JsonObject + created_at: datetime + + +class MemoryReviewLabelRow(TypedDict): + id: UUID + user_id: UUID + memory_id: UUID + label: str + note: str | None + created_at: datetime + + +class EmbeddingConfigRow(TypedDict): + id: UUID + user_id: UUID + provider: str + model: str + version: str + dimensions: int + status: str + metadata: JsonObject + created_at: datetime + + +class MemoryEmbeddingRow(TypedDict): + id: UUID + user_id: UUID + memory_id: UUID + embedding_config_id: UUID + dimensions: int + vector: list[float] + created_at: datetime + updated_at: datetime + + +class SemanticMemoryRetrievalRow(TypedDict): + id: UUID + user_id: UUID + memory_key: str + value: JsonValue + status: str + source_event_ids: list[str] + created_at: datetime + updated_at: datetime + deleted_at: datetime | None + score: float + + +class EntityRow(TypedDict): + id: UUID + user_id: UUID + entity_type: str + name: str + source_memory_ids: list[str] + created_at: datetime + + +class EntityEdgeRow(TypedDict): + id: UUID + user_id: UUID + from_entity_id: UUID + to_entity_id: UUID + relationship_type: str + valid_from: datetime | None + valid_to: datetime | None + source_memory_ids: list[str] + created_at: datetime + + +class ConsentRow(TypedDict): + id: UUID + user_id: UUID + consent_key: str + status: str + metadata: JsonObject + created_at: datetime + updated_at: datetime + + +class PolicyRow(TypedDict): + id: UUID + user_id: UUID + name: str + action: str + scope: str + effect: str + priority: int + active: bool + conditions: JsonObject + required_consents: list[str] + created_at: datetime + updated_at: datetime + + +class ToolRow(TypedDict): + id: UUID + user_id: UUID + tool_key: str + name: str + description: str + version: str + metadata_version: str + active: bool + tags: list[str] + action_hints: list[str] + scope_hints: list[str] + domain_hints: list[str] + risk_hints: list[str] + metadata: JsonObject + created_at: datetime + + +class ApprovalRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + tool_id: UUID + task_step_id: UUID | None + status: str + request: JsonObject + tool: JsonObject + routing: JsonObject + routing_trace_id: UUID + created_at: datetime + resolved_at: datetime | None + resolved_by_user_id: UUID | None + + +class TaskRow(TypedDict): + id: UUID + user_id: UUID + thread_id: UUID + tool_id: UUID + status: str + request: JsonObject + tool: JsonObject + latest_approval_id: UUID | None + latest_execution_id: UUID | None + created_at: datetime + updated_at: datetime + + +class TaskWorkspaceRow(TypedDict): + id: UUID + user_id: UUID + task_id: UUID + status: str + local_path: str + created_at: datetime + updated_at: datetime + + +class TaskStepRow(TypedDict): + id: UUID + user_id: UUID + task_id: UUID + sequence_no: int + parent_step_id: UUID | None + source_approval_id: UUID | None + source_execution_id: UUID | None + kind: str + status: str + request: JsonObject + outcome: JsonObject + trace_id: UUID + trace_kind: str + created_at: datetime + updated_at: datetime + + +class ToolExecutionRow(TypedDict): + id: UUID + user_id: UUID + approval_id: UUID + task_step_id: UUID + thread_id: UUID + tool_id: UUID + trace_id: UUID + request_event_id: UUID | None + result_event_id: UUID | None + status: str + handler_key: str | None + request: JsonObject + tool: JsonObject + result: JsonObject + executed_at: datetime + + +class ExecutionBudgetRow(TypedDict): + id: UUID + user_id: UUID + tool_key: str | None + domain_hint: str | None + max_completed_executions: int + rolling_window_seconds: int | None + status: str + deactivated_at: datetime | None + superseded_by_budget_id: UUID | None + supersedes_budget_id: UUID | None + created_at: datetime + + +class CountRow(TypedDict): + count: int + + +class LabelCountRow(TypedDict): + label: str + count: int + + +INSERT_USER_SQL = """ + INSERT INTO users (id, email, display_name) + VALUES (%s, %s, %s) + RETURNING id, email, display_name, created_at + """ + +GET_USER_SQL = """ + SELECT id, email, display_name, created_at + FROM users + WHERE id = %s + """ + +INSERT_THREAD_SQL = """ + INSERT INTO threads (user_id, title) + VALUES (app.current_user_id(), %s) + RETURNING id, user_id, title, created_at, updated_at + """ + +GET_THREAD_SQL = """ + SELECT id, user_id, title, created_at, updated_at + FROM threads + WHERE id = %s + """ + +INSERT_SESSION_SQL = """ + INSERT INTO sessions (user_id, thread_id, status) + VALUES (app.current_user_id(), %s, %s) + RETURNING id, user_id, thread_id, status, started_at, ended_at, created_at + """ + +LIST_THREAD_SESSIONS_SQL = """ + SELECT id, user_id, thread_id, status, started_at, ended_at, created_at + FROM sessions + WHERE thread_id = %s + ORDER BY started_at ASC, created_at ASC, id ASC + """ + +LOCK_THREAD_EVENTS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 0))" +LOCK_TASK_STEPS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 2))" +LOCK_TASK_WORKSPACES_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 3))" + +INSERT_EVENT_SQL = """ + WITH next_sequence AS ( + SELECT COALESCE(MAX(sequence_no) + 1, 1) AS sequence_no + FROM events + WHERE thread_id = %s + AND user_id = app.current_user_id() + ) + INSERT INTO events (user_id, thread_id, session_id, sequence_no, kind, payload) + SELECT app.current_user_id(), %s, %s, next_sequence.sequence_no, %s, %s + FROM next_sequence + RETURNING id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + """ + +LIST_THREAD_EVENTS_SQL = """ + SELECT id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + FROM events + WHERE thread_id = %s + ORDER BY sequence_no ASC + """ + +LIST_EVENTS_BY_IDS_SQL = """ + SELECT id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + FROM events + WHERE id = ANY(%s) + ORDER BY sequence_no ASC + """ + +INSERT_TRACE_SQL = """ + INSERT INTO traces (user_id, thread_id, kind, compiler_version, status, limits) + VALUES (%s, %s, %s, %s, %s, %s) + RETURNING id, user_id, thread_id, kind, compiler_version, status, limits, created_at + """ + +GET_TRACE_SQL = """ + SELECT id, user_id, thread_id, kind, compiler_version, status, limits, created_at + FROM traces + WHERE id = %s + """ + +INSERT_TRACE_EVENT_SQL = """ + INSERT INTO trace_events (user_id, trace_id, sequence_no, kind, payload) + VALUES (app.current_user_id(), %s, %s, %s, %s) + RETURNING id, user_id, trace_id, sequence_no, kind, payload, created_at + """ + +LIST_TRACE_EVENTS_SQL = """ + SELECT id, user_id, trace_id, sequence_no, kind, payload, created_at + FROM trace_events + WHERE trace_id = %s + ORDER BY sequence_no ASC + """ + +INSERT_MEMORY_SQL = """ + INSERT INTO memories ( + user_id, + memory_key, + value, + status, + source_event_ids, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + """ + +GET_MEMORY_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = %s + """ + +LIST_MEMORIES_BY_IDS_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = ANY(%s) + ORDER BY created_at ASC, id ASC + """ + +GET_MEMORY_BY_KEY_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE memory_key = %s + """ + +LIST_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY created_at ASC, id ASC + """ + +COUNT_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + """ + +COUNT_MEMORIES_BY_STATUS_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE status = %s + """ + +COUNT_UNLABELED_REVIEW_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE status = 'active' + AND NOT EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + """ + +LIST_REVIEW_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """ + +LIST_REVIEW_MEMORIES_BY_STATUS_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE status = %s + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """ + +LIST_UNLABELED_REVIEW_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE status = 'active' + AND NOT EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """ + +LIST_CONTEXT_MEMORIES_SQL = """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY updated_at ASC, created_at ASC, id ASC + """ + +UPDATE_MEMORY_SQL = """ + UPDATE memories + SET value = %s, + status = %s, + source_event_ids = %s, + updated_at = clock_timestamp(), + deleted_at = CASE + WHEN %s = 'deleted' THEN clock_timestamp() + ELSE NULL + END + WHERE id = %s + RETURNING id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + """ + +LOCK_MEMORY_REVISIONS_SQL = "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 1))" + +INSERT_MEMORY_REVISION_SQL = """ + WITH next_sequence AS ( + SELECT COALESCE(MAX(sequence_no) + 1, 1) AS sequence_no + FROM memory_revisions + WHERE memory_id = %s + AND user_id = app.current_user_id() + ) + INSERT INTO memory_revisions ( + user_id, + memory_id, + sequence_no, + action, + memory_key, + previous_value, + new_value, + source_event_ids, + candidate + ) + SELECT + app.current_user_id(), + %s, + next_sequence.sequence_no, + %s, + %s, + %s, + %s, + %s, + %s + FROM next_sequence + RETURNING id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + """ + +LIST_MEMORY_REVISIONS_SQL = """ + SELECT id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + FROM memory_revisions + WHERE memory_id = %s + ORDER BY sequence_no ASC + """ + +COUNT_MEMORY_REVISIONS_SQL = """ + SELECT COUNT(*) AS count + FROM memory_revisions + WHERE memory_id = %s + """ + +LIST_LIMITED_MEMORY_REVISIONS_SQL = """ + SELECT id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + FROM memory_revisions + WHERE memory_id = %s + ORDER BY sequence_no ASC + LIMIT %s + """ + +INSERT_MEMORY_REVIEW_LABEL_SQL = """ + INSERT INTO memory_review_labels (user_id, memory_id, label, note) + VALUES (app.current_user_id(), %s, %s, %s) + RETURNING id, user_id, memory_id, label, note, created_at + """ + +LIST_MEMORY_REVIEW_LABELS_SQL = """ + SELECT id, user_id, memory_id, label, note, created_at + FROM memory_review_labels + WHERE memory_id = %s + ORDER BY created_at ASC, id ASC + """ + +LIST_MEMORY_REVIEW_LABEL_COUNTS_SQL = """ + SELECT label, COUNT(*) AS count + FROM memory_review_labels + WHERE memory_id = %s + GROUP BY label + ORDER BY label ASC + """ + +COUNT_LABELED_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + """ + +COUNT_UNLABELED_MEMORIES_SQL = """ + SELECT COUNT(*) AS count + FROM memories + WHERE NOT EXISTS ( + SELECT 1 + FROM memory_review_labels + WHERE memory_review_labels.memory_id = memories.id + ) + """ + +LIST_ALL_MEMORY_REVIEW_LABEL_COUNTS_SQL = """ + SELECT label, COUNT(*) AS count + FROM memory_review_labels + GROUP BY label + ORDER BY label ASC + """ + +INSERT_EMBEDDING_CONFIG_SQL = """ + INSERT INTO embedding_configs ( + user_id, + provider, + model, + version, + dimensions, + status, + metadata, + created_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, %s, %s, clock_timestamp()) + RETURNING id, user_id, provider, model, version, dimensions, status, metadata, created_at + """ + +GET_EMBEDDING_CONFIG_SQL = """ + SELECT id, user_id, provider, model, version, dimensions, status, metadata, created_at + FROM embedding_configs + WHERE id = %s + """ + +GET_EMBEDDING_CONFIG_BY_IDENTITY_SQL = """ + SELECT id, user_id, provider, model, version, dimensions, status, metadata, created_at + FROM embedding_configs + WHERE provider = %s + AND model = %s + AND version = %s + """ + +LIST_EMBEDDING_CONFIGS_SQL = """ + SELECT id, user_id, provider, model, version, dimensions, status, metadata, created_at + FROM embedding_configs + ORDER BY created_at ASC, id ASC + """ + +INSERT_MEMORY_EMBEDDING_SQL = """ + INSERT INTO memory_embeddings ( + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + """ + +GET_MEMORY_EMBEDDING_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE id = %s + """ + +GET_MEMORY_EMBEDDING_BY_MEMORY_AND_CONFIG_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE memory_id = %s + AND embedding_config_id = %s + """ + +LIST_MEMORY_EMBEDDINGS_FOR_MEMORY_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE memory_id = %s + ORDER BY created_at ASC, id ASC + """ + +LIST_MEMORY_EMBEDDINGS_FOR_CONFIG_SQL = """ + SELECT + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + FROM memory_embeddings + WHERE embedding_config_id = %s + ORDER BY created_at ASC, id ASC + """ + +UPDATE_MEMORY_EMBEDDING_SQL = """ + UPDATE memory_embeddings + SET dimensions = %s, + vector = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + memory_id, + embedding_config_id, + dimensions, + vector, + created_at, + updated_at + """ + +RETRIEVE_SEMANTIC_MEMORY_MATCHES_SQL = """ + SELECT + memories.id, + memories.user_id, + memories.memory_key, + memories.value, + memories.status, + memories.source_event_ids, + memories.created_at, + memories.updated_at, + memories.deleted_at, + 1 - ( + replace(memory_embeddings.vector::text, ' ', '')::vector <=> %s::vector + ) AS score + FROM memory_embeddings + JOIN memories + ON memories.id = memory_embeddings.memory_id + AND memories.user_id = memory_embeddings.user_id + WHERE memory_embeddings.embedding_config_id = %s + AND memory_embeddings.dimensions = %s + AND memories.status = 'active' + ORDER BY score DESC, memories.created_at ASC, memories.id ASC + LIMIT %s + """ + +INSERT_ENTITY_SQL = """ + INSERT INTO entities (user_id, entity_type, name, source_memory_ids, created_at) + VALUES (app.current_user_id(), %s, %s, %s, clock_timestamp()) + RETURNING id, user_id, entity_type, name, source_memory_ids, created_at + """ + +GET_ENTITY_SQL = """ + SELECT id, user_id, entity_type, name, source_memory_ids, created_at + FROM entities + WHERE id = %s + """ + +LIST_ENTITIES_SQL = """ + SELECT id, user_id, entity_type, name, source_memory_ids, created_at + FROM entities + ORDER BY created_at ASC, id ASC + """ + +INSERT_ENTITY_EDGE_SQL = """ + INSERT INTO entity_edges ( + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, %s, %s, clock_timestamp()) + RETURNING + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + """ + +LIST_ENTITY_EDGES_FOR_ENTITY_SQL = """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = %s OR to_entity_id = %s + ORDER BY created_at ASC, id ASC + """ + +LIST_ENTITY_EDGES_FOR_ENTITIES_SQL = """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = ANY(%s) OR to_entity_id = ANY(%s) + ORDER BY created_at ASC, id ASC + """ + +INSERT_CONSENT_SQL = """ + INSERT INTO consents ( + user_id, + consent_key, + status, + metadata, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING id, user_id, consent_key, status, metadata, created_at, updated_at + """ + +GET_CONSENT_BY_KEY_SQL = """ + SELECT id, user_id, consent_key, status, metadata, created_at, updated_at + FROM consents + WHERE consent_key = %s + """ + +LIST_CONSENTS_SQL = """ + SELECT id, user_id, consent_key, status, metadata, created_at, updated_at + FROM consents + ORDER BY consent_key ASC, created_at ASC, id ASC + """ + +UPDATE_CONSENT_SQL = """ + UPDATE consents + SET status = %s, + metadata = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING id, user_id, consent_key, status, metadata, created_at, updated_at + """ + +INSERT_POLICY_SQL = """ + INSERT INTO policies ( + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + ) + VALUES (app.current_user_id(), %s, %s, %s, %s, %s, %s, %s, %s, clock_timestamp(), clock_timestamp()) + RETURNING + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + """ + +GET_POLICY_SQL = """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + WHERE id = %s + """ + +LIST_POLICIES_SQL = """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + ORDER BY priority ASC, created_at ASC, id ASC + """ + +LIST_ACTIVE_POLICIES_SQL = """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + WHERE active = TRUE + ORDER BY priority ASC, created_at ASC, id ASC + """ + +INSERT_TOOL_SQL = """ + INSERT INTO tools ( + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp() + ) + RETURNING + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + """ + +GET_TOOL_SQL = """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + WHERE id = %s + """ + +LIST_TOOLS_SQL = """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + ORDER BY tool_key ASC, version ASC, created_at ASC, id ASC + """ + +LIST_ACTIVE_TOOLS_SQL = """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + WHERE active = TRUE + ORDER BY tool_key ASC, version ASC, created_at ASC, id ASC + """ + +INSERT_APPROVAL_SQL = """ + INSERT INTO approvals ( + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp() + ) + RETURNING + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + """ + +GET_APPROVAL_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + FROM approvals + WHERE id = %s + """ + +LIST_APPROVALS_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + FROM approvals + ORDER BY created_at ASC, id ASC + """ + +UPDATE_APPROVAL_RESOLUTION_SQL = """ + UPDATE approvals + SET status = %s, + resolved_at = clock_timestamp(), + resolved_by_user_id = app.current_user_id() + WHERE id = %s + AND status = 'pending' + RETURNING + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + """ + +UPDATE_APPROVAL_TASK_STEP_SQL = """ + UPDATE approvals + SET task_step_id = %s + WHERE id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + created_at, + resolved_at, + resolved_by_user_id + """ + +INSERT_TASK_SQL = """ + INSERT INTO tasks ( + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +GET_TASK_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + FROM tasks + WHERE id = %s + """ + +GET_TASK_BY_APPROVAL_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + FROM tasks + WHERE latest_approval_id = %s + """ + +LIST_TASKS_SQL = """ + SELECT + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + FROM tasks + ORDER BY created_at ASC, id ASC + """ + +UPDATE_TASK_STATUS_BY_APPROVAL_SQL = """ + UPDATE tasks + SET status = %s, + updated_at = clock_timestamp() + WHERE latest_approval_id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +UPDATE_TASK_EXECUTION_BY_APPROVAL_SQL = """ + UPDATE tasks + SET status = %s, + latest_execution_id = %s, + updated_at = clock_timestamp() + WHERE latest_approval_id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +UPDATE_TASK_STATUS_SQL = """ + UPDATE tasks + SET status = %s, + latest_approval_id = %s, + latest_execution_id = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id, + created_at, + updated_at + """ + +INSERT_TASK_WORKSPACE_SQL = """ + INSERT INTO task_workspaces ( + user_id, + task_id, + status, + local_path, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + """ + +GET_TASK_WORKSPACE_SQL = """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE id = %s + """ + +GET_ACTIVE_TASK_WORKSPACE_FOR_TASK_SQL = """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE task_id = %s + AND status = 'active' + ORDER BY created_at ASC, id ASC + LIMIT 1 + """ + +LIST_TASK_WORKSPACES_SQL = """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + ORDER BY created_at ASC, id ASC + """ + +INSERT_TASK_STEP_SQL = """ + INSERT INTO task_steps ( + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + """ + +GET_TASK_STEP_SQL = """ + SELECT + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + FROM task_steps + WHERE id = %s + """ + +GET_TASK_STEP_FOR_TASK_SEQUENCE_SQL = """ + SELECT + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + FROM task_steps + WHERE task_id = %s + AND sequence_no = %s + """ + +LIST_TASK_STEPS_FOR_TASK_SQL = """ + SELECT + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + FROM task_steps + WHERE task_id = %s + ORDER BY sequence_no ASC, created_at ASC, id ASC + """ + +UPDATE_TASK_STEP_FOR_TASK_SEQUENCE_SQL = """ + UPDATE task_steps + SET status = %s, + outcome = %s, + trace_id = %s, + trace_kind = %s, + updated_at = clock_timestamp() + WHERE task_id = %s + AND sequence_no = %s + RETURNING + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + """ + +UPDATE_TASK_STEP_SQL = """ + UPDATE task_steps + SET status = %s, + outcome = %s, + trace_id = %s, + trace_kind = %s, + updated_at = clock_timestamp() + WHERE id = %s + RETURNING + id, + user_id, + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + request, + outcome, + trace_id, + trace_kind, + created_at, + updated_at + """ + +INSERT_TOOL_EXECUTION_SQL = """ + INSERT INTO tool_executions ( + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + clock_timestamp() + ) + RETURNING + id, + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + """ + +GET_TOOL_EXECUTION_SQL = """ + SELECT + id, + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + FROM tool_executions + WHERE id = %s + """ + +LIST_TOOL_EXECUTIONS_SQL = """ + SELECT + id, + user_id, + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result, + executed_at + FROM tool_executions + ORDER BY executed_at ASC, id ASC + """ + +INSERT_EXECUTION_BUDGET_SQL = """ + INSERT INTO execution_budgets ( + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + supersedes_budget_id + ) + VALUES ( + COALESCE(%s, gen_random_uuid()), + app.current_user_id(), + %s, + %s, + %s, + %s, + %s + ) + RETURNING + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + """ + +GET_EXECUTION_BUDGET_SQL = """ + SELECT + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + FROM execution_budgets + WHERE id = %s + """ + +LIST_EXECUTION_BUDGETS_SQL = """ + SELECT + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + FROM execution_budgets + ORDER BY created_at ASC, id ASC + """ + +DEACTIVATE_EXECUTION_BUDGET_SQL = """ + UPDATE execution_budgets + SET status = 'inactive', + deactivated_at = now() + WHERE id = %s + AND status = 'active' + RETURNING + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + """ + +SUPERSEDE_EXECUTION_BUDGET_SQL = """ + UPDATE execution_budgets + SET status = 'superseded', + deactivated_at = now(), + superseded_by_budget_id = %s + WHERE id = %s + AND status = 'active' + RETURNING + id, + user_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + status, + deactivated_at, + superseded_by_budget_id, + supersedes_budget_id, + created_at + """ + +UPDATE_EVENT_ERROR = "events are append-only and must be superseded by new records" +DELETE_EVENT_ERROR = "events are append-only and must not be deleted in place" +UPDATE_TRACE_EVENT_ERROR = "trace events are append-only and must be superseded by new records" +DELETE_TRACE_EVENT_ERROR = "trace events are append-only and must not be deleted in place" + + +class AppendOnlyViolation(RuntimeError): + """Raised when a caller attempts to mutate an immutable event.""" + + +class ContinuityStoreInvariantError(RuntimeError): + """Raised when a write query does not return the row its contract promises.""" + + +class ContinuityStore: + def __init__(self, conn: psycopg.Connection): + self.conn = conn + + def _fetch_one( + self, + operation_name: str, + query: str, + params: tuple[object, ...] | None = None, + ) -> RowT: + with self.conn.cursor() as cur: + cur.execute(query, params) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + f"{operation_name} did not return a row from the database", + ) + + return cast(RowT, row) + + def _fetch_all( + self, + query: str, + params: tuple[object, ...] | None = None, + ) -> list[RowT]: + with self.conn.cursor() as cur: + cur.execute(query, params) + return cast(list[RowT], list(cur.fetchall())) + + def _fetch_optional_one( + self, + query: str, + params: tuple[object, ...] | None = None, + ) -> RowT | None: + with self.conn.cursor() as cur: + cur.execute(query, params) + row = cur.fetchone() + return cast(RowT | None, row) + + def _fetch_count( + self, + query: str, + params: tuple[object, ...] | None = None, + ) -> int: + with self.conn.cursor() as cur: + cur.execute(query, params) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "count query did not return a row from the database", + ) + + return cast(CountRow, row)["count"] + + @staticmethod + def _vector_literal(vector: list[float]) -> str: + return "[" + ",".join(repr(value) for value in vector) + "]" + + def create_user(self, user_id: UUID, email: str, display_name: str | None = None) -> UserRow: + return self._fetch_one( + "create_user", + INSERT_USER_SQL, + (user_id, email, display_name), + ) + + def get_user(self, user_id: UUID) -> UserRow: + return self._fetch_one("get_user", GET_USER_SQL, (user_id,)) + + def create_thread(self, title: str) -> ThreadRow: + return self._fetch_one("create_thread", INSERT_THREAD_SQL, (title,)) + + def get_thread(self, thread_id: UUID) -> ThreadRow: + return self._fetch_one("get_thread", GET_THREAD_SQL, (thread_id,)) + + def get_thread_optional(self, thread_id: UUID) -> ThreadRow | None: + return self._fetch_optional_one(GET_THREAD_SQL, (thread_id,)) + + def create_session(self, thread_id: UUID, status: str = "active") -> SessionRow: + return self._fetch_one("create_session", INSERT_SESSION_SQL, (thread_id, status)) + + def list_thread_sessions(self, thread_id: UUID) -> list[SessionRow]: + return self._fetch_all(LIST_THREAD_SESSIONS_SQL, (thread_id,)) + + def append_event( + self, + thread_id: UUID, + session_id: UUID | None, + kind: str, + payload: JsonObject, + ) -> EventRow: + with self.conn.cursor() as cur: + cur.execute(LOCK_THREAD_EVENTS_SQL, (str(thread_id),)) + cur.execute( + INSERT_EVENT_SQL, + (thread_id, thread_id, session_id, kind, Jsonb(payload)), + ) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "append_event did not return a row from the database", + ) + + return cast(EventRow, row) + + def list_thread_events(self, thread_id: UUID) -> list[EventRow]: + return self._fetch_all(LIST_THREAD_EVENTS_SQL, (thread_id,)) + + def list_events_by_ids(self, event_ids: list[UUID]) -> list[EventRow]: + if not event_ids: + return [] + return self._fetch_all(LIST_EVENTS_BY_IDS_SQL, (event_ids,)) + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: JsonObject, + ) -> TraceRow: + return self._fetch_one( + "create_trace", + INSERT_TRACE_SQL, + (user_id, thread_id, kind, compiler_version, status, Jsonb(limits)), + ) + + def get_trace(self, trace_id: UUID) -> TraceRow: + return self._fetch_one("get_trace", GET_TRACE_SQL, (trace_id,)) + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: JsonObject, + ) -> TraceEventRow: + return self._fetch_one( + "append_trace_event", + INSERT_TRACE_EVENT_SQL, + (trace_id, sequence_no, kind, Jsonb(payload)), + ) + + def list_trace_events(self, trace_id: UUID) -> list[TraceEventRow]: + return self._fetch_all(LIST_TRACE_EVENTS_SQL, (trace_id,)) + + def create_memory( + self, + *, + memory_key: str, + value: JsonValue, + status: str, + source_event_ids: list[str], + ) -> MemoryRow: + return self._fetch_one( + "create_memory", + INSERT_MEMORY_SQL, + (memory_key, Jsonb(value), status, Jsonb(source_event_ids)), + ) + + def get_memory(self, memory_id: UUID) -> MemoryRow: + return self._fetch_one("get_memory", GET_MEMORY_SQL, (memory_id,)) + + def get_memory_optional(self, memory_id: UUID) -> MemoryRow | None: + return self._fetch_optional_one(GET_MEMORY_SQL, (memory_id,)) + + def list_memories_by_ids(self, memory_ids: list[UUID]) -> list[MemoryRow]: + if not memory_ids: + return [] + return self._fetch_all(LIST_MEMORIES_BY_IDS_SQL, (memory_ids,)) + + def get_memory_by_key(self, memory_key: str) -> MemoryRow | None: + return self._fetch_optional_one(GET_MEMORY_BY_KEY_SQL, (memory_key,)) + + def list_memories(self) -> list[MemoryRow]: + return self._fetch_all(LIST_MEMORIES_SQL) + + def count_memories(self, *, status: str | None = None) -> int: + if status is None: + return self._fetch_count(COUNT_MEMORIES_SQL) + return self._fetch_count(COUNT_MEMORIES_BY_STATUS_SQL, (status,)) + + def count_unlabeled_review_memories(self) -> int: + return self._fetch_count(COUNT_UNLABELED_REVIEW_MEMORIES_SQL) + + def list_review_memories(self, *, status: str | None = None, limit: int) -> list[MemoryRow]: + if status is None: + return self._fetch_all(LIST_REVIEW_MEMORIES_SQL, (limit,)) + return self._fetch_all(LIST_REVIEW_MEMORIES_BY_STATUS_SQL, (status, limit)) + + def list_unlabeled_review_memories(self, *, limit: int) -> list[MemoryRow]: + return self._fetch_all(LIST_UNLABELED_REVIEW_MEMORIES_SQL, (limit,)) + + def list_context_memories(self) -> list[MemoryRow]: + return self._fetch_all(LIST_CONTEXT_MEMORIES_SQL) + + def update_memory( + self, + *, + memory_id: UUID, + value: JsonValue, + status: str, + source_event_ids: list[str], + ) -> MemoryRow: + return self._fetch_one( + "update_memory", + UPDATE_MEMORY_SQL, + (Jsonb(value), status, Jsonb(source_event_ids), status, memory_id), + ) + + def append_memory_revision( + self, + *, + memory_id: UUID, + action: str, + memory_key: str, + previous_value: JsonValue | None, + new_value: JsonValue | None, + source_event_ids: list[str], + candidate: JsonObject, + ) -> MemoryRevisionRow: + with self.conn.cursor() as cur: + cur.execute(LOCK_MEMORY_REVISIONS_SQL, (str(memory_id),)) + cur.execute( + INSERT_MEMORY_REVISION_SQL, + ( + memory_id, + memory_id, + action, + memory_key, + Jsonb(previous_value), + Jsonb(new_value), + Jsonb(source_event_ids), + Jsonb(candidate), + ), + ) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "append_memory_revision did not return a row from the database", + ) + + return cast(MemoryRevisionRow, row) + + def count_memory_revisions(self, memory_id: UUID) -> int: + return self._fetch_count(COUNT_MEMORY_REVISIONS_SQL, (memory_id,)) + + def list_memory_revisions( + self, + memory_id: UUID, + *, + limit: int | None = None, + ) -> list[MemoryRevisionRow]: + if limit is None: + return self._fetch_all(LIST_MEMORY_REVISIONS_SQL, (memory_id,)) + return self._fetch_all(LIST_LIMITED_MEMORY_REVISIONS_SQL, (memory_id, limit)) + + def create_memory_review_label( + self, + *, + memory_id: UUID, + label: str, + note: str | None, + ) -> MemoryReviewLabelRow: + return self._fetch_one( + "create_memory_review_label", + INSERT_MEMORY_REVIEW_LABEL_SQL, + (memory_id, label, note), + ) + + def list_memory_review_labels(self, memory_id: UUID) -> list[MemoryReviewLabelRow]: + return self._fetch_all(LIST_MEMORY_REVIEW_LABELS_SQL, (memory_id,)) + + def list_memory_review_label_counts(self, memory_id: UUID) -> list[LabelCountRow]: + return self._fetch_all(LIST_MEMORY_REVIEW_LABEL_COUNTS_SQL, (memory_id,)) + + def count_labeled_memories(self) -> int: + return self._fetch_count(COUNT_LABELED_MEMORIES_SQL) + + def count_unlabeled_memories(self) -> int: + return self._fetch_count(COUNT_UNLABELED_MEMORIES_SQL) + + def list_all_memory_review_label_counts(self) -> list[LabelCountRow]: + return self._fetch_all(LIST_ALL_MEMORY_REVIEW_LABEL_COUNTS_SQL) + + def create_embedding_config( + self, + *, + provider: str, + model: str, + version: str, + dimensions: int, + status: str, + metadata: JsonObject, + ) -> EmbeddingConfigRow: + return self._fetch_one( + "create_embedding_config", + INSERT_EMBEDDING_CONFIG_SQL, + (provider, model, version, dimensions, status, Jsonb(metadata)), + ) + + def get_embedding_config_optional(self, embedding_config_id: UUID) -> EmbeddingConfigRow | None: + return self._fetch_optional_one(GET_EMBEDDING_CONFIG_SQL, (embedding_config_id,)) + + def get_embedding_config_by_identity_optional( + self, + *, + provider: str, + model: str, + version: str, + ) -> EmbeddingConfigRow | None: + return self._fetch_optional_one( + GET_EMBEDDING_CONFIG_BY_IDENTITY_SQL, + (provider, model, version), + ) + + def list_embedding_configs(self) -> list[EmbeddingConfigRow]: + return self._fetch_all(LIST_EMBEDDING_CONFIGS_SQL) + + def create_memory_embedding( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + dimensions: int, + vector: list[float], + ) -> MemoryEmbeddingRow: + return self._fetch_one( + "create_memory_embedding", + INSERT_MEMORY_EMBEDDING_SQL, + (memory_id, embedding_config_id, dimensions, Jsonb(vector)), + ) + + def get_memory_embedding_optional(self, memory_embedding_id: UUID) -> MemoryEmbeddingRow | None: + return self._fetch_optional_one(GET_MEMORY_EMBEDDING_SQL, (memory_embedding_id,)) + + def get_memory_embedding_by_memory_and_config_optional( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + ) -> MemoryEmbeddingRow | None: + return self._fetch_optional_one( + GET_MEMORY_EMBEDDING_BY_MEMORY_AND_CONFIG_SQL, + (memory_id, embedding_config_id), + ) + + def list_memory_embeddings_for_memory(self, memory_id: UUID) -> list[MemoryEmbeddingRow]: + return self._fetch_all(LIST_MEMORY_EMBEDDINGS_FOR_MEMORY_SQL, (memory_id,)) + + def list_memory_embeddings_for_config( + self, + embedding_config_id: UUID, + ) -> list[MemoryEmbeddingRow]: + return self._fetch_all(LIST_MEMORY_EMBEDDINGS_FOR_CONFIG_SQL, (embedding_config_id,)) + + def update_memory_embedding( + self, + *, + memory_embedding_id: UUID, + dimensions: int, + vector: list[float], + ) -> MemoryEmbeddingRow: + return self._fetch_one( + "update_memory_embedding", + UPDATE_MEMORY_EMBEDDING_SQL, + (dimensions, Jsonb(vector), memory_embedding_id), + ) + + def retrieve_semantic_memory_matches( + self, + *, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[SemanticMemoryRetrievalRow]: + return self._fetch_all( + RETRIEVE_SEMANTIC_MEMORY_MATCHES_SQL, + ( + self._vector_literal(query_vector), + embedding_config_id, + len(query_vector), + limit, + ), + ) + + def create_entity( + self, + *, + entity_type: str, + name: str, + source_memory_ids: list[str], + ) -> EntityRow: + return self._fetch_one( + "create_entity", + INSERT_ENTITY_SQL, + (entity_type, name, Jsonb(source_memory_ids)), + ) + + def get_entity_optional(self, entity_id: UUID) -> EntityRow | None: + return self._fetch_optional_one(GET_ENTITY_SQL, (entity_id,)) + + def list_entities(self) -> list[EntityRow]: + return self._fetch_all(LIST_ENTITIES_SQL) + + def create_entity_edge( + self, + *, + from_entity_id: UUID, + to_entity_id: UUID, + relationship_type: str, + valid_from: datetime | None, + valid_to: datetime | None, + source_memory_ids: list[str], + ) -> EntityEdgeRow: + return self._fetch_one( + "create_entity_edge", + INSERT_ENTITY_EDGE_SQL, + ( + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + Jsonb(source_memory_ids), + ), + ) + + def list_entity_edges_for_entity(self, entity_id: UUID) -> list[EntityEdgeRow]: + return self._fetch_all(LIST_ENTITY_EDGES_FOR_ENTITY_SQL, (entity_id, entity_id)) + + def list_entity_edges_for_entities(self, entity_ids: list[UUID]) -> list[EntityEdgeRow]: + if not entity_ids: + return [] + return self._fetch_all(LIST_ENTITY_EDGES_FOR_ENTITIES_SQL, (entity_ids, entity_ids)) + + def create_consent( + self, + *, + consent_key: str, + status: str, + metadata: JsonObject, + ) -> ConsentRow: + return self._fetch_one( + "create_consent", + INSERT_CONSENT_SQL, + (consent_key, status, Jsonb(metadata)), + ) + + def get_consent_by_key_optional(self, consent_key: str) -> ConsentRow | None: + return self._fetch_optional_one(GET_CONSENT_BY_KEY_SQL, (consent_key,)) + + def list_consents(self) -> list[ConsentRow]: + return self._fetch_all(LIST_CONSENTS_SQL) + + def update_consent( + self, + *, + consent_id: UUID, + status: str, + metadata: JsonObject, + ) -> ConsentRow: + return self._fetch_one( + "update_consent", + UPDATE_CONSENT_SQL, + (status, Jsonb(metadata), consent_id), + ) + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: JsonObject, + required_consents: list[str], + ) -> PolicyRow: + return self._fetch_one( + "create_policy", + INSERT_POLICY_SQL, + ( + name, + action, + scope, + effect, + priority, + active, + Jsonb(conditions), + Jsonb(required_consents), + ), + ) + + def get_policy_optional(self, policy_id: UUID) -> PolicyRow | None: + return self._fetch_optional_one(GET_POLICY_SQL, (policy_id,)) + + def list_policies(self) -> list[PolicyRow]: + return self._fetch_all(LIST_POLICIES_SQL) + + def list_active_policies(self) -> list[PolicyRow]: + return self._fetch_all(LIST_ACTIVE_POLICIES_SQL) + + def create_tool( + self, + *, + tool_key: str, + name: str, + description: str, + version: str, + metadata_version: str, + active: bool, + tags: list[str], + action_hints: list[str], + scope_hints: list[str], + domain_hints: list[str], + risk_hints: list[str], + metadata: JsonObject, + ) -> ToolRow: + return self._fetch_one( + "create_tool", + INSERT_TOOL_SQL, + ( + tool_key, + name, + description, + version, + metadata_version, + active, + Jsonb(tags), + Jsonb(action_hints), + Jsonb(scope_hints), + Jsonb(domain_hints), + Jsonb(risk_hints), + Jsonb(metadata), + ), + ) + + def get_tool_optional(self, tool_id: UUID) -> ToolRow | None: + return self._fetch_optional_one(GET_TOOL_SQL, (tool_id,)) + + def list_tools(self) -> list[ToolRow]: + return self._fetch_all(LIST_TOOLS_SQL) + + def list_active_tools(self) -> list[ToolRow]: + return self._fetch_all(LIST_ACTIVE_TOOLS_SQL) + + def create_approval( + self, + *, + thread_id: UUID, + tool_id: UUID, + task_step_id: UUID | None, + status: str, + request: JsonObject, + tool: JsonObject, + routing: JsonObject, + routing_trace_id: UUID, + ) -> ApprovalRow: + return self._fetch_one( + "create_approval", + INSERT_APPROVAL_SQL, + ( + thread_id, + tool_id, + task_step_id, + status, + Jsonb(request), + Jsonb(tool), + Jsonb(routing), + routing_trace_id, + ), + ) + + def get_approval_optional(self, approval_id: UUID) -> ApprovalRow | None: + return self._fetch_optional_one(GET_APPROVAL_SQL, (approval_id,)) + + def list_approvals(self) -> list[ApprovalRow]: + return self._fetch_all(LIST_APPROVALS_SQL) + + def resolve_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> ApprovalRow | None: + return self._fetch_optional_one( + UPDATE_APPROVAL_RESOLUTION_SQL, + (status, approval_id), + ) + + def update_approval_task_step_optional( + self, + *, + approval_id: UUID, + task_step_id: UUID, + ) -> ApprovalRow | None: + return self._fetch_optional_one( + UPDATE_APPROVAL_TASK_STEP_SQL, + (task_step_id, approval_id), + ) + + def create_task( + self, + *, + thread_id: UUID, + tool_id: UUID, + status: str, + request: JsonObject, + tool: JsonObject, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> TaskRow: + return self._fetch_one( + "create_task", + INSERT_TASK_SQL, + ( + thread_id, + tool_id, + status, + Jsonb(request), + Jsonb(tool), + latest_approval_id, + latest_execution_id, + ), + ) + + def get_task_optional(self, task_id: UUID) -> TaskRow | None: + return self._fetch_optional_one(GET_TASK_SQL, (task_id,)) + + def get_task_by_approval_optional(self, approval_id: UUID) -> TaskRow | None: + return self._fetch_optional_one(GET_TASK_BY_APPROVAL_SQL, (approval_id,)) + + def list_tasks(self) -> list[TaskRow]: + return self._fetch_all(LIST_TASKS_SQL) + + def update_task_status_by_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> TaskRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STATUS_BY_APPROVAL_SQL, + (status, approval_id), + ) + + def update_task_execution_by_approval_optional( + self, + *, + approval_id: UUID, + latest_execution_id: UUID, + status: str, + ) -> TaskRow | None: + return self._fetch_optional_one( + UPDATE_TASK_EXECUTION_BY_APPROVAL_SQL, + (status, latest_execution_id, approval_id), + ) + + def update_task_status_optional( + self, + *, + task_id: UUID, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> TaskRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STATUS_SQL, + (status, latest_approval_id, latest_execution_id, task_id), + ) + + def lock_task_workspaces(self, task_id: UUID) -> None: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_WORKSPACES_SQL, (str(task_id),)) + + def create_task_workspace( + self, + *, + task_id: UUID, + status: str, + local_path: str, + ) -> TaskWorkspaceRow: + return self._fetch_one( + "create_task_workspace", + INSERT_TASK_WORKSPACE_SQL, + (task_id, status, local_path), + ) + + def get_task_workspace_optional(self, task_workspace_id: UUID) -> TaskWorkspaceRow | None: + return self._fetch_optional_one(GET_TASK_WORKSPACE_SQL, (task_workspace_id,)) + + def get_active_task_workspace_for_task_optional(self, task_id: UUID) -> TaskWorkspaceRow | None: + return self._fetch_optional_one(GET_ACTIVE_TASK_WORKSPACE_FOR_TASK_SQL, (task_id,)) + + def list_task_workspaces(self) -> list[TaskWorkspaceRow]: + return self._fetch_all(LIST_TASK_WORKSPACES_SQL) + + def lock_task_steps(self, task_id: UUID) -> None: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_STEPS_SQL, (str(task_id),)) + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: JsonObject, + outcome: JsonObject, + trace_id: UUID, + trace_kind: str, + ) -> TaskStepRow: + with self.conn.cursor() as cur: + cur.execute(LOCK_TASK_STEPS_SQL, (str(task_id),)) + cur.execute( + INSERT_TASK_STEP_SQL, + ( + task_id, + sequence_no, + parent_step_id, + source_approval_id, + source_execution_id, + kind, + status, + Jsonb(request), + Jsonb(outcome), + trace_id, + trace_kind, + ), + ) + row = cur.fetchone() + + if row is None: + raise ContinuityStoreInvariantError( + "create_task_step did not return a row from the database", + ) + + return cast(TaskStepRow, row) + + def get_task_step_optional(self, task_step_id: UUID) -> TaskStepRow | None: + return self._fetch_optional_one(GET_TASK_STEP_SQL, (task_step_id,)) + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> TaskStepRow | None: + return self._fetch_optional_one( + GET_TASK_STEP_FOR_TASK_SEQUENCE_SQL, + (task_id, sequence_no), + ) + + def list_task_steps_for_task(self, task_id: UUID) -> list[TaskStepRow]: + return self._fetch_all(LIST_TASK_STEPS_FOR_TASK_SQL, (task_id,)) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: JsonObject, + trace_id: UUID, + trace_kind: str, + ) -> TaskStepRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STEP_FOR_TASK_SEQUENCE_SQL, + ( + status, + Jsonb(outcome), + trace_id, + trace_kind, + task_id, + sequence_no, + ), + ) + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: JsonObject, + trace_id: UUID, + trace_kind: str, + ) -> TaskStepRow | None: + return self._fetch_optional_one( + UPDATE_TASK_STEP_SQL, + ( + status, + Jsonb(outcome), + trace_id, + trace_kind, + task_step_id, + ), + ) + + def create_tool_execution( + self, + *, + approval_id: UUID, + task_step_id: UUID, + thread_id: UUID, + tool_id: UUID, + trace_id: UUID, + request_event_id: UUID | None, + result_event_id: UUID | None, + status: str, + handler_key: str | None, + request: JsonObject, + tool: JsonObject, + result: JsonObject, + ) -> ToolExecutionRow: + return self._fetch_one( + "create_tool_execution", + INSERT_TOOL_EXECUTION_SQL, + ( + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + Jsonb(request), + Jsonb(tool), + Jsonb(result), + ), + ) + + def get_tool_execution_optional(self, execution_id: UUID) -> ToolExecutionRow | None: + return self._fetch_optional_one(GET_TOOL_EXECUTION_SQL, (execution_id,)) + + def list_tool_executions(self) -> list[ToolExecutionRow]: + return self._fetch_all(LIST_TOOL_EXECUTIONS_SQL) + + def create_execution_budget( + self, + *, + budget_id: UUID | None = None, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, + supersedes_budget_id: UUID | None = None, + ) -> ExecutionBudgetRow: + return self._fetch_one( + "create_execution_budget", + INSERT_EXECUTION_BUDGET_SQL, + ( + budget_id, + tool_key, + domain_hint, + max_completed_executions, + rolling_window_seconds, + supersedes_budget_id, + ), + ) + + def get_execution_budget_optional(self, execution_budget_id: UUID) -> ExecutionBudgetRow | None: + return self._fetch_optional_one(GET_EXECUTION_BUDGET_SQL, (execution_budget_id,)) + + def list_execution_budgets(self) -> list[ExecutionBudgetRow]: + return self._fetch_all(LIST_EXECUTION_BUDGETS_SQL) + + def deactivate_execution_budget_optional( + self, + execution_budget_id: UUID, + ) -> ExecutionBudgetRow | None: + return self._fetch_optional_one(DEACTIVATE_EXECUTION_BUDGET_SQL, (execution_budget_id,)) + + def supersede_execution_budget_optional( + self, + *, + execution_budget_id: UUID, + superseded_by_budget_id: UUID, + ) -> ExecutionBudgetRow | None: + return self._fetch_optional_one( + SUPERSEDE_EXECUTION_BUDGET_SQL, + ( + superseded_by_budget_id, + execution_budget_id, + ), + ) + + def update_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(UPDATE_EVENT_ERROR) + + def delete_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(DELETE_EVENT_ERROR) + + def update_trace_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(UPDATE_TRACE_EVENT_ERROR) + + def delete_trace_event(self, *_args: Any, **_kwargs: Any) -> None: + raise AppendOnlyViolation(DELETE_TRACE_EVENT_ERROR) diff --git a/apps/api/src/alicebot_api/tasks.py b/apps/api/src/alicebot_api/tasks.py new file mode 100644 index 0000000..da88e5f --- /dev/null +++ b/apps/api/src/alicebot_api/tasks.py @@ -0,0 +1,1170 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import cast +from uuid import UUID + +import psycopg + +from alicebot_api.contracts import ( + TASK_LIST_ORDER, + TASK_STEP_CONTINUATION_VERSION_V0, + TASK_STEP_LIST_ORDER, + TASK_STEP_TRANSITION_VERSION_V0, + TRACE_KIND_TASK_STEP_CONTINUATION, + TRACE_KIND_TASK_STEP_TRANSITION, + TaskCreateInput, + TaskCreateResponse, + TaskDetailResponse, + TaskLifecycleSource, + TaskLifecycleStateTracePayload, + TaskLifecycleSummaryTracePayload, + TaskListResponse, + TaskListSummary, + TaskRecord, + TaskStatus, + TaskStepCreateInput, + TaskStepCreateResponse, + TaskStepDetailResponse, + TaskStepContinuationLineageTracePayload, + TaskStepContinuationRequestTracePayload, + TaskStepContinuationSummaryTracePayload, + TaskStepLifecycleStateTracePayload, + TaskStepLifecycleSummaryTracePayload, + TaskStepLineageRecord, + TaskStepListSummary, + TaskStepListResponse, + TaskStepMutationTraceSummary, + TaskStepNextCreateInput, + TaskStepNextCreateResponse, + TaskStepOutcomeSnapshot, + TaskStepRecord, + TaskStepStatus, + TaskStepTransitionInput, + TaskStepTransitionRequestTracePayload, + TaskStepTransitionResponse, + TaskStepTransitionStateTracePayload, + TaskStepTransitionSummaryTracePayload, +) +from alicebot_api.store import ( + ContinuityStore, + ContinuityStoreInvariantError, + TaskRow, + TaskStepRow, + ToolExecutionRow, +) + +TASK_LIFECYCLE_STATE_EVENT_KIND = "task.lifecycle.state" +TASK_LIFECYCLE_SUMMARY_EVENT_KIND = "task.lifecycle.summary" +TASK_STEP_LIFECYCLE_STATE_EVENT_KIND = "task.step.lifecycle.state" +TASK_STEP_LIFECYCLE_SUMMARY_EVENT_KIND = "task.step.lifecycle.summary" +TASK_STEP_CONTINUATION_REQUEST_EVENT_KIND = "task.step.continuation.request" +TASK_STEP_CONTINUATION_LINEAGE_EVENT_KIND = "task.step.continuation.lineage" +TASK_STEP_CONTINUATION_SUMMARY_EVENT_KIND = "task.step.continuation.summary" +TASK_STEP_TRANSITION_REQUEST_EVENT_KIND = "task.step.transition.request" +TASK_STEP_TRANSITION_STATE_EVENT_KIND = "task.step.transition.state" +TASK_STEP_TRANSITION_SUMMARY_EVENT_KIND = "task.step.transition.summary" +DEFAULT_TASK_STEP_SEQUENCE_NO = 1 +DEFAULT_TASK_STEP_KIND = "governed_request" +TASK_STEP_APPENDABLE_STATUSES = frozenset({"executed", "blocked", "denied"}) +TASK_STEP_INITIAL_STATUSES = frozenset({"created", "approved", "denied"}) +TASK_STEP_STATUS_GRAPH: dict[TaskStepStatus, tuple[TaskStepStatus, ...]] = { + "created": ("approved", "denied"), + "approved": ("executed", "blocked"), + "executed": (), + "blocked": (), + "denied": (), +} + + +class TaskNotFoundError(LookupError): + """Raised when a task record is not visible inside the current user scope.""" + + +class TaskStepNotFoundError(LookupError): + """Raised when a task-step record is not visible inside the current user scope.""" + + +class TaskStepSequenceError(RuntimeError): + """Raised when a task-step append request violates deterministic sequencing rules.""" + + +class TaskStepTransitionError(RuntimeError): + """Raised when a task-step transition request violates the explicit status graph.""" + + +class TaskStepLifecycleBoundaryError(RuntimeError): + """Raised when first-step-only lifecycle helpers are routed a later-step context.""" + + +class TaskStepApprovalLinkageError(RuntimeError): + """Raised when approval resolution cannot validate its linked task step.""" + + +class TaskStepExecutionLinkageError(RuntimeError): + """Raised when execution synchronization cannot validate its linked task step.""" + + +@dataclass(frozen=True, slots=True) +class TaskTransitionResult: + task: TaskRecord + previous_status: TaskStatus | None + + +@dataclass(frozen=True, slots=True) +class TaskStepTransitionResult: + task_step: TaskStepRecord + previous_status: TaskStepStatus | None + + +def _append_trace_events( + store: ContinuityStore, + *, + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> None: + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace_id, + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + +def _trace_summary( + trace_id: UUID, + trace_events: list[tuple[str, dict[str, object]]], +) -> TaskStepMutationTraceSummary: + return { + "trace_id": str(trace_id), + "trace_event_count": len(trace_events), + } + + +def validate_linked_task_step_for_approval( + store: ContinuityStore, + *, + approval_id: UUID, + task_step_id: UUID | None, +) -> tuple[TaskRow, TaskStepRow]: + if task_step_id is None: + raise TaskStepApprovalLinkageError(f"approval {approval_id} is missing linked task_step_id") + + unlocked_task = store.get_task_by_approval_optional(approval_id) + if unlocked_task is None: + raise TaskStepApprovalLinkageError(f"approval {approval_id} is not linked to a visible task") + store.lock_task_steps(cast(UUID, unlocked_task["id"])) + + task = store.get_task_optional(cast(UUID, unlocked_task["id"])) + if task is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared during approval linkage validation" + ) + + task_step = store.get_task_step_optional(task_step_id) + if task_step is None: + raise TaskStepApprovalLinkageError( + f"approval {approval_id} references linked task step {task_step_id} that was not found" + ) + if task_step["task_id"] != task["id"]: + raise TaskStepApprovalLinkageError( + f"approval {approval_id} links task step {task_step_id} outside task {task['id']}" + ) + + outcome = cast(TaskStepOutcomeSnapshot, task_step["outcome"]) + if outcome["approval_id"] != str(approval_id): + raise TaskStepApprovalLinkageError( + f"approval {approval_id} is inconsistent with linked task step {task_step_id}" + ) + + return task, task_step + + +def validate_linked_task_step_for_execution( + store: ContinuityStore, + *, + task_id: UUID, + execution: ToolExecutionRow, +) -> TaskStepRow: + store.lock_task_steps(task_id) + + execution_id = cast(UUID, execution["id"]) + task_step_id = cast(UUID | None, execution["task_step_id"]) + if task_step_id is None: + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} is missing linked task_step_id" + ) + + task_step = store.get_task_step_optional(task_step_id) + if task_step is None: + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} references linked task step {task_step_id} that was not found" + ) + if task_step["task_id"] != task_id: + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} links task step {task_step_id} outside task {task_id}" + ) + + outcome = cast(TaskStepOutcomeSnapshot, task_step["outcome"]) + if outcome["approval_id"] != str(execution["approval_id"]): + raise TaskStepExecutionLinkageError( + f"tool execution {execution_id} is inconsistent with linked task step {task_step_id}" + ) + + return task_step + + +def serialize_task_row(row: TaskRow) -> TaskRecord: + return { + "id": str(row["id"]), + "thread_id": str(row["thread_id"]), + "tool_id": str(row["tool_id"]), + "status": cast(TaskStatus, row["status"]), + "request": cast(dict[str, object], row["request"]), + "tool": cast(dict[str, object], row["tool"]), + "latest_approval_id": None if row["latest_approval_id"] is None else str(row["latest_approval_id"]), + "latest_execution_id": None if row["latest_execution_id"] is None else str(row["latest_execution_id"]), + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def serialize_task_step_row(row: TaskStepRow) -> TaskStepRecord: + return { + "id": str(row["id"]), + "task_id": str(row["task_id"]), + "sequence_no": row["sequence_no"], + "lineage": { + "parent_step_id": None if row["parent_step_id"] is None else str(row["parent_step_id"]), + "source_approval_id": ( + None if row["source_approval_id"] is None else str(row["source_approval_id"]) + ), + "source_execution_id": ( + None if row["source_execution_id"] is None else str(row["source_execution_id"]) + ), + }, + "kind": cast(str, row["kind"]), + "status": cast(TaskStepStatus, row["status"]), + "request": cast(dict[str, object], row["request"]), + "outcome": cast(TaskStepOutcomeSnapshot, row["outcome"]), + "trace": { + "trace_id": str(row["trace_id"]), + "trace_kind": row["trace_kind"], + }, + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def task_status_for_routing_decision(decision: str) -> TaskStatus: + return { + "approval_required": "pending_approval", + "ready": "approved", + "denied": "denied", + }[decision] + + +def task_status_for_approval_status(approval_status: str) -> TaskStatus: + return { + "pending": "pending_approval", + "approved": "approved", + "rejected": "denied", + }[approval_status] + + +def next_task_status_for_approval( + *, + current_status: TaskStatus, + approval_status: str, +) -> TaskStatus: + if current_status in {"executed", "blocked"}: + return current_status + return task_status_for_approval_status(approval_status) + + +def task_status_for_execution_status(execution_status: str) -> TaskStatus: + return { + "completed": "executed", + "blocked": "blocked", + }[execution_status] + + +def task_status_for_step_status(step_status: TaskStepStatus) -> TaskStatus: + return { + "created": "pending_approval", + "approved": "approved", + "executed": "executed", + "blocked": "blocked", + "denied": "denied", + }[step_status] + + +def task_step_status_for_routing_decision(decision: str) -> TaskStepStatus: + return { + "approval_required": "created", + "ready": "approved", + "denied": "denied", + }[decision] + + +def task_step_status_for_approval_status(approval_status: str) -> TaskStepStatus: + return { + "pending": "created", + "approved": "approved", + "rejected": "denied", + }[approval_status] + + +def next_task_step_status_for_approval( + *, + current_status: TaskStepStatus, + approval_status: str, +) -> TaskStepStatus: + if current_status in {"executed", "blocked"}: + return current_status + return task_step_status_for_approval_status(approval_status) + + +def task_step_status_for_execution_status(execution_status: str) -> TaskStepStatus: + return { + "completed": "executed", + "blocked": "blocked", + }[execution_status] + + +def allowed_task_step_transitions(current_status: TaskStepStatus) -> list[TaskStepStatus]: + return list(TASK_STEP_STATUS_GRAPH[current_status]) + + +def task_step_outcome_snapshot( + *, + routing_decision: str, + approval_id: str | None, + approval_status: str | None, + execution_id: str | None, + execution_status: str | None, + blocked_reason: str | None, +) -> TaskStepOutcomeSnapshot: + return { + "routing_decision": cast(str, routing_decision), + "approval_id": approval_id, + "approval_status": cast(str | None, approval_status), + "execution_id": execution_id, + "execution_status": cast(str | None, execution_status), + "blocked_reason": blocked_reason, + } + + +def create_task_for_governed_request( + store: ContinuityStore, + *, + request: TaskCreateInput, +) -> TaskCreateResponse: + task = store.create_task( + thread_id=request.thread_id, + tool_id=request.tool_id, + status=request.status, + request=cast(dict[str, object], request.request), + tool=cast(dict[str, object], request.tool), + latest_approval_id=request.latest_approval_id, + latest_execution_id=request.latest_execution_id, + ) + return {"task": serialize_task_row(task)} + + +def create_task_step_for_governed_request( + store: ContinuityStore, + *, + request: TaskStepCreateInput, +) -> TaskStepCreateResponse: + task_step = store.create_task_step( + task_id=request.task_id, + sequence_no=request.sequence_no, + kind=request.kind, + status=request.status, + request=cast(dict[str, object], request.request), + outcome=cast(dict[str, object], request.outcome), + trace_id=request.trace_id, + trace_kind=request.trace_kind, + ) + return {"task_step": serialize_task_step_row(task_step)} + + +def _task_step_sequencing_summary( + *, + task_id: str, + items: list[TaskStepRecord], +) -> TaskStepListSummary: + latest = items[-1] if items else None + latest_status = None if latest is None else latest["status"] + latest_sequence_no = None if latest is None else latest["sequence_no"] + return { + "task_id": task_id, + "total_count": len(items), + "latest_sequence_no": latest_sequence_no, + "latest_status": latest_status, + "next_sequence_no": 1 if latest_sequence_no is None else latest_sequence_no + 1, + "append_allowed": latest_status in TASK_STEP_APPENDABLE_STATUSES if latest_status is not None else False, + "order": list(TASK_STEP_LIST_ORDER), + } + + +def _validated_optional_approval_id( + store: ContinuityStore, + *, + approval_id: str | None, + current_approval_id: UUID | None, + task: TaskRow, + require_existing: bool, + missing_error: str, + error_cls: type[TaskStepSequenceError] | type[TaskStepTransitionError], +) -> UUID | None: + def _approval_belongs_to_task(approval_uuid: UUID) -> bool: + if current_approval_id == approval_uuid: + return True + for task_step in store.list_task_steps_for_task(task["id"]): + outcome = cast(dict[str, object], task_step["outcome"]) + linked_approval_id = outcome.get("approval_id") + if linked_approval_id is not None and str(linked_approval_id) == str(approval_uuid): + return True + return False + + if approval_id is None: + if require_existing and current_approval_id is None: + raise error_cls(missing_error) + approval_uuid = current_approval_id + else: + approval_uuid = UUID(approval_id) + if not _approval_belongs_to_task(approval_uuid): + raise error_cls(f"approval {approval_uuid} does not belong to task {task['id']}") + + if approval_uuid is None: + return None + + approval_row = store.get_approval_optional(approval_uuid) + if approval_row is None: + raise error_cls(f"approval {approval_uuid} was not found") + return approval_uuid + + +def _validated_optional_execution_id( + store: ContinuityStore, + *, + execution_id: str | None, + current_execution_id: UUID | None, + task: TaskRow, + require_existing: bool, + missing_error: str, + error_cls: type[TaskStepSequenceError] | type[TaskStepTransitionError], +) -> UUID | None: + def _execution_belongs_to_task(execution_uuid: UUID) -> bool: + if current_execution_id == execution_uuid: + return True + for task_step in store.list_task_steps_for_task(task["id"]): + outcome = cast(dict[str, object], task_step["outcome"]) + linked_execution_id = outcome.get("execution_id") + if linked_execution_id is not None and str(linked_execution_id) == str(execution_uuid): + return True + return False + + if execution_id is None: + if require_existing and current_execution_id is None: + raise error_cls(missing_error) + execution_uuid = current_execution_id + else: + execution_uuid = UUID(execution_id) + if not _execution_belongs_to_task(execution_uuid): + raise error_cls(f"tool execution {execution_uuid} does not belong to task {task['id']}") + + if execution_uuid is None: + return None + + execution_row = store.get_tool_execution_optional(execution_uuid) + if execution_row is None: + raise error_cls(f"tool execution {execution_uuid} was not found") + return execution_uuid + + +def _validated_continuation_parent_step( + *, + task_id: UUID, + latest: TaskStepRecord, + existing_items: list[TaskStepRecord], + parent_step_id: UUID, +) -> TaskStepRecord: + parent_step = next( + ( + item + for item in existing_items + if item["id"] == str(parent_step_id) + ), + None, + ) + if parent_step is None: + raise TaskStepSequenceError(f"task step {parent_step_id} does not belong to task {task_id}") + if parent_step["id"] != latest["id"]: + raise TaskStepSequenceError( + f"task {task_id} continuation must reference latest step {latest['id']}; received {parent_step_id}" + ) + return parent_step + + +def _validated_continuation_lineage( + *, + parent_step: TaskStepRecord, + source_approval_id: UUID | None, + source_execution_id: UUID | None, +) -> TaskStepLineageRecord: + parent_outcome = parent_step["outcome"] + if source_approval_id is not None and parent_outcome["approval_id"] != str(source_approval_id): + raise TaskStepSequenceError( + f"approval {source_approval_id} is not linked from parent step {parent_step['id']}" + ) + if source_execution_id is not None and parent_outcome["execution_id"] != str(source_execution_id): + raise TaskStepSequenceError( + f"tool execution {source_execution_id} is not linked from parent step {parent_step['id']}" + ) + + return { + "parent_step_id": parent_step["id"], + "source_approval_id": None if source_approval_id is None else str(source_approval_id), + "source_execution_id": None if source_execution_id is None else str(source_execution_id), + } + + +def sync_task_with_task_step_status( + store: ContinuityStore, + *, + task_id: UUID, + task_step_status: TaskStepStatus, + linked_approval_id: UUID | None, + linked_execution_id: UUID | None, +) -> TaskTransitionResult: + current = store.get_task_optional(task_id) + if current is None: + raise ContinuityStoreInvariantError( + f"task {task_id} disappeared before task-step lifecycle synchronization" + ) + previous_status = cast(TaskStatus, current["status"]) + target_status = task_status_for_step_status(task_step_status) + latest_execution_id = ( + current["latest_execution_id"] if linked_execution_id is None else linked_execution_id + ) if target_status in {"executed", "blocked"} else None + updated = store.update_task_status_optional( + task_id=task_id, + status=target_status, + latest_approval_id=linked_approval_id, + latest_execution_id=latest_execution_id, + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"task {task_id} disappeared during task-step lifecycle synchronization" + ) + return TaskTransitionResult( + task=serialize_task_row(updated), + previous_status=previous_status, + ) + + +def sync_task_with_approval( + store: ContinuityStore, + *, + approval_id: UUID, + approval_status: str, +) -> TaskTransitionResult: + current = store.get_task_by_approval_optional(approval_id) + if current is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared before lifecycle synchronization" + ) + previous_status = cast(TaskStatus, current["status"]) + + updated = store.update_task_status_by_approval_optional( + approval_id=approval_id, + status=next_task_status_for_approval( + current_status=previous_status, + approval_status=approval_status, + ), + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared during lifecycle synchronization" + ) + + return TaskTransitionResult( + task=serialize_task_row(updated), + previous_status=previous_status, + ) + + +def sync_task_step_with_approval( + store: ContinuityStore, + *, + approval_id: UUID, + task_step_id: UUID | None, + approval_status: str, + trace_id: UUID, + trace_kind: str, +) -> TaskStepTransitionResult: + _, current = validate_linked_task_step_for_approval( + store, + approval_id=approval_id, + task_step_id=task_step_id, + ) + previous_status = cast(TaskStepStatus, current["status"]) + current_outcome = cast(TaskStepOutcomeSnapshot, current["outcome"]) + updated_outcome = task_step_outcome_snapshot( + routing_decision=current_outcome["routing_decision"], + approval_id=str(approval_id), + approval_status=approval_status, + execution_id=current_outcome["execution_id"], + execution_status=current_outcome["execution_status"], + blocked_reason=current_outcome["blocked_reason"], + ) + + updated = store.update_task_step_optional( + task_step_id=cast(UUID, current["id"]), + status=next_task_step_status_for_approval( + current_status=previous_status, + approval_status=approval_status, + ), + outcome=cast(dict[str, object], updated_outcome), + trace_id=trace_id, + trace_kind=trace_kind, + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"linked task step {current['id']} disappeared during approval lifecycle synchronization" + ) + + return TaskStepTransitionResult( + task_step=serialize_task_step_row(updated), + previous_status=previous_status, + ) + + +def sync_task_with_execution( + store: ContinuityStore, + *, + approval_id: UUID, + execution_id: UUID, + execution_status: str, +) -> TaskTransitionResult: + current = store.get_task_by_approval_optional(approval_id) + if current is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared before execution synchronization" + ) + previous_status = cast(TaskStatus, current["status"]) + + updated = store.update_task_execution_by_approval_optional( + approval_id=approval_id, + latest_execution_id=execution_id, + status=task_status_for_execution_status(execution_status), + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"task for approval {approval_id} disappeared during execution synchronization" + ) + + return TaskTransitionResult( + task=serialize_task_row(updated), + previous_status=previous_status, + ) + + +def sync_task_step_with_execution( + store: ContinuityStore, + *, + task_id: UUID, + execution: ToolExecutionRow, + trace_id: UUID, + trace_kind: str, +) -> TaskStepTransitionResult: + current = validate_linked_task_step_for_execution( + store, + task_id=task_id, + execution=execution, + ) + previous_status = cast(TaskStepStatus, current["status"]) + current_outcome = cast(TaskStepOutcomeSnapshot, current["outcome"]) + execution_result = cast(dict[str, object], execution["result"]) + updated_outcome = task_step_outcome_snapshot( + routing_decision=current_outcome["routing_decision"], + approval_id=current_outcome["approval_id"], + approval_status=current_outcome["approval_status"], + execution_id=str(execution["id"]), + execution_status=cast(str, execution["status"]), + blocked_reason=cast(str | None, execution_result.get("reason")), + ) + + updated = store.update_task_step_optional( + task_step_id=cast(UUID, current["id"]), + status=task_step_status_for_execution_status(cast(str, execution["status"])), + outcome=cast(dict[str, object], updated_outcome), + trace_id=trace_id, + trace_kind=trace_kind, + ) + if updated is None: + raise ContinuityStoreInvariantError( + f"linked task step {current['id']} disappeared during execution lifecycle synchronization" + ) + + return TaskStepTransitionResult( + task_step=serialize_task_step_row(updated), + previous_status=previous_status, + ) + + +def create_next_task_step_record( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskStepNextCreateInput, +) -> TaskStepNextCreateResponse: + del user_id + + task_row = store.get_task_optional(request.task_id) + if task_row is None: + raise TaskNotFoundError(f"task {request.task_id} was not found") + + store.lock_task_steps(request.task_id) + existing_items = [serialize_task_step_row(row) for row in store.list_task_steps_for_task(request.task_id)] + if not existing_items: + raise TaskStepSequenceError(f"task {request.task_id} has no existing steps and cannot append a next step") + + latest = existing_items[-1] + if latest["status"] not in TASK_STEP_APPENDABLE_STATUSES: + raise TaskStepSequenceError( + f"task {request.task_id} latest step {latest['id']} is {latest['status']} and cannot append a next step" + ) + if request.status not in TASK_STEP_INITIAL_STATUSES: + allowed = ", ".join(sorted(TASK_STEP_INITIAL_STATUSES)) + raise TaskStepSequenceError( + f"new task step for task {request.task_id} must start in one of {allowed}; received {request.status}" + ) + parent_step = _validated_continuation_parent_step( + task_id=request.task_id, + latest=latest, + existing_items=existing_items, + parent_step_id=request.lineage.parent_step_id, + ) + source_approval_id = _validated_optional_approval_id( + store, + approval_id=( + None if request.lineage.source_approval_id is None else str(request.lineage.source_approval_id) + ), + current_approval_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + source_execution_id = _validated_optional_execution_id( + store, + execution_id=( + None if request.lineage.source_execution_id is None else str(request.lineage.source_execution_id) + ), + current_execution_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + lineage = _validated_continuation_lineage( + parent_step=parent_step, + source_approval_id=source_approval_id, + source_execution_id=source_execution_id, + ) + linked_approval_id = _validated_optional_approval_id( + store, + approval_id=request.outcome["approval_id"], + current_approval_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + linked_execution_id = _validated_optional_execution_id( + store, + execution_id=request.outcome["execution_id"], + current_execution_id=None, + task=task_row, + require_existing=False, + missing_error="", + error_cls=TaskStepSequenceError, + ) + + trace = store.create_trace( + user_id=task_row["user_id"], + thread_id=task_row["thread_id"], + kind=TRACE_KIND_TASK_STEP_CONTINUATION, + compiler_version=TASK_STEP_CONTINUATION_VERSION_V0, + status="completed", + limits={ + "order": list(TASK_STEP_LIST_ORDER), + "appendable_statuses": sorted(TASK_STEP_APPENDABLE_STATUSES), + "initial_statuses": sorted(TASK_STEP_INITIAL_STATUSES), + "parent_step_id": parent_step["id"], + "parent_sequence_no": parent_step["sequence_no"], + }, + ) + try: + created = store.create_task_step( + task_id=request.task_id, + sequence_no=latest["sequence_no"] + 1, + parent_step_id=request.lineage.parent_step_id, + source_approval_id=source_approval_id, + source_execution_id=source_execution_id, + kind=request.kind, + status=request.status, + request=cast(dict[str, object], request.request), + outcome=cast(dict[str, object], request.outcome), + trace_id=trace["id"], + trace_kind=TRACE_KIND_TASK_STEP_CONTINUATION, + ) + except psycopg.IntegrityError as exc: + raise TaskStepSequenceError( + f"task {request.task_id} next-step creation conflicted with a concurrent append" + ) from exc + task_step = serialize_task_step_row(created) + task_transition = sync_task_with_task_step_status( + store, + task_id=request.task_id, + task_step_status=request.status, + linked_approval_id=( + source_approval_id if request.status == "created" and linked_approval_id is None else linked_approval_id + ), + linked_execution_id=linked_execution_id, + ) + updated_items = [*existing_items, task_step] + sequencing = _task_step_sequencing_summary( + task_id=str(task_row["id"]), + items=updated_items, + ) + + request_payload: TaskStepContinuationRequestTracePayload = { + "task_id": str(task_row["id"]), + "parent_task_step_id": parent_step["id"], + "parent_sequence_no": parent_step["sequence_no"], + "parent_status": parent_step["status"], + "requested_kind": request.kind, + "requested_status": request.status, + "requested_source_approval_id": lineage["source_approval_id"], + "requested_source_execution_id": lineage["source_execution_id"], + } + lineage_payload: TaskStepContinuationLineageTracePayload = { + "task_id": str(task_row["id"]), + "parent_task_step_id": parent_step["id"], + "parent_sequence_no": parent_step["sequence_no"], + "parent_status": parent_step["status"], + "source_approval_id": lineage["source_approval_id"], + "source_execution_id": lineage["source_execution_id"], + } + summary_payload: TaskStepContinuationSummaryTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": task_step["id"], + "latest_sequence_no": task_step["sequence_no"], + "next_sequence_no": sequencing["next_sequence_no"], + "append_allowed": sequencing["append_allowed"], + "lineage": task_step["lineage"], + } + trace_events: list[tuple[str, dict[str, object]]] = [ + (TASK_STEP_CONTINUATION_REQUEST_EVENT_KIND, cast(dict[str, object], request_payload)), + (TASK_STEP_CONTINUATION_LINEAGE_EVENT_KIND, cast(dict[str, object], lineage_payload)), + (TASK_STEP_CONTINUATION_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="task_step_continuation", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=task_step, + previous_status=None, + source="task_step_continuation", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + return { + "task": task_transition.task, + "task_step": task_step, + "sequencing": sequencing, + "trace": _trace_summary(trace["id"], trace_events), + } + + +def transition_task_step_record( + store: ContinuityStore, + *, + user_id: UUID, + request: TaskStepTransitionInput, +) -> TaskStepTransitionResponse: + del user_id + + step_row = store.get_task_step_optional(request.task_step_id) + if step_row is None: + raise TaskStepNotFoundError(f"task step {request.task_step_id} was not found") + + task_row = store.get_task_optional(step_row["task_id"]) + if task_row is None: + raise ContinuityStoreInvariantError( + f"task {step_row['task_id']} disappeared before task-step transition" + ) + + existing_items = [serialize_task_step_row(row) for row in store.list_task_steps_for_task(step_row["task_id"])] + latest = existing_items[-1] if existing_items else None + if latest is None: + raise ContinuityStoreInvariantError( + f"task {step_row['task_id']} has no visible steps during transition" + ) + if latest["id"] != str(step_row["id"]): + raise TaskStepTransitionError( + f"task step {request.task_step_id} is not the latest step on task {step_row['task_id']}" + ) + + previous_status = cast(TaskStepStatus, step_row["status"]) + allowed_next_statuses = allowed_task_step_transitions(previous_status) + if request.status not in allowed_next_statuses: + allowed = ", ".join(allowed_next_statuses) or "no further statuses" + raise TaskStepTransitionError( + f"task step {request.task_step_id} is {previous_status} and cannot transition to {request.status}; allowed: {allowed}" + ) + linked_approval_id = _validated_optional_approval_id( + store, + approval_id=request.outcome["approval_id"], + current_approval_id=task_row["latest_approval_id"], + task=task_row, + require_existing=request.status == "created", + missing_error=f"task {task_row['id']} cannot reflect created without an approval link", + error_cls=TaskStepTransitionError, + ) + linked_execution_id = _validated_optional_execution_id( + store, + execution_id=request.outcome["execution_id"], + current_execution_id=task_row["latest_execution_id"], + task=task_row, + require_existing=request.status in {"executed", "blocked"}, + missing_error=f"task {task_row['id']} cannot reflect {request.status} without an existing execution link", + error_cls=TaskStepTransitionError, + ) + + trace = store.create_trace( + user_id=task_row["user_id"], + thread_id=task_row["thread_id"], + kind=TRACE_KIND_TASK_STEP_TRANSITION, + compiler_version=TASK_STEP_TRANSITION_VERSION_V0, + status="completed", + limits={ + "order": list(TASK_STEP_LIST_ORDER), + "status_graph": {status: list(next_statuses) for status, next_statuses in TASK_STEP_STATUS_GRAPH.items()}, + "requested_status": request.status, + }, + ) + updated_row = store.update_task_step_for_task_sequence_optional( + task_id=step_row["task_id"], + sequence_no=step_row["sequence_no"], + status=request.status, + outcome=cast(dict[str, object], request.outcome), + trace_id=trace["id"], + trace_kind=TRACE_KIND_TASK_STEP_TRANSITION, + ) + if updated_row is None: + raise ContinuityStoreInvariantError( + f"task step {request.task_step_id} disappeared during transition" + ) + + updated_step = serialize_task_step_row(updated_row) + task_transition = sync_task_with_task_step_status( + store, + task_id=step_row["task_id"], + task_step_status=request.status, + linked_approval_id=linked_approval_id, + linked_execution_id=linked_execution_id, + ) + updated_items = [*existing_items[:-1], updated_step] + sequencing = _task_step_sequencing_summary( + task_id=str(task_row["id"]), + items=updated_items, + ) + + request_payload: TaskStepTransitionRequestTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": updated_step["id"], + "sequence_no": updated_step["sequence_no"], + "previous_status": previous_status, + "requested_status": request.status, + } + state_payload: TaskStepTransitionStateTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": updated_step["id"], + "sequence_no": updated_step["sequence_no"], + "previous_status": previous_status, + "current_status": updated_step["status"], + "allowed_next_statuses": allowed_next_statuses, + "trace": updated_step["trace"], + } + summary_payload: TaskStepTransitionSummaryTracePayload = { + "task_id": str(task_row["id"]), + "task_step_id": updated_step["id"], + "sequence_no": updated_step["sequence_no"], + "final_status": updated_step["status"], + "parent_task_status": task_transition.task["status"], + "trace": updated_step["trace"], + } + trace_events: list[tuple[str, dict[str, object]]] = [ + (TASK_STEP_TRANSITION_REQUEST_EVENT_KIND, cast(dict[str, object], request_payload)), + (TASK_STEP_TRANSITION_STATE_EVENT_KIND, cast(dict[str, object], state_payload)), + (TASK_STEP_TRANSITION_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + trace_events.extend( + task_lifecycle_trace_events( + task=task_transition.task, + previous_status=task_transition.previous_status, + source="task_step_transition", + ) + ) + trace_events.extend( + task_step_lifecycle_trace_events( + task_step=updated_step, + previous_status=previous_status, + source="task_step_transition", + ) + ) + _append_trace_events(store, trace_id=trace["id"], trace_events=trace_events) + + return { + "task": task_transition.task, + "task_step": updated_step, + "sequencing": sequencing, + "trace": _trace_summary(trace["id"], trace_events), + } + + +def task_lifecycle_trace_events( + *, + task: TaskRecord, + previous_status: TaskStatus | None, + source: TaskLifecycleSource, +) -> list[tuple[str, dict[str, object]]]: + state_payload: TaskLifecycleStateTracePayload = { + "task_id": task["id"], + "source": source, + "previous_status": previous_status, + "current_status": task["status"], + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + } + summary_payload: TaskLifecycleSummaryTracePayload = { + "task_id": task["id"], + "source": source, + "final_status": task["status"], + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + } + return [ + (TASK_LIFECYCLE_STATE_EVENT_KIND, cast(dict[str, object], state_payload)), + (TASK_LIFECYCLE_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + + +def task_step_lifecycle_trace_events( + *, + task_step: TaskStepRecord, + previous_status: TaskStepStatus | None, + source: TaskLifecycleSource, +) -> list[tuple[str, dict[str, object]]]: + state_payload: TaskStepLifecycleStateTracePayload = { + "task_id": task_step["task_id"], + "task_step_id": task_step["id"], + "source": source, + "sequence_no": task_step["sequence_no"], + "kind": task_step["kind"], + "previous_status": previous_status, + "current_status": task_step["status"], + "trace": task_step["trace"], + } + summary_payload: TaskStepLifecycleSummaryTracePayload = { + "task_id": task_step["task_id"], + "task_step_id": task_step["id"], + "source": source, + "sequence_no": task_step["sequence_no"], + "kind": task_step["kind"], + "final_status": task_step["status"], + "trace": task_step["trace"], + } + return [ + (TASK_STEP_LIFECYCLE_STATE_EVENT_KIND, cast(dict[str, object], state_payload)), + (TASK_STEP_LIFECYCLE_SUMMARY_EVENT_KIND, cast(dict[str, object], summary_payload)), + ] + + +def list_task_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> TaskListResponse: + del user_id + + items = [serialize_task_row(row) for row in store.list_tasks()] + summary: TaskListSummary = { + "total_count": len(items), + "order": list(TASK_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_task_record( + store: ContinuityStore, + *, + user_id: UUID, + task_id: UUID, +) -> TaskDetailResponse: + del user_id + + task = store.get_task_optional(task_id) + if task is None: + raise TaskNotFoundError(f"task {task_id} was not found") + return {"task": serialize_task_row(task)} + + +def list_task_step_records( + store: ContinuityStore, + *, + user_id: UUID, + task_id: UUID, +) -> TaskStepListResponse: + del user_id + + task = store.get_task_optional(task_id) + if task is None: + raise TaskNotFoundError(f"task {task_id} was not found") + + items = [serialize_task_step_row(row) for row in store.list_task_steps_for_task(task_id)] + summary = _task_step_sequencing_summary(task_id=str(task["id"]), items=items) + return { + "items": items, + "summary": summary, + } + + +def get_task_step_record( + store: ContinuityStore, + *, + user_id: UUID, + task_step_id: UUID, +) -> TaskStepDetailResponse: + del user_id + + task_step = store.get_task_step_optional(task_step_id) + if task_step is None: + raise TaskStepNotFoundError(f"task step {task_step_id} was not found") + return {"task_step": serialize_task_step_row(task_step)} diff --git a/apps/api/src/alicebot_api/tools.py b/apps/api/src/alicebot_api/tools.py new file mode 100644 index 0000000..0634990 --- /dev/null +++ b/apps/api/src/alicebot_api/tools.py @@ -0,0 +1,553 @@ +from __future__ import annotations + +from dataclasses import dataclass +from uuid import UUID + +from alicebot_api.contracts import ( + TOOL_ALLOWLIST_EVALUATION_VERSION_V0, + TOOL_ROUTING_VERSION_V0, + TOOL_LIST_ORDER, + TRACE_KIND_TOOL_ALLOWLIST_EVALUATE, + TRACE_KIND_TOOL_ROUTE, + PolicyEvaluationRequestInput, + ToolAllowlistDecisionRecord, + ToolAllowlistEvaluationRequestInput, + ToolAllowlistEvaluationResponse, + ToolAllowlistEvaluationSummary, + ToolAllowlistReason, + ToolAllowlistTraceSummary, + ToolRoutingDecision, + ToolRoutingDecisionTracePayload, + ToolRoutingRequestInput, + ToolRoutingRequestTracePayload, + ToolRoutingResponse, + ToolRoutingSummary, + ToolRoutingSummaryTracePayload, + ToolRoutingTraceSummary, + ToolCreateInput, + ToolCreateResponse, + ToolDetailResponse, + ToolListResponse, + ToolListSummary, + ToolRecord, + isoformat_or_none, +) +from alicebot_api.policy import ( + evaluate_policy_against_context, + load_policy_evaluation_context, +) +from alicebot_api.store import ContinuityStore, ToolRow + + +class ToolValidationError(ValueError): + """Raised when a tool-registry request fails explicit validation.""" + + +class ToolNotFoundError(LookupError): + """Raised when a requested tool is not visible inside the current user scope.""" + + +class ToolAllowlistValidationError(ValueError): + """Raised when a tool-allowlist evaluation request fails explicit validation.""" + + +class ToolRoutingValidationError(ValueError): + """Raised when a tool-routing request fails explicit validation.""" + + +@dataclass(frozen=True, slots=True) +class ToolClassificationResult: + decision: str + tool: ToolRecord + reasons: list[ToolAllowlistReason] + matched_policy_id: str | None + + +def _serialize_tool(tool: ToolRow) -> ToolRecord: + return { + "id": str(tool["id"]), + "tool_key": tool["tool_key"], + "name": tool["name"], + "description": tool["description"], + "version": tool["version"], + "metadata_version": tool["metadata_version"], + "active": tool["active"], + "tags": list(tool["tags"]), + "action_hints": list(tool["action_hints"]), + "scope_hints": list(tool["scope_hints"]), + "domain_hints": list(tool["domain_hints"]), + "risk_hints": list(tool["risk_hints"]), + "metadata": tool["metadata"], + "created_at": tool["created_at"].isoformat(), + } + + +def _build_tool_reason( + *, + code: str, + source: str, + message: str, + tool_id: UUID, + policy_id: str | None = None, + consent_key: str | None = None, +) -> ToolAllowlistReason: + return { + "code": code, + "source": source, + "message": message, + "tool_id": str(tool_id), + "policy_id": policy_id, + "consent_key": consent_key, + } + + +def _metadata_match_reasons( + *, + tool: ToolRow, + request: ToolAllowlistEvaluationRequestInput, +) -> tuple[bool, list[ToolAllowlistReason]]: + reasons: list[ToolAllowlistReason] = [] + matched = True + + if request.action not in tool["action_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_action_unsupported", + source="tool", + message=f"Tool '{tool['tool_key']}' does not declare support for action '{request.action}'.", + tool_id=tool["id"], + ) + ) + + if request.scope not in tool["scope_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_scope_unsupported", + source="tool", + message=f"Tool '{tool['tool_key']}' does not declare support for scope '{request.scope}'.", + tool_id=tool["id"], + ) + ) + + if request.domain_hint is not None and tool["domain_hints"] and request.domain_hint not in tool["domain_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_domain_mismatch", + source="tool", + message=( + f"Tool '{tool['tool_key']}' does not declare domain hint '{request.domain_hint}'." + ), + tool_id=tool["id"], + ) + ) + + if request.risk_hint is not None and tool["risk_hints"] and request.risk_hint not in tool["risk_hints"]: + matched = False + reasons.append( + _build_tool_reason( + code="tool_risk_mismatch", + source="tool", + message=f"Tool '{tool['tool_key']}' does not declare risk hint '{request.risk_hint}'.", + tool_id=tool["id"], + ) + ) + + if matched: + reasons.append( + _build_tool_reason( + code="tool_metadata_matched", + source="tool", + message="Tool metadata matched the requested action, scope, and optional hints.", + tool_id=tool["id"], + ) + ) + + return matched, reasons + + +def _policy_attributes( + *, + tool: ToolRow, + request: ToolAllowlistEvaluationRequestInput, +) -> dict[str, object]: + attributes: dict[str, object] = dict(request.attributes) + attributes["tool_key"] = tool["tool_key"] + attributes["tool_version"] = tool["version"] + attributes["metadata_version"] = tool["metadata_version"] + if request.domain_hint is not None: + attributes["domain_hint"] = request.domain_hint + if request.risk_hint is not None: + attributes["risk_hint"] = request.risk_hint + return attributes + + +def _classify_tool_request( + *, + tool: ToolRow, + request: ToolAllowlistEvaluationRequestInput, + policy_context, +) -> ToolClassificationResult: + metadata_matched, metadata_reasons = _metadata_match_reasons(tool=tool, request=request) + serialized_tool = _serialize_tool(tool) + + if not metadata_matched: + return ToolClassificationResult( + decision="denied", + tool=serialized_tool, + reasons=metadata_reasons, + matched_policy_id=None, + ) + + policy_decision = evaluate_policy_against_context( + policy_context, + request=PolicyEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + attributes=_policy_attributes(tool=tool, request=request), + ), + ) + reasons = metadata_reasons + [ + { + "code": reason["code"], + "source": reason["source"], + "message": reason["message"], + "tool_id": str(tool["id"]), + "policy_id": reason["policy_id"], + "consent_key": reason["consent_key"], + } + for reason in policy_decision.reasons + ] + return ToolClassificationResult( + decision={ + "allow": "allowed", + "deny": "denied", + "require_approval": "approval_required", + }[policy_decision.decision], + tool=serialized_tool, + reasons=reasons, + matched_policy_id=( + None if policy_decision.matched_policy is None else str(policy_decision.matched_policy["id"]) + ), + ) + + +def _decision_record_from_classification( + classification: ToolClassificationResult, +) -> ToolAllowlistDecisionRecord: + return { + "decision": classification.decision, + "tool": classification.tool, + "reasons": classification.reasons, + } + + +def _allowlist_trace_payload( + classification: ToolClassificationResult, +) -> dict[str, object]: + return { + "tool_id": classification.tool["id"], + "tool_key": classification.tool["tool_key"], + "tool_version": classification.tool["version"], + "decision": classification.decision, + "matched_policy_id": classification.matched_policy_id, + "reasons": classification.reasons, + } + + +def _allowlist_request_from_routing( + request: ToolRoutingRequestInput, +) -> ToolAllowlistEvaluationRequestInput: + return ToolAllowlistEvaluationRequestInput( + thread_id=request.thread_id, + action=request.action, + scope=request.scope, + domain_hint=request.domain_hint, + risk_hint=request.risk_hint, + attributes=request.attributes, + ) + + +def _routing_decision_from_allowlist(allowlist_decision: str) -> ToolRoutingDecision: + return { + "allowed": "ready", + "denied": "denied", + "approval_required": "approval_required", + }[allowlist_decision] + + +def create_tool_record( + store: ContinuityStore, + *, + user_id: UUID, + tool: ToolCreateInput, +) -> ToolCreateResponse: + del user_id + + created = store.create_tool( + tool_key=tool.tool_key, + name=tool.name, + description=tool.description, + version=tool.version, + metadata_version=tool.metadata_version, + active=tool.active, + tags=list(tool.tags), + action_hints=list(tool.action_hints), + scope_hints=list(tool.scope_hints), + domain_hints=list(tool.domain_hints), + risk_hints=list(tool.risk_hints), + metadata=tool.metadata, + ) + return {"tool": _serialize_tool(created)} + + +def list_tool_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> ToolListResponse: + del user_id + + items = [_serialize_tool(tool) for tool in store.list_tools()] + summary: ToolListSummary = { + "total_count": len(items), + "order": list(TOOL_LIST_ORDER), + } + return { + "items": items, + "summary": summary, + } + + +def get_tool_record( + store: ContinuityStore, + *, + user_id: UUID, + tool_id: UUID, +) -> ToolDetailResponse: + del user_id + + tool = store.get_tool_optional(tool_id) + if tool is None: + raise ToolNotFoundError(f"tool {tool_id} was not found") + return {"tool": _serialize_tool(tool)} + + +def evaluate_tool_allowlist( + store: ContinuityStore, + *, + user_id: UUID, + request: ToolAllowlistEvaluationRequestInput, +) -> ToolAllowlistEvaluationResponse: + del user_id + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise ToolAllowlistValidationError( + "thread_id must reference an existing thread owned by the user" + ) + + active_tools = store.list_active_tools() + policy_context = load_policy_evaluation_context(store) + + allowed: list[ToolAllowlistDecisionRecord] = [] + denied: list[ToolAllowlistDecisionRecord] = [] + approval_required: list[ToolAllowlistDecisionRecord] = [] + tool_trace_events: list[tuple[str, dict[str, object]]] = [] + + for tool in active_tools: + classification = _classify_tool_request( + tool=tool, + request=request, + policy_context=policy_context, + ) + decision_record = _decision_record_from_classification(classification) + + if classification.decision == "allowed": + allowed.append(decision_record) + elif classification.decision == "approval_required": + approval_required.append(decision_record) + else: + denied.append(decision_record) + + tool_trace_events.append( + ( + "tool.allowlist.decision", + _allowlist_trace_payload(classification), + ) + ) + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_TOOL_ALLOWLIST_EVALUATE, + compiler_version=TOOL_ALLOWLIST_EVALUATION_VERSION_V0, + status="completed", + limits={ + "order": list(TOOL_LIST_ORDER), + "active_tool_count": len(active_tools), + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + }, + ) + + trace_events: list[tuple[str, dict[str, object]]] = [ + ( + "tool.allowlist.request", + { + "thread_id": str(request.thread_id), + "action": request.action, + "scope": request.scope, + "domain_hint": request.domain_hint, + "risk_hint": request.risk_hint, + "attributes": request.attributes, + }, + ), + ( + "tool.allowlist.order", + { + "order": list(TOOL_LIST_ORDER), + "tool_ids": [str(tool["id"]) for tool in active_tools], + }, + ), + *tool_trace_events, + ( + "tool.allowlist.summary", + { + "allowed_count": len(allowed), + "denied_count": len(denied), + "approval_required_count": len(approval_required), + }, + ), + ] + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + summary: ToolAllowlistEvaluationSummary = { + "action": request.action, + "scope": request.scope, + "domain_hint": request.domain_hint, + "risk_hint": request.risk_hint, + "evaluated_tool_count": len(active_tools), + "allowed_count": len(allowed), + "denied_count": len(denied), + "approval_required_count": len(approval_required), + "order": list(TOOL_LIST_ORDER), + } + trace_summary: ToolAllowlistTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "allowed": allowed, + "denied": denied, + "approval_required": approval_required, + "summary": summary, + "trace": trace_summary, + } + + +def route_tool_invocation( + store: ContinuityStore, + *, + user_id: UUID, + request: ToolRoutingRequestInput, +) -> ToolRoutingResponse: + del user_id + + thread = store.get_thread_optional(request.thread_id) + if thread is None: + raise ToolRoutingValidationError( + "thread_id must reference an existing thread owned by the user" + ) + + tool = store.get_tool_optional(request.tool_id) + if tool is None or tool["active"] is not True: + raise ToolRoutingValidationError( + "tool_id must reference an existing active tool owned by the user" + ) + + policy_context = load_policy_evaluation_context(store) + classification = _classify_tool_request( + tool=tool, + request=_allowlist_request_from_routing(request), + policy_context=policy_context, + ) + routing_decision = _routing_decision_from_allowlist(classification.decision) + + trace = store.create_trace( + user_id=thread["user_id"], + thread_id=thread["id"], + kind=TRACE_KIND_TOOL_ROUTE, + compiler_version=TOOL_ROUTING_VERSION_V0, + status="completed", + limits={ + "order": list(TOOL_LIST_ORDER), + "evaluated_tool_count": 1, + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + }, + ) + + request_payload: ToolRoutingRequestTracePayload = request.as_payload() + decision_payload: ToolRoutingDecisionTracePayload = { + "tool_id": classification.tool["id"], + "tool_key": classification.tool["tool_key"], + "tool_version": classification.tool["version"], + "allowlist_decision": classification.decision, + "routing_decision": routing_decision, + "matched_policy_id": classification.matched_policy_id, + "reasons": classification.reasons, + } + summary_payload: ToolRoutingSummaryTracePayload = { + "decision": routing_decision, + "evaluated_tool_count": 1, + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + } + trace_events = [ + ("tool.route.request", request_payload), + ("tool.route.decision", decision_payload), + ("tool.route.summary", summary_payload), + ] + for sequence_no, (kind, payload) in enumerate(trace_events, start=1): + store.append_trace_event( + trace_id=trace["id"], + sequence_no=sequence_no, + kind=kind, + payload=payload, + ) + + summary: ToolRoutingSummary = { + "thread_id": str(request.thread_id), + "tool_id": classification.tool["id"], + "action": request.action, + "scope": request.scope, + "domain_hint": request.domain_hint, + "risk_hint": request.risk_hint, + "decision": routing_decision, + "evaluated_tool_count": 1, + "active_policy_count": len(policy_context.active_policies), + "consent_count": len(policy_context.consents_by_key), + "order": list(TOOL_LIST_ORDER), + } + trace_summary: ToolRoutingTraceSummary = { + "trace_id": str(trace["id"]), + "trace_event_count": len(trace_events), + } + return { + "request": request_payload, + "decision": routing_decision, + "tool": classification.tool, + "reasons": classification.reasons, + "summary": summary, + "trace": trace_summary, + } diff --git a/apps/api/src/alicebot_api/workspaces.py b/apps/api/src/alicebot_api/workspaces.py new file mode 100644 index 0000000..d058fb0 --- /dev/null +++ b/apps/api/src/alicebot_api/workspaces.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from pathlib import Path +from typing import cast +from uuid import UUID + +from alicebot_api.config import Settings +from alicebot_api.contracts import ( + TASK_WORKSPACE_LIST_ORDER, + TaskWorkspaceCreateInput, + TaskWorkspaceCreateResponse, + TaskWorkspaceDetailResponse, + TaskWorkspaceListResponse, + TaskWorkspaceRecord, + TaskWorkspaceStatus, +) +from alicebot_api.tasks import TaskNotFoundError +from alicebot_api.store import ContinuityStore, TaskWorkspaceRow + + +class TaskWorkspaceNotFoundError(LookupError): + """Raised when a task workspace record is not visible inside the current user scope.""" + + +class TaskWorkspaceAlreadyExistsError(RuntimeError): + """Raised when an active task workspace already exists for a task.""" + + +class TaskWorkspaceProvisioningError(RuntimeError): + """Raised when local workspace provisioning cannot satisfy rooted path rules.""" + + +def resolve_workspace_root(workspace_root: str) -> Path: + return Path(workspace_root).expanduser().resolve() + + +def build_task_workspace_path( + *, + workspace_root: Path, + user_id: UUID, + task_id: UUID, +) -> Path: + return workspace_root / str(user_id) / str(task_id) + + +def ensure_workspace_path_is_rooted( + *, + workspace_root: Path, + workspace_path: Path, +) -> None: + resolved_root = workspace_root.resolve() + resolved_path = workspace_path.resolve() + try: + resolved_path.relative_to(resolved_root) + except ValueError as exc: + raise TaskWorkspaceProvisioningError( + f"workspace path {resolved_path} escapes configured root {resolved_root}" + ) from exc + + +def serialize_task_workspace_row(row: TaskWorkspaceRow) -> TaskWorkspaceRecord: + return { + "id": str(row["id"]), + "task_id": str(row["task_id"]), + "status": cast(TaskWorkspaceStatus, row["status"]), + "local_path": row["local_path"], + "created_at": row["created_at"].isoformat(), + "updated_at": row["updated_at"].isoformat(), + } + + +def create_task_workspace_record( + store: ContinuityStore, + *, + settings: Settings, + user_id: UUID, + request: TaskWorkspaceCreateInput, +) -> TaskWorkspaceCreateResponse: + task = store.get_task_optional(request.task_id) + if task is None: + raise TaskNotFoundError(f"task {request.task_id} was not found") + + workspace_root = resolve_workspace_root(settings.task_workspace_root) + workspace_path = build_task_workspace_path( + workspace_root=workspace_root, + user_id=user_id, + task_id=request.task_id, + ) + ensure_workspace_path_is_rooted( + workspace_root=workspace_root, + workspace_path=workspace_path, + ) + + store.lock_task_workspaces(request.task_id) + existing_workspace = store.get_active_task_workspace_for_task_optional(request.task_id) + if existing_workspace is not None: + raise TaskWorkspaceAlreadyExistsError( + f"task {request.task_id} already has active workspace {existing_workspace['id']}" + ) + + try: + workspace_path.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise TaskWorkspaceProvisioningError( + f"workspace path {workspace_path} could not be provisioned" + ) from exc + + row = store.create_task_workspace( + task_id=request.task_id, + status=request.status, + local_path=str(workspace_path), + ) + return {"workspace": serialize_task_workspace_row(row)} + + +def list_task_workspace_records( + store: ContinuityStore, + *, + user_id: UUID, +) -> TaskWorkspaceListResponse: + del user_id + + items = [serialize_task_workspace_row(row) for row in store.list_task_workspaces()] + return { + "items": items, + "summary": { + "total_count": len(items), + "order": list(TASK_WORKSPACE_LIST_ORDER), + }, + } + + +def get_task_workspace_record( + store: ContinuityStore, + *, + user_id: UUID, + task_workspace_id: UUID, +) -> TaskWorkspaceDetailResponse: + del user_id + + row = store.get_task_workspace_optional(task_workspace_id) + if row is None: + raise TaskWorkspaceNotFoundError(f"task workspace {task_workspace_id} was not found") + return {"workspace": serialize_task_workspace_row(row)} diff --git a/apps/web/.gitkeep b/apps/web/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/apps/web/.gitkeep @@ -0,0 +1 @@ + diff --git a/apps/web/app/layout.tsx b/apps/web/app/layout.tsx new file mode 100644 index 0000000..ed6cafd --- /dev/null +++ b/apps/web/app/layout.tsx @@ -0,0 +1,10 @@ +export default function RootLayout({ + children, +}: Readonly<{ children: React.ReactNode }>) { + return ( + + {children} + + ); +} + diff --git a/apps/web/app/page.tsx b/apps/web/app/page.tsx new file mode 100644 index 0000000..7a46a7a --- /dev/null +++ b/apps/web/app/page.tsx @@ -0,0 +1,51 @@ +const milestones = [ + "API foundation and migrations", + "Continuity event store", + "Web dashboard shell", + "Worker orchestration", +]; + +export default function HomePage() { + return ( +
+
+

+ AliceBot Foundation +

+

+ Operational shell for the modular monolith +

+

+ The web app is intentionally minimal in this sprint. It exists to prove repository + structure while continuity, migrations, and safety primitives land in the API layer. +

+
    + {milestones.map((item) => ( +
  • {item}
  • + ))} +
+
+
+ ); +} + diff --git a/apps/web/next-env.d.ts b/apps/web/next-env.d.ts new file mode 100644 index 0000000..dc86238 --- /dev/null +++ b/apps/web/next-env.d.ts @@ -0,0 +1,5 @@ +/// +/// + +// This file is managed by Next.js. + diff --git a/apps/web/next.config.mjs b/apps/web/next.config.mjs new file mode 100644 index 0000000..06cd07e --- /dev/null +++ b/apps/web/next.config.mjs @@ -0,0 +1,6 @@ +const nextConfig = { + reactStrictMode: true, +}; + +export default nextConfig; + diff --git a/apps/web/package.json b/apps/web/package.json new file mode 100644 index 0000000..7f5ec8b --- /dev/null +++ b/apps/web/package.json @@ -0,0 +1,25 @@ +{ + "name": "@alicebot/web", + "private": true, + "version": "0.1.0", + "scripts": { + "dev": "next dev", + "build": "next build", + "start": "next start", + "lint": "next lint" + }, + "dependencies": { + "next": "15.2.0", + "react": "19.0.0", + "react-dom": "19.0.0" + }, + "devDependencies": { + "@types/node": "22.13.10", + "@types/react": "19.0.10", + "@types/react-dom": "19.0.4", + "eslint": "9.22.0", + "eslint-config-next": "15.2.0", + "typescript": "5.8.2" + } +} + diff --git a/apps/web/tsconfig.json b/apps/web/tsconfig.json new file mode 100644 index 0000000..bbd3768 --- /dev/null +++ b/apps/web/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["dom", "dom.iterable", "es2022"], + "allowJs": false, + "skipLibCheck": true, + "strict": true, + "noEmit": true, + "module": "esnext", + "moduleResolution": "bundler", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve", + "incremental": true + }, + "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"], + "exclude": ["node_modules"] +} + diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..2066a2b --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,36 @@ +services: + postgres: + image: pgvector/pgvector:pg16 + container_name: alicebot-postgres + environment: + POSTGRES_USER: alicebot_admin + POSTGRES_PASSWORD: alicebot_admin + POSTGRES_DB: alicebot + ports: + - "127.0.0.1:5432:5432" + volumes: + - postgres-data:/var/lib/postgresql/data + - ./infra/postgres/init:/docker-entrypoint-initdb.d:ro + + redis: + image: redis:7-alpine + container_name: alicebot-redis + ports: + - "127.0.0.1:6379:6379" + + minio: + image: minio/minio:RELEASE.2025-02-28T09-55-16Z + container_name: alicebot-minio + command: server /data --console-address ":9001" + environment: + MINIO_ROOT_USER: alicebot + MINIO_ROOT_PASSWORD: alicebot-secret + ports: + - "127.0.0.1:9000:9000" + - "127.0.0.1:9001:9001" + volumes: + - minio-data:/data + +volumes: + postgres-data: + minio-data: diff --git a/docs/adr/.gitkeep b/docs/adr/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/adr/.gitkeep @@ -0,0 +1 @@ + diff --git a/docs/archive/.gitkeep b/docs/archive/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/archive/.gitkeep @@ -0,0 +1 @@ + diff --git a/docs/runbooks/.gitkeep b/docs/runbooks/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/runbooks/.gitkeep @@ -0,0 +1 @@ + diff --git a/infra/postgres/init/001_roles.sql b/infra/postgres/init/001_roles.sql new file mode 100644 index 0000000..78f9d49 --- /dev/null +++ b/infra/postgres/init/001_roles.sql @@ -0,0 +1,16 @@ +DO +$$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'alicebot_app') THEN + CREATE ROLE alicebot_app + LOGIN + PASSWORD 'alicebot_app' + NOSUPERUSER + NOCREATEDB + NOCREATEROLE + NOINHERIT; + END IF; +END +$$; + +GRANT CONNECT ON DATABASE alicebot TO alicebot_app; diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..51cf232 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=69", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "alicebot-foundation" +version = "0.1.0" +description = "Foundation scaffold for the AliceBot modular monolith." +requires-python = ">=3.12" +dependencies = [ + "alembic>=1.14,<2.0", + "fastapi>=0.115,<1.0", + "psycopg[binary]>=3.2,<4.0", + "sqlalchemy>=2.0,<3.0", + "uvicorn>=0.34,<1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3,<9.0", +] + +[tool.setuptools.package-dir] +"" = "." + +[tool.setuptools.packages.find] +where = ["apps/api/src", "workers"] + +[tool.pytest.ini_options] +pythonpath = ["apps/api/src", "workers"] +testpaths = ["tests"] + diff --git a/scripts/.gitkeep b/scripts/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/scripts/.gitkeep @@ -0,0 +1 @@ + diff --git a/scripts/api_dev.sh b/scripts/api_dev.sh new file mode 100755 index 0000000..17ee7f4 --- /dev/null +++ b/scripts/api_dev.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +if [ -f "${REPO_ROOT}/.env" ]; then + set -a + . "${REPO_ROOT}/.env" + set +a +fi + +PYTHON_BIN="python3" +if [ -x "${REPO_ROOT}/.venv/bin/python" ]; then + PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" +fi + +cd "${REPO_ROOT}" + +UVICORN_ARGS=( + --app-dir "${REPO_ROOT}/apps/api/src" + --host "${APP_HOST:-127.0.0.1}" + --port "${APP_PORT:-8000}" +) + +if [ "${APP_RELOAD:-true}" = "true" ]; then + UVICORN_ARGS+=(--reload) +fi + +exec "${PYTHON_BIN}" -m uvicorn alicebot_api.main:app "${UVICORN_ARGS[@]}" diff --git a/scripts/dev_up.sh b/scripts/dev_up.sh new file mode 100755 index 0000000..983ce1b --- /dev/null +++ b/scripts/dev_up.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +if [ -f "${REPO_ROOT}/.env" ]; then + set -a + . "${REPO_ROOT}/.env" + set +a +fi + +PYTHON_BIN="python3" +if [ -x "${REPO_ROOT}/.venv/bin/python" ]; then + PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" +fi + +cd "${REPO_ROOT}" + +docker compose up -d + +"${PYTHON_BIN}" -c ' +import os +import sys +import time + +import psycopg + +database_url = os.getenv( + "DATABASE_ADMIN_URL", + "postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot", +) +deadline = time.time() + 60 + +while time.time() < deadline: + try: + with psycopg.connect(database_url, connect_timeout=1) as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1 FROM pg_roles WHERE rolname = %s", ("alicebot_app",)) + if cur.fetchone() == (1,): + sys.exit(0) + except psycopg.Error: + pass + time.sleep(1) + +raise SystemExit("Timed out waiting for Postgres readiness and alicebot_app bootstrap") +' + +"${PYTHON_BIN}" -m alembic -c "${REPO_ROOT}/apps/api/alembic.ini" upgrade head diff --git a/scripts/migrate.sh b/scripts/migrate.sh new file mode 100755 index 0000000..ef2401b --- /dev/null +++ b/scripts/migrate.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +if [ -f "${REPO_ROOT}/.env" ]; then + set -a + . "${REPO_ROOT}/.env" + set +a +fi + +PYTHON_BIN="python3" +if [ -x "${REPO_ROOT}/.venv/bin/python" ]; then + PYTHON_BIN="${REPO_ROOT}/.venv/bin/python" +fi + +cd "${REPO_ROOT}" + +"${PYTHON_BIN}" -m alembic -c "${REPO_ROOT}/apps/api/alembic.ini" upgrade "${1:-head}" diff --git a/tests/.gitkeep b/tests/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/.gitkeep @@ -0,0 +1 @@ + diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..f413549 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from collections.abc import Iterator +import os +from urllib.parse import urlsplit, urlunsplit +from uuid import uuid4 + +from alembic import command +import psycopg +from psycopg import sql +import pytest + +from alicebot_api.migrations import make_alembic_config + + +DEFAULT_ADMIN_URL = "postgresql://alicebot_admin:alicebot_admin@localhost:5432/alicebot" +DEFAULT_APP_URL = "postgresql://alicebot_app:alicebot_app@localhost:5432/alicebot" + + +def swap_database_name(database_url: str, database_name: str) -> str: + parsed = urlsplit(database_url) + return urlunsplit((parsed.scheme, parsed.netloc, f"/{database_name}", parsed.query, parsed.fragment)) + + +@pytest.fixture +def database_urls() -> Iterator[dict[str, str]]: + admin_root_url = os.getenv("DATABASE_ADMIN_URL", DEFAULT_ADMIN_URL) + app_root_url = os.getenv("DATABASE_URL", DEFAULT_APP_URL) + database_name = f"alicebot_test_{uuid4().hex[:12]}" + admin_database_url = swap_database_name(admin_root_url, database_name) + app_database_url = swap_database_name(app_root_url, database_name) + + with psycopg.connect(admin_root_url, autocommit=True) as conn: + with conn.cursor() as cur: + cur.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(database_name))) + cur.execute( + sql.SQL("GRANT CONNECT, TEMPORARY ON DATABASE {} TO alicebot_app").format( + sql.Identifier(database_name) + ) + ) + + yield {"admin": admin_database_url, "app": app_database_url} + + with psycopg.connect(admin_root_url, autocommit=True) as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL("DROP DATABASE IF EXISTS {} WITH (FORCE)").format(sql.Identifier(database_name)) + ) + + +@pytest.fixture +def migrated_database_urls(database_urls: dict[str, str]) -> Iterator[dict[str, str]]: + config = make_alembic_config(database_urls["admin"]) + command.upgrade(config, "head") + yield database_urls diff --git a/tests/integration/test_approval_api.py b/tests/integration/test_approval_api.py new file mode 100644 index 0000000..0f6e980 --- /dev/null +++ b/tests/integration/test_approval_api.py @@ -0,0 +1,929 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Approval thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_approval_request_persists_record_for_approval_required_route( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + status, payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + + assert status == 200 + assert list(payload) == [ + "request", + "decision", + "tool", + "reasons", + "task", + "approval", + "routing_trace", + "trace", + ] + assert payload["decision"] == "approval_required" + assert payload["task"]["status"] == "pending_approval" + assert payload["task"]["latest_approval_id"] == payload["approval"]["id"] + assert payload["task"]["latest_execution_id"] is None + assert payload["approval"] is not None + assert payload["approval"]["status"] == "pending" + assert payload["approval"]["task_step_id"] is not None + assert payload["approval"]["resolution"] is None + assert payload["approval"]["request"] == payload["request"] + assert payload["approval"]["tool"] == payload["tool"] + assert payload["approval"]["routing"] == { + "decision": "approval_required", + "reasons": payload["reasons"], + "trace": payload["routing_trace"], + } + assert payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + } + assert payload["routing_trace"]["trace_event_count"] == 3 + assert payload["trace"]["trace_event_count"] == 8 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + approvals = store.list_approvals() + tasks = store.list_tasks() + task_steps = store.list_task_steps_for_task(tasks[0]["id"]) + approval_trace = store.get_trace(UUID(payload["trace"]["trace_id"])) + approval_trace_events = store.list_trace_events(UUID(payload["trace"]["trace_id"])) + + assert len(approvals) == 1 + assert len(tasks) == 1 + assert len(task_steps) == 1 + assert approvals[0]["id"] == UUID(payload["approval"]["id"]) + assert approvals[0]["task_step_id"] == task_steps[0]["id"] + assert tasks[0]["id"] == UUID(payload["task"]["id"]) + assert approval_trace["kind"] == "approval.request" + assert approval_trace["compiler_version"] == "approval_request_v0" + assert approval_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "persisted": True, + } + assert [event["kind"] for event in approval_trace_events] == [ + "approval.request.request", + "approval.request.routing", + "approval.request.persisted", + "approval.request.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert approval_trace_events[1]["payload"] == { + "decision": "approval_required", + "tool_id": str(tool["id"]), + "tool_key": "shell.exec", + "tool_version": "1.0.0", + "routing_trace_id": payload["routing_trace"]["trace_id"], + "routing_trace_event_count": 3, + "reasons": payload["reasons"], + } + assert approval_trace_events[4]["payload"] == { + "task_id": payload["task"]["id"], + "source": "approval_request", + "previous_status": None, + "current_status": "pending_approval", + "latest_approval_id": payload["approval"]["id"], + "latest_execution_id": None, + } + assert approval_trace_events[2]["payload"] == { + "approval_id": payload["approval"]["id"], + "task_step_id": payload["approval"]["task_step_id"], + "decision": "approval_required", + "persisted": True, + } + assert approval_trace_events[6]["payload"] == { + "task_id": payload["task"]["id"], + "task_step_id": str(task_steps[0]["id"]), + "source": "approval_request", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": None, + "current_status": "created", + "trace": { + "trace_id": payload["trace"]["trace_id"], + "trace_kind": "approval.request", + }, + } + + +def test_approval_request_does_not_create_records_for_ready_or_denied_routes( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {}, + }, + ) + denied_status, denied_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert ready_status == 200 + assert ready_payload["decision"] == "ready" + assert ready_payload["task"]["status"] == "approved" + assert ready_payload["approval"] is None + assert denied_status == 200 + assert denied_payload["decision"] == "denied" + assert denied_payload["task"]["status"] == "denied" + assert denied_payload["approval"] is None + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + approvals = store.list_approvals() + tasks = store.list_tasks() + ready_task_steps = store.list_task_steps_for_task(tasks[0]["id"]) + denied_task_steps = store.list_task_steps_for_task(tasks[1]["id"]) + + assert approvals == [] + assert [task["status"] for task in tasks] == ["approved", "denied"] + assert [task_step["status"] for task_step in ready_task_steps] == ["approved"] + assert [task_step["status"] for task_step in denied_task_steps] == ["denied"] + + +def test_approval_endpoints_list_and_detail_are_deterministic_and_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + first_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + second_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="2.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + first_status, first_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(first_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "pwd"}, + }, + ) + second_status, second_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(second_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/approvals", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/approvals/{second_payload['approval']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/approvals", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/approvals/{first_payload['approval']['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert first_status == 200 + assert second_status == 200 + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + first_payload["approval"]["id"], + second_payload["approval"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"approval": second_payload["approval"]} + + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"approval {first_payload['approval']['id']} was not found" + } + + +def test_approval_resolution_endpoints_update_reads_and_emit_trace( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + first_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + second_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="2.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + _, first_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(first_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "pwd"}, + }, + ) + _, second_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(second_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{first_request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + reject_status, reject_payload = invoke_request( + "POST", + f"/v0/approvals/{second_request_payload['approval']['id']}/reject", + payload={"user_id": str(owner['user_id'])}, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/approvals", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/approvals/{second_request_payload['approval']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert approve_status == 200 + assert list(approve_payload) == ["approval", "trace"] + assert approve_payload["approval"]["status"] == "approved" + assert approve_payload["approval"]["task_step_id"] == first_request_payload["approval"]["task_step_id"] + assert approve_payload["approval"]["resolution"] is not None + assert approve_payload["trace"]["trace_event_count"] == 7 + + assert reject_status == 200 + assert list(reject_payload) == ["approval", "trace"] + assert reject_payload["approval"]["status"] == "rejected" + assert reject_payload["approval"]["task_step_id"] == second_request_payload["approval"]["task_step_id"] + assert reject_payload["approval"]["resolution"] is not None + assert reject_payload["trace"]["trace_event_count"] == 7 + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + first_request_payload["approval"]["id"], + second_request_payload["approval"]["id"], + ] + assert [item["status"] for item in list_payload["items"]] == ["approved", "rejected"] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"approval": reject_payload["approval"]} + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + approve_trace = store.get_trace(UUID(approve_payload["trace"]["trace_id"])) + approve_trace_events = store.list_trace_events(UUID(approve_payload["trace"]["trace_id"])) + reject_trace = store.get_trace(UUID(reject_payload["trace"]["trace_id"])) + reject_trace_events = store.list_trace_events(UUID(reject_payload["trace"]["trace_id"])) + + assert approve_trace["kind"] == "approval.resolve" + assert approve_trace["compiler_version"] == "approval_resolution_v0" + assert approve_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "approve", + "outcome": "resolved", + } + assert [event["kind"] for event in approve_trace_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert approve_trace_events[1]["payload"]["current_status"] == "approved" + assert approve_trace_events[1]["payload"]["task_step_id"] == first_request_payload["approval"]["task_step_id"] + assert approve_trace_events[1]["payload"]["resolved_by_user_id"] == str(owner["user_id"]) + + assert reject_trace["kind"] == "approval.resolve" + assert reject_trace["compiler_version"] == "approval_resolution_v0" + assert reject_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "reject", + "outcome": "resolved", + } + assert [event["kind"] for event in reject_trace_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert reject_trace_events[1]["payload"]["current_status"] == "rejected" + assert reject_trace_events[1]["payload"]["task_step_id"] == second_request_payload["approval"]["task_step_id"] + assert reject_trace_events[1]["payload"]["resolved_by_user_id"] == str(owner["user_id"]) + + +def test_approval_resolution_rejects_duplicate_conflicting_and_cross_user_attempts( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + _, request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"command": "ls"}, + }, + ) + approval_id = request_payload["approval"]["id"] + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + duplicate_status, duplicate_payload = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + conflict_status, conflict_payload = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/reject", + payload={"user_id": str(owner["user_id"])}, + ) + intruder_status, intruder_payload = invoke_request( + "POST", + f"/v0/approvals/{approval_id}/reject", + payload={"user_id": str(intruder["user_id"])}, + ) + + assert first_approve_status == 200 + assert duplicate_status == 409 + assert duplicate_payload == {"detail": f"approval {approval_id} was already approved"} + assert conflict_status == 409 + assert conflict_payload == { + "detail": f"approval {approval_id} was already approved and cannot be rejected" + } + assert intruder_status == 404 + assert intruder_payload == {"detail": f"approval {approval_id} was not found"} + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + approval = store.get_approval_optional(UUID(approval_id)) + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, limits + FROM traces + WHERE thread_id = %s + AND kind = 'approval.resolve' + ORDER BY created_at ASC, id ASC + """, + (owner["thread_id"],), + ) + trace_rows = cur.fetchall() + duplicate_trace = trace_rows[-2] + conflict_trace = trace_rows[-1] + duplicate_events = store.list_trace_events(duplicate_trace["id"]) + conflict_events = store.list_trace_events(conflict_trace["id"]) + + assert approval is not None + assert approval["status"] == "approved" + assert duplicate_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "approve", + "outcome": "duplicate_rejected", + } + assert [event["kind"] for event in duplicate_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert duplicate_events[1]["payload"] == { + "approval_id": approval_id, + "task_step_id": str(approval["task_step_id"]), + "requested_action": "approve", + "previous_status": "approved", + "outcome": "duplicate_rejected", + "current_status": "approved", + "resolved_at": approval["resolved_at"].isoformat(), + "resolved_by_user_id": str(owner["user_id"]), + } + assert conflict_trace["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "reject", + "outcome": "conflict_rejected", + } + assert [event["kind"] for event in conflict_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert conflict_events[1]["payload"] == { + "approval_id": approval_id, + "task_step_id": str(approval["task_step_id"]), + "requested_action": "reject", + "previous_status": "approved", + "outcome": "conflict_rejected", + "current_status": "approved", + "resolved_at": approval["resolved_at"].isoformat(), + "resolved_by_user_id": str(owner["user_id"]), + } + + +def test_approval_resolution_rejects_inconsistent_linkage_without_mutating_task_steps( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-boundary@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + + _, request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "initial"}, + }, + ) + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert detail_status == 200 + assert step_list_status == 200 + initial_execution_id = detail_payload["task"]["latest_execution_id"] + assert initial_execution_id is not None + + create_step_status, create_step_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-2"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": step_list_payload["items"][0]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + assert create_step_status == 201 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + conn.execute( + "UPDATE approvals SET task_step_id = %s WHERE id = %s", + ( + create_step_payload["task_step"]["id"], + request_payload["approval"]["id"], + ), + ) + + boundary_status, boundary_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + + assert boundary_status == 409 + assert boundary_payload == { + "detail": ( + f"approval {request_payload['approval']['id']} is inconsistent with linked task step " + f"{create_step_payload['task_step']['id']}" + ) + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + task = store.get_task_optional(UUID(request_payload["task"]["id"])) + task_steps = store.list_task_steps_for_task(UUID(request_payload["task"]["id"])) + approval = store.get_approval_optional(UUID(request_payload["approval"]["id"])) + approval_resolve_traces = store.conn.execute( + """ + SELECT id + FROM traces + WHERE thread_id = %s + AND kind = 'approval.resolve' + ORDER BY created_at ASC, id ASC + """, + (owner["thread_id"],), + ).fetchall() + + assert task is not None + assert approval is not None + assert task["status"] == "pending_approval" + assert task["latest_approval_id"] == UUID(request_payload["approval"]["id"]) + assert task["latest_execution_id"] is None + assert len(task_steps) == 2 + assert task_steps[0]["status"] == "executed" + assert task_steps[0]["trace_id"] == UUID(execute_payload["trace"]["trace_id"]) + assert task_steps[0]["outcome"]["execution_id"] == initial_execution_id + assert task_steps[1]["status"] == "created" + assert task_steps[1]["id"] == UUID(create_step_payload["task_step"]["id"]) + assert task_steps[1]["trace_kind"] == "task.step.continuation" + assert approval["status"] == "approved" + assert approval["task_step_id"] == UUID(create_step_payload["task_step"]["id"]) + assert len(approval_resolve_traces) == 1 diff --git a/tests/integration/test_context_compile.py b/tests/integration/test_context_compile.py new file mode 100644 index 0000000..f86bfe7 --- /dev/null +++ b/tests/integration/test_context_compile.py @@ -0,0 +1,890 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime +from typing import Any +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_compile_context(payload: dict[str, Any]) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/context/compile", + "raw_path": b"/v0/context/compile", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_traceable_thread( + database_url: str, + *, + email: str = "owner@example.com", + display_name: str = "Owner", +) -> dict[str, object]: + user_id = uuid4() + included_edge_valid_from = datetime(2026, 3, 12, 10, 0, tzinfo=UTC) + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, display_name) + thread = store.create_thread("Context thread") + first_session = store.create_session(thread["id"], status="complete") + second_session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], first_session["id"], "message.user", {"text": "old"})["id"], + store.append_event(thread["id"], second_session["id"], "message.assistant", {"text": "newer"})["id"], + store.append_event(thread["id"], second_session["id"], "message.user", {"text": "newest"})["id"], + ] + breakfast_memory = store.create_memory( + memory_key="user.preference.breakfast", + value={"likes": "toast"}, + status="active", + source_event_ids=[str(event_ids[0])], + ) + coffee_memory = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(event_ids[1])], + ) + deleted_memory = store.create_memory( + memory_key="user.preference.old", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_ids[1])], + ) + deleted_memory = store.update_memory( + memory_id=deleted_memory["id"], + value=deleted_memory["value"], + status="deleted", + source_event_ids=[str(event_ids[2])], + ) + person = store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(breakfast_memory["id"])], + ) + merchant = store.create_entity( + entity_type="merchant", + name="Neighborhood Cafe", + source_memory_ids=[str(coffee_memory["id"])], + ) + project = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(breakfast_memory["id"]), str(coffee_memory["id"])], + ) + excluded_edge = store.create_entity_edge( + from_entity_id=person["id"], + to_entity_id=project["id"], + relationship_type="visited_by", + valid_from=None, + valid_to=None, + source_memory_ids=[str(breakfast_memory["id"])], + ) + included_edge = store.create_entity_edge( + from_entity_id=project["id"], + to_entity_id=merchant["id"], + relationship_type="depends_on", + valid_from=included_edge_valid_from, + valid_to=None, + source_memory_ids=[str(coffee_memory["id"])], + ) + ignored_when_project_only_edge = store.create_entity_edge( + from_entity_id=person["id"], + to_entity_id=merchant["id"], + relationship_type="introduced_to", + valid_from=None, + valid_to=None, + source_memory_ids=[str(breakfast_memory["id"])], + ) + entities = store.list_entities() + entity_edges = store.list_entity_edges_for_entities([person["id"], merchant["id"], project["id"]]) + + return { + "user_id": user_id, + "thread_id": thread["id"], + "event_ids": event_ids, + "memories": { + "breakfast": breakfast_memory, + "coffee": coffee_memory, + "deleted": deleted_memory, + }, + "entities": entities, + "entity_edges": entity_edges, + "project_only_candidate_edges": { + "excluded": excluded_edge, + "included": included_edge, + "ignored": ignored_when_project_only_edge, + }, + "included_edge_valid_from": included_edge_valid_from, + } + + +def seed_thread_with_updated_active_memory(database_url: str) -> dict[str, object]: + user_id = uuid4() + included_edge_valid_from = datetime(2026, 3, 12, 11, 0, tzinfo=UTC) + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Updated memory thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "baseline memory evidence"}, + )["id"], + store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "updated memory evidence"}, + )["id"], + ] + store.create_memory( + memory_key="user.preference.breakfast", + value={"likes": "toast"}, + status="active", + source_event_ids=[str(event_ids[0])], + ) + coffee_memory = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_ids[0])], + ) + store.update_memory( + memory_id=coffee_memory["id"], + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(event_ids[1])], + ) + routine = store.create_entity( + entity_type="routine", + name="Breakfast", + source_memory_ids=[str(coffee_memory["id"])], + ) + project = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(coffee_memory["id"])], + ) + included_edge = store.create_entity_edge( + from_entity_id=project["id"], + to_entity_id=routine["id"], + relationship_type="references", + valid_from=included_edge_valid_from, + valid_to=None, + source_memory_ids=[str(coffee_memory["id"])], + ) + store.create_entity_edge( + from_entity_id=routine["id"], + to_entity_id=routine["id"], + relationship_type="superseded_by", + valid_from=None, + valid_to=None, + source_memory_ids=[str(coffee_memory["id"])], + ) + entities = store.list_entities() + + return { + "user_id": user_id, + "thread_id": thread["id"], + "event_ids": event_ids, + "entities": entities, + "included_edge": included_edge, + "included_edge_valid_from": included_edge_valid_from, + } + + +def seed_embedding_config_for_user( + database_url: str, + *, + user_id: UUID, + dimensions: int = 3, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + config = store.create_embedding_config( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=dimensions, + status="active", + metadata={"task": "compile_semantic_retrieval"}, + ) + return config["id"] + + +def seed_memory_embedding_for_user( + database_url: str, + *, + user_id: UUID, + memory_id: UUID, + embedding_config_id: UUID, + vector: list[float], +) -> None: + with user_connection(database_url, user_id) as conn: + ContinuityStore(conn).create_memory_embedding( + memory_id=memory_id, + embedding_config_id=embedding_config_id, + dimensions=len(vector), + vector=vector, + ) + + +def test_compile_context_endpoint_persists_trace_and_trace_events(migrated_database_urls, monkeypatch) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + user_id = seeded["user_id"] + thread_id = seeded["thread_id"] + event_ids = seeded["event_ids"] + entities = seeded["entities"] + included_entity = entities[-1] + project_only_candidate_edges = seeded["project_only_candidate_edges"] + included_entity_edge = project_only_candidate_edges["included"] + excluded_entity_edge = project_only_candidate_edges["excluded"] + ignored_entity_edge = project_only_candidate_edges["ignored"] + included_edge_valid_from = seeded["included_edge_valid_from"] + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(user_id), + "thread_id": str(thread_id), + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + ) + + assert status_code == 200 + assert payload["trace_event_count"] > 0 + assert payload["context_pack"]["limits"] == { + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + assert [session["status"] for session in payload["context_pack"]["sessions"]] == ["active"] + assert [event["sequence_no"] for event in payload["context_pack"]["events"]] == [3] + assert payload["context_pack"]["memories"] == [ + { + "id": payload["context_pack"]["memories"][0]["id"], + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_ids[1])], + "created_at": payload["context_pack"]["memories"][0]["created_at"], + "updated_at": payload["context_pack"]["memories"][0]["updated_at"], + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ] + assert payload["context_pack"]["memory_summary"] == { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert payload["context_pack"]["entities"] == [ + { + "id": str(included_entity["id"]), + "entity_type": included_entity["entity_type"], + "name": included_entity["name"], + "source_memory_ids": included_entity["source_memory_ids"], + "created_at": included_entity["created_at"].isoformat(), + } + ] + assert payload["context_pack"]["entity_summary"] == { + "candidate_count": 3, + "included_count": 1, + "excluded_limit_count": 2, + } + assert payload["context_pack"]["entity_edges"] == [ + { + "id": str(included_entity_edge["id"]), + "from_entity_id": str(included_entity_edge["from_entity_id"]), + "to_entity_id": str(included_entity_edge["to_entity_id"]), + "relationship_type": included_entity_edge["relationship_type"], + "valid_from": included_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": included_entity_edge["source_memory_ids"], + "created_at": payload["context_pack"]["entity_edges"][0]["created_at"], + } + ] + assert payload["context_pack"]["entity_edge_summary"] == { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(trace_id) + trace_events = store.list_trace_events(trace_id) + + assert trace["thread_id"] == thread_id + assert trace["kind"] == "context.compile" + assert trace["limits"] == { + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + assert trace_events[0]["kind"] == "context.included" + assert trace_events[-1]["kind"] == "context.summary" + assert any( + event["payload"]["reason"] == "session_limit_exceeded" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "event_limit_exceeded" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_deleted" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_hybrid_memory_limit" + and event["payload"]["memory_key"] == "user.preference.coffee" + and event["payload"]["selected_sources"] == ["symbolic"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "entity_limit_exceeded" + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_entity_limit" + and event["payload"]["name"] == included_entity["name"] + and event["payload"]["record_entity_type"] == included_entity["entity_type"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "entity_edge_limit_exceeded" + and event["payload"]["entity_id"] == str(excluded_entity_edge["id"]) + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_entity_edge_limit" + and event["payload"]["entity_id"] == str(included_entity_edge["id"]) + and event["payload"]["valid_from"] == included_edge_valid_from.isoformat() + for event in trace_events + if event["kind"] == "context.included" + ) + assert all( + event["payload"].get("entity_id") != str(ignored_entity_edge["id"]) + for event in trace_events + ) + assert trace_events[-1]["payload"]["included_memory_count"] == 1 + assert trace_events[-1]["payload"]["excluded_deleted_memory_count"] == 1 + assert trace_events[-1]["payload"]["excluded_memory_limit_count"] == 0 + assert trace_events[-1]["payload"]["hybrid_memory_requested"] is False + assert trace_events[-1]["payload"]["hybrid_memory_candidate_count"] == 2 + assert trace_events[-1]["payload"]["hybrid_memory_merged_candidate_count"] == 1 + assert trace_events[-1]["payload"]["hybrid_memory_deduplicated_count"] == 0 + assert trace_events[-1]["payload"]["included_entity_count"] == 1 + assert trace_events[-1]["payload"]["excluded_entity_limit_count"] == 2 + assert trace_events[-1]["payload"]["included_entity_edge_count"] == 1 + assert trace_events[-1]["payload"]["excluded_entity_edge_limit_count"] == 1 + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + with pytest.raises(psycopg.Error, match="append-only"): + cur.execute("UPDATE trace_events SET kind = 'mutated' WHERE trace_id = %s", (trace_id,)) + + +def test_compile_context_prefers_updated_active_memory_within_same_transaction( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_thread_with_updated_active_memory(migrated_database_urls["app"]) + user_id = seeded["user_id"] + thread_id = seeded["thread_id"] + event_ids = seeded["event_ids"] + entities = seeded["entities"] + excluded_entity = entities[0] + included_edge = seeded["included_edge"] + included_edge_valid_from = seeded["included_edge_valid_from"] + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(user_id), + "thread_id": str(thread_id), + "max_sessions": 1, + "max_events": 2, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + } + ) + + assert status_code == 200 + assert payload["trace_event_count"] > 0 + assert payload["context_pack"]["memories"] == [ + { + "id": payload["context_pack"]["memories"][0]["id"], + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_ids[1])], + "created_at": payload["context_pack"]["memories"][0]["created_at"], + "updated_at": payload["context_pack"]["memories"][0]["updated_at"], + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ] + assert payload["context_pack"]["memory_summary"] == { + "candidate_count": 1, + "included_count": 1, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert payload["context_pack"]["entity_summary"] == { + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + } + assert payload["context_pack"]["entity_edges"] == [ + { + "id": str(included_edge["id"]), + "from_entity_id": str(included_edge["from_entity_id"]), + "to_entity_id": str(included_edge["to_entity_id"]), + "relationship_type": included_edge["relationship_type"], + "valid_from": included_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": included_edge["source_memory_ids"], + "created_at": payload["context_pack"]["entity_edges"][0]["created_at"], + } + ] + assert payload["context_pack"]["entity_edge_summary"] == { + "anchor_entity_count": 1, + "candidate_count": 1, + "included_count": 1, + "excluded_limit_count": 0, + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_hybrid_memory_limit" + and event["payload"]["memory_key"] == "user.preference.coffee" + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "entity_limit_exceeded" + and event["payload"]["name"] == excluded_entity["name"] + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "within_entity_edge_limit" + and event["payload"]["entity_id"] == str(included_edge["id"]) + for event in trace_events + if event["kind"] == "context.included" + ) + + +def test_compile_context_endpoint_merges_hybrid_memory_provenance_and_trace_events( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + user_id = seeded["user_id"] + thread_id = seeded["thread_id"] + memories = seeded["memories"] + config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=user_id, + ) + seed_memory_embedding_for_user( + migrated_database_urls["app"], + user_id=user_id, + memory_id=memories["breakfast"]["id"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_memory_embedding_for_user( + migrated_database_urls["app"], + user_id=user_id, + memory_id=memories["coffee"]["id"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_memory_embedding_for_user( + migrated_database_urls["app"], + user_id=user_id, + memory_id=memories["deleted"]["id"], + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(user_id), + "thread_id": str(thread_id), + "max_sessions": 1, + "max_events": 1, + "max_memories": 1, + "max_entities": 1, + "max_entity_edges": 1, + "semantic": { + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 2, + }, + } + ) + + assert status_code == 200 + assert payload["trace_event_count"] > 0 + assert payload["context_pack"]["memories"] == [ + { + "id": str(memories["coffee"]["id"]), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": memories["coffee"]["source_event_ids"], + "created_at": memories["coffee"]["created_at"].isoformat(), + "updated_at": memories["coffee"]["updated_at"].isoformat(), + "source_provenance": { + "sources": ["symbolic", "semantic"], + "semantic_score": 1.0, + }, + } + ] + assert payload["context_pack"]["memory_summary"] == { + "candidate_count": 3, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 1, + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "semantic_limit": 2, + "symbolic_selected_count": 1, + "semantic_selected_count": 2, + "merged_candidate_count": 2, + "deduplicated_count": 1, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + + trace_id = UUID(payload["trace_id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + trace_events = ContinuityStore(conn).list_trace_events(trace_id) + + assert any( + event["payload"]["reason"] == "within_hybrid_memory_limit" + and event["payload"]["entity_id"] == str(memories["coffee"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] == 1.0 + and event["payload"]["selected_sources"] == ["symbolic", "semantic"] + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_deduplicated" + and event["payload"]["entity_id"] == str(memories["coffee"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] == 1.0 + for event in trace_events + if event["kind"] == "context.included" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_limit_exceeded" + and event["payload"]["entity_id"] == str(memories["breakfast"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] == 1.0 + and event["payload"]["selected_sources"] == ["semantic"] + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert any( + event["payload"]["reason"] == "hybrid_memory_deleted" + and event["payload"]["entity_id"] == str(memories["deleted"]["id"]) + and event["payload"]["embedding_config_id"] == str(config_id) + and event["payload"]["semantic_score"] is None + and event["payload"]["selected_sources"] == ["symbolic"] + for event in trace_events + if event["kind"] == "context.excluded" + ) + assert trace_events[-1]["payload"]["hybrid_memory_requested"] is True + assert trace_events[-1]["payload"]["hybrid_memory_candidate_count"] == 3 + assert trace_events[-1]["payload"]["hybrid_memory_merged_candidate_count"] == 2 + assert trace_events[-1]["payload"]["hybrid_memory_deduplicated_count"] == 1 + assert trace_events[-1]["payload"]["included_dual_source_memory_count"] == 1 + + +def test_compile_context_semantic_validation_rejects_missing_config_dimension_mismatch_and_cross_user_access( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_traceable_thread(migrated_database_urls["app"]) + intruder = seed_traceable_thread( + migrated_database_urls["app"], + email="intruder@example.com", + display_name="Intruder", + ) + owner_config_id = seed_embedding_config_for_user( + migrated_database_urls["app"], + user_id=owner["user_id"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_status, missing_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "semantic": { + "embedding_config_id": str(uuid4()), + "query_vector": [1.0, 0.0, 0.0], + "limit": 1, + }, + } + ) + mismatch_status, mismatch_payload = invoke_compile_context( + { + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "semantic": { + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0], + "limit": 1, + }, + } + ) + cross_user_status, cross_user_payload = invoke_compile_context( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "semantic": { + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 1, + }, + } + ) + + assert missing_status == 400 + assert missing_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "query_vector length must match embedding config dimensions (3): 2" + assert cross_user_status == 400 + assert cross_user_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{owner_config_id}" + ) + + +def test_traces_and_trace_events_respect_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + seeded = seed_traceable_thread(migrated_database_urls["app"]) + owner_id = seeded["user_id"] + thread_id = seeded["thread_id"] + owner_event_ids = seeded["event_ids"] + owner_entities = seeded["entities"] + owner_entity_edges = seeded["entity_edges"] + intruder_id = uuid4() + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + store.create_user(intruder_id, "intruder@example.com", "Intruder") + intruder_thread = store.create_thread("Intruder thread") + intruder_session = store.create_session(intruder_thread["id"], status="active") + intruder_event = store.append_event( + intruder_thread["id"], + intruder_session["id"], + "message.user", + {"text": "intruder memory"}, + ) + store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(intruder_event["id"])], + ) + intruder_memory = store.create_memory( + memory_key="user.preference.tea", + value={"likes": "green"}, + status="active", + source_event_ids=[str(intruder_event["id"])], + ) + store.create_entity( + entity_type="merchant", + name="Intruder Cafe", + source_memory_ids=[str(intruder_memory["id"])], + ) + intruder_project = store.create_entity( + entity_type="project", + name="Intruder Project", + source_memory_ids=[str(intruder_memory["id"])], + ) + store.create_entity_edge( + from_entity_id=intruder_project["id"], + to_entity_id=store.list_entities()[0]["id"], + relationship_type="hidden_from_owner", + valid_from=None, + valid_to=None, + source_memory_ids=[str(intruder_memory["id"])], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_compile_context( + { + "user_id": str(owner_id), + "thread_id": str(thread_id), + } + ) + + assert status_code == 200 + trace_id = UUID(payload["trace_id"]) + assert [memory["source_event_ids"] for memory in payload["context_pack"]["memories"]] == [ + [str(owner_event_ids[0])], + [str(owner_event_ids[1])], + ] + assert [memory["source_provenance"] for memory in payload["context_pack"]["memories"]] == [ + {"sources": ["symbolic"], "semantic_score": None}, + {"sources": ["symbolic"], "semantic_score": None}, + ] + assert [entity["id"] for entity in payload["context_pack"]["entities"]] == [ + str(entity["id"]) for entity in owner_entities + ] + assert [edge["id"] for edge in payload["context_pack"]["entity_edges"]] == [ + str(edge["id"]) for edge in owner_entity_edges + ] + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*) AS count FROM traces WHERE id = %s", (trace_id,)) + trace_count = cur.fetchone() + cur.execute("SELECT COUNT(*) AS count FROM trace_events WHERE trace_id = %s", (trace_id,)) + trace_event_count = cur.fetchone() + + assert trace_count["count"] == 0 + assert trace_event_count["count"] == 0 + assert store.list_trace_events(trace_id) == [] diff --git a/tests/integration/test_continuity_store.py b/tests/integration/test_continuity_store.py new file mode 100644 index 0000000..9561563 --- /dev/null +++ b/tests/integration/test_continuity_store.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor, TimeoutError +from uuid import uuid4 + +import psycopg +from psycopg.rows import dict_row +import pytest + +from alicebot_api.db import set_current_user, user_connection +from alicebot_api.store import ContinuityStore + + +def test_thread_session_and_event_persistence(migrated_database_urls): + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + user = store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Starter thread") + session = store.create_session(thread["id"]) + first_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "hello"}, + ) + second_event = store.append_event( + thread["id"], + session["id"], + "message.assistant", + {"text": "hi"}, + ) + events = store.list_thread_events(thread["id"]) + + assert user["id"] == user_id + assert session["thread_id"] == thread["id"] + assert [first_event["sequence_no"], second_event["sequence_no"]] == [1, 2] + assert [event["kind"] for event in events] == ["message.user", "message.assistant"] + assert events[0]["payload"]["text"] == "hello" + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "UPDATE events SET kind = 'message.mutated' WHERE id = %s", + (first_event["id"],), + ) + + +def test_event_deletes_are_rejected_at_database_level(migrated_database_urls): + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Delete-protected thread") + session = store.create_session(thread["id"]) + event = store.append_event(thread["id"], session["id"], "message.user", {"text": "keep"}) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute("DELETE FROM events WHERE id = %s", (event["id"],)) + + +def test_continuity_rls_blocks_cross_user_access(migrated_database_urls): + owner_id = uuid4() + intruder_id = uuid4() + + with user_connection(migrated_database_urls["app"], owner_id) as owner_conn: + owner_store = ContinuityStore(owner_conn) + owner_store.create_user(owner_id, "owner@example.com", "Owner") + thread = owner_store.create_thread("Private thread") + session = owner_store.create_session(thread["id"]) + owner_store.append_event(thread["id"], session["id"], "message.user", {"text": "secret"}) + + with user_connection(migrated_database_urls["app"], intruder_id) as intruder_conn: + intruder_store = ContinuityStore(intruder_conn) + intruder_store.create_user(intruder_id, "intruder@example.com", "Intruder") + + with intruder_conn.cursor() as cur: + cur.execute("SELECT COUNT(*) AS count FROM users WHERE id = %s", (owner_id,)) + user_count_row = cur.fetchone() + cur.execute("SELECT COUNT(*) AS count FROM threads WHERE id = %s", (thread["id"],)) + thread_count_row = cur.fetchone() + cur.execute("SELECT COUNT(*) AS count FROM sessions WHERE id = %s", (session["id"],)) + session_count_row = cur.fetchone() + + visible_events = intruder_store.list_thread_events(thread["id"]) + + assert user_count_row["count"] == 0 + assert thread_count_row["count"] == 0 + assert session_count_row["count"] == 0 + assert visible_events == [] + + with pytest.raises(psycopg.Error): + intruder_store.append_event( + thread["id"], + None, + "message.user", + {"text": "tamper"}, + ) + + +def test_runtime_role_is_insert_select_only_for_continuity_tables(migrated_database_urls): + with psycopg.connect(migrated_database_urls["app"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT + has_table_privilege(current_user, 'users', 'SELECT'), + has_table_privilege(current_user, 'users', 'INSERT'), + has_table_privilege(current_user, 'users', 'UPDATE'), + has_table_privilege(current_user, 'threads', 'UPDATE'), + has_table_privilege(current_user, 'sessions', 'UPDATE'), + has_table_privilege(current_user, 'events', 'UPDATE'), + has_table_privilege(current_user, 'events', 'DELETE'), + has_table_privilege(current_user, 'traces', 'SELECT'), + has_table_privilege(current_user, 'traces', 'INSERT'), + has_table_privilege(current_user, 'traces', 'UPDATE'), + has_table_privilege(current_user, 'trace_events', 'SELECT'), + has_table_privilege(current_user, 'trace_events', 'INSERT'), + has_table_privilege(current_user, 'trace_events', 'UPDATE'), + has_table_privilege(current_user, 'trace_events', 'DELETE'), + has_table_privilege(current_user, 'consents', 'SELECT'), + has_table_privilege(current_user, 'consents', 'INSERT'), + has_table_privilege(current_user, 'consents', 'UPDATE'), + has_table_privilege(current_user, 'consents', 'DELETE'), + has_table_privilege(current_user, 'policies', 'SELECT'), + has_table_privilege(current_user, 'policies', 'INSERT'), + has_table_privilege(current_user, 'policies', 'UPDATE'), + has_table_privilege(current_user, 'policies', 'DELETE'), + has_table_privilege(current_user, 'tools', 'SELECT'), + has_table_privilege(current_user, 'tools', 'INSERT'), + has_table_privilege(current_user, 'tools', 'UPDATE'), + has_table_privilege(current_user, 'tools', 'DELETE') + """ + ) + ( + users_select, + users_insert, + users_update, + threads_update, + sessions_update, + events_update, + events_delete, + traces_select, + traces_insert, + traces_update, + trace_events_select, + trace_events_insert, + trace_events_update, + trace_events_delete, + consents_select, + consents_insert, + consents_update, + consents_delete, + policies_select, + policies_insert, + policies_update, + policies_delete, + tools_select, + tools_insert, + tools_update, + tools_delete, + ) = cur.fetchone() + + assert users_select is True + assert users_insert is True + assert users_update is False + assert threads_update is False + assert sessions_update is False + assert events_update is False + assert events_delete is False + assert traces_select is True + assert traces_insert is True + assert traces_update is False + assert trace_events_select is True + assert trace_events_insert is True + assert trace_events_update is False + assert trace_events_delete is False + assert consents_select is True + assert consents_insert is True + assert consents_update is True + assert consents_delete is False + assert policies_select is True + assert policies_insert is True + assert policies_update is False + assert policies_delete is False + assert tools_select is True + assert tools_insert is True + assert tools_update is False + assert tools_delete is False + + +def test_concurrent_event_appends_keep_monotonic_sequence_numbers(migrated_database_urls): + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Concurrent thread") + session = store.create_session(thread["id"]) + + with ( + psycopg.connect(migrated_database_urls["app"], row_factory=dict_row) as first_conn, + psycopg.connect(migrated_database_urls["app"], row_factory=dict_row) as second_conn, + ): + set_current_user(first_conn, user_id) + set_current_user(second_conn, user_id) + + first_store = ContinuityStore(first_conn) + second_store = ContinuityStore(second_conn) + first_event = first_store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "first"}, + ) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + second_store.append_event, + thread["id"], + session["id"], + "message.assistant", + {"text": "second"}, + ) + + with pytest.raises(TimeoutError): + future.result(timeout=0.2) + + first_conn.commit() + second_event = future.result(timeout=5) + + second_conn.commit() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + events = store.list_thread_events(thread["id"]) + + assert [first_event["sequence_no"], second_event["sequence_no"]] == [1, 2] + assert [event["sequence_no"] for event in events] == [1, 2] diff --git a/tests/integration/test_embeddings_api.py b/tests/integration/test_embeddings_api.py new file mode 100644 index 0000000..974c2c5 --- /dev/null +++ b/tests/integration/test_embeddings_api.py @@ -0,0 +1,793 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user_with_memory(database_url: str, *, email: str) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Embedding source thread") + session = store.create_session(thread["id"], status="active") + event_id = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "likes oat milk"}, + )["id"] + memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_id,), + ), + ) + + return { + "user_id": user_id, + "memory_id": UUID(memory.memory["id"]), + } + + +def seed_embedding_config( + database_url: str, + *, + user_id: UUID, + provider: str, + model: str, + version: str, + dimensions: int, +) -> UUID: + with user_connection(database_url, user_id) as conn: + created = ContinuityStore(conn).create_embedding_config( + provider=provider, + model=model, + version=version, + dimensions=dimensions, + status="active", + metadata={"task": "memory_retrieval"}, + ) + return created["id"] + + +def seed_memory_with_embedding( + database_url: str, + *, + user_id: UUID, + memory_key: str, + value: dict[str, object], + embedding_config_id: UUID, + vector: list[float], + delete_requested: bool = False, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + thread = store.create_thread(f"Semantic retrieval thread for {memory_key}") + session = store.create_session(thread["id"], status="active") + event_id = store.append_event( + thread["id"], + session["id"], + "message.user", + {"memory_key": memory_key, "value": value}, + )["id"] + admitted = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key=memory_key, + value=value, + source_event_ids=(event_id,), + ), + ) + memory_id = UUID(admitted.memory["id"]) + store.create_memory_embedding( + memory_id=memory_id, + embedding_config_id=embedding_config_id, + dimensions=len(vector), + vector=vector, + ) + if delete_requested: + delete_event_id = store.append_event( + thread["id"], + session["id"], + "message.user", + {"memory_key": memory_key, "delete_requested": True}, + )["id"] + admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key=memory_key, + value=None, + source_event_ids=(delete_event_id,), + delete_requested=True, + ), + ) + return memory_id + + +def test_embedding_config_endpoints_create_and_list_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-small", + version="2026-03-11", + dimensions=1536, + ) + seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + create_status, create_payload = invoke_request( + "POST", + "/v0/embedding-configs", + payload={ + "user_id": str(seeded["user_id"]), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-13", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/embedding-configs", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert create_status == 201 + assert create_payload["embedding_config"]["provider"] == "openai" + assert create_payload["embedding_config"]["version"] == "2026-03-13" + assert list_status == 200 + assert list_payload["summary"] == { + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + expected_configs = ContinuityStore(conn).list_embedding_configs() + + assert [item["id"] for item in list_payload["items"]] == [ + str(config["id"]) for config in expected_configs + ] + + +def test_embedding_config_create_rejects_duplicate_provider_model_version( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + first_status, first_payload = invoke_request( + "POST", + "/v0/embedding-configs", + payload={ + "user_id": str(seeded["user_id"]), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + }, + ) + second_status, second_payload = invoke_request( + "POST", + "/v0/embedding-configs", + payload={ + "user_id": str(seeded["user_id"]), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + }, + ) + + assert first_status == 201 + assert first_payload["embedding_config"]["version"] == "2026-03-12" + assert second_status == 400 + assert second_payload == { + "detail": ( + "embedding config already exists for provider/model/version under the user scope: " + "openai/text-embedding-3-large/2026-03-12" + ) + } + + +def test_memory_embedding_endpoints_persist_and_read_embeddings( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + first_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-small", + version="2026-03-11", + dimensions=3, + ) + second_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + first_write_status, first_write_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "memory_id": str(seeded["memory_id"]), + "embedding_config_id": str(first_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + second_write_status, second_write_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "memory_id": str(seeded["memory_id"]), + "embedding_config_id": str(second_config_id), + "vector": [0.4, 0.5, 0.6], + }, + ) + update_status, update_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(seeded["user_id"]), + "memory_id": str(seeded["memory_id"]), + "embedding_config_id": str(first_config_id), + "vector": [0.9, 0.8, 0.7], + }, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/embeddings", + query_params={"user_id": str(seeded["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/memory-embeddings/{first_write_payload['embedding']['id']}", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert first_write_status == 201 + assert first_write_payload["write_mode"] == "created" + assert second_write_status == 201 + assert second_write_payload["write_mode"] == "created" + assert update_status == 201 + assert update_payload["write_mode"] == "updated" + assert update_payload["embedding"]["id"] == first_write_payload["embedding"]["id"] + assert update_payload["embedding"]["vector"] == [0.9, 0.8, 0.7] + assert list_status == 200 + assert list_payload["summary"] == { + "memory_id": str(seeded["memory_id"]), + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload["embedding"]["id"] == first_write_payload["embedding"]["id"] + assert detail_payload["embedding"]["vector"] == [0.9, 0.8, 0.7] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored = ContinuityStore(conn).list_memory_embeddings_for_memory(seeded["memory_id"]) + + assert [item["id"] for item in list_payload["items"]] == [ + str(embedding["id"]) for embedding in stored + ] + assert len(stored) == 2 + assert stored[0]["embedding_config_id"] == first_config_id + assert stored[0]["vector"] == [0.9, 0.8, 0.7] + assert stored[1]["embedding_config_id"] == second_config_id + + +def test_memory_embedding_writes_reject_invalid_references_dimension_mismatches_and_cross_user_refs( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_config_status, missing_config_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(uuid4()), + "vector": [0.1, 0.2, 0.3], + }, + ) + missing_memory_status, missing_memory_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(uuid4()), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + mismatch_status, mismatch_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2], + }, + ) + cross_user_status, cross_user_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(intruder["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(intruder_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + cross_user_config_status, cross_user_config_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(intruder["user_id"]), + "memory_id": str(intruder["memory_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + + assert missing_config_status == 400 + assert missing_config_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert missing_memory_status == 400 + assert missing_memory_payload["detail"].startswith( + "memory_id must reference an existing memory owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "vector length must match embedding config dimensions (3): 2" + assert cross_user_status == 400 + assert cross_user_payload["detail"] == ( + f"memory_id must reference an existing memory owned by the user: {owner['memory_id']}" + ) + assert cross_user_config_status == 400 + assert cross_user_config_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{owner_config_id}" + ) + + +def test_embedding_reads_respect_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + write_status, write_payload = invoke_request( + "POST", + "/v0/memory-embeddings", + payload={ + "user_id": str(owner["user_id"]), + "memory_id": str(owner["memory_id"]), + "embedding_config_id": str(owner_config_id), + "vector": [0.1, 0.2, 0.3], + }, + ) + config_list_status, config_list_payload = invoke_request( + "GET", + "/v0/embedding-configs", + query_params={"user_id": str(intruder["user_id"])}, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{owner['memory_id']}/embeddings", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/memory-embeddings/{write_payload['embedding']['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert write_status == 201 + assert config_list_status == 200 + assert config_list_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["created_at_asc", "id_asc"], + }, + } + assert list_status == 404 + assert list_payload == {"detail": f"memory {owner['memory_id']} was not found"} + assert detail_status == 404 + assert detail_payload == { + "detail": f"memory embedding {write_payload['embedding']['id']} was not found" + } + + +def test_semantic_memory_retrieval_returns_deterministic_results_and_excludes_deleted_memories( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=seeded["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + first_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.breakfast", + value={"likes": "porridge"}, + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + deleted_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.deleted", + value={"likes": "hidden"}, + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + delete_requested=True, + ) + second_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.lunch", + value={"likes": "ramen"}, + embedding_config_id=config_id, + vector=[1.0, 0.0, 0.0], + ) + third_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_key="user.preference.music", + value={"likes": "jazz"}, + embedding_config_id=config_id, + vector=[0.0, 1.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status, payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(seeded["user_id"]), + "embedding_config_id": str(config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 10, + }, + ) + + assert status == 200 + assert payload["summary"] == { + "embedding_config_id": str(config_id), + "limit": 10, + "returned_count": 3, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + } + assert [item["memory_id"] for item in payload["items"]] == [ + str(first_memory_id), + str(second_memory_id), + str(third_memory_id), + ] + assert str(deleted_memory_id) not in {item["memory_id"] for item in payload["items"]} + assert payload["items"][0]["score"] == payload["items"][1]["score"] + assert payload["items"][0]["score"] > payload["items"][2]["score"] + assert set(payload["items"][0]) == { + "memory_id", + "memory_key", + "value", + "source_event_ids", + "created_at", + "updated_at", + "score", + } + + +def test_semantic_memory_retrieval_rejects_invalid_config_dimension_mismatch_and_cross_user_access( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=owner["user_id"], + memory_key="user.preference.owner", + value={"likes": "oat milk"}, + embedding_config_id=owner_config_id, + vector=[1.0, 0.0, 0.0], + ) + seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=intruder["user_id"], + memory_key="user.preference.intruder", + value={"likes": "almond milk"}, + embedding_config_id=intruder_config_id, + vector=[1.0, 0.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + missing_status, missing_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(uuid4()), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + mismatch_status, mismatch_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0], + "limit": 5, + }, + ) + cross_user_status, cross_user_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(intruder["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + + assert missing_status == 400 + assert missing_payload["detail"].startswith( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + assert mismatch_status == 400 + assert mismatch_payload["detail"] == "query_vector length must match embedding config dimensions (3): 2" + assert cross_user_status == 400 + assert cross_user_payload["detail"] == ( + "embedding_config_id must reference an existing embedding config owned by the user: " + f"{owner_config_id}" + ) + + +def test_semantic_memory_retrieval_scopes_results_per_user( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_memory(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_memory(migrated_database_urls["app"], email="intruder@example.com") + owner_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=owner["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + intruder_config_id = seed_embedding_config( + migrated_database_urls["app"], + user_id=intruder["user_id"], + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + ) + owner_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=owner["user_id"], + memory_key="user.preference.owner.semantic", + value={"likes": "espresso"}, + embedding_config_id=owner_config_id, + vector=[1.0, 0.0, 0.0], + ) + intruder_memory_id = seed_memory_with_embedding( + migrated_database_urls["app"], + user_id=intruder["user_id"], + memory_key="user.preference.intruder.semantic", + value={"likes": "matcha"}, + embedding_config_id=intruder_config_id, + vector=[1.0, 0.0, 0.0], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + owner_status, owner_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(owner["user_id"]), + "embedding_config_id": str(owner_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + intruder_status, intruder_payload = invoke_request( + "POST", + "/v0/memories/semantic-retrieval", + payload={ + "user_id": str(intruder["user_id"]), + "embedding_config_id": str(intruder_config_id), + "query_vector": [1.0, 0.0, 0.0], + "limit": 5, + }, + ) + + assert owner_status == 200 + assert [item["memory_id"] for item in owner_payload["items"]] == [str(owner_memory_id)] + assert intruder_status == 200 + assert [item["memory_id"] for item in intruder_payload["items"]] == [str(intruder_memory_id)] diff --git a/tests/integration/test_entities_api.py b/tests/integration/test_entities_api.py new file mode 100644 index 0000000..4236c1f --- /dev/null +++ b/tests/integration/test_entities_api.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user_with_source_memories(database_url: str, *, email: str) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Entity source thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "works on AliceBot"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "drinks oat milk"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "shops at cafe"})["id"], + ] + + first_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.project.current", + value={"name": "AliceBot"}, + source_event_ids=(event_ids[0],), + ), + ) + second_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_ids[1],), + ), + ) + third_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.merchant", + value={"name": "Neighborhood Cafe"}, + source_event_ids=(event_ids[2],), + ), + ) + + return { + "user_id": user_id, + "memory_ids": [ + UUID(first_memory.memory["id"]), + UUID(second_memory.memory["id"]), + UUID(third_memory.memory["id"]), + ], + } + + +def test_create_entity_endpoint_persists_entity_backed_by_user_owned_source_memories( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entities", + payload={ + "user_id": str(seeded["user_id"]), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(seeded["memory_ids"][0]), str(seeded["memory_ids"][1])], + }, + ) + + assert status_code == 201 + assert payload["entity"]["entity_type"] == "project" + assert payload["entity"]["name"] == "AliceBot" + assert payload["entity"]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][1]), + ] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored_entities = ContinuityStore(conn).list_entities() + + assert len(stored_entities) == 1 + assert stored_entities[0]["id"] == UUID(payload["entity"]["id"]) + assert stored_entities[0]["entity_type"] == "project" + assert stored_entities[0]["name"] == "AliceBot" + assert stored_entities[0]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][1]), + ] + + +def test_entity_endpoints_list_and_get_entities_in_deterministic_user_scoped_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + created_entities = [ + store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(seeded["memory_ids"][0])], + ), + store.create_entity( + entity_type="merchant", + name="Neighborhood Cafe", + source_memory_ids=[str(seeded["memory_ids"][2])], + ), + store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(seeded["memory_ids"][0]), str(seeded["memory_ids"][1])], + ), + ] + + list_status, list_payload = invoke_request( + "GET", + "/v0/entities", + query_params={"user_id": str(seeded["user_id"])}, + ) + + expected_entities = sorted(created_entities, key=lambda entity: (entity["created_at"], entity["id"])) + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [str(entity["id"]) for entity in expected_entities] + assert list_payload["summary"] == { + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + target_entity = expected_entities[1] + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/entities/{target_entity['id']}", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert detail_status == 200 + assert detail_payload == { + "entity": { + "id": str(target_entity["id"]), + "entity_type": target_entity["entity_type"], + "name": target_entity["name"], + "source_memory_ids": target_entity["source_memory_ids"], + "created_at": target_entity["created_at"].isoformat(), + } + } + + +def test_entity_endpoints_enforce_per_user_isolation_and_not_found_behavior( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user_with_source_memories(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + entity = ContinuityStore(conn).create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(owner["memory_ids"][0])], + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/entities", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/entities/{entity['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + create_status, create_payload = invoke_request( + "POST", + "/v0/entities", + payload={ + "user_id": str(intruder["user_id"]), + "entity_type": "project", + "name": "Hidden Project", + "source_memory_ids": [str(owner["memory_ids"][0])], + }, + ) + + assert list_status == 200 + assert list_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["created_at_asc", "id_asc"], + }, + } + assert detail_status == 404 + assert detail_payload == { + "detail": f"entity {entity['id']} was not found", + } + assert create_status == 400 + assert create_payload["detail"].startswith( + "source_memory_ids must all reference existing memories owned by the user" + ) + + +def test_create_entity_endpoint_rejects_missing_source_memory_ids(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + missing_memory_id = uuid4() + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entities", + payload={ + "user_id": str(seeded["user_id"]), + "entity_type": "routine", + "name": "Morning Coffee", + "source_memory_ids": [str(missing_memory_id)], + }, + ) + + assert status_code == 400 + assert payload == { + "detail": "source_memory_ids must all reference existing memories owned by the user: " + f"{missing_memory_id}" + } diff --git a/tests/integration/test_entity_edges_api.py b/tests/integration/test_entity_edges_api.py new file mode 100644 index 0000000..d8ea5be --- /dev/null +++ b/tests/integration/test_entity_edges_api.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user_with_source_memories(database_url: str, *, email: str) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Entity edge source thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "works on AliceBot"})["id"], + store.append_event( + thread["id"], session["id"], "message.user", {"text": "works with Neighborhood Cafe"} + )["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "coffee preference"})["id"], + ] + + first_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.project.current", + value={"name": "AliceBot"}, + source_event_ids=(event_ids[0],), + ), + ) + second_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.merchant", + value={"name": "Neighborhood Cafe"}, + source_event_ids=(event_ids[1],), + ), + ) + third_memory = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_ids[2],), + ), + ) + + return { + "user_id": user_id, + "memory_ids": [ + UUID(first_memory.memory["id"]), + UUID(second_memory.memory["id"]), + UUID(third_memory.memory["id"]), + ], + } + + +def seed_entities( + database_url: str, + *, + user_id: UUID, + memory_ids: list[UUID], +) -> dict[str, UUID]: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + person = store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(memory_ids[2])], + ) + merchant = store.create_entity( + entity_type="merchant", + name="Neighborhood Cafe", + source_memory_ids=[str(memory_ids[1])], + ) + project = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(memory_ids[0])], + ) + + return { + "person": person["id"], + "merchant": merchant["id"], + "project": project["id"], + } + + +def test_create_entity_edge_endpoint_persists_user_scoped_edge_with_temporal_metadata( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + entities = seed_entities( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_ids=seeded["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(seeded["user_id"]), + "from_entity_id": str(entities["person"]), + "to_entity_id": str(entities["project"]), + "relationship_type": "works_on", + "valid_from": "2026-03-12T10:00:00+00:00", + "valid_to": "2026-03-12T12:00:00+00:00", + "source_memory_ids": [str(seeded["memory_ids"][0]), str(seeded["memory_ids"][2])], + }, + ) + + assert status_code == 201 + assert payload["edge"]["from_entity_id"] == str(entities["person"]) + assert payload["edge"]["to_entity_id"] == str(entities["project"]) + assert payload["edge"]["relationship_type"] == "works_on" + assert payload["edge"]["valid_from"] == "2026-03-12T10:00:00+00:00" + assert payload["edge"]["valid_to"] == "2026-03-12T12:00:00+00:00" + assert payload["edge"]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][2]), + ] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored_edges = ContinuityStore(conn).list_entity_edges_for_entity(entities["person"]) + + assert len(stored_edges) == 1 + assert stored_edges[0]["id"] == UUID(payload["edge"]["id"]) + assert stored_edges[0]["relationship_type"] == "works_on" + assert stored_edges[0]["source_memory_ids"] == [ + str(seeded["memory_ids"][0]), + str(seeded["memory_ids"][2]), + ] + + +def test_entity_edge_list_endpoint_returns_incident_edges_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + entities = seed_entities( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_ids=seeded["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + created_edges = [ + store.create_entity_edge( + from_entity_id=entities["person"], + to_entity_id=entities["project"], + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=[str(seeded["memory_ids"][0])], + ), + store.create_entity_edge( + from_entity_id=entities["merchant"], + to_entity_id=entities["project"], + relationship_type="supplies", + valid_from=None, + valid_to=None, + source_memory_ids=[str(seeded["memory_ids"][1])], + ), + store.create_entity_edge( + from_entity_id=entities["project"], + to_entity_id=entities["merchant"], + relationship_type="references", + valid_from=None, + valid_to=None, + source_memory_ids=[str(seeded["memory_ids"][2])], + ), + ] + + status_code, payload = invoke_request( + "GET", + f"/v0/entities/{entities['project']}/edges", + query_params={"user_id": str(seeded["user_id"])}, + ) + + expected_edges = sorted(created_edges, key=lambda edge: (edge["created_at"], edge["id"])) + + assert status_code == 200 + assert [item["id"] for item in payload["items"]] == [str(edge["id"]) for edge in expected_edges] + assert payload["summary"] == { + "entity_id": str(entities["project"]), + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + +def test_entity_edge_endpoints_enforce_per_user_isolation_and_reference_validation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + owner_entities = seed_entities( + migrated_database_urls["app"], + user_id=owner["user_id"], + memory_ids=owner["memory_ids"], + ) + intruder = seed_user_with_source_memories(migrated_database_urls["app"], email="intruder@example.com") + intruder_entities = seed_entities( + migrated_database_urls["app"], + user_id=intruder["user_id"], + memory_ids=intruder["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + ContinuityStore(conn).create_entity_edge( + from_entity_id=owner_entities["person"], + to_entity_id=owner_entities["project"], + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=[str(owner["memory_ids"][0])], + ) + + list_status, list_payload = invoke_request( + "GET", + f"/v0/entities/{owner_entities['project']}/edges", + query_params={"user_id": str(intruder['user_id'])}, + ) + entity_status, entity_payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(intruder["user_id"]), + "from_entity_id": str(owner_entities["person"]), + "to_entity_id": str(intruder_entities["project"]), + "relationship_type": "works_on", + "source_memory_ids": [str(intruder["memory_ids"][0])], + }, + ) + memory_status, memory_payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(intruder["user_id"]), + "from_entity_id": str(intruder_entities["person"]), + "to_entity_id": str(intruder_entities["project"]), + "relationship_type": "works_on", + "source_memory_ids": [str(owner["memory_ids"][0])], + }, + ) + + assert list_status == 404 + assert list_payload == { + "detail": f"entity {owner_entities['project']} was not found", + } + assert entity_status == 400 + assert entity_payload == { + "detail": "from_entity_id must reference an existing entity owned by the user: " + f"{owner_entities['person']}" + } + assert memory_status == 400 + assert memory_payload == { + "detail": "source_memory_ids must all reference existing memories owned by the user: " + f"{owner['memory_ids'][0]}" + } + + +def test_create_entity_edge_endpoint_rejects_invalid_temporal_range( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user_with_source_memories(migrated_database_urls["app"], email="owner@example.com") + entities = seed_entities( + migrated_database_urls["app"], + user_id=seeded["user_id"], + memory_ids=seeded["memory_ids"], + ) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "POST", + "/v0/entity-edges", + payload={ + "user_id": str(seeded["user_id"]), + "from_entity_id": str(entities["person"]), + "to_entity_id": str(entities["project"]), + "relationship_type": "works_on", + "valid_from": "2026-03-12T12:00:00+00:00", + "valid_to": "2026-03-12T10:00:00+00:00", + "source_memory_ids": [str(seeded["memory_ids"][0])], + }, + ) + + assert status_code == 400 + assert payload == { + "detail": "valid_to must be greater than or equal to valid_from", + } diff --git a/tests/integration/test_execution_budgets_api.py b/tests/integration/test_execution_budgets_api.py new file mode 100644 index 0000000..5fe3e4f --- /dev/null +++ b/tests/integration/test_execution_budgets_api.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Budget lifecycle thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def create_budget( + *, + user_id: UUID, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, +) -> tuple[int, dict[str, Any]]: + payload: dict[str, Any] = { + "user_id": str(user_id), + "max_completed_executions": max_completed_executions, + } + if tool_key is not None: + payload["tool_key"] = tool_key + if domain_hint is not None: + payload["domain_hint"] = domain_hint + if rolling_window_seconds is not None: + payload["rolling_window_seconds"] = rolling_window_seconds + return invoke_request("POST", "/v0/execution-budgets", payload=payload) + + +def test_execution_budget_endpoints_create_list_and_get_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + second_status, second_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ) + first_status, first_payload = create_budget( + user_id=owner["user_id"], + tool_key=None, + domain_hint="docs", + max_completed_executions=1, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{second_payload['execution_budget']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert first_status == 201 + assert second_status == 201 + assert second_payload["execution_budget"]["status"] == "active" + assert second_payload["execution_budget"]["deactivated_at"] is None + assert second_payload["execution_budget"]["rolling_window_seconds"] == 3600 + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + second_payload["execution_budget"]["id"], + first_payload["execution_budget"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"execution_budget": second_payload["execution_budget"]} + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{first_payload['execution_budget']['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"execution budget {first_payload['execution_budget']['id']} was not found" + } + + +def test_create_execution_budget_endpoint_requires_at_least_one_selector( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + status_code, payload = invoke_request( + "POST", + "/v0/execution-budgets", + payload={ + "user_id": str(owner["user_id"]), + "max_completed_executions": 1, + }, + ) + + assert status_code == 400 + assert payload == { + "detail": "execution budget requires at least one selector: tool_key or domain_hint" + } + + +def test_create_execution_budget_endpoint_rejects_duplicate_active_scope( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + first_status, _ = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ) + second_status, second_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ) + + assert first_status == 201 + assert second_status == 400 + assert second_payload == { + "detail": "active execution budget already exists for selector scope tool_key='proxy.echo', domain_hint='docs'" + } + + +def test_deactivate_execution_budget_endpoint_updates_reads_and_emits_trace( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + create_status, create_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + assert create_status == 201 + + deactivate_status, deactivate_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_status, isolated_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + }, + ) + + assert deactivate_status == 200 + assert deactivate_payload["execution_budget"]["status"] == "inactive" + assert deactivate_payload["execution_budget"]["deactivated_at"] is not None + assert deactivate_payload["trace"]["trace_event_count"] == 3 + assert list_status == 200 + assert list_payload["items"][0] == deactivate_payload["execution_budget"] + assert detail_status == 200 + assert detail_payload == {"execution_budget": deactivate_payload["execution_budget"]} + assert isolated_status == 404 + assert isolated_payload == { + "detail": f"execution budget {create_payload['execution_budget']['id']} was not found" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(deactivate_payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(deactivate_payload["trace"]["trace_id"])) + + assert trace["kind"] == "execution_budget.lifecycle" + assert trace["compiler_version"] == "execution_budget_lifecycle_v0" + assert trace["limits"]["requested_action"] == "deactivate" + assert [event["kind"] for event in trace_events] == [ + "execution_budget.lifecycle.request", + "execution_budget.lifecycle.state", + "execution_budget.lifecycle.summary", + ] + assert trace_events[1]["payload"]["current_status"] == "inactive" + assert trace_events[2]["payload"]["outcome"] == "deactivated" + + +def test_supersede_execution_budget_endpoint_replaces_active_budget_and_emits_trace( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + create_status, create_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + rolling_window_seconds=1800, + ) + assert create_status == 201 + + supersede_status, supersede_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/supersede", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "max_completed_executions": 3, + }, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/execution-budgets", + query_params={"user_id": str(owner["user_id"])}, + ) + original_detail_status, original_detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + replacement_detail_status, replacement_detail_payload = invoke_request( + "GET", + f"/v0/execution-budgets/{supersede_payload['replacement_budget']['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert supersede_status == 200 + assert supersede_payload["superseded_budget"]["status"] == "superseded" + assert supersede_payload["replacement_budget"]["status"] == "active" + assert supersede_payload["replacement_budget"]["rolling_window_seconds"] == 1800 + assert supersede_payload["replacement_budget"]["supersedes_budget_id"] == create_payload["execution_budget"]["id"] + assert supersede_payload["superseded_budget"]["superseded_by_budget_id"] == supersede_payload["replacement_budget"]["id"] + assert list_status == 200 + assert [item["status"] for item in list_payload["items"]] == ["superseded", "active"] + assert original_detail_status == 200 + assert original_detail_payload == {"execution_budget": supersede_payload["superseded_budget"]} + assert replacement_detail_status == 200 + assert replacement_detail_payload == {"execution_budget": supersede_payload["replacement_budget"]} + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(supersede_payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(supersede_payload["trace"]["trace_id"])) + + assert trace["limits"]["requested_action"] == "supersede" + assert trace["limits"]["outcome"] == "superseded" + assert trace_events[1]["payload"]["replacement_budget_id"] == supersede_payload["replacement_budget"]["id"] + assert trace_events[2]["payload"]["active_budget_id"] == supersede_payload["replacement_budget"]["id"] + + +def test_execution_budget_lifecycle_rejects_invalid_transition_deterministically( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + create_status, create_payload = create_budget( + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + assert create_status == 201 + + first_status, _ = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + second_status, second_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{create_payload['execution_budget']['id']}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + + assert first_status == 200 + assert second_status == 409 + assert second_payload == { + "detail": f"execution budget {create_payload['execution_budget']['id']} is inactive and cannot be deactivated" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace_rows = store.conn.execute( + "SELECT id FROM traces WHERE kind = %s ORDER BY created_at ASC, id ASC", + ("execution_budget.lifecycle",), + ).fetchall() + rejected_trace_events = store.list_trace_events(trace_rows[-1]["id"]) + + assert rejected_trace_events[1]["payload"]["rejection_reason"] == second_payload["detail"] + assert rejected_trace_events[2]["payload"]["outcome"] == "rejected" + + +def test_execution_budget_active_scope_uniqueness_is_enforced_in_database( + migrated_database_urls, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ) + + with pytest.raises(psycopg.IntegrityError): + with conn.transaction(): + store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ) diff --git a/tests/integration/test_explicit_preferences_api.py b/tests/integration/test_explicit_preferences_api.py new file mode 100644 index 0000000..67c4db2 --- /dev/null +++ b/tests/integration/test_explicit_preferences_api.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import json +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.explicit_preferences import _build_memory_key +from alicebot_api.store import ContinuityStore + + +def invoke_extract_explicit_preferences(payload: dict[str, str]) -> tuple[int, dict[str, object]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/memories/extract-explicit-preferences", + "raw_path": b"/v0/memories/extract-explicit-preferences", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_explicit_preference_events(database_url: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Explicit preference extraction") + session = store.create_session(thread["id"], status="active") + like_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I like black coffee."}, + )["id"] + dislike_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I don't like black coffee."}, + )["id"] + unsupported_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I had coffee yesterday."}, + )["id"] + clause_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I prefer that we meet tomorrow."}, + )["id"] + cpp_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I like C++."}, + )["id"] + csharp_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "I like C#."}, + )["id"] + assistant_event = store.append_event( + thread["id"], + session["id"], + "message.assistant", + {"text": "I like black coffee."}, + )["id"] + + return { + "user_id": user_id, + "like_event_id": like_event, + "dislike_event_id": dislike_event, + "unsupported_event_id": unsupported_event, + "clause_event_id": clause_event, + "cpp_event_id": cpp_event, + "csharp_event_id": csharp_event, + "assistant_event_id": assistant_event, + } + + +def test_extract_explicit_preferences_endpoint_admits_supported_candidates_and_persists_revisions( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + memory_key = _build_memory_key("black coffee") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + add_status, add_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["like_event_id"]), + } + ) + update_status, update_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["dislike_event_id"]), + } + ) + + assert add_status == 200 + assert add_payload["candidates"] == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["like_event_id"])], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ] + assert add_payload["admissions"][0]["decision"] == "ADD" + assert add_payload["summary"] == { + "source_event_id": str(seeded["like_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + } + + assert update_status == 200 + assert update_payload["candidates"] == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["dislike_event_id"])], + "delete_requested": False, + "pattern": "i_dont_like", + "subject_text": "black coffee", + } + ] + assert update_payload["admissions"][0]["decision"] == "UPDATE" + assert update_payload["summary"] == { + "source_event_id": str(seeded["dislike_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + } + + memory_id = UUID(str(update_payload["admissions"][0]["memory"]["id"])) + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + memories = store.list_memories() + revisions = store.list_memory_revisions(memory_id) + + assert len(memories) == 1 + assert memories[0]["id"] == memory_id + assert memories[0]["memory_key"] == memory_key + assert memories[0]["value"] == { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + } + assert [revision["action"] for revision in revisions] == ["ADD", "UPDATE"] + assert revisions[0]["candidate"] == { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["like_event_id"])], + "delete_requested": False, + } + assert revisions[1]["candidate"] == { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(seeded["dislike_event_id"])], + "delete_requested": False, + } + + +def test_extract_explicit_preferences_endpoint_returns_no_candidates_for_unsupported_text( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["unsupported_event_id"]), + } + ) + + assert status_code == 200 + assert payload == { + "candidates": [], + "admissions": [], + "summary": { + "source_event_id": str(seeded["unsupported_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 0, + "admission_count": 0, + "persisted_change_count": 0, + "noop_count": 0, + }, + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] + + +def test_extract_explicit_preferences_endpoint_rejects_clause_style_tail( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["clause_event_id"]), + } + ) + + assert status_code == 200 + assert payload == { + "candidates": [], + "admissions": [], + "summary": { + "source_event_id": str(seeded["clause_event_id"]), + "source_event_kind": "message.user", + "candidate_count": 0, + "admission_count": 0, + "persisted_change_count": 0, + "noop_count": 0, + }, + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] + + +def test_extract_explicit_preferences_endpoint_keeps_symbol_subjects_in_distinct_memories( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + cpp_key = _build_memory_key("C++") + csharp_key = _build_memory_key("C#") + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + cpp_status, cpp_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["cpp_event_id"]), + } + ) + csharp_status, csharp_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["csharp_event_id"]), + } + ) + + assert cpp_status == 200 + assert cpp_payload["candidates"][0]["memory_key"] == cpp_key + assert cpp_payload["admissions"][0]["decision"] == "ADD" + assert csharp_status == 200 + assert csharp_payload["candidates"][0]["memory_key"] == csharp_key + assert csharp_payload["admissions"][0]["decision"] == "ADD" + assert cpp_key != csharp_key + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + memories = sorted(store.list_memories(), key=lambda memory: memory["memory_key"]) + + assert [memory["memory_key"] for memory in memories] == sorted([cpp_key, csharp_key]) + assert {memory["memory_key"]: memory["value"] for memory in memories} == { + cpp_key: { + "kind": "explicit_preference", + "preference": "like", + "text": "C++", + }, + csharp_key: { + "kind": "explicit_preference", + "preference": "like", + "text": "C#", + }, + } + + +def test_extract_explicit_preferences_endpoint_validates_source_event_and_user_scope( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_explicit_preference_events(migrated_database_urls["app"]) + intruder_id = uuid4() + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + + assistant_status, assistant_payload = invoke_extract_explicit_preferences( + { + "user_id": str(seeded["user_id"]), + "source_event_id": str(seeded["assistant_event_id"]), + } + ) + intruder_status, intruder_payload = invoke_extract_explicit_preferences( + { + "user_id": str(intruder_id), + "source_event_id": str(seeded["like_event_id"]), + } + ) + + assert assistant_status == 400 + assert assistant_payload == { + "detail": "source_event_id must reference an existing message.user event owned by the user" + } + assert intruder_status == 400 + assert intruder_payload == { + "detail": "source_event_id must reference an existing message.user event owned by the user" + } + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] diff --git a/tests/integration/test_healthcheck.py b/tests/integration/test_healthcheck.py new file mode 100644 index 0000000..47801f1 --- /dev/null +++ b/tests/integration/test_healthcheck.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +import socket +import subprocess +import time +from urllib import error, request + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings + + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def invoke_healthcheck() -> tuple[int, dict[str, object]]: + messages: list[dict[str, object]] = [] + + async def receive() -> dict[str, object]: + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "GET", + "scheme": "http", + "path": "/healthz", + "raw_path": b"/healthz", + "query_string": b"", + "headers": [], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def test_healthcheck_endpoint_returns_ok_response(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=2, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "ping_database", lambda *_args, **_kwargs: True) + + status_code, payload = invoke_healthcheck() + + assert status_code == 200 + assert payload["status"] == "ok" + assert payload["services"]["database"]["status"] == "ok" + assert payload["services"]["redis"]["status"] == "not_checked" + assert payload["services"]["redis"]["url"] == "redis://cache:6379/0" + assert payload["services"]["object_storage"]["status"] == "not_checked" + + +def test_healthcheck_endpoint_returns_degraded_response(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=2, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "ping_database", lambda *_args, **_kwargs: False) + + status_code, payload = invoke_healthcheck() + + assert status_code == 503 + assert payload["status"] == "degraded" + assert payload["services"]["database"]["status"] == "unreachable" + assert payload["services"]["redis"]["status"] == "not_checked" + assert payload["services"]["redis"]["url"] == "redis://cache:6379/0" + assert payload["services"]["object_storage"]["status"] == "not_checked" + + +def test_api_dev_script_serves_live_healthcheck() -> None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + + env = os.environ.copy() + env.update( + { + "APP_HOST": "127.0.0.1", + "APP_PORT": str(port), + "APP_RELOAD": "false", + "APP_ENV": "test", + "DATABASE_URL": "postgresql://invalid:invalid@127.0.0.1:1/invalid", + "REDIS_URL": "redis://alicebot:supersecret@localhost:6379/0", + "HEALTHCHECK_TIMEOUT_SECONDS": "1", + } + ) + + process = subprocess.Popen( + ["/bin/bash", str(REPO_ROOT / "scripts" / "api_dev.sh")], + cwd=REPO_ROOT, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + payload: dict[str, object] | None = None + status_code: int | None = None + + try: + deadline = time.time() + 15 + url = f"http://127.0.0.1:{port}/healthz" + + while time.time() < deadline: + if process.poll() is not None: + stdout, stderr = process.communicate(timeout=1) + raise AssertionError( + "api_dev.sh exited before serving /healthz\n" + f"stdout:\n{stdout}\n" + f"stderr:\n{stderr}" + ) + + try: + with request.urlopen(url, timeout=0.5) as response: + status_code = response.status + payload = json.loads(response.read()) + break + except error.HTTPError as exc: + status_code = exc.code + payload = json.loads(exc.read()) + break + except OSError: + time.sleep(0.1) + else: + raise AssertionError("Timed out waiting for api_dev.sh to serve /healthz") + finally: + process.terminate() + try: + process.communicate(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.communicate(timeout=5) + + assert status_code == 503 + assert payload == { + "status": "degraded", + "environment": "test", + "services": { + "database": {"status": "unreachable"}, + "redis": {"status": "not_checked", "url": "redis://localhost:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://localhost:9000", + }, + }, + } diff --git a/tests/integration/test_memory_admission.py b/tests/integration/test_memory_admission.py new file mode 100644 index 0000000..43c4e75 --- /dev/null +++ b/tests/integration/test_memory_admission.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import json +from typing import Any +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_admit_memory(payload: dict[str, Any]) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/memories/admit", + "raw_path": b"/v0/memories/admit", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_memory_evidence(database_url: str) -> tuple[UUID, list[UUID]]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Memory thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "likes black coffee"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "actually likes oat milk"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "stop remembering coffee"})["id"], + ] + + return user_id, event_ids + + +def test_admit_memory_endpoint_returns_noop_and_persists_nothing_without_value( + migrated_database_urls, + monkeypatch, +) -> None: + user_id, event_ids = seed_memory_evidence(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": None, + "source_event_ids": [str(event_ids[0])], + } + ) + + assert status_code == 200 + assert payload == { + "decision": "NOOP", + "reason": "candidate_value_missing", + "memory": None, + "revision": None, + } + + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + assert store.list_memories() == [] + + +def test_admit_memory_endpoint_rejects_unknown_source_events(migrated_database_urls, monkeypatch) -> None: + user_id = uuid4() + + with user_connection(migrated_database_urls["app"], user_id) as conn: + ContinuityStore(conn).create_user(user_id, "owner@example.com", "Owner") + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(uuid4())], + } + ) + + assert status_code == 400 + assert payload["detail"].startswith( + "source_event_ids must all reference existing events owned by the user" + ) + + +def test_admit_memory_endpoint_persists_add_update_and_delete_revisions( + migrated_database_urls, + monkeypatch, +) -> None: + user_id, event_ids = seed_memory_evidence(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + add_status, add_payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(event_ids[0])], + } + ) + update_status, update_payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": [str(event_ids[1])], + } + ) + delete_status, delete_payload = invoke_admit_memory( + { + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": None, + "source_event_ids": [str(event_ids[2])], + "delete_requested": True, + } + ) + + assert add_status == 200 + assert add_payload["decision"] == "ADD" + assert update_status == 200 + assert update_payload["decision"] == "UPDATE" + assert delete_status == 200 + assert delete_payload["decision"] == "DELETE" + + memory_id = UUID(delete_payload["memory"]["id"]) + with user_connection(migrated_database_urls["app"], user_id) as conn: + store = ContinuityStore(conn) + memories = store.list_memories() + revisions = store.list_memory_revisions(memory_id) + + assert len(memories) == 1 + assert memories[0]["id"] == memory_id + assert memories[0]["status"] == "deleted" + assert memories[0]["source_event_ids"] == [str(event_ids[2])] + assert [revision["sequence_no"] for revision in revisions] == [1, 2, 3] + assert [revision["action"] for revision in revisions] == ["ADD", "UPDATE", "DELETE"] + assert revisions[0]["new_value"] == {"likes": "black"} + assert revisions[1]["previous_value"] == {"likes": "black"} + assert revisions[1]["new_value"] == {"likes": "oat milk"} + assert revisions[2]["previous_value"] == {"likes": "oat milk"} + assert revisions[2]["new_value"] is None + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with conn.cursor() as cur: + with pytest.raises(psycopg.Error, match="append-only"): + cur.execute( + "UPDATE memory_revisions SET action = 'MUTATED' WHERE memory_id = %s", + (memory_id,), + ) + + +def test_memories_and_memory_revisions_respect_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner_id, event_ids = seed_memory_evidence(migrated_database_urls["app"]) + intruder_id = uuid4() + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_admit_memory( + { + "user_id": str(owner_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(event_ids[0])], + } + ) + + assert status_code == 200 + memory_id = UUID(payload["memory"]["id"]) + + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + store = ContinuityStore(conn) + store.create_user(intruder_id, "intruder@example.com", "Intruder") + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*) AS count FROM memories WHERE id = %s", (memory_id,)) + memory_count = cur.fetchone() + cur.execute( + "SELECT COUNT(*) AS count FROM memory_revisions WHERE memory_id = %s", + (memory_id,), + ) + revision_count = cur.fetchone() + cur.execute( + "UPDATE memories SET status = 'deleted' WHERE id = %s RETURNING id", + (memory_id,), + ) + updated_rows = cur.fetchall() + + assert memory_count["count"] == 0 + assert revision_count["count"] == 0 + assert updated_rows == [] + assert store.list_memories() == [] + assert store.list_memory_revisions(memory_id) == [] diff --git a/tests/integration/test_memory_review_api.py b/tests/integration/test_memory_review_api.py new file mode 100644 index 0000000..c096817 --- /dev/null +++ b/tests/integration/test_memory_review_api.py @@ -0,0 +1,526 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_review_memories(database_url: str) -> dict[str, str]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "reviewer@example.com", "Reviewer") + thread = store.create_thread("Memory review thread") + session = store.create_session(thread["id"], status="active") + event_ids = [ + store.append_event(thread["id"], session["id"], "message.user", {"text": "likes black coffee"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "likes salty snacks"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "reads science fiction"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "enjoys hiking"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "forget the snack preference"})["id"], + store.append_event(thread["id"], session["id"], "message.user", {"text": "actually likes oat milk"})["id"], + ] + + coffee = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "black"}, + source_event_ids=(event_ids[0],), + ), + ) + snack = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.snack", + value={"likes": "chips"}, + source_event_ids=(event_ids[1],), + ), + ) + book = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.book", + value={"genre": "science fiction"}, + source_event_ids=(event_ids[2],), + ), + ) + hobby = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.hobby", + value={"likes": "hiking"}, + source_event_ids=(event_ids[3],), + ), + ) + admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.snack", + value=None, + source_event_ids=(event_ids[4],), + delete_requested=True, + ), + ) + admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_ids[5],), + ), + ) + + return { + "user_id": str(user_id), + "coffee_memory_id": coffee.memory["id"], + "snack_memory_id": snack.memory["id"], + "book_memory_id": book.memory["id"], + "hobby_memory_id": hobby.memory["id"], + "coffee_add_event_id": str(event_ids[0]), + "coffee_update_event_id": str(event_ids[5]), + "book_add_event_id": str(event_ids[2]), + "hobby_add_event_id": str(event_ids[3]), + "snack_delete_event_id": str(event_ids[4]), + } + + +def seed_review_queue_state(database_url: str) -> dict[str, str]: + seeded = seed_review_memories(database_url) + + with user_connection(database_url, UUID(seeded["user_id"])) as conn: + store = ContinuityStore(conn) + store.create_memory_review_label( + memory_id=UUID(seeded["hobby_memory_id"]), + label="correct", + note="Already reviewed.", + ) + store.create_memory_review_label( + memory_id=UUID(seeded["snack_memory_id"]), + label="outdated", + note="Deleted memory remains part of evaluation counts only.", + ) + + return seeded + + +def seed_memory_evaluation_state(database_url: str) -> dict[str, str]: + seeded = seed_review_memories(database_url) + + with user_connection(database_url, UUID(seeded["user_id"])) as conn: + store = ContinuityStore(conn) + store.create_memory_review_label( + memory_id=UUID(seeded["coffee_memory_id"]), + label="correct", + note="Matches the latest coffee preference.", + ) + store.create_memory_review_label( + memory_id=UUID(seeded["coffee_memory_id"]), + label="insufficient_evidence", + note="One source event is still a thin basis.", + ) + store.create_memory_review_label( + memory_id=UUID(seeded["snack_memory_id"]), + label="outdated", + note="The deleted snack preference is superseded.", + ) + + return seeded + + +def test_list_memories_endpoint_returns_filtered_memories_with_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_memories(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + "/v0/memories", + query_params={ + "user_id": seeded["user_id"], + "status": "active", + "limit": "2", + }, + ) + + assert status_code == 200 + assert [item["id"] for item in payload["items"]] == [ + seeded["coffee_memory_id"], + seeded["hobby_memory_id"], + ] + assert payload["items"][0]["status"] == "active" + assert payload["items"][0]["value"] == {"likes": "oat milk"} + assert payload["items"][0]["source_event_ids"] == [seeded["coffee_update_event_id"]] + assert payload["summary"] == { + "status": "active", + "limit": 2, + "returned_count": 2, + "total_count": 3, + "has_more": True, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + } + + deleted_status, deleted_payload = invoke_request( + "GET", + "/v0/memories", + query_params={ + "user_id": seeded["user_id"], + "status": "deleted", + "limit": "5", + }, + ) + + assert deleted_status == 200 + assert deleted_payload["items"] == [ + { + "id": seeded["snack_memory_id"], + "memory_key": "user.preference.snack", + "value": {"likes": "chips"}, + "status": "deleted", + "source_event_ids": [seeded["snack_delete_event_id"]], + "created_at": deleted_payload["items"][0]["created_at"], + "updated_at": deleted_payload["items"][0]["updated_at"], + "deleted_at": deleted_payload["items"][0]["deleted_at"], + } + ] + assert deleted_payload["summary"] == { + "status": "deleted", + "limit": 5, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + } + + +def test_memory_review_endpoints_return_current_memory_and_revision_history( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_memories(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + memory_status, memory_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}", + query_params={"user_id": seeded["user_id"]}, + ) + revisions_status, revisions_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}/revisions", + query_params={"user_id": seeded["user_id"], "limit": "5"}, + ) + + assert memory_status == 200 + assert memory_payload["memory"]["id"] == seeded["coffee_memory_id"] + assert memory_payload["memory"]["memory_key"] == "user.preference.coffee" + assert memory_payload["memory"]["status"] == "active" + assert memory_payload["memory"]["value"] == {"likes": "oat milk"} + assert memory_payload["memory"]["source_event_ids"] == [seeded["coffee_update_event_id"]] + + assert revisions_status == 200 + assert [item["sequence_no"] for item in revisions_payload["items"]] == [1, 2] + assert [item["action"] for item in revisions_payload["items"]] == ["ADD", "UPDATE"] + assert revisions_payload["items"][0]["new_value"] == {"likes": "black"} + assert revisions_payload["items"][0]["source_event_ids"] == [seeded["coffee_add_event_id"]] + assert revisions_payload["items"][1]["previous_value"] == {"likes": "black"} + assert revisions_payload["items"][1]["new_value"] == {"likes": "oat milk"} + assert revisions_payload["items"][1]["source_event_ids"] == [seeded["coffee_update_event_id"]] + assert revisions_payload["summary"] == { + "memory_id": seeded["coffee_memory_id"], + "limit": 5, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["sequence_no_asc"], + } + + +def test_memory_review_endpoints_enforce_per_user_isolation_and_not_found_behavior( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_memories(migrated_database_urls["app"]) + intruder_id = uuid4() + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/memories", + query_params={"user_id": str(intruder_id), "status": "all", "limit": "10"}, + ) + memory_status, memory_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}", + query_params={"user_id": str(intruder_id)}, + ) + revisions_status, revisions_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['coffee_memory_id']}/revisions", + query_params={"user_id": str(intruder_id), "limit": "10"}, + ) + + assert list_status == 200 + assert list_payload == { + "items": [], + "summary": { + "status": "all", + "limit": 10, + "returned_count": 0, + "total_count": 0, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert memory_status == 404 + assert memory_payload == { + "detail": f"memory {seeded['coffee_memory_id']} was not found", + } + assert revisions_status == 404 + assert revisions_payload == { + "detail": f"memory {seeded['coffee_memory_id']} was not found", + } + + +def test_memory_review_queue_endpoint_returns_only_active_unlabeled_memories_in_deterministic_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_review_queue_state(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + "/v0/memories/review-queue", + query_params={ + "user_id": seeded["user_id"], + "limit": "2", + }, + ) + + assert status_code == 200 + assert payload == { + "items": [ + { + "id": seeded["coffee_memory_id"], + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [seeded["coffee_update_event_id"]], + "created_at": payload["items"][0]["created_at"], + "updated_at": payload["items"][0]["updated_at"], + }, + { + "id": seeded["book_memory_id"], + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": [seeded["book_add_event_id"]], + "created_at": payload["items"][1]["created_at"], + "updated_at": payload["items"][1]["updated_at"], + }, + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 2, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + +def test_memory_evaluation_summary_endpoint_returns_explicit_consistent_counts( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_evaluation_state(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + "/v0/memories/evaluation-summary", + query_params={"user_id": seeded["user_id"]}, + ) + + assert status_code == 200 + assert payload == { + "summary": { + "total_memory_count": 4, + "active_memory_count": 3, + "deleted_memory_count": 1, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } + + +def test_memory_review_queue_and_evaluation_summary_enforce_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_evaluation_state(migrated_database_urls["app"]) + intruder_id = uuid4() + with user_connection(migrated_database_urls["app"], intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + queue_status, queue_payload = invoke_request( + "GET", + "/v0/memories/review-queue", + query_params={"user_id": str(intruder_id), "limit": "10"}, + ) + summary_status, summary_payload = invoke_request( + "GET", + "/v0/memories/evaluation-summary", + query_params={"user_id": str(intruder_id)}, + ) + + assert seeded["user_id"] != str(intruder_id) + assert queue_status == 200 + assert queue_payload == { + "items": [], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 10, + "returned_count": 0, + "total_count": 0, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert summary_status == 200 + assert summary_payload == { + "summary": { + "total_memory_count": 0, + "active_memory_count": 0, + "deleted_memory_count": 0, + "labeled_memory_count": 0, + "unlabeled_memory_count": 0, + "total_label_row_count": 0, + "label_row_counts_by_value": { + "correct": 0, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } diff --git a/tests/integration/test_memory_review_labels_api.py b/tests/integration/test_memory_review_labels_api.py new file mode 100644 index 0000000..1b184e8 --- /dev/null +++ b/tests/integration/test_memory_review_labels_api.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.db import user_connection +from alicebot_api.memory import admit_memory_candidate +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_memory_for_review_labels(database_url: str) -> dict[str, str]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, "reviewer@example.com", "Reviewer") + thread = store.create_thread("Memory review labels thread") + session = store.create_session(thread["id"], status="active") + event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "likes oat milk in coffee"}, + ) + decision = admit_memory_candidate( + store, + user_id=user_id, + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event["id"],), + ), + ) + + assert decision.memory is not None + return { + "user_id": str(user_id), + "memory_id": decision.memory["id"], + } + + +def seed_intruder(database_url: str) -> UUID: + intruder_id = uuid4() + with user_connection(database_url, intruder_id) as conn: + ContinuityStore(conn).create_user(intruder_id, "intruder@example.com", "Intruder") + return intruder_id + + +def test_memory_review_label_endpoints_create_and_list_labels_with_stable_summary_counts( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + first_status, first_payload = invoke_request( + "POST", + f"/v0/memories/{seeded['memory_id']}/labels", + payload={ + "user_id": seeded["user_id"], + "label": "correct", + "note": "Matches the latest admitted evidence.", + }, + ) + second_status, second_payload = invoke_request( + "POST", + f"/v0/memories/{seeded['memory_id']}/labels", + payload={ + "user_id": seeded["user_id"], + "label": "outdated", + "note": None, + }, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": seeded["user_id"]}, + ) + + assert first_status == 201 + assert first_payload["label"]["memory_id"] == seeded["memory_id"] + assert first_payload["label"]["reviewer_user_id"] == seeded["user_id"] + assert first_payload["label"]["label"] == "correct" + assert first_payload["label"]["note"] == "Matches the latest admitted evidence." + assert first_payload["summary"] == { + "memory_id": seeded["memory_id"], + "total_count": 1, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + } + + assert second_status == 201 + assert second_payload["label"]["label"] == "outdated" + assert second_payload["label"]["note"] is None + assert second_payload["summary"] == { + "memory_id": seeded["memory_id"], + "total_count": 2, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + } + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + first_payload["label"]["id"], + second_payload["label"]["id"], + ] + assert list_payload["summary"] == second_payload["summary"] + + +def test_memory_review_label_listing_uses_deterministic_created_at_then_id_order( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + with user_connection(migrated_database_urls["app"], UUID(seeded["user_id"])) as conn: + store = ContinuityStore(conn) + created_labels = [ + store.create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="incorrect", + note="Conflicts with the source event.", + ), + store.create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="insufficient_evidence", + note="The evidence is too weak.", + ), + store.create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="outdated", + note="Superseded by newer behavior.", + ), + ] + + status_code, payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": seeded["user_id"]}, + ) + + expected_ids = [ + str(label["id"]) + for label in sorted( + created_labels, + key=lambda label: (label["created_at"], label["id"]), + ) + ] + + assert status_code == 200 + assert [item["id"] for item in payload["items"]] == expected_ids + assert payload["summary"] == { + "memory_id": seeded["memory_id"], + "total_count": 3, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 1, + "insufficient_evidence": 1, + }, + "order": ["created_at_asc", "id_asc"], + } + + +def test_memory_review_label_list_returns_empty_items_and_zero_filled_summary_for_unlabeled_memory( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + status_code, payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": seeded["user_id"]}, + ) + + assert status_code == 200 + assert payload == { + "items": [], + "summary": { + "memory_id": seeded["memory_id"], + "total_count": 0, + "counts_by_label": { + "correct": 0, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_memory_review_labels_reject_update_and_delete_at_database_level(migrated_database_urls) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + + with user_connection(migrated_database_urls["app"], UUID(seeded["user_id"])) as conn: + label = ContinuityStore(conn).create_memory_review_label( + memory_id=UUID(seeded["memory_id"]), + label="correct", + note="Initial review label.", + ) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "UPDATE memory_review_labels SET label = 'incorrect' WHERE id = %s", + (label["id"],), + ) + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "DELETE FROM memory_review_labels WHERE id = %s", + (label["id"],), + ) + + +def test_memory_review_label_endpoints_enforce_per_user_isolation_and_not_found_behavior( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_memory_for_review_labels(migrated_database_urls["app"]) + intruder_id = seed_intruder(migrated_database_urls["app"]) + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings(database_url=migrated_database_urls["app"]), + ) + + create_status, create_payload = invoke_request( + "POST", + f"/v0/memories/{seeded['memory_id']}/labels", + payload={ + "user_id": str(intruder_id), + "label": "incorrect", + "note": "Should not be able to label another user's memory.", + }, + ) + list_status, list_payload = invoke_request( + "GET", + f"/v0/memories/{seeded['memory_id']}/labels", + query_params={"user_id": str(intruder_id)}, + ) + + assert create_status == 404 + assert create_payload == {"detail": f"memory {seeded['memory_id']} was not found"} + assert list_status == 404 + assert list_payload == {"detail": f"memory {seeded['memory_id']} was not found"} diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py new file mode 100644 index 0000000..434645e --- /dev/null +++ b/tests/integration/test_migrations.py @@ -0,0 +1,798 @@ +from __future__ import annotations + +from alembic import command +import psycopg + +from alicebot_api.migrations import make_alembic_config + + +def test_tool_execution_task_step_linkage_migration_backfills_existing_rows(database_urls): + config = make_alembic_config(database_urls["admin"]) + user_id = "00000000-0000-0000-0000-000000000001" + thread_id = "00000000-0000-0000-0000-000000000002" + trace_id = "00000000-0000-0000-0000-000000000003" + tool_id = "00000000-0000-0000-0000-000000000004" + approval_id = "00000000-0000-0000-0000-000000000005" + task_id = "00000000-0000-0000-0000-000000000006" + task_step_id = "00000000-0000-0000-0000-000000000007" + execution_id = "00000000-0000-0000-0000-000000000008" + + command.upgrade(config, "20260313_0020") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO users (id, email, display_name) + VALUES (%s, 'migration@example.com', 'Migration User') + """, + (user_id,), + ) + cur.execute( + """ + INSERT INTO threads (id, user_id, title) + VALUES (%s, %s, 'Migration Thread') + """, + (thread_id, user_id), + ) + cur.execute( + """ + INSERT INTO traces ( + id, + user_id, + thread_id, + kind, + compiler_version, + status, + limits + ) + VALUES ( + %s, + %s, + %s, + 'migration.seed', + 'v0', + 'completed', + '{}'::jsonb + ) + """, + (trace_id, user_id, thread_id), + ) + cur.execute( + """ + INSERT INTO tools ( + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata + ) + VALUES ( + %s, + %s, + 'proxy.echo', + 'Proxy Echo', + 'Seed tool for migration coverage', + '1.0.0', + 'tool_metadata_v0', + TRUE, + '[]'::jsonb, + '[]'::jsonb, + '[]'::jsonb, + '[]'::jsonb, + '[]'::jsonb, + '{}'::jsonb + ) + """, + (tool_id, user_id), + ) + cur.execute( + """ + INSERT INTO approvals ( + id, + user_id, + thread_id, + tool_id, + task_step_id, + status, + request, + tool, + routing, + routing_trace_id, + resolved_at, + resolved_by_user_id + ) + VALUES ( + %s, + %s, + %s, + %s, + NULL, + 'approved', + '{"action":"echo"}'::jsonb, + '{"id":"tool"}'::jsonb, + '{"decision":"approval_required"}'::jsonb, + %s, + now(), + %s + ) + """, + (approval_id, user_id, thread_id, tool_id, trace_id, user_id), + ) + cur.execute( + """ + INSERT INTO tasks ( + id, + user_id, + thread_id, + tool_id, + status, + request, + tool, + latest_approval_id, + latest_execution_id + ) + VALUES ( + %s, + %s, + %s, + %s, + 'approved', + '{"action":"echo"}'::jsonb, + '{"id":"tool"}'::jsonb, + %s, + NULL + ) + """, + (task_id, user_id, thread_id, tool_id, approval_id), + ) + cur.execute( + """ + INSERT INTO task_steps ( + id, + user_id, + task_id, + sequence_no, + kind, + status, + request, + outcome, + trace_id, + trace_kind + ) + VALUES ( + %s, + %s, + %s, + 1, + 'governed_request', + 'approved', + '{"action":"echo"}'::jsonb, + '{"routing_decision":"approval_required","approval_id":"00000000-0000-0000-0000-000000000005","approval_status":"approved","execution_id":null,"execution_status":null,"blocked_reason":null}'::jsonb, + %s, + 'migration.seed' + ) + """, + (task_step_id, user_id, task_id, trace_id), + ) + cur.execute( + """ + INSERT INTO tool_executions ( + id, + user_id, + approval_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + status, + handler_key, + request, + tool, + result + ) + VALUES ( + %s, + %s, + %s, + %s, + %s, + %s, + NULL, + NULL, + 'blocked', + NULL, + '{"action":"echo"}'::jsonb, + '{"id":"tool"}'::jsonb, + '{"blocked_reason":"seed"}'::jsonb + ) + """, + (execution_id, user_id, approval_id, thread_id, tool_id, trace_id), + ) + conn.commit() + + command.upgrade(config, "head") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT task_step_id + FROM tool_executions + WHERE id = %s + """, + (execution_id,), + ) + row = cur.fetchone() + assert row is not None + assert str(row[0]) == task_step_id + cur.execute( + """ + SELECT is_nullable + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'tool_executions' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchone() == ("NO",) + + +def test_migrations_upgrade_and_downgrade(database_urls): + config = make_alembic_config(database_urls["admin"]) + + command.upgrade(config, "head") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.users')") + assert cur.fetchone()[0] == "users" + cur.execute("SELECT to_regclass('public.threads')") + assert cur.fetchone()[0] == "threads" + cur.execute("SELECT to_regclass('public.sessions')") + assert cur.fetchone()[0] == "sessions" + cur.execute("SELECT to_regclass('public.events')") + assert cur.fetchone()[0] == "events" + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] == "memories" + cur.execute("SELECT to_regclass('public.memory_revisions')") + assert cur.fetchone()[0] == "memory_revisions" + cur.execute("SELECT to_regclass('public.memory_review_labels')") + assert cur.fetchone()[0] == "memory_review_labels" + cur.execute("SELECT to_regclass('public.entities')") + assert cur.fetchone()[0] == "entities" + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] == "entity_edges" + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] == "embedding_configs" + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] == "memory_embeddings" + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] == "consents" + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] == "policies" + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] == "tools" + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] == "approvals" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [("task_step_id",)] + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] == "tasks" + cur.execute("SELECT to_regclass('public.task_workspaces')") + assert cur.fetchone()[0] == "task_workspaces" + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] == "task_steps" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'task_steps' + AND column_name IN ( + 'parent_step_id', + 'source_approval_id', + 'source_execution_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [ + ("parent_step_id",), + ("source_approval_id",), + ("source_execution_id",), + ] + cur.execute("SELECT to_regclass('public.tool_executions')") + assert cur.fetchone()[0] == "tool_executions" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'tool_executions' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [("task_step_id",)] + cur.execute("SELECT to_regclass('public.execution_budgets')") + assert cur.fetchone()[0] == "execution_budgets" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'execution_budgets' + AND column_name IN ( + 'status', + 'deactivated_at', + 'superseded_by_budget_id', + 'supersedes_budget_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [ + ("deactivated_at",), + ("status",), + ("superseded_by_budget_id",), + ("supersedes_budget_id",), + ] + cur.execute( + """ + SELECT c.relname, c.relrowsecurity, c.relforcerowsecurity + FROM pg_class AS c + JOIN pg_namespace AS n + ON n.oid = c.relnamespace + WHERE n.nspname = 'public' + AND c.relname IN ( + 'users', + 'threads', + 'sessions', + 'events', + 'memories', + 'memory_revisions', + 'memory_review_labels', + 'entities', + 'entity_edges', + 'embedding_configs', + 'memory_embeddings', + 'consents', + 'policies', + 'tools', + 'approvals', + 'tasks', + 'task_workspaces', + 'task_steps', + 'execution_budgets', + 'tool_executions' + ) + ORDER BY c.relname + """ + ) + assert cur.fetchall() == [ + ("approvals", True, True), + ("consents", True, True), + ("embedding_configs", True, True), + ("entities", True, True), + ("entity_edges", True, True), + ("events", True, True), + ("execution_budgets", True, True), + ("memories", True, True), + ("memory_embeddings", True, True), + ("memory_review_labels", True, True), + ("memory_revisions", True, True), + ("policies", True, True), + ("sessions", True, True), + ("task_steps", True, True), + ("task_workspaces", True, True), + ("tasks", True, True), + ("threads", True, True), + ("tool_executions", True, True), + ("tools", True, True), + ("users", True, True), + ] + cur.execute( + """ + SELECT tgname + FROM pg_trigger + WHERE tgrelid = 'events'::regclass + AND NOT tgisinternal + """ + ) + assert cur.fetchall() == [("events_append_only",)] + cur.execute( + """ + SELECT tgname + FROM pg_trigger + WHERE tgrelid = 'memory_revisions'::regclass + AND NOT tgisinternal + """ + ) + assert cur.fetchall() == [("memory_revisions_append_only",)] + cur.execute( + """ + SELECT tgname + FROM pg_trigger + WHERE tgrelid = 'memory_review_labels'::regclass + AND NOT tgisinternal + """ + ) + assert cur.fetchall() == [("memory_review_labels_append_only",)] + cur.execute( + """ + SELECT + has_table_privilege('alicebot_app', 'users', 'UPDATE'), + has_table_privilege('alicebot_app', 'threads', 'UPDATE'), + has_table_privilege('alicebot_app', 'sessions', 'UPDATE'), + has_table_privilege('alicebot_app', 'memories', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_revisions', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_revisions', 'DELETE'), + has_table_privilege('alicebot_app', 'memory_review_labels', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_review_labels', 'DELETE'), + has_table_privilege('alicebot_app', 'entities', 'UPDATE'), + has_table_privilege('alicebot_app', 'entities', 'DELETE'), + has_table_privilege('alicebot_app', 'entity_edges', 'UPDATE'), + has_table_privilege('alicebot_app', 'entity_edges', 'DELETE'), + has_table_privilege('alicebot_app', 'embedding_configs', 'UPDATE'), + has_table_privilege('alicebot_app', 'embedding_configs', 'DELETE'), + has_table_privilege('alicebot_app', 'memory_embeddings', 'UPDATE'), + has_table_privilege('alicebot_app', 'memory_embeddings', 'DELETE'), + has_table_privilege('alicebot_app', 'consents', 'UPDATE'), + has_table_privilege('alicebot_app', 'consents', 'DELETE'), + has_table_privilege('alicebot_app', 'policies', 'UPDATE'), + has_table_privilege('alicebot_app', 'policies', 'DELETE'), + has_table_privilege('alicebot_app', 'tools', 'UPDATE'), + has_table_privilege('alicebot_app', 'tools', 'DELETE'), + has_table_privilege('alicebot_app', 'approvals', 'UPDATE'), + has_table_privilege('alicebot_app', 'approvals', 'DELETE'), + has_table_privilege('alicebot_app', 'tasks', 'UPDATE'), + has_table_privilege('alicebot_app', 'tasks', 'DELETE'), + has_table_privilege('alicebot_app', 'task_workspaces', 'UPDATE'), + has_table_privilege('alicebot_app', 'task_workspaces', 'DELETE'), + has_table_privilege('alicebot_app', 'task_steps', 'UPDATE'), + has_table_privilege('alicebot_app', 'task_steps', 'DELETE'), + has_table_privilege('alicebot_app', 'execution_budgets', 'UPDATE'), + has_table_privilege('alicebot_app', 'execution_budgets', 'DELETE'), + has_table_privilege('alicebot_app', 'tool_executions', 'UPDATE'), + has_table_privilege('alicebot_app', 'tool_executions', 'DELETE') + """ + ) + assert cur.fetchone() == ( + False, + False, + False, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + False, + True, + False, + False, + False, + False, + False, + True, + False, + True, + False, + False, + False, + True, + False, + True, + False, + False, + False, + ) + + command.downgrade(config, "20260313_0021") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.task_workspaces')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [("task_step_id",)] + + command.downgrade(config, "20260313_0018") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [] + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] == "task_steps" + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'tool_executions' + AND column_name = 'task_step_id' + """ + ) + assert cur.fetchall() == [] + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'task_steps' + AND column_name IN ( + 'parent_step_id', + 'source_approval_id', + 'source_execution_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [] + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] == "tasks" + + command.downgrade(config, "20260313_0017") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] == "tasks" + + command.downgrade(config, "20260313_0014") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.execution_budgets')") + assert cur.fetchone()[0] == "execution_budgets" + cur.execute("SELECT to_regclass('public.tasks')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.task_steps')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'execution_budgets' + AND column_name IN ( + 'status', + 'deactivated_at', + 'superseded_by_budget_id', + 'supersedes_budget_id' + ) + ORDER BY column_name + """ + ) + assert cur.fetchall() == [] + cur.execute( + "SELECT has_table_privilege('alicebot_app', 'execution_budgets', 'UPDATE')" + ) + assert cur.fetchone()[0] is False + + command.downgrade(config, "20260313_0013") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.execution_budgets')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tool_executions')") + assert cur.fetchone()[0] == "tool_executions" + + command.downgrade(config, "20260312_0012") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.tool_executions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] == "approvals" + + command.downgrade(config, "20260312_0011") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] == "approvals" + cur.execute( + """ + SELECT + has_table_privilege('alicebot_app', 'approvals', 'UPDATE'), + EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'resolved_at' + ), + EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'approvals' + AND column_name = 'resolved_by_user_id' + ) + """ + ) + assert cur.fetchone() == ( + False, + False, + False, + ) + + command.downgrade(config, "20260312_0010") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] == "tools" + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] == "consents" + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] == "policies" + + command.downgrade(config, "20260312_0009") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] == "consents" + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] == "policies" + + command.downgrade(config, "20260312_0008") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] == "embedding_configs" + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] == "memory_embeddings" + + command.downgrade(config, "20260312_0007") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.approvals')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] == "memories" + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] == "entity_edges" + + command.downgrade(config, "20260311_0003") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_revisions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_review_labels')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entities')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT + has_table_privilege('alicebot_app', 'users', 'UPDATE'), + has_table_privilege('alicebot_app', 'threads', 'UPDATE'), + has_table_privilege('alicebot_app', 'sessions', 'UPDATE') + """ + ) + # Revision 20260310_0001 already leaves the runtime role without UPDATE + # access, so downgrading from head must preserve that same privilege floor. + assert cur.fetchone() == (False, False, False) + + command.downgrade(config, "20260310_0001") + + command.downgrade(config, "base") + + with psycopg.connect(database_urls["admin"]) as conn: + with conn.cursor() as cur: + cur.execute("SELECT to_regclass('public.users')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.threads')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.sessions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.events')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memories')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_revisions')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_review_labels')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entities')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.entity_edges')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.embedding_configs')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.memory_embeddings')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.consents')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.policies')") + assert cur.fetchone()[0] is None + cur.execute("SELECT to_regclass('public.tools')") + assert cur.fetchone()[0] is None + cur.execute( + """ + SELECT extname + FROM pg_extension + WHERE extname IN ('pgcrypto', 'vector') + ORDER BY extname + """ + ) + assert [row[0] for row in cur.fetchall()] == ["pgcrypto", "vector"] diff --git a/tests/integration/test_policy_api.py b/tests/integration/test_policy_api.py new file mode 100644 index 0000000..0ae0b37 --- /dev/null +++ b/tests/integration/test_policy_api.py @@ -0,0 +1,424 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Policy thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_consent_endpoints_upsert_and_list_deterministically(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + first_status, first_payload = invoke_request( + "POST", + "/v0/consents", + payload={ + "user_id": str(seeded["user_id"]), + "consent_key": "email_marketing", + "status": "granted", + "metadata": {"source": "settings"}, + }, + ) + second_status, second_payload = invoke_request( + "POST", + "/v0/consents", + payload={ + "user_id": str(seeded["user_id"]), + "consent_key": "analytics_tracking", + "status": "revoked", + "metadata": {"source": "banner"}, + }, + ) + third_status, third_payload = invoke_request( + "POST", + "/v0/consents", + payload={ + "user_id": str(seeded["user_id"]), + "consent_key": "email_marketing", + "status": "revoked", + "metadata": {"source": "preferences"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/consents", + query_params={"user_id": str(seeded["user_id"])}, + ) + + assert first_status == 201 + assert second_status == 201 + assert third_status == 200 + assert first_payload["write_mode"] == "created" + assert second_payload["write_mode"] == "created" + assert third_payload["write_mode"] == "updated" + assert third_payload["consent"]["id"] == first_payload["consent"]["id"] + assert list_status == 200 + assert [item["consent_key"] for item in list_payload["items"]] == [ + "analytics_tracking", + "email_marketing", + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["consent_key_asc", "created_at_asc", "id_asc"], + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + stored_consents = ContinuityStore(conn).list_consents() + + assert [consent["consent_key"] for consent in stored_consents] == [ + "analytics_tracking", + "email_marketing", + ] + assert stored_consents[1]["status"] == "revoked" + assert stored_consents[1]["metadata"] == {"source": "preferences"} + + +def test_policy_endpoints_create_list_and_get_in_priority_order(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + low_priority_status, low_priority_payload = invoke_request( + "POST", + "/v0/policies", + payload={ + "user_id": str(seeded["user_id"]), + "name": "Require approval for export", + "action": "memory.export", + "scope": "profile", + "effect": "require_approval", + "priority": 20, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing", "email_marketing"], + }, + ) + high_priority_status, high_priority_payload = invoke_request( + "POST", + "/v0/policies", + payload={ + "user_id": str(seeded["user_id"]), + "name": "Allow profile read", + "action": "memory.read", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {}, + "required_consents": [], + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/policies", + query_params={"user_id": str(seeded["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/policies/{low_priority_payload['policy']['id']}", + query_params={"user_id": str(seeded['user_id'])}, + ) + + assert low_priority_status == 201 + assert high_priority_status == 201 + assert low_priority_payload["policy"]["required_consents"] == ["email_marketing"] + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + high_priority_payload["policy"]["id"], + low_priority_payload["policy"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"policy": low_priority_payload["policy"]} + + +def test_policy_evaluation_allow_records_trace_events(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ) + created_policy = store.create_policy( + name="Allow export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={"channel": "email"}, + required_consents=["email_marketing"], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {"channel": "email"}, + }, + ) + + assert status_code == 200 + assert payload["decision"] == "allow" + assert payload["matched_policy"]["id"] == str(created_policy["id"]) + assert payload["evaluation"] == { + "action": "memory.export", + "scope": "profile", + "evaluated_policy_count": 1, + "matched_policy_id": str(created_policy["id"]), + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "policy_effect_allow", + ] + assert payload["trace"]["trace_event_count"] == 3 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(payload["trace"]["trace_id"])) + + assert trace["kind"] == "policy.evaluate" + assert trace["compiler_version"] == "policy_evaluation_v0" + assert trace["limits"] == { + "order": ["priority_asc", "created_at_asc", "id_asc"], + "active_policy_count": 1, + "consent_count": 1, + } + assert [event["kind"] for event in trace_events] == [ + "policy.evaluate.request", + "policy.evaluate.order", + "policy.evaluate.decision", + ] + assert trace_events[2]["payload"]["decision"] == "allow" + assert trace_events[2]["payload"]["matched_policy_id"] == str(created_policy["id"]) + + +def test_policy_evaluation_denies_when_required_consent_is_missing(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + ContinuityStore(conn).create_policy( + name="Allow export with consent", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {}, + }, + ) + + assert status_code == 200 + assert payload["decision"] == "deny" + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "consent_missing", + ] + + +def test_policy_evaluation_returns_require_approval(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + created_policy = ContinuityStore(conn).create_policy( + name="Escalate export", + action="memory.export", + scope="profile", + effect="require_approval", + priority=10, + active=True, + conditions={}, + required_consents=[], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {}, + }, + ) + + assert status_code == 200 + assert payload["decision"] == "require_approval" + assert payload["matched_policy"]["id"] == str(created_policy["id"]) + assert payload["reasons"][-1]["code"] == "policy_effect_require_approval" + + +def test_policy_and_consent_endpoints_enforce_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent(consent_key="email_marketing", status="granted", metadata={}) + owner_policy = store.create_policy( + name="Allow export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + + consent_status, consent_payload = invoke_request( + "GET", + "/v0/consents", + query_params={"user_id": str(intruder["user_id"])}, + ) + policy_status, policy_payload = invoke_request( + "GET", + "/v0/policies", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/policies/{owner_policy['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + evaluation_status, evaluation_payload = invoke_request( + "POST", + "/v0/policies/evaluate", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "action": "memory.export", + "scope": "profile", + "attributes": {}, + }, + ) + + assert consent_status == 200 + assert consent_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["consent_key_asc", "created_at_asc", "id_asc"], + }, + } + assert policy_status == 200 + assert policy_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["priority_asc", "created_at_asc", "id_asc"], + }, + } + assert detail_status == 404 + assert detail_payload == {"detail": f"policy {owner_policy['id']} was not found"} + assert evaluation_status == 200 + assert evaluation_payload["decision"] == "deny" + assert evaluation_payload["matched_policy"] is None + assert evaluation_payload["reasons"] == [ + { + "code": "no_matching_policy", + "source": "system", + "message": "No active policy matched the requested action, scope, and attributes.", + "policy_id": None, + "consent_key": None, + } + ] diff --git a/tests/integration/test_proxy_execution_api.py b/tests/integration/test_proxy_execution_api.py new file mode 100644 index 0000000..755f5f3 --- /dev/null +++ b/tests/integration/test_proxy_execution_api.py @@ -0,0 +1,1478 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio +import psycopg + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Proxy execution thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def create_tool_and_policy( + database_url: str, + *, + user_id: UUID, + tool_key: str, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key=tool_key, + name="Proxy Tool", + description="Deterministic proxy tool.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name=f"Require approval for {tool_key}", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": tool_key}, + required_consents=[], + ) + return tool["id"] + + +def create_pending_approval( + *, + user_id: UUID, + thread_id: UUID, + tool_id: UUID, +) -> tuple[int, dict[str, Any]]: + return invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(user_id), + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "hello", "count": 2}, + }, + ) + + +def create_execution_budget( + database_url: str, + *, + user_id: UUID, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, +) -> UUID: + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + budget = store.create_execution_budget( + tool_key=tool_key, + domain_hint=domain_hint, + max_completed_executions=max_completed_executions, + rolling_window_seconds=rolling_window_seconds, + supersedes_budget_id=None, + ) + return budget["id"] + + +def set_execution_executed_at( + admin_database_url: str, + *, + execution_id: UUID, + executed_at_sql: str, +) -> None: + with psycopg.connect(admin_database_url) as conn: + conn.execute( + f"UPDATE tool_executions SET executed_at = {executed_at_sql} WHERE id = %s", + (execution_id,), + ) + conn.commit() + + +def test_execute_approved_proxy_endpoint_executes_only_approved_requests( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + assert approve_payload["approval"]["status"] == "approved" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 200 + assert list(execute_payload) == ["request", "approval", "tool", "result", "events", "trace"] + assert execute_payload["request"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + } + assert execute_payload["approval"]["id"] == create_payload["approval"]["id"] + assert execute_payload["approval"]["status"] == "approved" + assert execute_payload["tool"]["id"] == str(tool_id) + assert execute_payload["tool"]["tool_key"] == "proxy.echo" + assert execute_payload["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": { + "mode": "no_side_effect", + "tool_key": "proxy.echo", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + } + assert execute_payload["events"]["request_sequence_no"] == 1 + assert execute_payload["events"]["result_sequence_no"] == 2 + assert execute_payload["trace"]["trace_event_count"] == 9 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + thread_events = store.list_thread_events(owner["thread_id"]) + tasks = store.list_tasks() + task_steps = store.list_task_steps_for_task(tasks[0]["id"]) + tool_executions = store.list_tool_executions() + execute_trace = store.get_trace(UUID(execute_payload["trace"]["trace_id"])) + execute_trace_events = store.list_trace_events(UUID(execute_payload["trace"]["trace_id"])) + + assert [event["kind"] for event in thread_events] == [ + "tool.proxy.execution.request", + "tool.proxy.execution.result", + ] + assert len(tool_executions) == 1 + assert len(tasks) == 1 + assert len(task_steps) == 1 + assert tasks[0]["status"] == "executed" + assert tasks[0]["latest_execution_id"] == tool_executions[0]["id"] + assert task_steps[0]["status"] == "executed" + assert tool_executions[0]["approval_id"] == UUID(create_payload["approval"]["id"]) + assert tool_executions[0]["task_step_id"] == task_steps[0]["id"] + assert tool_executions[0]["thread_id"] == owner["thread_id"] + assert tool_executions[0]["tool_id"] == tool_id + assert tool_executions[0]["trace_id"] == UUID(execute_payload["trace"]["trace_id"]) + assert tool_executions[0]["handler_key"] == "proxy.echo" + assert tool_executions[0]["status"] == "completed" + assert tool_executions[0]["request"] == thread_events[0]["payload"]["request"] + assert tool_executions[0]["tool"]["tool_key"] == "proxy.echo" + assert tool_executions[0]["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": execute_payload["result"]["output"], + "reason": None, + } + assert thread_events[0]["payload"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + "tool_id": str(tool_id), + "tool_key": "proxy.echo", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + } + assert execute_trace["kind"] == "tool.proxy.execute" + assert execute_trace["compiler_version"] == "proxy_execution_v0" + assert execute_trace["limits"] == { + "approval_status": "approved", + "enabled_handler_keys": ["proxy.echo"], + "budget_match_order": ["specificity_desc", "created_at_asc", "id_asc"], + } + assert [event["kind"] for event in execute_trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert execute_trace_events[0]["payload"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + } + assert execute_trace_events[1]["payload"]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert execute_trace_events[2]["payload"]["decision"] == "allow" + assert execute_trace_events[3]["payload"]["dispatch_status"] == "executed" + assert execute_trace_events[3]["payload"]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert execute_trace_events[4]["payload"]["request_event_id"] == execute_payload["events"]["request_event_id"] + assert execute_trace_events[4]["payload"]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert execute_trace_events[7]["payload"] == { + "task_id": create_payload["task"]["id"], + "task_step_id": str(task_steps[0]["id"]), + "source": "proxy_execution", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": "approved", + "current_status": "executed", + "trace": { + "trace_id": execute_payload["trace"]["trace_id"], + "trace_kind": "tool.proxy.execute", + }, + } + + +def test_execute_approved_proxy_endpoint_rejects_pending_approval( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 409 + assert execute_payload == { + "detail": f"approval {create_payload['approval']['id']} is pending and cannot be executed" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace_rows = store.conn.execute( + "SELECT id, kind, limits FROM traces WHERE kind = %s ORDER BY created_at ASC, id ASC", + ("tool.proxy.execute",), + ).fetchall() + trace_events = store.list_trace_events(trace_rows[-1]["id"]) + thread_events = store.list_thread_events(owner["thread_id"]) + + assert thread_events == [] + assert trace_rows[-1]["limits"] == { + "approval_status": "pending", + "enabled_handler_keys": ["proxy.echo"], + "budget_match_order": ["specificity_desc", "created_at_asc", "id_asc"], + } + assert trace_events[2]["payload"]["dispatch_status"] == "blocked" + assert trace_events[3]["payload"]["execution_status"] == "blocked" + + +def test_execute_approved_proxy_endpoint_rejects_rejected_approval( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + reject_status, reject_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/reject", + payload={"user_id": str(owner["user_id"])}, + ) + assert reject_status == 200 + assert reject_payload["approval"]["status"] == "rejected" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 409 + assert execute_payload == { + "detail": f"approval {create_payload['approval']['id']} is rejected and cannot be executed" + } + + +def test_execute_approved_proxy_endpoint_rejects_missing_handler( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.missing", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + assert approve_payload["approval"]["status"] == "approved" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert execute_status == 409 + assert execute_payload == { + "detail": "tool 'proxy.missing' has no registered proxy handler" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + trace_rows = store.conn.execute( + "SELECT id FROM traces WHERE kind = %s ORDER BY created_at ASC, id ASC", + ("tool.proxy.execute",), + ).fetchall() + trace_events = store.list_trace_events(trace_rows[-1]["id"]) + tool_executions = store.list_tool_executions() + thread_events = store.list_thread_events(owner["thread_id"]) + + assert thread_events == [] + assert len(tool_executions) == 1 + assert tool_executions[0]["approval_id"] == UUID(create_payload["approval"]["id"]) + assert tool_executions[0]["task_step_id"] == UUID(create_payload["approval"]["task_step_id"]) + assert tool_executions[0]["handler_key"] is None + assert tool_executions[0]["status"] == "blocked" + assert tool_executions[0]["request_event_id"] is None + assert tool_executions[0]["result_event_id"] is None + assert tool_executions[0]["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + } + assert trace_events[2]["payload"]["decision"] == "allow" + assert trace_events[3]["payload"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_payload["approval"]["task_step_id"], + "tool_id": str(tool_id), + "tool_key": "proxy.missing", + "handler_key": None, + "dispatch_status": "blocked", + "reason": "tool 'proxy.missing' has no registered proxy handler", + "result_status": "blocked", + "output": None, + } + + list_status, list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{tool_executions[0]['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert list_status == 200 + assert list_payload["items"][0]["task_step_id"] == create_payload["approval"]["task_step_id"] + assert list_payload["items"][0]["status"] == "blocked" + assert list_payload["items"][0]["request_event_id"] is None + assert list_payload["items"][0]["result_event_id"] is None + assert list_payload["items"][0]["result"]["reason"] == "tool 'proxy.missing' has no registered proxy handler" + assert detail_status == 200 + assert detail_payload == {"execution": list_payload["items"][0]} + + +def test_execute_approved_proxy_endpoint_enforces_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + other_user = seed_user(migrated_database_urls["app"], email="other@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(other_user["user_id"])}, + ) + + assert execute_status == 404 + assert execute_payload == { + "detail": f"approval {create_payload['approval']['id']} was not found" + } + + +def test_execute_approved_proxy_endpoint_updates_the_explicitly_linked_later_step( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-step-linkage@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + create_status, create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert create_status == 200 + + approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{create_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{create_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert detail_status == 200 + assert step_list_status == 200 + initial_execution_id = detail_payload["task"]["latest_execution_id"] + assert initial_execution_id is not None + + create_step_status, create_step_payload = invoke_request( + "POST", + f"/v0/tasks/{create_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-2"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": step_list_payload["items"][0]["id"], + "source_approval_id": create_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + assert create_step_status == 201 + + transition_status, transition_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_step_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": create_payload["approval"]["id"], + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert transition_status == 200 + assert transition_payload["task_step"]["status"] == "approved" + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + relinked = store.update_approval_task_step_optional( + approval_id=UUID(create_payload["approval"]["id"]), + task_step_id=UUID(create_step_payload["task_step"]["id"]), + ) + assert relinked is not None + + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert second_execute_status == 200 + assert second_execute_payload["request"] == { + "approval_id": create_payload["approval"]["id"], + "task_step_id": create_step_payload["task_step"]["id"], + } + assert second_execute_payload["approval"]["task_step_id"] == create_step_payload["task_step"]["id"] + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + task = store.get_task_optional(UUID(create_payload["task"]["id"])) + task_steps = store.list_task_steps_for_task(UUID(create_payload["task"]["id"])) + tool_executions = store.list_tool_executions() + proxy_traces = store.conn.execute( + """ + SELECT id + FROM traces + WHERE thread_id = %s + AND kind = 'tool.proxy.execute' + ORDER BY created_at ASC, id ASC + """, + (owner["thread_id"],), + ).fetchall() + + assert task is not None + assert task["status"] == "executed" + assert task["latest_approval_id"] == UUID(create_payload["approval"]["id"]) + assert len(task_steps) == 2 + assert task_steps[0]["status"] == "executed" + assert task_steps[0]["trace_id"] == UUID(execute_payload["trace"]["trace_id"]) + assert task_steps[0]["outcome"]["execution_id"] == initial_execution_id + assert task_steps[1]["status"] == "executed" + assert task_steps[1]["id"] == UUID(create_step_payload["task_step"]["id"]) + assert task_steps[1]["trace_id"] == UUID(second_execute_payload["trace"]["trace_id"]) + assert task_steps[1]["outcome"]["approval_id"] == create_payload["approval"]["id"] + assert task_steps[1]["outcome"]["execution_status"] == "completed" + assert len(tool_executions) == 2 + assert task["latest_execution_id"] == tool_executions[1]["id"] + assert tool_executions[1]["task_step_id"] == UUID(create_step_payload["task_step"]["id"]) + assert task_steps[1]["outcome"]["execution_id"] == str(tool_executions[1]["id"]) + assert len(proxy_traces) == 2 + + +def test_execute_approved_proxy_endpoint_blocks_when_execution_budget_is_exceeded( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, first_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert second_execute_status == 200 + assert second_execute_payload["events"] is None + assert second_execute_payload["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {budget_id} blocks execution: projected completed executions " + "2 would exceed limit 1" + ), + "budget_decision": { + "matched_budget_id": str(budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + stored_executions = store.list_tool_executions() + blocked_trace = store.get_trace(UUID(second_execute_payload["trace"]["trace_id"])) + blocked_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + thread_events = store.list_thread_events(owner["thread_id"]) + + list_status, list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{stored_executions[1]['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + + assert len(stored_executions) == 2 + assert [row["status"] for row in stored_executions] == ["completed", "blocked"] + assert stored_executions[1]["task_step_id"] == UUID(second_execute_payload["request"]["task_step_id"]) + assert stored_executions[1]["result"] == second_execute_payload["result"] + assert stored_executions[1]["request_event_id"] is None + assert stored_executions[1]["result_event_id"] is None + assert [event["kind"] for event in thread_events] == [ + "tool.proxy.execution.request", + "tool.proxy.execution.result", + ] + assert blocked_trace["limits"] == { + "approval_status": "approved", + "enabled_handler_keys": ["proxy.echo"], + "budget_match_order": ["specificity_desc", "created_at_asc", "id_asc"], + } + assert [event["kind"] for event in blocked_trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert blocked_trace_events[0]["payload"] == second_execute_payload["request"] + assert blocked_trace_events[1]["payload"]["task_step_id"] == second_execute_payload["request"]["task_step_id"] + assert blocked_trace_events[2]["payload"] == second_execute_payload["result"]["budget_decision"] + assert blocked_trace_events[3]["payload"]["dispatch_status"] == "blocked" + assert blocked_trace_events[3]["payload"]["task_step_id"] == second_execute_payload["request"]["task_step_id"] + assert list_status == 200 + assert list_payload["items"][1]["task_step_id"] == second_execute_payload["request"]["task_step_id"] + assert [item["status"] for item in list_payload["items"]] == ["completed", "blocked"] + assert list_payload["items"][1]["result"] == second_execute_payload["result"] + assert detail_status == 200 + assert detail_payload == {"execution": list_payload["items"][1]} + + +def test_execute_approved_proxy_endpoint_allows_when_recent_history_is_within_rolling_window_limit( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert second_execute_status == 200 + assert second_execute_payload["result"]["status"] == "completed" + assert second_execute_payload["events"] is not None + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"]["matched_budget_id"] is not None + assert execute_trace_events[2]["payload"]["rolling_window_seconds"] == 3600 + assert execute_trace_events[2]["payload"]["count_scope"] == "rolling_window" + assert execute_trace_events[2]["payload"]["window_started_at"] is not None + assert execute_trace_events[2]["payload"]["completed_execution_count"] == 1 + assert execute_trace_events[2]["payload"]["projected_completed_execution_count"] == 2 + assert execute_trace_events[2]["payload"]["decision"] == "allow" + assert execute_trace_events[2]["payload"]["reason"] == "within_budget" + assert execute_trace_events[2]["payload"]["history_order"] == ["executed_at_asc", "id_asc"] + + +def test_execute_approved_proxy_endpoint_blocks_when_recent_window_history_exceeds_limit( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=3600, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert second_execute_status == 200 + assert second_execute_payload["events"] is None + assert list(second_execute_payload["result"]) == [ + "handler_key", + "status", + "output", + "reason", + "budget_decision", + ] + assert second_execute_payload["result"]["handler_key"] is None + assert second_execute_payload["result"]["status"] == "blocked" + assert second_execute_payload["result"]["output"] is None + assert second_execute_payload["result"]["reason"] == ( + f"execution budget {budget_id} blocks execution: projected completed executions " + "2 within rolling window 3600 seconds would exceed limit 1" + ) + assert second_execute_payload["result"]["budget_decision"]["matched_budget_id"] == str(budget_id) + assert second_execute_payload["result"]["budget_decision"]["rolling_window_seconds"] == 3600 + assert second_execute_payload["result"]["budget_decision"]["count_scope"] == "rolling_window" + assert second_execute_payload["result"]["budget_decision"]["window_started_at"] is not None + assert second_execute_payload["result"]["budget_decision"]["completed_execution_count"] == 1 + assert second_execute_payload["result"]["budget_decision"]["projected_completed_execution_count"] == 2 + assert second_execute_payload["result"]["budget_decision"]["history_order"] == [ + "executed_at_asc", + "id_asc", + ] + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + stored_executions = store.list_tool_executions() + blocked_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert [row["status"] for row in stored_executions] == ["completed", "blocked"] + assert stored_executions[1]["task_step_id"] == UUID(second_execute_payload["request"]["task_step_id"]) + assert stored_executions[1]["result"] == second_execute_payload["result"] + assert blocked_trace_events[2]["payload"] == second_execute_payload["result"]["budget_decision"] + + +def test_execute_approved_proxy_endpoint_excludes_old_window_history_and_keeps_counts_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + other_user = seed_user(migrated_database_urls["app"], email="other@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + owner_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + other_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=other_user["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=60, + ) + + owner_first_status, owner_first_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=owner_tool_id, + ) + owner_second_status, owner_second_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=owner_tool_id, + ) + other_status, other_payload = create_pending_approval( + user_id=other_user["user_id"], + thread_id=other_user["thread_id"], + tool_id=other_tool_id, + ) + assert owner_first_status == 200 + assert owner_second_status == 200 + assert other_status == 200 + + for approval_payload, user_id in ( + (owner_first_payload, owner["user_id"]), + (owner_second_payload, owner["user_id"]), + (other_payload, other_user["user_id"]), + ): + approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{approval_payload['approval']['id']}/approve", + payload={"user_id": str(user_id)}, + ) + assert approve_status == 200 + + owner_first_execute_status, owner_first_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{owner_first_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + other_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{other_payload['approval']['id']}/execute", + payload={"user_id": str(other_user["user_id"])}, + ) + assert owner_first_execute_status == 200 + assert other_execute_status == 200 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + owner_first_execution_id = store.list_tool_executions()[0]["id"] + + set_execution_executed_at( + migrated_database_urls["admin"], + execution_id=owner_first_execution_id, + executed_at_sql="clock_timestamp() - interval '2 hours'", + ) + + owner_second_execute_status, owner_second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{owner_second_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert owner_second_execute_status == 200 + assert owner_second_execute_payload["result"]["status"] == "completed" + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(owner_second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"]["matched_budget_id"] == str(budget_id) + assert execute_trace_events[2]["payload"]["rolling_window_seconds"] == 60 + assert execute_trace_events[2]["payload"]["count_scope"] == "rolling_window" + assert execute_trace_events[2]["payload"]["window_started_at"] is not None + assert execute_trace_events[2]["payload"]["completed_execution_count"] == 0 + assert execute_trace_events[2]["payload"]["projected_completed_execution_count"] == 1 + assert execute_trace_events[2]["payload"]["reason"] == "within_budget" + + +def test_execute_approved_proxy_endpoint_ignores_deactivated_budget( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + deactivate_status, deactivate_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{budget_id}/deactivate", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + }, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert deactivate_status == 200 + assert deactivate_payload["execution_budget"]["status"] == "inactive" + assert second_execute_status == 200 + assert second_execute_payload["result"]["status"] == "completed" + assert second_execute_payload["trace"]["trace_event_count"] == 9 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"] == { + "matched_budget_id": None, + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": None, + "budget_domain_hint": None, + "max_completed_executions": None, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 0, + "projected_completed_execution_count": 1, + "decision": "allow", + "reason": "no_matching_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + + +def test_execute_approved_proxy_endpoint_uses_replacement_budget_after_supersession( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + budget_id = create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + first_create_status, first_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_create_status, second_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_create_status == 200 + assert second_create_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + supersede_status, supersede_payload = invoke_request( + "POST", + f"/v0/execution-budgets/{budget_id}/supersede", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "max_completed_executions": 2, + }, + ) + second_execute_status, second_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{second_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + + assert first_execute_status == 200 + assert supersede_status == 200 + assert supersede_payload["superseded_budget"]["status"] == "superseded" + assert supersede_payload["replacement_budget"]["status"] == "active" + assert second_execute_status == 200 + assert second_execute_payload["result"]["status"] == "completed" + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + execute_trace_events = store.list_trace_events(UUID(second_execute_payload["trace"]["trace_id"])) + + assert execute_trace_events[2]["payload"] == { + "matched_budget_id": supersede_payload["replacement_budget"]["id"], + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 2, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "allow", + "reason": "within_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + + +def test_execute_approved_proxy_execution_budget_is_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + other_user = seed_user(migrated_database_urls["app"], email="other@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + owner_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + other_tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=other_user["user_id"], + tool_key="proxy.echo", + ) + create_execution_budget( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + + owner_create_status, owner_create_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=owner_tool_id, + ) + other_create_status, other_create_payload = create_pending_approval( + user_id=other_user["user_id"], + thread_id=other_user["thread_id"], + tool_id=other_tool_id, + ) + assert owner_create_status == 200 + assert other_create_status == 200 + + owner_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{owner_create_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + other_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{other_create_payload['approval']['id']}/approve", + payload={"user_id": str(other_user["user_id"])}, + ) + assert owner_approve_status == 200 + assert other_approve_status == 200 + + owner_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{owner_create_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + other_execute_status, other_execute_payload = invoke_request( + "POST", + f"/v0/approvals/{other_create_payload['approval']['id']}/execute", + payload={"user_id": str(other_user["user_id"])}, + ) + + assert owner_execute_status == 200 + assert other_execute_status == 200 + assert other_execute_payload["result"]["status"] == "completed" + + +def test_tool_execution_review_endpoints_are_deterministic_and_user_scoped( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + tool_id = create_tool_and_policy( + migrated_database_urls["app"], + user_id=owner["user_id"], + tool_key="proxy.echo", + ) + + first_status, first_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + second_status, second_payload = create_pending_approval( + user_id=owner["user_id"], + thread_id=owner["thread_id"], + tool_id=tool_id, + ) + assert first_status == 200 + assert second_status == 200 + + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + assert second_approve_status == 200 + + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + second_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_execute_status == 200 + assert second_execute_status == 200 + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + stored_executions = store.list_tool_executions() + + list_status, list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{stored_executions[1]['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/tool-executions", + query_params={"user_id": str(intruder["user_id"])}, + ) + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + str(stored_executions[0]["id"]), + str(stored_executions[1]["id"]), + ] + assert [item["task_step_id"] for item in list_payload["items"]] == [ + str(stored_executions[0]["task_step_id"]), + str(stored_executions[1]["task_step_id"]), + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["executed_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == { + "execution": next( + item for item in list_payload["items"] if item["id"] == str(stored_executions[1]["id"]) + ) + } + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["executed_at_asc", "id_asc"]}, + } + + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/tool-executions/{stored_executions[0]['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"tool execution {stored_executions[0]['id']} was not found" + } diff --git a/tests/integration/test_responses_api.py b/tests/integration/test_responses_api.py new file mode 100644 index 0000000..1ed051f --- /dev/null +++ b/tests/integration/test_responses_api.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import json +from uuid import UUID, uuid4 + +import anyio +import psycopg +import pytest + +import apps.api.src.alicebot_api.main as main_module +import alicebot_api.response_generation as response_generation_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_generate_response(payload: dict[str, object]) -> tuple[int, dict[str, object]]: + messages: list[dict[str, object]] = [] + encoded_body = json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": "/v0/responses", + "raw_path": b"/v0/responses", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_response_thread( + database_url: str, + *, + email: str = "owner@example.com", + display_name: str = "Owner", +) -> dict[str, object]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, display_name) + thread = store.create_thread("Response thread") + session = store.create_session(thread["id"], status="active") + prior_event = store.append_event( + thread["id"], + session["id"], + "message.user", + {"text": "Remember that I like oat milk."}, + ) + memory = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(prior_event["id"])], + ) + + return { + "user_id": user_id, + "thread_id": thread["id"], + "session_id": session["id"], + "prior_event_id": prior_event["id"], + "memory_id": memory["id"], + } + + +def test_generate_response_persists_user_and_assistant_events_and_trace_metadata( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_response_thread(migrated_database_urls["app"]) + captured: dict[str, object] = {} + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + model_provider="openai_responses", + model_name="gpt-5-mini", + model_api_key="test-key", + ), + ) + + def fake_invoke_model(*, settings, request): + captured["settings"] = settings + captured["request_payload"] = request.as_payload() + return response_generation_module.ModelInvocationResponse( + provider="openai_responses", + model="gpt-5-mini", + response_id="resp_123", + finish_reason="completed", + output_text="You prefer oat milk.", + usage={"input_tokens": 20, "output_tokens": 6, "total_tokens": 26}, + ) + + monkeypatch.setattr(response_generation_module, "invoke_model", fake_invoke_model) + + status_code, payload = invoke_generate_response( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "message": "What do I usually take in coffee?", + } + ) + + assert status_code == 200 + assert payload["assistant"] == { + "event_id": payload["assistant"]["event_id"], + "sequence_no": 3, + "text": "You prefer oat milk.", + "model_provider": "openai_responses", + "model": "gpt-5-mini", + } + assert payload["trace"]["compile_trace_event_count"] > 0 + assert payload["trace"]["response_trace_event_count"] == 2 + assert captured["request_payload"]["tool_choice"] == "none" + assert captured["request_payload"]["tools"] == [] + assert captured["request_payload"]["store"] is False + assert captured["request_payload"]["sections"] == [ + "system", + "developer", + "context", + "conversation", + ] + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + events = store.list_thread_events(seeded["thread_id"]) + compile_trace = store.get_trace(UUID(payload["trace"]["compile_trace_id"])) + response_trace = store.get_trace(UUID(payload["trace"]["response_trace_id"])) + response_trace_events = store.list_trace_events(UUID(payload["trace"]["response_trace_id"])) + + assert [event["sequence_no"] for event in events] == [1, 2, 3] + assert [event["kind"] for event in events] == [ + "message.user", + "message.user", + "message.assistant", + ] + assert events[1]["payload"] == {"text": "What do I usually take in coffee?"} + assert events[2]["payload"] == { + "text": "You prefer oat milk.", + "model": { + "provider": "openai_responses", + "model": "gpt-5-mini", + "response_id": "resp_123", + "finish_reason": "completed", + "usage": {"input_tokens": 20, "output_tokens": 6, "total_tokens": 26}, + }, + "prompt": { + "assembly_version": "prompt_assembly_v0", + "prompt_sha256": events[2]["payload"]["prompt"]["prompt_sha256"], + "section_order": ["system", "developer", "context", "conversation"], + }, + } + assert compile_trace["kind"] == "context.compile" + assert response_trace["kind"] == "response.generate" + assert response_trace["compiler_version"] == "response_generation_v0" + assert [event["kind"] for event in response_trace_events] == [ + "response.prompt.assembled", + "response.model.completed", + ] + assert response_trace_events[0]["payload"]["compile_trace_id"] == payload["trace"]["compile_trace_id"] + assert response_trace_events[1]["payload"] == { + "provider": "openai_responses", + "model": "gpt-5-mini", + "tool_choice": "none", + "tools_enabled": False, + "response_id": "resp_123", + "finish_reason": "completed", + "output_text_char_count": len("You prefer oat milk."), + "usage": {"input_tokens": 20, "output_tokens": 6, "total_tokens": 26}, + "error_message": None, + } + + with psycopg.connect(migrated_database_urls["admin"]) as conn: + with pytest.raises(psycopg.Error, match="append-only"): + with conn.cursor() as cur: + cur.execute( + "UPDATE events SET kind = 'message.mutated' WHERE id = %s", + (UUID(payload["assistant"]["event_id"]),), + ) + + +def test_generate_response_returns_clean_failure_without_persisting_assistant_event( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_response_thread(migrated_database_urls["app"]) + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + model_provider="openai_responses", + model_name="gpt-5-mini", + model_api_key="test-key", + ), + ) + monkeypatch.setattr( + response_generation_module, + "invoke_model", + lambda **_kwargs: (_ for _ in ()).throw( + response_generation_module.ModelInvocationError("upstream timeout") + ), + ) + + status_code, payload = invoke_generate_response( + { + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "message": "What do I usually take in coffee?", + } + ) + + assert status_code == 502 + assert payload["detail"] == "upstream timeout" + assert payload["trace"]["compile_trace_event_count"] > 0 + assert payload["trace"]["response_trace_event_count"] == 2 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + events = store.list_thread_events(seeded["thread_id"]) + response_trace_events = store.list_trace_events(UUID(payload["trace"]["response_trace_id"])) + + assert [event["sequence_no"] for event in events] == [1, 2] + assert [event["kind"] for event in events] == ["message.user", "message.user"] + assert events[-1]["payload"] == {"text": "What do I usually take in coffee?"} + assert [event["kind"] for event in response_trace_events] == [ + "response.prompt.assembled", + "response.model.failed", + ] + assert response_trace_events[1]["payload"] == { + "provider": "openai_responses", + "model": "gpt-5-mini", + "tool_choice": "none", + "tools_enabled": False, + "response_id": None, + "finish_reason": "incomplete", + "output_text_char_count": 0, + "usage": {"input_tokens": None, "output_tokens": None, "total_tokens": None}, + "error_message": "upstream timeout", + } + + +def test_generate_response_respects_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_response_thread(migrated_database_urls["app"]) + intruder = seed_response_thread( + migrated_database_urls["app"], + email="intruder@example.com", + display_name="Intruder", + ) + captured = {"invoke_model_called": False} + + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + model_provider="openai_responses", + model_name="gpt-5-mini", + model_api_key="test-key", + ), + ) + + def fake_invoke_model(**_kwargs): + captured["invoke_model_called"] = True + raise AssertionError("invoke_model should not be called for cross-user access") + + monkeypatch.setattr(response_generation_module, "invoke_model", fake_invoke_model) + + status_code, payload = invoke_generate_response( + { + "user_id": str(intruder["user_id"]), + "thread_id": str(owner["thread_id"]), + "message": "Tell me their preferences.", + } + ) + + assert status_code == 404 + assert payload == {"detail": "get_thread did not return a row from the database"} + assert captured["invoke_model_called"] is False + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_events = ContinuityStore(conn).list_thread_events(owner["thread_id"]) + + assert [event["sequence_no"] for event in owner_events] == [1] diff --git a/tests/integration/test_task_workspaces_api.py b/tests/integration/test_task_workspaces_api.py new file mode 100644 index 0000000..31aa9d5 --- /dev/null +++ b/tests/integration/test_task_workspaces_api.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_task(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Workspace thread") + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + task = store.create_task( + thread_id=thread["id"], + tool_id=tool["id"], + status="approved", + request={ + "thread_id": str(thread["id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + tool={ + "id": str(tool["id"]), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + latest_approval_id=None, + latest_execution_id=None, + ) + + return { + "user_id": user_id, + "task_id": task["id"], + } + + +def test_task_workspace_endpoints_provision_read_isolate_and_reject_duplicates( + migrated_database_urls, + monkeypatch, + tmp_path, +) -> None: + owner = seed_task(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_task(migrated_database_urls["app"], email="intruder@example.com") + workspace_root = tmp_path / "task-workspaces" + monkeypatch.setattr( + main_module, + "get_settings", + lambda: Settings( + database_url=migrated_database_urls["app"], + task_workspace_root=str(workspace_root), + ), + ) + + create_status, create_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/task-workspaces", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/task-workspaces/{create_payload['workspace']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + duplicate_status, duplicate_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(owner["user_id"])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/task-workspaces", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/task-workspaces/{create_payload['workspace']['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_create_status, isolated_create_payload = invoke_request( + "POST", + f"/v0/tasks/{owner['task_id']}/workspace", + payload={"user_id": str(intruder["user_id"])}, + ) + + expected_path = (workspace_root / str(owner["user_id"]) / str(owner["task_id"])).resolve() + + assert create_status == 201 + assert create_payload["workspace"] == { + "id": create_payload["workspace"]["id"], + "task_id": str(owner["task_id"]), + "status": "active", + "local_path": str(expected_path), + "created_at": create_payload["workspace"]["created_at"], + "updated_at": create_payload["workspace"]["updated_at"], + } + assert Path(create_payload["workspace"]["local_path"]).is_dir() + + assert list_status == 200 + assert list_payload == { + "items": [create_payload["workspace"]], + "summary": {"total_count": 1, "order": ["created_at_asc", "id_asc"]}, + } + + assert detail_status == 200 + assert detail_payload == {"workspace": create_payload["workspace"]} + + assert duplicate_status == 409 + assert duplicate_payload == { + "detail": f"task {owner['task_id']} already has active workspace {create_payload['workspace']['id']}" + } + + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"task workspace {create_payload['workspace']['id']} was not found" + } + + assert isolated_create_status == 404 + assert isolated_create_payload == {"detail": f"task {owner['task_id']} was not found"} diff --git a/tests/integration/test_tasks_api.py b/tests/integration/test_tasks_api.py new file mode 100644 index 0000000..2987567 --- /dev/null +++ b/tests/integration/test_tasks_api.py @@ -0,0 +1,946 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Task thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_task_endpoints_list_detail_lifecycle_and_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + approval_tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=20, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + pending_status, pending_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(approval_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "hello"}, + }, + ) + assert pending_status == 200 + assert pending_payload["task"]["status"] == "pending_approval" + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{pending_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + assert approve_payload["approval"]["status"] == "approved" + + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{pending_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + assert execute_payload["result"]["status"] == "completed" + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {}, + }, + ) + assert ready_status == 200 + assert ready_payload["task"]["status"] == "approved" + + denied_status, denied_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + assert denied_status == 200 + assert denied_payload["task"]["status"] == "denied" + + list_status, list_payload = invoke_request( + "GET", + "/v0/tasks", + query_params={"user_id": str(owner["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + step_detail_status, step_detail_payload = invoke_request( + "GET", + f"/v0/task-steps/{step_list_payload['items'][0]['id']}", + query_params={"user_id": str(owner['user_id'])}, + ) + isolated_list_status, isolated_list_payload = invoke_request( + "GET", + "/v0/tasks", + query_params={"user_id": str(intruder["user_id"])}, + ) + isolated_detail_status, isolated_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + isolated_step_list_status, isolated_step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{pending_payload['task']['id']}/steps", + query_params={"user_id": str(intruder['user_id'])}, + ) + isolated_step_detail_status, isolated_step_detail_payload = invoke_request( + "GET", + f"/v0/task-steps/{step_list_payload['items'][0]['id']}", + query_params={"user_id": str(intruder['user_id'])}, + ) + + assert list_status == 200 + assert [item["id"] for item in list_payload["items"]] == [ + pending_payload["task"]["id"], + ready_payload["task"]["id"], + denied_payload["task"]["id"], + ] + assert [item["status"] for item in list_payload["items"]] == [ + "executed", + "approved", + "denied", + ] + assert list_payload["summary"] == { + "total_count": 3, + "order": ["created_at_asc", "id_asc"], + } + + assert detail_status == 200 + assert detail_payload["task"]["id"] == pending_payload["task"]["id"] + assert detail_payload["task"]["status"] == "executed" + assert detail_payload["task"]["latest_approval_id"] == pending_payload["approval"]["id"] + assert detail_payload["task"]["latest_execution_id"] is not None + assert step_list_status == 200 + assert [item["sequence_no"] for item in step_list_payload["items"]] == [1] + assert step_list_payload["summary"] == { + "task_id": pending_payload["task"]["id"], + "total_count": 1, + "latest_sequence_no": 1, + "latest_status": "executed", + "next_sequence_no": 2, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert step_list_payload["items"][0] == { + "id": step_list_payload["items"][0]["id"], + "task_id": pending_payload["task"]["id"], + "sequence_no": 1, + "lineage": { + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + }, + "kind": "governed_request", + "status": "executed", + "request": pending_payload["request"], + "outcome": { + "routing_decision": "approval_required", + "approval_id": pending_payload["approval"]["id"], + "approval_status": "approved", + "execution_id": detail_payload["task"]["latest_execution_id"], + "execution_status": "completed", + "blocked_reason": None, + }, + "trace": { + "trace_id": execute_payload["trace"]["trace_id"], + "trace_kind": "tool.proxy.execute", + }, + "created_at": step_list_payload["items"][0]["created_at"], + "updated_at": step_list_payload["items"][0]["updated_at"], + } + assert step_detail_status == 200 + assert step_detail_payload == {"task_step": step_list_payload["items"][0]} + + assert isolated_list_status == 200 + assert isolated_list_payload == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + assert isolated_detail_status == 404 + assert isolated_detail_payload == { + "detail": f"task {pending_payload['task']['id']} was not found" + } + assert isolated_step_list_status == 404 + assert isolated_step_list_payload == { + "detail": f"task {pending_payload['task']['id']} was not found" + } + assert isolated_step_detail_status == 404 + assert isolated_step_detail_payload == { + "detail": f"task step {step_list_payload['items'][0]['id']} was not found" + } + + +def test_task_step_sequence_and_transition_endpoints_preserve_parent_consistency_trace_and_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-sequence@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder-sequence@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + + request_status, request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "seed-step"}, + }, + ) + assert request_status == 200 + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert approve_status == 200 + execute_status, execute_payload = invoke_request( + "POST", + f"/v0/approvals/{request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert execute_status == 200 + initial_detail_status, initial_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert initial_detail_status == 200 + initial_step_list_status, initial_step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert initial_step_list_status == 200 + initial_execution_id = initial_detail_payload["task"]["latest_execution_id"] + assert initial_execution_id is not None + + create_status, create_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-2"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": initial_step_list_payload["items"][0]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + + assert create_status == 201 + assert create_payload["task"]["status"] == "pending_approval" + assert create_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert create_payload["task_step"]["sequence_no"] == 2 + assert create_payload["task_step"]["status"] == "created" + assert create_payload["task_step"]["lineage"] == { + "parent_step_id": initial_step_list_payload["items"][0]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + } + assert create_payload["sequencing"] == { + "task_id": request_payload["task"]["id"], + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "created", + "next_sequence_no": 3, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + + duplicate_create_status, duplicate_create_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "step-3"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": create_payload["task_step"]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + assert duplicate_create_status == 409 + assert duplicate_create_payload["detail"] == ( + f"task {request_payload['task']['id']} latest step {create_payload['task_step']['id']} is created and cannot append a next step" + ) + + invalid_transition_status, invalid_transition_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "executed", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": str(uuid4()), + "execution_id": str(uuid4()), + "execution_status": "completed", + "blocked_reason": None, + }, + }, + ) + assert invalid_transition_status == 409 + assert invalid_transition_payload["detail"] == ( + f"task step {create_payload['task_step']['id']} is created and cannot transition to executed; allowed: approved, denied" + ) + + approve_step_status, approve_step_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": request_payload["approval"]["id"], + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert approve_step_status == 200 + assert approve_step_payload["task"]["status"] == "approved" + assert approve_step_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert approve_step_payload["task"]["latest_execution_id"] is None + assert approve_step_payload["task_step"]["status"] == "approved" + + execute_step_status, execute_step_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "executed", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": request_payload["approval"]["id"], + "execution_id": initial_execution_id, + "execution_status": "completed", + "blocked_reason": None, + }, + }, + ) + assert execute_step_status == 200 + assert execute_step_payload["task"]["status"] == "executed" + assert execute_step_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert execute_step_payload["task"]["latest_execution_id"] == initial_execution_id + assert execute_step_payload["task_step"]["status"] == "executed" + assert execute_step_payload["sequencing"] == { + "task_id": request_payload["task"]["id"], + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "executed", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + step_list_status, step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + step_detail_status, step_detail_payload = invoke_request( + "GET", + f"/v0/task-steps/{create_payload['task_step']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert detail_status == 200 + assert detail_payload["task"]["status"] == "executed" + assert detail_payload["task"]["latest_approval_id"] == request_payload["approval"]["id"] + assert detail_payload["task"]["latest_execution_id"] == initial_execution_id + assert step_list_status == 200 + assert [item["sequence_no"] for item in step_list_payload["items"]] == [1, 2] + assert step_list_payload["items"][1]["lineage"] == create_payload["task_step"]["lineage"] + assert step_list_payload["summary"] == { + "task_id": request_payload["task"]["id"], + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "executed", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert step_detail_status == 200 + assert step_detail_payload["task_step"] == step_list_payload["items"][1] + assert step_detail_payload["task_step"]["lineage"] == create_payload["task_step"]["lineage"] + assert step_detail_payload["task_step"]["outcome"] == { + "routing_decision": "approval_required", + "approval_id": request_payload["approval"]["id"], + "approval_status": "approved", + "execution_id": initial_execution_id, + "execution_status": "completed", + "blocked_reason": None, + } + + isolated_create_status, isolated_create_payload = invoke_request( + "POST", + f"/v0/tasks/{request_payload['task']['id']}/steps", + payload={ + "user_id": str(intruder["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": create_payload["task_step"]["id"], + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + }, + }, + ) + isolated_transition_status, isolated_transition_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(intruder["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": str(uuid4()), + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert isolated_create_status == 404 + assert isolated_create_payload == { + "detail": f"task {request_payload['task']['id']} was not found" + } + assert isolated_transition_status == 404 + assert isolated_transition_payload == { + "detail": f"task step {create_payload['task_step']['id']} was not found" + } + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + create_trace_events = store.list_trace_events(UUID(create_payload["trace"]["trace_id"])) + transition_trace_events = store.list_trace_events(UUID(execute_step_payload["trace"]["trace_id"])) + + assert [event["kind"] for event in create_trace_events] == [ + "task.step.continuation.request", + "task.step.continuation.lineage", + "task.step.continuation.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert create_trace_events[1]["payload"] == { + "task_id": request_payload["task"]["id"], + "parent_task_step_id": step_list_payload["items"][0]["id"], + "parent_sequence_no": 1, + "parent_status": "executed", + "source_approval_id": request_payload["approval"]["id"], + "source_execution_id": initial_execution_id, + } + assert [event["kind"] for event in transition_trace_events] == [ + "task.step.transition.request", + "task.step.transition.state", + "task.step.transition.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert transition_trace_events[1]["payload"] == { + "task_id": request_payload["task"]["id"], + "task_step_id": create_payload["task_step"]["id"], + "sequence_no": 2, + "previous_status": "approved", + "current_status": "executed", + "allowed_next_statuses": ["executed", "blocked"], + "trace": { + "trace_id": execute_step_payload["trace"]["trace_id"], + "trace_kind": "task.step.transition", + }, + } + assert transition_trace_events[2]["payload"] == { + "task_id": request_payload["task"]["id"], + "task_step_id": create_payload["task_step"]["id"], + "sequence_no": 2, + "final_status": "executed", + "parent_task_status": "executed", + "trace": { + "trace_id": execute_step_payload["trace"]["trace_id"], + "trace_kind": "task.step.transition", + }, + } + + +def test_task_step_mutations_reject_visible_links_from_other_task_lineages( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner-lineage@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + store = ContinuityStore(conn) + tool = store.create_tool( + tool_key="proxy.echo", + name="Proxy Echo", + description="Deterministic proxy handler.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["proxy"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + store.create_policy( + name="Require proxy approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "proxy.echo"}, + required_consents=[], + ) + + first_request_status, first_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "first"}, + }, + ) + assert first_request_status == 200 + first_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_approve_status == 200 + first_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{first_request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert first_execute_status == 200 + first_detail_status, first_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{first_request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert first_detail_status == 200 + first_step_list_status, first_step_list_payload = invoke_request( + "GET", + f"/v0/tasks/{first_request_payload['task']['id']}/steps", + query_params={"user_id": str(owner["user_id"])}, + ) + assert first_step_list_status == 200 + first_step_id = first_step_list_payload["items"][0]["id"] + first_execution_id = first_detail_payload["task"]["latest_execution_id"] + assert first_execution_id is not None + + second_request_status, second_request_payload = invoke_request( + "POST", + "/v0/approvals/requests", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "second"}, + }, + ) + assert second_request_status == 200 + second_approve_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_request_payload['approval']['id']}/approve", + payload={"user_id": str(owner["user_id"])}, + ) + assert second_approve_status == 200 + second_execute_status, _ = invoke_request( + "POST", + f"/v0/approvals/{second_request_payload['approval']['id']}/execute", + payload={"user_id": str(owner["user_id"])}, + ) + assert second_execute_status == 200 + second_detail_status, second_detail_payload = invoke_request( + "GET", + f"/v0/tasks/{second_request_payload['task']['id']}", + query_params={"user_id": str(owner["user_id"])}, + ) + assert second_detail_status == 200 + second_execution_id = second_detail_payload["task"]["latest_execution_id"] + assert second_execution_id is not None + + wrong_create_status, wrong_create_payload = invoke_request( + "POST", + f"/v0/tasks/{first_request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "lineage-mismatch"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": first_step_id, + "source_approval_id": second_request_payload["approval"]["id"], + "source_execution_id": None, + }, + }, + ) + assert wrong_create_status == 409 + assert wrong_create_payload == { + "detail": ( + f"approval {second_request_payload['approval']['id']} does not belong to task {first_request_payload['task']['id']}" + ) + } + + create_status, create_payload = invoke_request( + "POST", + f"/v0/tasks/{first_request_payload['task']['id']}/steps", + payload={ + "user_id": str(owner["user_id"]), + "kind": "governed_request", + "status": "created", + "request": { + "thread_id": str(owner["thread_id"]), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {"message": "valid"}, + }, + "outcome": { + "routing_decision": "approval_required", + "approval_status": None, + "approval_id": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "lineage": { + "parent_step_id": first_step_id, + "source_approval_id": first_request_payload["approval"]["id"], + "source_execution_id": first_execution_id, + }, + }, + ) + assert create_status == 201 + + approve_status, approve_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "approved", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": first_request_payload["approval"]["id"], + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + }, + ) + assert approve_status == 200 + + wrong_execute_status, wrong_execute_payload = invoke_request( + "POST", + f"/v0/task-steps/{create_payload['task_step']['id']}/transition", + payload={ + "user_id": str(owner["user_id"]), + "status": "executed", + "outcome": { + "routing_decision": "approval_required", + "approval_status": "approved", + "approval_id": first_request_payload["approval"]["id"], + "execution_id": second_execution_id, + "execution_status": "completed", + "blocked_reason": None, + }, + }, + ) + assert wrong_execute_status == 409 + assert wrong_execute_payload == { + "detail": ( + f"tool execution {second_execution_id} does not belong to task {first_request_payload['task']['id']}" + ) + } + + assert first_execution_id != second_execution_id + assert first_request_payload["approval"]["id"] != second_request_payload["approval"]["id"] + assert approve_payload["task"]["latest_approval_id"] == first_request_payload["approval"]["id"] diff --git a/tests/integration/test_tool_api.py b/tests/integration/test_tool_api.py new file mode 100644 index 0000000..df7afd3 --- /dev/null +++ b/tests/integration/test_tool_api.py @@ -0,0 +1,930 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.parse import urlencode +from uuid import UUID, uuid4 + +import anyio + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.db import user_connection +from alicebot_api.store import ContinuityStore + + +def invoke_request( + method: str, + path: str, + *, + query_params: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, +) -> tuple[int, dict[str, Any]]: + messages: list[dict[str, object]] = [] + encoded_body = b"" if payload is None else json.dumps(payload).encode() + request_received = False + + async def receive() -> dict[str, object]: + nonlocal request_received + if request_received: + return {"type": "http.disconnect"} + + request_received = True + return {"type": "http.request", "body": encoded_body, "more_body": False} + + async def send(message: dict[str, object]) -> None: + messages.append(message) + + query_string = urlencode(query_params or {}).encode() + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "scheme": "http", + "path": path, + "raw_path": path.encode(), + "query_string": query_string, + "headers": [(b"content-type", b"application/json")], + "client": ("testclient", 50000), + "server": ("testserver", 80), + "root_path": "", + } + + anyio.run(main_module.app, scope, receive, send) + + start_message = next(message for message in messages if message["type"] == "http.response.start") + body = b"".join( + message.get("body", b"") + for message in messages + if message["type"] == "http.response.body" + ) + return start_message["status"], json.loads(body) + + +def seed_user(database_url: str, *, email: str) -> dict[str, UUID]: + user_id = uuid4() + + with user_connection(database_url, user_id) as conn: + store = ContinuityStore(conn) + store.create_user(user_id, email, email.split("@", 1)[0].title()) + thread = store.create_thread("Tool thread") + + return { + "user_id": user_id, + "thread_id": thread["id"], + } + + +def test_tool_endpoints_create_list_and_get_in_deterministic_order(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + second_status, second_payload = invoke_request( + "POST", + "/v0/tools", + payload={ + "user_id": str(seeded["user_id"]), + "tool_key": "zeta.fetch", + "name": "Zeta Fetch", + "description": "Fetch zeta records.", + "version": "2.0.0", + "active": True, + "tags": ["fetch"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + ) + first_status, first_payload = invoke_request( + "POST", + "/v0/tools", + payload={ + "user_id": str(seeded["user_id"]), + "tool_key": "alpha.open", + "name": "Alpha Open", + "description": "Open alpha pages.", + "version": "1.0.0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + ) + list_status, list_payload = invoke_request( + "GET", + "/v0/tools", + query_params={"user_id": str(seeded["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tools/{second_payload['tool']['id']}", + query_params={"user_id": str(seeded['user_id'])}, + ) + + assert first_status == 201 + assert second_status == 201 + assert list_status == 200 + assert [item["tool_key"] for item in list_payload["items"]] == ["alpha.open", "zeta.fetch"] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert detail_status == 200 + assert detail_payload == {"tool": second_payload["tool"]} + assert first_payload["tool"]["metadata_version"] == "tool_metadata_v0" + + +def test_tool_allowlist_evaluation_returns_allowed_denied_and_approval_required_with_trace( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + allowed_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_by_metadata_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + denied_by_consent_tool = store.create_tool( + tool_key="contacts.export", + name="Contacts Export", + description="Export contacts.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["contacts"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + store.create_policy( + name="Allow contacts export with consent", + action="tool.run", + scope="workspace", + effect="allow", + priority=20, + active=True, + conditions={"tool_key": "contacts.export", "domain_hint": "docs"}, + required_consents=["contacts_consent"], + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=30, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + status_code, payload = invoke_request( + "POST", + "/v0/tools/allowlist/evaluate", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {"channel": "chat"}, + }, + ) + + assert status_code == 200 + assert [item["tool"]["id"] for item in payload["allowed"]] == [str(allowed_tool["id"])] + assert [item["tool"]["id"] for item in payload["approval_required"]] == [str(approval_tool["id"])] + assert [item["tool"]["id"] for item in payload["denied"]] == [ + str(denied_by_metadata_tool["id"]), + str(denied_by_consent_tool["id"]), + ] + assert [reason["code"] for reason in payload["denied"][0]["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert [reason["code"] for reason in payload["denied"][1]["reasons"]] == [ + "tool_metadata_matched", + "matched_policy", + "consent_missing", + ] + assert payload["summary"] == { + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "evaluated_tool_count": 4, + "allowed_count": 1, + "denied_count": 2, + "approval_required_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 7 + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(payload["trace"]["trace_id"])) + + assert trace["kind"] == "tool.allowlist.evaluate" + assert trace["compiler_version"] == "tool_allowlist_evaluation_v0" + assert trace["limits"] == { + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + "active_tool_count": 4, + "active_policy_count": 3, + "consent_count": 1, + } + assert [event["kind"] for event in trace_events] == [ + "tool.allowlist.request", + "tool.allowlist.order", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.summary", + ] + assert trace_events[2]["payload"]["decision"] == "allowed" + assert trace_events[-1]["payload"] == { + "allowed_count": 1, + "denied_count": 2, + "approval_required_count": 1, + } + + +def test_tool_route_returns_ready_denied_and_approval_required_with_trace( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + ready_policy = store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + approval_policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=20, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {"channel": "chat"}, + }, + ) + denied_status, denied_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + approval_status, approval_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(approval_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert ready_status == 200 + assert list(ready_payload) == ["request", "decision", "tool", "reasons", "summary", "trace"] + assert ready_payload["decision"] == "ready" + assert ready_payload["request"] == { + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {"channel": "chat"}, + } + assert ready_payload["tool"]["id"] == str(ready_tool["id"]) + assert [reason["code"] for reason in ready_payload["reasons"]] == [ + "tool_metadata_matched", + "matched_policy", + "policy_effect_allow", + ] + assert ready_payload["summary"] == { + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 2, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert ready_payload["trace"]["trace_event_count"] == 3 + + assert denied_status == 200 + assert denied_payload["decision"] == "denied" + assert [reason["code"] for reason in denied_payload["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert denied_payload["summary"]["decision"] == "denied" + + assert approval_status == 200 + assert approval_payload["decision"] == "approval_required" + assert approval_payload["summary"]["decision"] == "approval_required" + assert approval_payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(approval_tool["id"]), + "policy_id": str(approval_policy["id"]), + "consent_key": None, + } + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + ready_trace = store.get_trace(UUID(ready_payload["trace"]["trace_id"])) + ready_trace_events = store.list_trace_events(UUID(ready_payload["trace"]["trace_id"])) + + assert ready_trace["kind"] == "tool.route" + assert ready_trace["compiler_version"] == "tool_routing_v0" + assert ready_trace["limits"] == { + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + "evaluated_tool_count": 1, + "active_policy_count": 2, + "consent_count": 1, + } + assert [event["kind"] for event in ready_trace_events] == [ + "tool.route.request", + "tool.route.decision", + "tool.route.summary", + ] + assert ready_trace_events[1]["payload"] == { + "tool_id": str(ready_tool["id"]), + "tool_key": "browser.open", + "tool_version": "1.0.0", + "allowlist_decision": "allowed", + "routing_decision": "ready", + "matched_policy_id": str(ready_policy["id"]), + "reasons": ready_payload["reasons"], + } + assert ready_trace_events[2]["payload"] == { + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 2, + "consent_count": 1, + } + + +def test_tool_route_validates_invalid_thread_and_tool(migrated_database_urls, monkeypatch) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + invalid_thread_status, invalid_thread_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(uuid4()), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + invalid_tool_status, invalid_tool_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert invalid_thread_status == 400 + assert invalid_thread_payload == { + "detail": "thread_id must reference an existing thread owned by the user" + } + assert invalid_tool_status == 400 + assert invalid_tool_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + + +def test_tool_endpoints_and_allowlist_enforce_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + list_status, list_payload = invoke_request( + "GET", + "/v0/tools", + query_params={"user_id": str(intruder["user_id"])}, + ) + detail_status, detail_payload = invoke_request( + "GET", + f"/v0/tools/{owner_tool['id']}", + query_params={"user_id": str(intruder["user_id"])}, + ) + evaluation_status, evaluation_payload = invoke_request( + "POST", + "/v0/tools/allowlist/evaluate", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert list_status == 200 + assert list_payload == { + "items": [], + "summary": { + "total_count": 0, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + } + assert detail_status == 404 + assert detail_payload == {"detail": f"tool {owner_tool['id']} was not found"} + assert evaluation_status == 200 + assert evaluation_payload["allowed"] == [] + assert evaluation_payload["denied"] == [] + assert evaluation_payload["approval_required"] == [] + assert evaluation_payload["summary"]["evaluated_tool_count"] == 0 + + +def test_tool_routing_returns_ready_denied_and_approval_required_with_trace( + migrated_database_urls, + monkeypatch, +) -> None: + seeded = seed_user(migrated_database_urls["app"], email="owner@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="contacts.export", + name="Contacts Export", + description="Export contacts.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["contacts"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + store.create_policy( + name="Allow contacts export with consent", + action="tool.run", + scope="workspace", + effect="allow", + priority=20, + active=True, + conditions={"tool_key": "contacts.export", "domain_hint": "docs"}, + required_consents=["contacts_consent"], + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=30, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + ready_status, ready_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {"channel": "chat"}, + }, + ) + denied_status, denied_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(denied_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "attributes": {}, + }, + ) + approval_status, approval_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(seeded["user_id"]), + "thread_id": str(seeded["thread_id"]), + "tool_id": str(approval_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert ready_status == 200 + assert ready_payload["decision"] == "ready" + assert ready_payload["tool"]["id"] == str(ready_tool["id"]) + assert ready_payload["summary"] == { + "thread_id": str(seeded["thread_id"]), + "tool_id": str(ready_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 3, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert ready_payload["trace"]["trace_event_count"] == 3 + + assert denied_status == 200 + assert denied_payload["decision"] == "denied" + assert [reason["code"] for reason in denied_payload["reasons"]] == [ + "tool_metadata_matched", + "matched_policy", + "consent_missing", + ] + + assert approval_status == 200 + assert approval_payload["decision"] == "approval_required" + assert approval_payload["reasons"][-1]["code"] == "policy_effect_require_approval" + + with user_connection(migrated_database_urls["app"], seeded["user_id"]) as conn: + store = ContinuityStore(conn) + trace = store.get_trace(UUID(ready_payload["trace"]["trace_id"])) + trace_events = store.list_trace_events(UUID(ready_payload["trace"]["trace_id"])) + + assert trace["kind"] == "tool.route" + assert trace["compiler_version"] == "tool_routing_v0" + assert trace["limits"] == { + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + "evaluated_tool_count": 1, + "active_policy_count": 3, + "consent_count": 1, + } + assert [event["kind"] for event in trace_events] == [ + "tool.route.request", + "tool.route.decision", + "tool.route.summary", + ] + assert trace_events[1]["payload"]["allowlist_decision"] == "allowed" + assert trace_events[1]["payload"]["routing_decision"] == "ready" + assert trace_events[2]["payload"] == { + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 3, + "consent_count": 1, + } + + +def test_tool_routing_validates_invalid_references_and_per_user_isolation( + migrated_database_urls, + monkeypatch, +) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + invalid_thread_status, invalid_thread_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(uuid4()), + "tool_id": str(owner_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + invalid_tool_status, invalid_tool_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(owner["user_id"]), + "thread_id": str(owner["thread_id"]), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + isolation_status, isolation_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "tool_id": str(owner_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert invalid_thread_status == 400 + assert invalid_thread_payload == { + "detail": "thread_id must reference an existing thread owned by the user" + } + assert invalid_tool_status == 400 + assert invalid_tool_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + assert isolation_status == 400 + assert isolation_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + + +def test_tool_route_enforces_per_user_isolation(migrated_database_urls, monkeypatch) -> None: + owner = seed_user(migrated_database_urls["app"], email="owner@example.com") + intruder = seed_user(migrated_database_urls["app"], email="intruder@example.com") + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url=migrated_database_urls["app"])) + + with user_connection(migrated_database_urls["app"], owner["user_id"]) as conn: + owner_tool = ContinuityStore(conn).create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + route_status, route_payload = invoke_request( + "POST", + "/v0/tools/route", + payload={ + "user_id": str(intruder["user_id"]), + "thread_id": str(intruder["thread_id"]), + "tool_id": str(owner_tool["id"]), + "action": "tool.run", + "scope": "workspace", + "attributes": {}, + }, + ) + + assert route_status == 400 + assert route_payload == { + "detail": "tool_id must reference an existing active tool owned by the user" + } diff --git a/tests/unit/test_20260310_0001_foundation_continuity.py b/tests/unit/test_20260310_0001_foundation_continuity.py new file mode 100644 index 0000000..9ac3fc7 --- /dev/null +++ b/tests/unit/test_20260310_0001_foundation_continuity.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260310_0001_foundation_continuity" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE users ENABLE ROW LEVEL SECURITY", + "ALTER TABLE users FORCE ROW LEVEL SECURITY", + "ALTER TABLE threads ENABLE ROW LEVEL SECURITY", + "ALTER TABLE threads FORCE ROW LEVEL SECURITY", + "ALTER TABLE sessions ENABLE ROW LEVEL SECURITY", + "ALTER TABLE sessions FORCE ROW LEVEL SECURITY", + "ALTER TABLE events ENABLE ROW LEVEL SECURITY", + "ALTER TABLE events FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_base_downgrade_does_not_drop_global_extensions() -> None: + module = load_migration_module() + + assert "DROP EXTENSION IF EXISTS vector" not in module._DOWNGRADE_STATEMENTS + assert "DROP EXTENSION IF EXISTS pgcrypto" not in module._DOWNGRADE_STATEMENTS + + +def test_base_schema_does_not_create_redundant_events_sequence_index() -> None: + module = load_migration_module() + + assert "CREATE INDEX events_thread_sequence_idx" not in module._UPGRADE_SCHEMA_STATEMENT diff --git a/tests/unit/test_20260311_0002_tighten_runtime_privileges.py b/tests/unit/test_20260311_0002_tighten_runtime_privileges.py new file mode 100644 index 0000000..af0925d --- /dev/null +++ b/tests/unit/test_20260311_0002_tighten_runtime_privileges.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260311_0002_tighten_runtime_privileges" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_downgrade_reasserts_revision_0001_privilege_floor() -> None: + module = load_migration_module() + + assert module._DOWNGRADE_STATEMENTS == ( + "REVOKE UPDATE ON users FROM alicebot_app", + "REVOKE UPDATE ON threads FROM alicebot_app", + "REVOKE UPDATE ON sessions FROM alicebot_app", + ) diff --git a/tests/unit/test_20260311_0003_trace_backbone.py b/tests/unit/test_20260311_0003_trace_backbone.py new file mode 100644 index 0000000..5780912 --- /dev/null +++ b/tests/unit/test_20260311_0003_trace_backbone.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260311_0003_trace_backbone" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE traces ENABLE ROW LEVEL SECURITY", + "ALTER TABLE traces FORCE ROW LEVEL SECURITY", + "ALTER TABLE trace_events ENABLE ROW LEVEL SECURITY", + "ALTER TABLE trace_events FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_trace_tables_keep_runtime_role_at_select_insert_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON traces TO alicebot_app", + "GRANT SELECT, INSERT ON trace_events TO alicebot_app", + ) diff --git a/tests/unit/test_20260311_0004_memory_admission.py b/tests/unit/test_20260311_0004_memory_admission.py new file mode 100644 index 0000000..fe561e1 --- /dev/null +++ b/tests/unit/test_20260311_0004_memory_admission.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260311_0004_memory_admission" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE memories ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memories FORCE ROW LEVEL SECURITY", + "ALTER TABLE memory_revisions ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memory_revisions FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_memory_table_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT, UPDATE ON memories TO alicebot_app", + "GRANT SELECT, INSERT ON memory_revisions TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0005_memory_review_labels.py b/tests/unit/test_20260312_0005_memory_review_labels.py new file mode 100644 index 0000000..2476797 --- /dev/null +++ b/tests/unit/test_20260312_0005_memory_review_labels.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0005_memory_review_labels" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + *module._UPGRADE_BOOTSTRAP_STATEMENTS, + module._UPGRADE_SCHEMA_STATEMENT, + module._UPGRADE_TRIGGER_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE memory_review_labels ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memory_review_labels FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_memory_review_label_table_privileges_stay_append_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON memory_review_labels TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0006_entities_backbone.py b/tests/unit/test_20260312_0006_entities_backbone.py new file mode 100644 index 0000000..d099878 --- /dev/null +++ b/tests/unit/test_20260312_0006_entities_backbone.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0006_entities_backbone" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE entities ENABLE ROW LEVEL SECURITY", + "ALTER TABLE entities FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_entities_table_privileges_stay_insert_select_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON entities TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0007_entity_edges.py b/tests/unit/test_20260312_0007_entity_edges.py new file mode 100644 index 0000000..255b9fb --- /dev/null +++ b/tests/unit/test_20260312_0007_entity_edges.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0007_entity_edges" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE entity_edges ENABLE ROW LEVEL SECURITY", + "ALTER TABLE entity_edges FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_entity_edges_table_privileges_stay_insert_select_only() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON entity_edges TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0008_embedding_substrate.py b/tests/unit/test_20260312_0008_embedding_substrate.py new file mode 100644 index 0000000..240286f --- /dev/null +++ b/tests/unit/test_20260312_0008_embedding_substrate.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0008_embedding_substrate" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE embedding_configs ENABLE ROW LEVEL SECURITY", + "ALTER TABLE embedding_configs FORCE ROW LEVEL SECURITY", + "ALTER TABLE memory_embeddings ENABLE ROW LEVEL SECURITY", + "ALTER TABLE memory_embeddings FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_embedding_tables_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON embedding_configs TO alicebot_app", + "GRANT SELECT, INSERT, UPDATE ON memory_embeddings TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0009_policy_and_consent_core.py b/tests/unit/test_20260312_0009_policy_and_consent_core.py new file mode 100644 index 0000000..b926485 --- /dev/null +++ b/tests/unit/test_20260312_0009_policy_and_consent_core.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0009_policy_and_consent_core" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE consents ENABLE ROW LEVEL SECURITY", + "ALTER TABLE consents FORCE ROW LEVEL SECURITY", + "ALTER TABLE policies ENABLE ROW LEVEL SECURITY", + "ALTER TABLE policies FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_policy_and_consent_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT, UPDATE ON consents TO alicebot_app", + "GRANT SELECT, INSERT ON policies TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0010_tools_registry_and_allowlist.py b/tests/unit/test_20260312_0010_tools_registry_and_allowlist.py new file mode 100644 index 0000000..b7c4215 --- /dev/null +++ b/tests/unit/test_20260312_0010_tools_registry_and_allowlist.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0010_tools_registry_and_allowlist" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE tools ENABLE ROW LEVEL SECURITY", + "ALTER TABLE tools FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_tools_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON tools TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0011_approval_request_records.py b/tests/unit/test_20260312_0011_approval_request_records.py new file mode 100644 index 0000000..00c051b --- /dev/null +++ b/tests/unit/test_20260312_0011_approval_request_records.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0011_approval_request_records" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE approvals ENABLE ROW LEVEL SECURITY", + "ALTER TABLE approvals FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_approvals_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON approvals TO alicebot_app", + ) diff --git a/tests/unit/test_20260312_0012_approval_resolution.py b/tests/unit/test_20260312_0012_approval_resolution.py new file mode 100644 index 0000000..7e37cd5 --- /dev/null +++ b/tests/unit/test_20260312_0012_approval_resolution.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260312_0012_approval_resolution" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_approvals_resolution_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT UPDATE ON approvals TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0013_tool_executions.py b/tests/unit/test_20260313_0013_tool_executions.py new file mode 100644 index 0000000..84e4f67 --- /dev/null +++ b/tests/unit/test_20260313_0013_tool_executions.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0013_tool_executions" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE tool_executions ENABLE ROW LEVEL SECURITY", + "ALTER TABLE tool_executions FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_tool_executions_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON tool_executions TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0014_execution_budgets.py b/tests/unit/test_20260313_0014_execution_budgets.py new file mode 100644 index 0000000..a1cadf3 --- /dev/null +++ b/tests/unit/test_20260313_0014_execution_budgets.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0014_execution_budgets" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE execution_budgets ENABLE ROW LEVEL SECURITY", + "ALTER TABLE execution_budgets FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_execution_budgets_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON execution_budgets TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0015_execution_budget_lifecycle.py b/tests/unit/test_20260313_0015_execution_budget_lifecycle.py new file mode 100644 index 0000000..f1a7468 --- /dev/null +++ b/tests/unit/test_20260313_0015_execution_budget_lifecycle.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0015_execution_budget_lifecycle" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_execution_budget_lifecycle_privileges_stay_narrow() -> None: + module = load_migration_module() + + assert module._UPGRADE_STATEMENTS[-1] == "GRANT SELECT, INSERT, UPDATE ON execution_budgets TO alicebot_app" diff --git a/tests/unit/test_20260313_0016_execution_budget_rolling_window.py b/tests/unit/test_20260313_0016_execution_budget_rolling_window.py new file mode 100644 index 0000000..631b0bb --- /dev/null +++ b/tests/unit/test_20260313_0016_execution_budget_rolling_window.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0016_execution_budget_rolling_window" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0018_task_steps.py b/tests/unit/test_20260313_0018_task_steps.py new file mode 100644 index 0000000..c3ab793 --- /dev/null +++ b/tests/unit/test_20260313_0018_task_steps.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0018_task_steps" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE task_steps ENABLE ROW LEVEL SECURITY", + "ALTER TABLE task_steps FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_task_step_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT, UPDATE ON task_steps TO alicebot_app", + ) diff --git a/tests/unit/test_20260313_0019_task_step_lineage.py b/tests/unit/test_20260313_0019_task_step_lineage.py new file mode 100644 index 0000000..68fdcca --- /dev/null +++ b/tests/unit/test_20260313_0019_task_step_lineage.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0019_task_step_lineage" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statement(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [module._UPGRADE_SCHEMA_STATEMENT] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0020_approval_task_step_linkage.py b/tests/unit/test_20260313_0020_approval_task_step_linkage.py new file mode 100644 index 0000000..5f7816a --- /dev/null +++ b/tests/unit/test_20260313_0020_approval_task_step_linkage.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0020_approval_task_step_linkage" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [module._UPGRADE_SCHEMA_STATEMENT] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0021_tool_execution_task_step_linkage.py b/tests/unit/test_20260313_0021_tool_execution_task_step_linkage.py new file mode 100644 index 0000000..31f5330 --- /dev/null +++ b/tests/unit/test_20260313_0021_tool_execution_task_step_linkage.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0021_tool_execution_task_step_linkage" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == list(module._UPGRADE_STATEMENTS) + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) diff --git a/tests/unit/test_20260313_0022_task_workspaces.py b/tests/unit/test_20260313_0022_task_workspaces.py new file mode 100644 index 0000000..6e352b9 --- /dev/null +++ b/tests/unit/test_20260313_0022_task_workspaces.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import importlib + + +MODULE_NAME = "apps.api.alembic.versions.20260313_0022_task_workspaces" + + +def load_migration_module(): + return importlib.import_module(MODULE_NAME) + + +def test_upgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.upgrade() + + assert executed == [ + module._UPGRADE_SCHEMA_STATEMENT, + *module._UPGRADE_GRANT_STATEMENTS, + "ALTER TABLE task_workspaces ENABLE ROW LEVEL SECURITY", + "ALTER TABLE task_workspaces FORCE ROW LEVEL SECURITY", + module._UPGRADE_POLICY_STATEMENT, + ] + + +def test_downgrade_executes_expected_statements_in_order(monkeypatch) -> None: + module = load_migration_module() + executed: list[str] = [] + + monkeypatch.setattr(module.op, "execute", executed.append) + + module.downgrade() + + assert executed == list(module._DOWNGRADE_STATEMENTS) + + +def test_task_workspace_privileges_allow_only_expected_runtime_writes() -> None: + module = load_migration_module() + + assert module._UPGRADE_GRANT_STATEMENTS == ( + "GRANT SELECT, INSERT ON task_workspaces TO alicebot_app", + ) diff --git a/tests/unit/test_approval_store.py b/tests/unit/test_approval_store.py new file mode 100644 index 0000000..7a944e7 --- /dev/null +++ b/tests/unit/test_approval_store.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_approval_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + approval_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + task_step_id = uuid4() + routing_trace_id = uuid4() + resolved_by_user_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "pending", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": None, + "resolved_by_user_id": None, + }, + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "pending", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": None, + "resolved_by_user_id": None, + }, + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "approved", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": resolved_by_user_id, + }, + ], + fetchall_result=[ + { + "id": approval_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": "pending", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": {"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + "routing_trace_id": routing_trace_id, + "resolved_at": None, + "resolved_by_user_id": None, + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_approval( + thread_id=thread_id, + tool_id=tool_id, + task_step_id=task_step_id, + status="pending", + request={"thread_id": str(thread_id), "tool_id": str(tool_id)}, + tool={"id": str(tool_id), "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "trace": {"trace_id": str(routing_trace_id)}}, + routing_trace_id=routing_trace_id, + ) + fetched = store.get_approval_optional(approval_id) + listed = store.list_approvals() + resolved = store.resolve_approval_optional(approval_id=approval_id, status="approved") + + assert created["id"] == approval_id + assert created["resolved_at"] is None + assert fetched is not None + assert listed[0]["id"] == approval_id + assert resolved is not None + assert resolved["status"] == "approved" + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO approvals" in create_query + assert create_params is not None + assert create_params[:4] == (thread_id, tool_id, task_step_id, "pending") + assert isinstance(create_params[4], Jsonb) + assert create_params[4].obj == {"thread_id": str(thread_id), "tool_id": str(tool_id)} + assert isinstance(create_params[5], Jsonb) + assert create_params[5].obj == {"id": str(tool_id), "tool_key": "shell.exec"} + assert isinstance(create_params[6], Jsonb) + assert create_params[6].obj == { + "decision": "approval_required", + "trace": {"trace_id": str(routing_trace_id)}, + } + assert create_params[7] == routing_trace_id + assert "resolved_at" in cursor.executed[1][0] + assert "ORDER BY created_at ASC, id ASC" in cursor.executed[2][0] + + resolve_query, resolve_params = cursor.executed[3] + assert "UPDATE approvals" in resolve_query + assert "WHERE id = %s" in resolve_query + assert "AND status = 'pending'" in resolve_query + assert resolve_params == ("approved", approval_id) diff --git a/tests/unit/test_approvals.py b/tests/unit/test_approvals.py new file mode 100644 index 0000000..2ac7b2f --- /dev/null +++ b/tests/unit/test_approvals.py @@ -0,0 +1,1200 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +from alicebot_api.approvals import ( + ApprovalNotFoundError, + ApprovalResolutionConflictError, + approve_approval_record, + get_approval_record, + list_approval_records, + reject_approval_record, + submit_approval_request, +) +from alicebot_api.contracts import ApprovalApproveInput, ApprovalRejectInput, ApprovalRequestCreateInput +from alicebot_api.tasks import TaskStepApprovalLinkageError + + +class ApprovalStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.locked_task_ids: list[UUID] = [] + self.consents: dict[str, dict[str, object]] = {} + self.policies: list[dict[str, object]] = [] + self.tools: list[dict[str, object]] = [] + self.approvals: list[dict[str, object]] = [] + self.tasks: list[dict[str, object]] = [] + self.task_steps: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def create_consent(self, *, consent_key: str, status: str, metadata: dict[str, object]) -> dict[str, object]: + consent = { + "id": uuid4(), + "user_id": self.user_id, + "consent_key": consent_key, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.consents)), + "updated_at": self.base_time + timedelta(minutes=len(self.consents)), + } + self.consents[consent_key] = consent + return consent + + def list_consents(self) -> list[dict[str, object]]: + return sorted( + self.consents.values(), + key=lambda consent: (consent["consent_key"], consent["created_at"], consent["id"]), + ) + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: dict[str, object], + required_consents: list[str], + ) -> dict[str, object]: + policy = { + "id": uuid4(), + "user_id": self.user_id, + "name": name, + "action": action, + "scope": scope, + "effect": effect, + "priority": priority, + "active": active, + "conditions": conditions, + "required_consents": required_consents, + "created_at": self.base_time + timedelta(minutes=len(self.policies)), + "updated_at": self.base_time + timedelta(minutes=len(self.policies)), + } + self.policies.append(policy) + return policy + + def list_active_policies(self) -> list[dict[str, object]]: + return sorted( + [policy for policy in self.policies if policy["active"] is True], + key=lambda policy: (policy["priority"], policy["created_at"], policy["id"]), + ) + + def create_tool( + self, + *, + tool_key: str, + name: str, + description: str, + version: str, + metadata_version: str, + active: bool, + tags: list[str], + action_hints: list[str], + scope_hints: list[str], + domain_hints: list[str], + risk_hints: list[str], + metadata: dict[str, object], + ) -> dict[str, object]: + tool = { + "id": uuid4(), + "user_id": self.user_id, + "tool_key": tool_key, + "name": name, + "description": description, + "version": version, + "metadata_version": metadata_version, + "active": active, + "tags": tags, + "action_hints": action_hints, + "scope_hints": scope_hints, + "domain_hints": domain_hints, + "risk_hints": risk_hints, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.tools)), + } + self.tools.append(tool) + return tool + + def get_tool_optional(self, tool_id: UUID) -> dict[str, object] | None: + return next((tool for tool in self.tools if tool["id"] == tool_id), None) + + def list_active_tools(self) -> list[dict[str, object]]: + return [tool for tool in self.tools if tool["active"] is True] + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Approval thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + def create_approval( + self, + *, + thread_id: UUID, + tool_id: UUID, + task_step_id: UUID | None, + status: str, + request: dict[str, object], + tool: dict[str, object], + routing: dict[str, object], + routing_trace_id: UUID, + ) -> dict[str, object]: + approval = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "tool_id": tool_id, + "task_step_id": task_step_id, + "status": status, + "request": request, + "tool": tool, + "routing": routing, + "routing_trace_id": routing_trace_id, + "created_at": self.base_time + timedelta(minutes=len(self.approvals)), + "resolved_at": None, + "resolved_by_user_id": None, + } + self.approvals.append(approval) + return approval + + def get_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((approval for approval in self.approvals if approval["id"] == approval_id), None) + + def list_approvals(self) -> list[dict[str, object]]: + return sorted( + self.approvals, + key=lambda approval: (approval["created_at"], approval["id"]), + ) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def create_task( + self, + *, + thread_id: UUID, + tool_id: UUID, + status: str, + request: dict[str, object], + tool: dict[str, object], + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object]: + task = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "tool_id": tool_id, + "status": status, + "request": request, + "tool": tool, + "latest_approval_id": latest_approval_id, + "latest_execution_id": latest_execution_id, + "created_at": self.base_time + timedelta(minutes=len(self.tasks)), + "updated_at": self.base_time + timedelta(minutes=len(self.tasks)), + } + self.tasks.append(task) + return task + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def get_task_by_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["latest_approval_id"] == approval_id), None) + + def lock_task_steps(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def list_tasks(self) -> list[dict[str, object]]: + return sorted( + self.tasks, + key=lambda task: (task["created_at"], task["id"]), + ) + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: dict[str, object], + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object]: + task_step = { + "id": uuid4(), + "user_id": self.user_id, + "task_id": task_id, + "sequence_no": sequence_no, + "parent_step_id": parent_step_id, + "source_approval_id": source_approval_id, + "source_execution_id": source_execution_id, + "kind": kind, + "status": status, + "request": request, + "outcome": outcome, + "trace_id": trace_id, + "trace_kind": trace_kind, + "created_at": self.base_time + timedelta(minutes=len(self.task_steps)), + "updated_at": self.base_time + timedelta(minutes=len(self.task_steps)), + } + self.task_steps.append(task_step) + return task_step + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> dict[str, object] | None: + return next( + ( + task_step + for task_step in self.task_steps + if task_step["task_id"] == task_id and task_step["sequence_no"] == sequence_no + ), + None, + ) + + def get_task_step_optional(self, task_step_id: UUID) -> dict[str, object] | None: + return next((task_step for task_step in self.task_steps if task_step["id"] == task_step_id), None) + + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return sorted( + [task_step for task_step in self.task_steps if task_step["task_id"] == task_id], + key=lambda task_step: (task_step["sequence_no"], task_step["created_at"], task_step["id"]), + ) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_for_task_sequence_optional(task_id=task_id, sequence_no=sequence_no) + if task_step is None: + return None + + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_optional(task_step_id) + if task_step is None: + return None + + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def update_task_status_by_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> dict[str, object] | None: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + return None + task["status"] = status + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def update_task_status_optional( + self, + *, + task_id: UUID, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object] | None: + task = self.get_task_optional(task_id) + if task is None: + return None + task["status"] = status + task["latest_approval_id"] = latest_approval_id + task["latest_execution_id"] = latest_execution_id + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def resolve_approval_optional(self, *, approval_id: UUID, status: str) -> dict[str, object] | None: + approval = self.get_approval_optional(approval_id) + if approval is None or approval["status"] != "pending": + return None + + approval["status"] = status + approval["resolved_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + approval["resolved_by_user_id"] = self.user_id + return approval + + def update_approval_task_step_optional( + self, + *, + approval_id: UUID, + task_step_id: UUID, + ) -> dict[str, object] | None: + approval = self.get_approval_optional(approval_id) + if approval is None: + return None + approval["task_step_id"] = task_step_id + return approval + + +def test_submit_approval_request_persists_record_for_approval_required_route() -> None: + store = ApprovalStoreStub() + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + payload = submit_approval_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalRequestCreateInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={"command": "ls"}, + ), + ) + + assert payload["decision"] == "approval_required" + assert payload["task"]["status"] == "pending_approval" + assert payload["task"]["latest_approval_id"] == payload["approval"]["id"] + assert payload["task"]["latest_execution_id"] is None + assert payload["approval"] is not None + assert payload["approval"]["status"] == "pending" + assert payload["approval"]["resolution"] is None + assert payload["approval"]["thread_id"] == str(store.thread_id) + assert payload["approval"]["task_step_id"] == str(store.task_steps[0]["id"]) + assert payload["approval"]["request"] == payload["request"] + assert payload["approval"]["tool"] == payload["tool"] + assert payload["approval"]["routing"] == { + "decision": "approval_required", + "reasons": payload["reasons"], + "trace": payload["routing_trace"], + } + assert payload["routing_trace"]["trace_event_count"] == 3 + assert payload["trace"]["trace_event_count"] == 8 + assert len(store.approvals) == 1 + assert len(store.tasks) == 1 + assert len(store.task_steps) == 1 + assert store.traces[0]["kind"] == "tool.route" + assert store.traces[1]["kind"] == "approval.request" + assert store.traces[1]["compiler_version"] == "approval_request_v0" + assert store.traces[1]["limits"] == { + "order": ["created_at_asc", "id_asc"], + "persisted": True, + } + assert [event["kind"] for event in store.trace_events[-8:]] == [ + "approval.request.request", + "approval.request.routing", + "approval.request.persisted", + "approval.request.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[-7]["payload"]["routing_trace_id"] == payload["routing_trace"]["trace_id"] + assert store.trace_events[-6]["payload"] == { + "approval_id": payload["approval"]["id"], + "task_step_id": payload["approval"]["task_step_id"], + "decision": "approval_required", + "persisted": True, + } + assert store.trace_events[-4]["payload"] == { + "task_id": payload["task"]["id"], + "source": "approval_request", + "previous_status": None, + "current_status": "pending_approval", + "latest_approval_id": payload["approval"]["id"], + "latest_execution_id": None, + } + assert store.trace_events[-2]["payload"] == { + "task_id": payload["task"]["id"], + "task_step_id": str(store.task_steps[0]["id"]), + "source": "approval_request", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": None, + "current_status": "created", + "trace": { + "trace_id": payload["trace"]["trace_id"], + "trace_kind": "approval.request", + }, + } + assert payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + } + + +def test_submit_approval_request_does_not_persist_for_ready_or_denied_routes() -> None: + ready_store = ApprovalStoreStub() + ready_store.create_consent( + consent_key="web_access", + status="granted", + metadata={"source": "settings"}, + ) + ready_tool = ready_store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + ready_store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + ready_payload = submit_approval_request( + ready_store, # type: ignore[arg-type] + user_id=ready_store.user_id, + request=ApprovalRequestCreateInput( + thread_id=ready_store.thread_id, + tool_id=ready_tool["id"], + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={}, + ), + ) + + denied_store = ApprovalStoreStub() + denied_tool = denied_store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + denied_payload = submit_approval_request( + denied_store, # type: ignore[arg-type] + user_id=denied_store.user_id, + request=ApprovalRequestCreateInput( + thread_id=denied_store.thread_id, + tool_id=denied_tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + assert ready_payload["decision"] == "ready" + assert ready_payload["task"]["status"] == "approved" + assert ready_payload["task"]["latest_approval_id"] is None + assert ready_payload["approval"] is None + assert ready_store.approvals == [] + assert len(ready_store.task_steps) == 1 + assert [event["kind"] for event in ready_store.trace_events[-8:]] == [ + "approval.request.request", + "approval.request.routing", + "approval.request.skipped", + "approval.request.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + + assert denied_payload["decision"] == "denied" + assert denied_payload["task"]["status"] == "denied" + assert denied_payload["task"]["latest_approval_id"] is None + assert denied_payload["approval"] is None + assert denied_store.approvals == [] + assert [reason["code"] for reason in denied_payload["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + + +def test_approve_approval_record_resolves_pending_and_records_trace() -> None: + store = ApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-1"}, + tool={"id": "tool-1", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-1", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + created_task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + created_step = store.create_task_step( + task_id=created_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + approval["task_step_id"] = created_step["id"] + + payload = approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=approval["id"]), + ) + + assert payload["approval"]["id"] == str(approval["id"]) + assert payload["approval"]["task_step_id"] == str(created_step["id"]) + assert payload["approval"]["status"] == "approved" + assert payload["approval"]["resolution"] == { + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + assert payload["trace"]["trace_event_count"] == 7 + assert store.traces[0]["kind"] == "approval.resolve" + assert store.traces[0]["compiler_version"] == "approval_resolution_v0" + assert store.traces[0]["limits"] == { + "order": ["created_at_asc", "id_asc"], + "requested_action": "approve", + "outcome": "resolved", + } + assert [event["kind"] for event in store.trace_events] == [ + "approval.resolution.request", + "approval.resolution.state", + "approval.resolution.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[1]["payload"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(approval["task_step_id"]), + "requested_action": "approve", + "previous_status": "pending", + "outcome": "resolved", + "current_status": "approved", + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + assert store.trace_events[3]["payload"] == { + "task_id": str(store.tasks[0]["id"]), + "source": "approval_resolution", + "previous_status": "pending_approval", + "current_status": "approved", + "latest_approval_id": str(approval["id"]), + "latest_execution_id": None, + } + assert store.trace_events[5]["payload"] == { + "task_id": str(store.tasks[0]["id"]), + "task_step_id": str(store.task_steps[0]["id"]), + "source": "approval_resolution", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": "created", + "current_status": "approved", + "trace": { + "trace_id": str(store.traces[0]["id"]), + "trace_kind": "approval.resolve", + }, + } + + +def test_reject_approval_record_resolves_pending_and_records_trace() -> None: + store = ApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-2"}, + tool={"id": "tool-2", "tool_key": "browser.open"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-2", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + created_task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + created_step = store.create_task_step( + task_id=created_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + approval["task_step_id"] = created_step["id"] + + payload = reject_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalRejectInput(approval_id=approval["id"]), + ) + + assert payload["approval"]["status"] == "rejected" + assert payload["approval"]["task_step_id"] == str(created_step["id"]) + assert payload["approval"]["resolution"] == { + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + assert store.trace_events[1]["payload"]["requested_action"] == "reject" + assert store.trace_events[1]["payload"]["current_status"] == "rejected" + + +def test_approval_resolution_locks_task_steps_before_task_and_step_mutation() -> None: + class LockingApprovalStoreStub(ApprovalStoreStub): + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + if task_id not in self.locked_task_ids: + raise AssertionError("task-step boundary was checked before the task-step lock was taken") + return super().list_task_steps_for_task(task_id) + + def update_task_status_by_approval_optional( + self, + *, + approval_id: UUID, + status: str, + ) -> dict[str, object] | None: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + return None + if task["id"] not in self.locked_task_ids: + raise AssertionError("task status changed before the task-step lock was taken") + return super().update_task_status_by_approval_optional( + approval_id=approval_id, + status=status, + ) + + store = LockingApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-lock"}, + tool={"id": "tool-lock", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-lock", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + created_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + approval["task_step_id"] = created_step["id"] + + payload = approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=approval["id"]), + ) + + assert payload["approval"]["status"] == "approved" + assert task["id"] in store.locked_task_ids + + +def test_resolution_rejects_duplicate_and_conflicting_updates_deterministically() -> None: + duplicate_store = ApprovalStoreStub() + duplicate_approval = duplicate_store.create_approval( + thread_id=duplicate_store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(duplicate_store.thread_id), "tool_id": "tool-3"}, + tool={"id": "tool-3", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-3", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + duplicate_task = duplicate_store.create_task( + thread_id=duplicate_store.thread_id, + tool_id=duplicate_approval["tool_id"], + status="pending_approval", + request=duplicate_approval["request"], + tool=duplicate_approval["tool"], + latest_approval_id=duplicate_approval["id"], + latest_execution_id=None, + ) + duplicate_step = duplicate_store.create_task_step( + task_id=duplicate_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=duplicate_approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(duplicate_approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + duplicate_approval["task_step_id"] = duplicate_step["id"] + approve_approval_record( + duplicate_store, # type: ignore[arg-type] + user_id=duplicate_store.user_id, + request=ApprovalApproveInput(approval_id=duplicate_approval["id"]), + ) + + try: + approve_approval_record( + duplicate_store, # type: ignore[arg-type] + user_id=duplicate_store.user_id, + request=ApprovalApproveInput(approval_id=duplicate_approval["id"]), + ) + except ApprovalResolutionConflictError as exc: + assert str(exc) == f"approval {duplicate_approval['id']} was already approved" + else: + raise AssertionError("expected ApprovalResolutionConflictError for duplicate approval") + + assert duplicate_store.trace_events[-6]["payload"]["outcome"] == "duplicate_rejected" + + conflict_store = ApprovalStoreStub() + conflict_approval = conflict_store.create_approval( + thread_id=conflict_store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(conflict_store.thread_id), "tool_id": "tool-4"}, + tool={"id": "tool-4", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-4", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + conflict_task = conflict_store.create_task( + thread_id=conflict_store.thread_id, + tool_id=conflict_approval["tool_id"], + status="pending_approval", + request=conflict_approval["request"], + tool=conflict_approval["tool"], + latest_approval_id=conflict_approval["id"], + latest_execution_id=None, + ) + conflict_step = conflict_store.create_task_step( + task_id=conflict_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=conflict_approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(conflict_approval["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + conflict_approval["task_step_id"] = conflict_step["id"] + approve_approval_record( + conflict_store, # type: ignore[arg-type] + user_id=conflict_store.user_id, + request=ApprovalApproveInput(approval_id=conflict_approval["id"]), + ) + + try: + reject_approval_record( + conflict_store, # type: ignore[arg-type] + user_id=conflict_store.user_id, + request=ApprovalRejectInput(approval_id=conflict_approval["id"]), + ) + except ApprovalResolutionConflictError as exc: + assert str(exc) == ( + f"approval {conflict_approval['id']} was already approved and cannot be rejected" + ) + else: + raise AssertionError("expected ApprovalResolutionConflictError for conflicting rejection") + + assert conflict_store.trace_events[-6]["payload"]["outcome"] == "conflict_rejected" + + +def test_approval_resolution_rejects_inconsistent_linkage_without_mutating_task_state() -> None: + store = ApprovalStoreStub() + approval = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="approved", + request={"thread_id": str(store.thread_id), "tool_id": "tool-boundary"}, + tool={"id": "tool-boundary", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-boundary", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + task = store.create_task( + thread_id=store.thread_id, + tool_id=approval["tool_id"], + status="pending_approval", + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval["id"], + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "approved", + "execution_id": str(uuid4()), + "execution_status": "completed", + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval["id"], + source_execution_id=uuid4(), + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + approval["task_step_id"] = later_step["id"] + + original_first_trace_id = first_step["trace_id"] + original_first_outcome = dict(first_step["outcome"]) + original_later_trace_id = later_step["trace_id"] + + try: + approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=approval["id"]), + ) + except TaskStepApprovalLinkageError as exc: + assert str(exc) == ( + f"approval {approval['id']} is inconsistent with linked task step {later_step['id']}" + ) + else: + raise AssertionError("expected TaskStepApprovalLinkageError") + + assert task["status"] == "pending_approval" + assert task["latest_execution_id"] is None + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "created" + assert later_step["trace_id"] == original_later_trace_id + assert store.traces == [] + assert store.trace_events == [] + + +def test_list_and_get_approval_records_use_deterministic_order_after_resolution() -> None: + store = ApprovalStoreStub() + first = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-1"}, + tool={"id": "tool-1", "tool_key": "shell.exec"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-1", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + first_task = store.create_task( + thread_id=store.thread_id, + tool_id=first["tool_id"], + status="pending_approval", + request=first["request"], + tool=first["tool"], + latest_approval_id=first["id"], + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=first_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=first["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(first["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + first["task_step_id"] = first_step["id"] + second = store.create_approval( + thread_id=store.thread_id, + tool_id=uuid4(), + task_step_id=None, + status="pending", + request={"thread_id": str(store.thread_id), "tool_id": "tool-2"}, + tool={"id": "tool-2", "tool_key": "browser.open"}, + routing={"decision": "approval_required", "reasons": [], "trace": {"trace_id": "trace-2", "trace_event_count": 3}}, + routing_trace_id=uuid4(), + ) + second_task = store.create_task( + thread_id=store.thread_id, + tool_id=second["tool_id"], + status="pending_approval", + request=second["request"], + tool=second["tool"], + latest_approval_id=second["id"], + latest_execution_id=None, + ) + second_step = store.create_task_step( + task_id=second_task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=second["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(second["id"]), + "approval_status": "pending", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request", + ) + second["task_step_id"] = second_step["id"] + + approve_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalApproveInput(approval_id=first["id"]), + ) + reject_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ApprovalRejectInput(approval_id=second["id"]), + ) + + listed = list_approval_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + approval_id=UUID(str(second["id"])), + ) + + assert [item["id"] for item in listed["items"]] == [str(first["id"]), str(second["id"])] + assert [item["task_step_id"] for item in listed["items"]] == [str(first_step["id"]), str(second_step["id"])] + assert [item["status"] for item in listed["items"]] == ["approved", "rejected"] + assert listed["items"][0]["resolution"] is not None + assert listed["items"][1]["resolution"] is not None + assert listed["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail["approval"]["id"] == str(second["id"]) + assert detail["approval"]["task_step_id"] == str(second_step["id"]) + assert detail["approval"]["status"] == "rejected" + assert detail["approval"]["resolution"] == { + "resolved_at": "2026-03-12T10:07:00+00:00", + "resolved_by_user_id": str(store.user_id), + } + + +def test_get_approval_record_raises_not_found_when_missing() -> None: + store = ApprovalStoreStub() + + try: + get_approval_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + approval_id=uuid4(), + ) + except ApprovalNotFoundError as exc: + assert "approval" in str(exc) + else: + raise AssertionError("expected ApprovalNotFoundError") diff --git a/tests/unit/test_approvals_main.py b/tests/unit/test_approvals_main.py new file mode 100644 index 0000000..833f78d --- /dev/null +++ b/tests/unit/test_approvals_main.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.approvals import ApprovalNotFoundError, ApprovalResolutionConflictError +from alicebot_api.tasks import TaskStepApprovalLinkageError +from alicebot_api.tools import ToolRoutingValidationError + + +def test_create_approval_request_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_submit_approval_request(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "request": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"command": "ls"}, + }, + "decision": "approval_required", + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "reasons": [], + "approval": { + "id": "approval-123", + "thread_id": str(thread_id), + "task_step_id": "task-step-123", + "status": "pending", + "resolution": None, + "request": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"command": "ls"}, + }, + "tool": {"id": str(tool_id), "tool_key": "shell.exec"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "routing_trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + "trace": {"trace_id": "approval-trace-123", "trace_event_count": 4}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "submit_approval_request", fake_submit_approval_request) + + response = main_module.create_approval_request( + main_module.CreateApprovalRequest( + user_id=user_id, + thread_id=thread_id, + tool_id=tool_id, + action="tool.run", + scope="workspace", + attributes={"command": "ls"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == { + "trace_id": "approval-trace-123", + "trace_event_count": 4, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].tool_id == tool_id + assert captured["request"].attributes == {"command": "ls"} + + +def test_create_approval_request_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_submit_approval_request(*_args, **_kwargs): + raise ToolRoutingValidationError("tool_id must reference an existing active tool owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "submit_approval_request", fake_submit_approval_request) + + response = main_module.create_approval_request( + main_module.CreateApprovalRequest( + user_id=user_id, + thread_id=uuid4(), + tool_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "tool_id must reference an existing active tool owned by the user" + } + + +def test_list_approvals_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_approval_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + }, + ) + + response = main_module.list_approvals(user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + +def test_get_approval_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_approval_record(*_args, **_kwargs): + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_approval_record", fake_get_approval_record) + + response = main_module.get_approval(approval_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was not found"} + + +def test_approve_approval_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_approve_approval_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-123", + "status": "approved", + "resolution": { + "resolved_at": "2026-03-12T10:00:00+00:00", + "resolved_by_user_id": str(user_id), + }, + "request": {"thread_id": "thread-123", "tool_id": "tool-123"}, + "tool": {"id": "tool-123", "tool_key": "shell.exec"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "trace": {"trace_id": "approval-resolution-trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "approve_approval_record", fake_approve_approval_record) + + response = main_module.approve_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == { + "trace_id": "approval-resolution-trace-123", + "trace_event_count": 3, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].approval_id == approval_id + + +def test_approve_approval_endpoint_maps_conflicts_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_approve_approval_record(*_args, **_kwargs): + raise ApprovalResolutionConflictError(f"approval {approval_id} was already approved") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "approve_approval_record", fake_approve_approval_record) + + response = main_module.approve_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was already approved"} + + +def test_approve_approval_endpoint_maps_linkage_errors_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_approve_approval_record(*_args, **_kwargs): + raise TaskStepApprovalLinkageError( + f"approval {approval_id} is inconsistent with linked task step task-step-123" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "approve_approval_record", fake_approve_approval_record) + + response = main_module.approve_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"approval {approval_id} is inconsistent with linked task step task-step-123" + } + + +def test_reject_approval_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_reject_approval_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-456", + "status": "rejected", + "resolution": { + "resolved_at": "2026-03-12T10:01:00+00:00", + "resolved_by_user_id": str(user_id), + }, + "request": {"thread_id": "thread-123", "tool_id": "tool-123"}, + "tool": {"id": "tool-123", "tool_key": "shell.exec"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "trace": {"trace_id": "approval-resolution-trace-456", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "reject_approval_record", fake_reject_approval_record) + + response = main_module.reject_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == { + "trace_id": "approval-resolution-trace-456", + "trace_event_count": 3, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].approval_id == approval_id + + +def test_reject_approval_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_reject_approval_record(*_args, **_kwargs): + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "reject_approval_record", fake_reject_approval_record) + + response = main_module.reject_approval( + approval_id, + main_module.ResolveApprovalRequest(user_id=user_id), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was not found"} diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py new file mode 100644 index 0000000..c221707 --- /dev/null +++ b/tests/unit/test_compiler.py @@ -0,0 +1,760 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +from alicebot_api.compiler import ( + SUMMARY_TRACE_EVENT_KIND, + _compile_memory_section, + compile_continuity_context, +) +from alicebot_api.contracts import CompileContextSemanticRetrievalInput, ContextCompilerLimits + + +def test_compile_continuity_context_is_deterministic_and_stably_ordered() -> None: + user_id = uuid4() + thread_id = uuid4() + base_time = datetime(2026, 3, 11, 9, 0, tzinfo=UTC) + session_ids = [uuid4(), uuid4(), uuid4()] + event_ids = [uuid4(), uuid4(), uuid4(), uuid4()] + memory_ids = [uuid4(), uuid4(), uuid4()] + entity_ids = [uuid4(), uuid4(), uuid4()] + edge_ids = [uuid4(), uuid4(), uuid4(), uuid4()] + + user = { + "id": user_id, + "email": "owner@example.com", + "display_name": "Owner", + "created_at": base_time, + } + thread = { + "id": thread_id, + "user_id": user_id, + "title": "Traceable thread", + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=4), + } + sessions = [ + { + "id": session_ids[0], + "user_id": user_id, + "thread_id": thread_id, + "status": "done", + "started_at": base_time, + "ended_at": base_time + timedelta(minutes=1), + "created_at": base_time, + }, + { + "id": session_ids[1], + "user_id": user_id, + "thread_id": thread_id, + "status": "done", + "started_at": base_time + timedelta(minutes=2), + "ended_at": base_time + timedelta(minutes=3), + "created_at": base_time + timedelta(minutes=2), + }, + { + "id": session_ids[2], + "user_id": user_id, + "thread_id": thread_id, + "status": "active", + "started_at": base_time + timedelta(minutes=4), + "ended_at": None, + "created_at": base_time + timedelta(minutes=4), + }, + ] + events = [ + { + "id": event_ids[0], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[0], + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "one"}, + "created_at": base_time, + }, + { + "id": event_ids[1], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[1], + "sequence_no": 2, + "kind": "message.assistant", + "payload": {"text": "two"}, + "created_at": base_time + timedelta(minutes=2), + }, + { + "id": event_ids[2], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[2], + "sequence_no": 3, + "kind": "message.user", + "payload": {"text": "three"}, + "created_at": base_time + timedelta(minutes=4), + }, + { + "id": event_ids[3], + "user_id": user_id, + "thread_id": thread_id, + "session_id": session_ids[2], + "sequence_no": 4, + "kind": "message.assistant", + "payload": {"text": "four"}, + "created_at": base_time + timedelta(minutes=5), + }, + ] + memories = [ + { + "id": memory_ids[0], + "user_id": user_id, + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "active", + "source_event_ids": [str(event_ids[0])], + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=1), + "deleted_at": None, + }, + { + "id": memory_ids[1], + "user_id": user_id, + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_ids[1])], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=4), + "deleted_at": None, + }, + { + "id": memory_ids[2], + "user_id": user_id, + "memory_key": "user.preference.snacks", + "value": {"likes": "almonds"}, + "status": "active", + "source_event_ids": [str(event_ids[2])], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=5), + "deleted_at": None, + }, + ] + entities = [ + { + "id": entity_ids[0], + "user_id": user_id, + "entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(memory_ids[0])], + "created_at": base_time, + }, + { + "id": entity_ids[1], + "user_id": user_id, + "entity_type": "merchant", + "name": "Neighborhood Cafe", + "source_memory_ids": [str(memory_ids[1])], + "created_at": base_time + timedelta(minutes=3), + }, + { + "id": entity_ids[2], + "user_id": user_id, + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(memory_ids[1]), str(memory_ids[2])], + "created_at": base_time + timedelta(minutes=6), + }, + ] + entity_edges = [ + { + "id": edge_ids[0], + "user_id": user_id, + "from_entity_id": entity_ids[0], + "to_entity_id": entity_ids[1], + "relationship_type": "visits", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(memory_ids[0])], + "created_at": base_time + timedelta(minutes=2), + }, + { + "id": edge_ids[1], + "user_id": user_id, + "from_entity_id": entity_ids[2], + "to_entity_id": entity_ids[0], + "relationship_type": "references", + "valid_from": base_time + timedelta(minutes=5), + "valid_to": None, + "source_memory_ids": [str(memory_ids[2])], + "created_at": base_time + timedelta(minutes=5), + }, + { + "id": edge_ids[2], + "user_id": user_id, + "from_entity_id": entity_ids[1], + "to_entity_id": entity_ids[2], + "relationship_type": "works_on", + "valid_from": None, + "valid_to": base_time + timedelta(minutes=8), + "source_memory_ids": [str(memory_ids[1]), str(memory_ids[2])], + "created_at": base_time + timedelta(minutes=8), + }, + { + "id": edge_ids[3], + "user_id": user_id, + "from_entity_id": entity_ids[0], + "to_entity_id": entity_ids[0], + "relationship_type": "self_loop", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(memory_ids[0])], + "created_at": base_time + timedelta(minutes=9), + }, + ] + limits = ContextCompilerLimits( + max_sessions=2, + max_events=2, + max_memories=2, + max_entities=2, + max_entity_edges=2, + ) + + first_run = compile_continuity_context( + user=user, + thread=thread, + sessions=sessions, + events=events, + memories=memories, + entities=entities, + entity_edges=entity_edges, + limits=limits, + ) + second_run = compile_continuity_context( + user=user, + thread=thread, + sessions=sessions, + events=events, + memories=memories, + entities=entities, + entity_edges=entity_edges, + limits=limits, + ) + + assert first_run.context_pack == second_run.context_pack + assert first_run.trace_events == second_run.trace_events + assert [session["id"] for session in first_run.context_pack["sessions"]] == [ + str(session_ids[1]), + str(session_ids[2]), + ] + assert [event["sequence_no"] for event in first_run.context_pack["events"]] == [3, 4] + assert [memory["memory_key"] for memory in first_run.context_pack["memories"]] == [ + "user.preference.coffee", + "user.preference.snacks", + ] + assert [memory["source_provenance"] for memory in first_run.context_pack["memories"]] == [ + {"sources": ["symbolic"], "semantic_score": None}, + {"sources": ["symbolic"], "semantic_score": None}, + ] + assert [entity["id"] for entity in first_run.context_pack["entities"]] == [ + str(entity_ids[1]), + str(entity_ids[2]), + ] + assert [edge["id"] for edge in first_run.context_pack["entity_edges"]] == [ + str(edge_ids[1]), + str(edge_ids[2]), + ] + assert first_run.context_pack["memory_summary"] == { + "candidate_count": 2, + "included_count": 2, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 2, + "semantic_selected_count": 0, + "merged_candidate_count": 2, + "deduplicated_count": 0, + "included_symbolic_only_count": 2, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert first_run.context_pack["entity_summary"] == { + "candidate_count": 3, + "included_count": 2, + "excluded_limit_count": 1, + } + assert first_run.context_pack["entity_edge_summary"] == { + "anchor_entity_count": 2, + "candidate_count": 3, + "included_count": 2, + "excluded_limit_count": 1, + } + + +def test_compile_continuity_context_records_included_and_excluded_reasons() -> None: + user_id = uuid4() + thread_id = uuid4() + base_time = datetime(2026, 3, 11, 9, 0, tzinfo=UTC) + kept_session_id = uuid4() + dropped_session_id = uuid4() + dropped_by_session_event_id = uuid4() + dropped_by_event_limit_id = uuid4() + kept_event_id = uuid4() + dropped_by_memory_limit_id = uuid4() + kept_memory_id = uuid4() + deleted_memory_id = uuid4() + dropped_entity_id = uuid4() + kept_entity_id = uuid4() + dropped_entity_edge_id = uuid4() + kept_entity_edge_id = uuid4() + ignored_entity_edge_id = uuid4() + external_entity_id = uuid4() + kept_edge_valid_from = base_time + timedelta(minutes=5) + + compiler_run = compile_continuity_context( + user={ + "id": user_id, + "email": "owner@example.com", + "display_name": "Owner", + "created_at": base_time, + }, + thread={ + "id": thread_id, + "user_id": user_id, + "title": "Traceable thread", + "created_at": base_time, + "updated_at": base_time, + }, + sessions=[ + { + "id": dropped_session_id, + "user_id": user_id, + "thread_id": thread_id, + "status": "done", + "started_at": base_time, + "ended_at": base_time, + "created_at": base_time, + }, + { + "id": kept_session_id, + "user_id": user_id, + "thread_id": thread_id, + "status": "active", + "started_at": base_time + timedelta(minutes=1), + "ended_at": None, + "created_at": base_time + timedelta(minutes=1), + }, + ], + events=[ + { + "id": dropped_by_session_event_id, + "user_id": user_id, + "thread_id": thread_id, + "session_id": dropped_session_id, + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "old session"}, + "created_at": base_time, + }, + { + "id": dropped_by_event_limit_id, + "user_id": user_id, + "thread_id": thread_id, + "session_id": kept_session_id, + "sequence_no": 2, + "kind": "message.assistant", + "payload": {"text": "too old"}, + "created_at": base_time + timedelta(minutes=1), + }, + { + "id": kept_event_id, + "user_id": user_id, + "thread_id": thread_id, + "session_id": kept_session_id, + "sequence_no": 3, + "kind": "message.user", + "payload": {"text": "keep"}, + "created_at": base_time + timedelta(minutes=2), + }, + ], + memories=[ + { + "id": dropped_by_memory_limit_id, + "user_id": user_id, + "memory_key": "user.preference.old", + "value": {"likes": "black"}, + "status": "active", + "source_event_ids": [str(dropped_by_session_event_id)], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + }, + { + "id": kept_memory_id, + "user_id": user_id, + "memory_key": "user.preference.keep", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(kept_event_id)], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=2), + "deleted_at": None, + }, + { + "id": deleted_memory_id, + "user_id": user_id, + "memory_key": "user.preference.deleted", + "value": {"likes": "espresso"}, + "status": "deleted", + "source_event_ids": [str(dropped_by_event_limit_id)], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=3), + "deleted_at": base_time + timedelta(minutes=3), + }, + ], + entities=[ + { + "id": dropped_entity_id, + "user_id": user_id, + "entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(dropped_by_memory_limit_id)], + "created_at": base_time, + }, + { + "id": kept_entity_id, + "user_id": user_id, + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(kept_memory_id)], + "created_at": base_time + timedelta(minutes=4), + }, + ], + entity_edges=[ + { + "id": dropped_entity_edge_id, + "user_id": user_id, + "from_entity_id": dropped_entity_id, + "to_entity_id": kept_entity_id, + "relationship_type": "related_to", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "created_at": base_time + timedelta(minutes=3), + }, + { + "id": kept_entity_edge_id, + "user_id": user_id, + "from_entity_id": kept_entity_id, + "to_entity_id": external_entity_id, + "relationship_type": "depends_on", + "valid_from": kept_edge_valid_from, + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "created_at": base_time + timedelta(minutes=5), + }, + { + "id": ignored_entity_edge_id, + "user_id": user_id, + "from_entity_id": dropped_entity_id, + "to_entity_id": external_entity_id, + "relationship_type": "ignored", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(dropped_by_memory_limit_id)], + "created_at": base_time + timedelta(minutes=6), + }, + ], + limits=ContextCompilerLimits( + max_sessions=1, + max_events=1, + max_memories=1, + max_entities=1, + max_entity_edges=1, + ), + ) + + trace_payloads = [trace_event.payload for trace_event in compiler_run.trace_events] + + assert {"entity_type": "session", "entity_id": str(kept_session_id), "reason": "within_session_limit", "position": 1} in trace_payloads + assert {"entity_type": "session", "entity_id": str(dropped_session_id), "reason": "session_limit_exceeded", "position": 1} in trace_payloads + assert {"entity_type": "event", "entity_id": str(dropped_by_session_event_id), "reason": "session_not_included", "position": 1} in trace_payloads + assert {"entity_type": "event", "entity_id": str(dropped_by_event_limit_id), "reason": "event_limit_exceeded", "position": 2} in trace_payloads + assert {"entity_type": "event", "entity_id": str(kept_event_id), "reason": "within_event_limit", "position": 3} in trace_payloads + assert { + "entity_type": "memory", + "entity_id": str(kept_memory_id), + "reason": "within_hybrid_memory_limit", + "position": 1, + "memory_key": "user.preference.keep", + "status": "active", + "source_event_ids": [str(kept_event_id)], + "embedding_config_id": None, + "selected_sources": ["symbolic"], + "semantic_score": None, + } in trace_payloads + assert { + "entity_type": "memory", + "entity_id": str(deleted_memory_id), + "reason": "hybrid_memory_deleted", + "position": 1, + "memory_key": "user.preference.deleted", + "status": "deleted", + "source_event_ids": [str(dropped_by_event_limit_id)], + "embedding_config_id": None, + "selected_sources": ["symbolic"], + "semantic_score": None, + } in trace_payloads + assert { + "entity_type": "entity", + "entity_id": str(dropped_entity_id), + "reason": "entity_limit_exceeded", + "position": 1, + "record_entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(dropped_by_memory_limit_id)], + } in trace_payloads + assert { + "entity_type": "entity", + "entity_id": str(kept_entity_id), + "reason": "within_entity_limit", + "position": 2, + "record_entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(kept_memory_id)], + } in trace_payloads + assert { + "entity_type": "entity_edge", + "entity_id": str(dropped_entity_edge_id), + "reason": "entity_edge_limit_exceeded", + "position": 1, + "from_entity_id": str(dropped_entity_id), + "to_entity_id": str(kept_entity_id), + "relationship_type": "related_to", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "attached_included_entity_ids": [str(kept_entity_id)], + } in trace_payloads + assert { + "entity_type": "entity_edge", + "entity_id": str(kept_entity_edge_id), + "reason": "within_entity_edge_limit", + "position": 2, + "from_entity_id": str(kept_entity_id), + "to_entity_id": str(external_entity_id), + "relationship_type": "depends_on", + "valid_from": kept_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "attached_included_entity_ids": [str(kept_entity_id)], + } in trace_payloads + assert all(payload.get("entity_id") != str(ignored_entity_edge_id) for payload in trace_payloads) + assert compiler_run.trace_events[-1].kind == SUMMARY_TRACE_EVENT_KIND + assert compiler_run.context_pack["events"][0]["id"] == str(kept_event_id) + assert compiler_run.context_pack["memories"] == [ + { + "id": str(kept_memory_id), + "memory_key": "user.preference.keep", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(kept_event_id)], + "created_at": (base_time + timedelta(minutes=1)).isoformat(), + "updated_at": (base_time + timedelta(minutes=2)).isoformat(), + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ] + assert compiler_run.context_pack["memory_summary"] == { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert compiler_run.context_pack["entities"] == [ + { + "id": str(kept_entity_id), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(kept_memory_id)], + "created_at": (base_time + timedelta(minutes=4)).isoformat(), + } + ] + assert compiler_run.context_pack["entity_edges"] == [ + { + "id": str(kept_entity_edge_id), + "from_entity_id": str(kept_entity_id), + "to_entity_id": str(external_entity_id), + "relationship_type": "depends_on", + "valid_from": kept_edge_valid_from.isoformat(), + "valid_to": None, + "source_memory_ids": [str(kept_memory_id)], + "created_at": (base_time + timedelta(minutes=5)).isoformat(), + } + ] + assert compiler_run.context_pack["entity_edge_summary"] == { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + } + assert compiler_run.trace_events[-1].payload["included_entity_edge_count"] == 1 + assert compiler_run.trace_events[-1].payload["excluded_entity_edge_limit_count"] == 1 + assert compiler_run.trace_events[-1].payload["hybrid_memory_requested"] is False + assert compiler_run.trace_events[-1].payload["hybrid_memory_candidate_count"] == 2 + assert compiler_run.trace_events[-1].payload["hybrid_memory_merged_candidate_count"] == 1 + assert compiler_run.trace_events[-1].payload["hybrid_memory_deduplicated_count"] == 0 + + +class SemanticCompileStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 12, 0, tzinfo=UTC) + self.config_id = uuid4() + self.memory_ids = [uuid4(), uuid4(), uuid4()] + self.event_ids = [uuid4(), uuid4(), uuid4()] + + def get_embedding_config_optional(self, embedding_config_id): + if embedding_config_id != self.config_id: + return None + return {"id": self.config_id, "dimensions": 3} + + def retrieve_semantic_memory_matches(self, *, embedding_config_id, query_vector, limit): + assert embedding_config_id == self.config_id + assert query_vector == [1.0, 0.0, 0.0] + assert limit > 1000 + return [ + { + "id": self.memory_ids[0], + "user_id": uuid4(), + "memory_key": "user.preference.breakfast", + "value": {"likes": "porridge"}, + "status": "active", + "source_event_ids": [str(self.event_ids[0])], + "created_at": self.base_time, + "updated_at": self.base_time, + "deleted_at": None, + "score": 1.0, + }, + { + "id": self.memory_ids[1], + "user_id": uuid4(), + "memory_key": "user.preference.lunch", + "value": {"likes": "ramen"}, + "status": "active", + "source_event_ids": [str(self.event_ids[1])], + "created_at": self.base_time + timedelta(minutes=1), + "updated_at": self.base_time + timedelta(minutes=1), + "deleted_at": None, + "score": 1.0, + }, + ] + + def list_memory_embeddings_for_config(self, embedding_config_id): + assert embedding_config_id == self.config_id + return [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": self.memory_ids[2], + "embedding_config_id": self.config_id, + "dimensions": 3, + "vector": [1.0, 0.0, 0.0], + "created_at": self.base_time + timedelta(minutes=2), + "updated_at": self.base_time + timedelta(minutes=2), + } + ] + + +def test_compile_memory_section_orders_limits_and_excludes_deleted() -> None: + store = SemanticCompileStoreStub() + deleted_memory = { + "id": store.memory_ids[2], + "user_id": uuid4(), + "memory_key": "user.preference.deleted", + "value": {"likes": "hidden"}, + "status": "deleted", + "source_event_ids": [str(store.event_ids[2])], + "created_at": store.base_time + timedelta(minutes=2), + "updated_at": store.base_time + timedelta(minutes=3), + "deleted_at": store.base_time + timedelta(minutes=3), + } + + memory_section = _compile_memory_section( + store, # type: ignore[arg-type] + memories=[deleted_memory], + limits=ContextCompilerLimits(max_memories=1), + semantic_retrieval=CompileContextSemanticRetrievalInput( + embedding_config_id=store.config_id, + query_vector=(1.0, 0.0, 0.0), + limit=1, + ), + ) + + assert memory_section.items == [ + { + "id": str(store.memory_ids[0]), + "memory_key": "user.preference.breakfast", + "value": {"likes": "porridge"}, + "status": "active", + "source_event_ids": [str(store.event_ids[0])], + "created_at": store.base_time.isoformat(), + "updated_at": store.base_time.isoformat(), + "source_provenance": { + "sources": ["semantic"], + "semantic_score": 1.0, + }, + } + ] + assert memory_section.summary == { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(store.config_id), + "query_vector_dimensions": 3, + "semantic_limit": 1, + "symbolic_selected_count": 0, + "semantic_selected_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 1, + "included_dual_source_count": 0, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert [decision.reason for decision in memory_section.decisions] == [ + "within_hybrid_memory_limit", + "hybrid_memory_deleted", + ] + assert memory_section.decisions[-1].metadata["selected_sources"] == ["symbolic"] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..6d10d22 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pytest + +from alicebot_api.config import Settings + + +def test_settings_defaults(monkeypatch): + for key in ( + "APP_ENV", + "APP_HOST", + "APP_PORT", + "DATABASE_URL", + "DATABASE_ADMIN_URL", + "REDIS_URL", + "S3_ENDPOINT_URL", + "S3_ACCESS_KEY", + "S3_SECRET_KEY", + "S3_BUCKET", + "HEALTHCHECK_TIMEOUT_SECONDS", + "MODEL_PROVIDER", + "MODEL_BASE_URL", + "MODEL_NAME", + "MODEL_API_KEY", + "MODEL_TIMEOUT_SECONDS", + "TASK_WORKSPACE_ROOT", + ): + monkeypatch.delenv(key, raising=False) + + settings = Settings.from_env() + + assert settings.app_env == "development" + assert settings.app_port == 8000 + assert settings.database_url.endswith("/alicebot") + assert settings.database_admin_url.endswith("/alicebot") + assert settings.s3_bucket == "alicebot-local" + assert settings.model_provider == "openai_responses" + assert settings.model_base_url == "https://api.openai.com/v1" + assert settings.model_name == "gpt-5-mini" + assert settings.model_timeout_seconds == 30 + assert settings.task_workspace_root == "/tmp/alicebot/task-workspaces" + + +def test_settings_honor_environment_overrides(monkeypatch): + monkeypatch.setenv("APP_ENV", "test") + monkeypatch.setenv("APP_PORT", "8100") + monkeypatch.setenv("DATABASE_URL", "postgresql://app:secret@localhost:5432/custom") + monkeypatch.setenv("HEALTHCHECK_TIMEOUT_SECONDS", "9") + monkeypatch.setenv("MODEL_BASE_URL", "https://example.test/v1") + monkeypatch.setenv("MODEL_NAME", "gpt-5") + monkeypatch.setenv("MODEL_TIMEOUT_SECONDS", "45") + monkeypatch.setenv("TASK_WORKSPACE_ROOT", "/tmp/custom-workspaces") + + settings = Settings.from_env() + + assert settings.app_env == "test" + assert settings.app_port == 8100 + assert settings.database_url == "postgresql://app:secret@localhost:5432/custom" + assert settings.healthcheck_timeout_seconds == 9 + assert settings.model_base_url == "https://example.test/v1" + assert settings.model_name == "gpt-5" + assert settings.model_timeout_seconds == 45 + assert settings.task_workspace_root == "/tmp/custom-workspaces" + + +def test_settings_can_be_loaded_from_an_explicit_environment_mapping() -> None: + settings = Settings.from_env( + { + "APP_ENV": "test", + "APP_PORT": "8200", + "DATABASE_URL": "postgresql://app:secret@localhost:5432/mapped", + "MODEL_PROVIDER": "openai_responses", + "MODEL_NAME": "gpt-5-mini", + "TASK_WORKSPACE_ROOT": "/tmp/mapped-workspaces", + } + ) + + assert settings.app_env == "test" + assert settings.app_port == 8200 + assert settings.database_url == "postgresql://app:secret@localhost:5432/mapped" + assert settings.model_provider == "openai_responses" + assert settings.model_name == "gpt-5-mini" + assert settings.task_workspace_root == "/tmp/mapped-workspaces" + + +def test_settings_raise_clear_error_for_invalid_integer_values() -> None: + with pytest.raises(ValueError, match="APP_PORT must be an integer"): + Settings.from_env({"APP_PORT": "not-an-integer"}) + + with pytest.raises(ValueError, match="MODEL_TIMEOUT_SECONDS must be an integer"): + Settings.from_env({"MODEL_TIMEOUT_SECONDS": "not-an-integer"}) diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py new file mode 100644 index 0000000..95559eb --- /dev/null +++ b/tests/unit/test_db.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from collections.abc import Iterator +from uuid import uuid4 + +import psycopg + +from alicebot_api import db + + +class RecordingCursor: + def __init__(self) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> tuple[int]: + return (1,) + + +class TransactionContext: + def __init__(self) -> None: + self.entered = False + self.exited = False + + def __enter__(self) -> None: + self.entered = True + return None + + def __exit__(self, exc_type, exc, tb) -> None: + self.exited = True + return None + + +class RecordingConnection: + def __init__(self) -> None: + self.cursor_instance = RecordingCursor() + self.transaction_context = TransactionContext() + + def __enter__(self) -> "RecordingConnection": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + def transaction(self) -> TransactionContext: + return self.transaction_context + + +def test_ping_database_returns_true_when_select_succeeds(monkeypatch) -> None: + connection = RecordingConnection() + captured: dict[str, object] = {} + + def fake_connect(database_url: str, **kwargs: object) -> RecordingConnection: + captured["database_url"] = database_url + captured["kwargs"] = kwargs + return connection + + monkeypatch.setattr(db.psycopg, "connect", fake_connect) + + assert db.ping_database("postgresql://example", timeout_seconds=3) is True + assert captured["database_url"] == "postgresql://example" + assert captured["kwargs"] == {"connect_timeout": 3} + assert connection.cursor_instance.executed == [("SELECT 1", None)] + + +def test_ping_database_returns_false_on_psycopg_error(monkeypatch) -> None: + def fake_connect(_database_url: str, **_kwargs: object) -> RecordingConnection: + raise psycopg.Error("boom") + + monkeypatch.setattr(db.psycopg, "connect", fake_connect) + + assert db.ping_database("postgresql://example", timeout_seconds=3) is False + + +def test_set_current_user_sets_database_context() -> None: + connection = RecordingConnection() + user_id = uuid4() + + db.set_current_user(connection, user_id) + + assert connection.cursor_instance.executed == [ + ("SELECT set_config('app.current_user_id', %s, true)", (str(user_id),)), + ] + + +def test_user_connection_sets_current_user_inside_transaction(monkeypatch) -> None: + connection = RecordingConnection() + user_id = uuid4() + captured: dict[str, object] = {} + set_current_user_calls: list[tuple[RecordingConnection, object]] = [] + + def fake_connect(database_url: str, **kwargs: object) -> RecordingConnection: + captured["database_url"] = database_url + captured["kwargs"] = kwargs + return connection + + def fake_set_current_user(conn: RecordingConnection, current_user_id: object) -> None: + set_current_user_calls.append((conn, current_user_id)) + + monkeypatch.setattr(db.psycopg, "connect", fake_connect) + monkeypatch.setattr(db, "set_current_user", fake_set_current_user) + + with db.user_connection("postgresql://example", user_id) as conn: + assert conn is connection + assert connection.transaction_context.entered is True + assert connection.transaction_context.exited is False + + assert captured["database_url"] == "postgresql://example" + assert captured["kwargs"] == {"row_factory": db.dict_row} + assert set_current_user_calls == [(connection, user_id)] + assert connection.transaction_context.exited is True diff --git a/tests/unit/test_embedding.py b/tests/unit/test_embedding.py new file mode 100644 index 0000000..44401d4 --- /dev/null +++ b/tests/unit/test_embedding.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import psycopg +import pytest + +from alicebot_api.contracts import EmbeddingConfigCreateInput, MemoryEmbeddingUpsertInput +from alicebot_api.embedding import ( + EmbeddingConfigValidationError, + MemoryEmbeddingNotFoundError, + MemoryEmbeddingValidationError, + create_embedding_config_record, + get_memory_embedding_record, + list_embedding_config_records, + list_memory_embedding_records, + upsert_memory_embedding_record, +) + + +class EmbeddingStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.memories: dict[UUID, dict[str, object]] = {} + self.configs: list[dict[str, object]] = [] + self.config_by_id: dict[UUID, dict[str, object]] = {} + self.embeddings: list[dict[str, object]] = [] + self.embedding_by_id: dict[UUID, dict[str, object]] = {} + + def create_embedding_config( + self, + *, + provider: str, + model: str, + version: str, + dimensions: int, + status: str, + metadata: dict[str, object], + ) -> dict[str, object]: + config_id = uuid4() + record = { + "id": config_id, + "user_id": uuid4(), + "provider": provider, + "model": model, + "version": version, + "dimensions": dimensions, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.configs)), + } + self.configs.append(record) + self.config_by_id[config_id] = record + return record + + def list_embedding_configs(self) -> list[dict[str, object]]: + return list(self.configs) + + def get_embedding_config_optional(self, embedding_config_id: UUID) -> dict[str, object] | None: + return self.config_by_id.get(embedding_config_id) + + def get_embedding_config_by_identity_optional( + self, + *, + provider: str, + model: str, + version: str, + ) -> dict[str, object] | None: + for config in self.configs: + if ( + config["provider"] == provider + and config["model"] == model + and config["version"] == version + ): + return config + return None + + def get_memory_optional(self, memory_id: UUID) -> dict[str, object] | None: + return self.memories.get(memory_id) + + def get_memory_embedding_by_memory_and_config_optional( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + ) -> dict[str, object] | None: + for embedding in self.embeddings: + if ( + embedding["memory_id"] == memory_id + and embedding["embedding_config_id"] == embedding_config_id + ): + return embedding + return None + + def create_memory_embedding( + self, + *, + memory_id: UUID, + embedding_config_id: UUID, + dimensions: int, + vector: list[float], + ) -> dict[str, object]: + embedding_id = uuid4() + record = { + "id": embedding_id, + "user_id": uuid4(), + "memory_id": memory_id, + "embedding_config_id": embedding_config_id, + "dimensions": dimensions, + "vector": vector, + "created_at": self.base_time + timedelta(minutes=len(self.embeddings)), + "updated_at": self.base_time + timedelta(minutes=len(self.embeddings)), + } + self.embeddings.append(record) + self.embedding_by_id[embedding_id] = record + return record + + def update_memory_embedding( + self, + *, + memory_embedding_id: UUID, + dimensions: int, + vector: list[float], + ) -> dict[str, object]: + record = self.embedding_by_id[memory_embedding_id] + updated = { + **record, + "dimensions": dimensions, + "vector": vector, + "updated_at": self.base_time + timedelta(minutes=10), + } + self.embedding_by_id[memory_embedding_id] = updated + for index, existing in enumerate(self.embeddings): + if existing["id"] == memory_embedding_id: + self.embeddings[index] = updated + return updated + + def get_memory_embedding_optional(self, memory_embedding_id: UUID) -> dict[str, object] | None: + return self.embedding_by_id.get(memory_embedding_id) + + def list_memory_embeddings_for_memory(self, memory_id: UUID) -> list[dict[str, object]]: + return [embedding for embedding in self.embeddings if embedding["memory_id"] == memory_id] + + +def seed_memory(store: EmbeddingStoreStub) -> UUID: + memory_id = uuid4() + store.memories[memory_id] = { + "id": memory_id, + "memory_key": "user.preference.coffee", + } + return memory_id + + +def seed_config(store: EmbeddingStoreStub, *, dimensions: int = 3) -> UUID: + created = store.create_embedding_config( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=dimensions, + status="active", + metadata={"task": "memory_retrieval"}, + ) + return created["id"] # type: ignore[return-value] + + +def test_create_and_list_embedding_configs_return_deterministic_shape() -> None: + store = EmbeddingStoreStub() + first = create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-small", + version="2026-03-11", + dimensions=1536, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + second = create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="deprecated", + metadata={"task": "memory_retrieval"}, + ), + ) + + payload = list_embedding_config_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + ) + + assert first["embedding_config"]["provider"] == "openai" + assert second["embedding_config"]["status"] == "deprecated" + assert payload == { + "items": [ + first["embedding_config"], + second["embedding_config"], + ], + "summary": { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_create_embedding_config_rejects_duplicate_provider_model_version() -> None: + store = EmbeddingStoreStub() + create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + + with pytest.raises( + EmbeddingConfigValidationError, + match="embedding config already exists for provider/model/version under the user scope", + ): + create_embedding_config_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + + +def test_create_embedding_config_translates_database_unique_violation_into_validation_error() -> None: + class DuplicateConfigStoreStub(EmbeddingStoreStub): + def get_embedding_config_by_identity_optional( + self, + *, + provider: str, + model: str, + version: str, + ) -> dict[str, object] | None: + return None + + def create_embedding_config( + self, + *, + provider: str, + model: str, + version: str, + dimensions: int, + status: str, + metadata: dict[str, object], + ) -> dict[str, object]: + raise psycopg.errors.UniqueViolation("duplicate key value violates unique constraint") + + with pytest.raises( + EmbeddingConfigValidationError, + match="embedding config already exists for provider/model/version under the user scope", + ): + create_embedding_config_record( + DuplicateConfigStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + config=EmbeddingConfigCreateInput( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3072, + status="active", + metadata={"task": "memory_retrieval"}, + ), + ) + + +def test_upsert_memory_embedding_creates_then_updates_existing_record() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + config_id = seed_config(store, dimensions=3) + + created = upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + updated = upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.3, 0.2, 0.1), + ), + ) + + assert created["write_mode"] == "created" + assert created["embedding"]["vector"] == [0.1, 0.2, 0.3] + assert updated["write_mode"] == "updated" + assert updated["embedding"]["id"] == created["embedding"]["id"] + assert updated["embedding"]["vector"] == [0.3, 0.2, 0.1] + + +def test_upsert_memory_embedding_rejects_missing_memory() -> None: + store = EmbeddingStoreStub() + config_id = seed_config(store) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="memory_id must reference an existing memory owned by the user", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=uuid4(), + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + + +def test_upsert_memory_embedding_rejects_missing_embedding_config() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="embedding_config_id must reference an existing embedding config owned by the user", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=uuid4(), + vector=(0.1, 0.2, 0.3), + ), + ) + + +def test_upsert_memory_embedding_rejects_dimension_mismatch_and_non_finite_values() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + config_id = seed_config(store, dimensions=2) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="vector length must match embedding config dimensions", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + + with pytest.raises( + MemoryEmbeddingValidationError, + match="vector must contain only finite numeric values", + ): + upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, float("inf")), + ), + ) + + +def test_memory_embedding_reads_return_deterministic_shape_and_not_found() -> None: + store = EmbeddingStoreStub() + memory_id = seed_memory(store) + config_id = seed_config(store, dimensions=3) + created = upsert_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=MemoryEmbeddingUpsertInput( + memory_id=memory_id, + embedding_config_id=config_id, + vector=(0.1, 0.2, 0.3), + ), + ) + + listed = list_memory_embedding_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=memory_id, + ) + detail = get_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_embedding_id=UUID(created["embedding"]["id"]), + ) + + assert listed == { + "items": [created["embedding"]], + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert detail == {"embedding": created["embedding"]} + + with pytest.raises(MemoryEmbeddingNotFoundError, match="memory .* was not found"): + list_memory_embedding_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=uuid4(), + ) + + with pytest.raises(MemoryEmbeddingNotFoundError, match="memory embedding .* was not found"): + get_memory_embedding_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_embedding_id=uuid4(), + ) diff --git a/tests/unit/test_embedding_store.py b/tests/unit/test_embedding_store.py new file mode 100644 index 0000000..5a2b695 --- /dev/null +++ b/tests/unit/test_embedding_store.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__( + self, + fetchone_results: list[dict[str, Any]], + fetchall_results: list[list[dict[str, Any]]] | None = None, + ) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_results = list(fetchall_results or []) + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + if not self.fetchall_results: + return [] + return self.fetchall_results.pop(0) + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_embedding_store_methods_use_expected_queries_and_serialization() -> None: + config_id = uuid4() + memory_id = uuid4() + embedding_id = uuid4() + created_at = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + updated_at = datetime(2026, 3, 12, 9, 5, tzinfo=UTC) + cursor = RecordingCursor( + fetchone_results=[ + { + "id": config_id, + "user_id": uuid4(), + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + "created_at": created_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "memory_id": memory_id, + "embedding_config_id": config_id, + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": created_at, + "updated_at": created_at, + }, + { + "id": embedding_id, + "user_id": uuid4(), + "memory_id": memory_id, + "embedding_config_id": config_id, + "dimensions": 3, + "vector": [0.3, 0.2, 0.1], + "created_at": created_at, + "updated_at": updated_at, + }, + ], + fetchall_results=[ + [ + { + "id": config_id, + "provider": "openai", + "version": "2026-03-12", + } + ], + [ + { + "id": embedding_id, + "memory_id": memory_id, + "embedding_config_id": config_id, + } + ], + [ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(uuid4())], + "created_at": created_at, + "updated_at": updated_at, + "deleted_at": None, + "score": 1.0, + } + ], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created_config = store.create_embedding_config( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + status="active", + metadata={"task": "memory_retrieval"}, + ) + listed_configs = store.list_embedding_configs() + created_embedding = store.create_memory_embedding( + memory_id=memory_id, + embedding_config_id=config_id, + dimensions=3, + vector=[0.1, 0.2, 0.3], + ) + updated_embedding = store.update_memory_embedding( + memory_embedding_id=embedding_id, + dimensions=3, + vector=[0.3, 0.2, 0.1], + ) + listed_embeddings = store.list_memory_embeddings_for_memory(memory_id) + retrieval_matches = store.retrieve_semantic_memory_matches( + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=5, + ) + + assert created_config["id"] == config_id + assert listed_configs == [{"id": config_id, "provider": "openai", "version": "2026-03-12"}] + assert created_embedding["id"] == embedding_id + assert updated_embedding["updated_at"] == updated_at + assert listed_embeddings == [ + {"id": embedding_id, "memory_id": memory_id, "embedding_config_id": config_id} + ] + assert len(retrieval_matches) == 1 + assert retrieval_matches[0]["id"] == memory_id + assert retrieval_matches[0]["memory_key"] == "user.preference.coffee" + assert retrieval_matches[0]["status"] == "active" + assert retrieval_matches[0]["score"] == 1.0 + + create_config_query, create_config_params = cursor.executed[0] + assert "INSERT INTO embedding_configs" in create_config_query + assert create_config_params is not None + assert create_config_params[:5] == ( + "openai", + "text-embedding-3-large", + "2026-03-12", + 3, + "active", + ) + assert isinstance(create_config_params[5], Jsonb) + assert create_config_params[5].obj == {"task": "memory_retrieval"} + + list_config_query, list_config_params = cursor.executed[1] + assert "FROM embedding_configs" in list_config_query + assert "ORDER BY created_at ASC, id ASC" in list_config_query + assert list_config_params is None + + create_embedding_query, create_embedding_params = cursor.executed[2] + assert "INSERT INTO memory_embeddings" in create_embedding_query + assert create_embedding_params is not None + assert create_embedding_params[:3] == (memory_id, config_id, 3) + assert isinstance(create_embedding_params[3], Jsonb) + assert create_embedding_params[3].obj == [0.1, 0.2, 0.3] + + update_embedding_query, update_embedding_params = cursor.executed[3] + assert "UPDATE memory_embeddings" in update_embedding_query + assert update_embedding_params is not None + assert update_embedding_params[0] == 3 + assert isinstance(update_embedding_params[1], Jsonb) + assert update_embedding_params[1].obj == [0.3, 0.2, 0.1] + assert update_embedding_params[2] == embedding_id + + list_embedding_query, list_embedding_params = cursor.executed[4] + assert "FROM memory_embeddings" in list_embedding_query + assert "ORDER BY created_at ASC, id ASC" in list_embedding_query + assert list_embedding_params == (memory_id,) + + retrieval_query, retrieval_params = cursor.executed[5] + assert "replace(memory_embeddings.vector::text, ' ', '')::vector <=> %s::vector" in retrieval_query + assert "JOIN memories" in retrieval_query + assert "memories.status = 'active'" in retrieval_query + assert "ORDER BY score DESC, memories.created_at ASC, memories.id ASC" in retrieval_query + assert retrieval_params == ("[0.1,0.2,0.3]", config_id, 3, 5) + + +def test_embedding_store_optional_reads_return_none_when_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + assert store.get_embedding_config_optional(uuid4()) is None + assert store.get_embedding_config_by_identity_optional( + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + ) is None + assert store.get_memory_embedding_optional(uuid4()) is None + assert store.get_memory_embedding_by_memory_and_config_optional( + memory_id=uuid4(), + embedding_config_id=uuid4(), + ) is None diff --git a/tests/unit/test_entity.py b/tests/unit/test_entity.py new file mode 100644 index 0000000..c417b55 --- /dev/null +++ b/tests/unit/test_entity.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import EntityCreateInput +from alicebot_api.entity import ( + EntityNotFoundError, + EntityValidationError, + create_entity_record, + get_entity_record, + list_entity_records, +) + + +class EntityStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.memories: dict[UUID, dict[str, object]] = {} + self.created_entities: list[dict[str, object]] = [] + self.entity_by_id: dict[UUID, dict[str, object]] = {} + + def list_memories_by_ids(self, memory_ids: list[UUID]) -> list[dict[str, object]]: + return [self.memories[memory_id] for memory_id in memory_ids if memory_id in self.memories] + + def create_entity( + self, + *, + entity_type: str, + name: str, + source_memory_ids: list[str], + ) -> dict[str, object]: + entity_id = uuid4() + entity = { + "id": entity_id, + "user_id": uuid4(), + "entity_type": entity_type, + "name": name, + "source_memory_ids": source_memory_ids, + "created_at": self.base_time + timedelta(minutes=len(self.created_entities)), + } + self.created_entities.append(entity) + self.entity_by_id[entity_id] = entity + return entity + + def list_entities(self) -> list[dict[str, object]]: + return list(self.created_entities) + + def get_entity_optional(self, entity_id: UUID) -> dict[str, object] | None: + return self.entity_by_id.get(entity_id) + + +def seed_memory(store: EntityStoreStub) -> UUID: + memory_id = uuid4() + store.memories[memory_id] = { + "id": memory_id, + "memory_key": "user.preference.coffee", + } + return memory_id + + +def test_create_entity_record_rejects_empty_source_memory_ids() -> None: + store = EntityStoreStub() + + with pytest.raises( + EntityValidationError, + match="source_memory_ids must include at least one existing memory owned by the user", + ): + create_entity_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity=EntityCreateInput( + entity_type="person", + name="Samir", + source_memory_ids=(), + ), + ) + + +def test_create_entity_record_rejects_missing_source_memories() -> None: + store = EntityStoreStub() + + with pytest.raises( + EntityValidationError, + match="source_memory_ids must all reference existing memories owned by the user", + ): + create_entity_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity=EntityCreateInput( + entity_type="project", + name="AliceBot", + source_memory_ids=(uuid4(),), + ), + ) + + +def test_create_entity_record_creates_entity_with_deduped_source_memories() -> None: + store = EntityStoreStub() + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + + payload = create_entity_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity=EntityCreateInput( + entity_type="project", + name="AliceBot", + source_memory_ids=(first_memory_id, first_memory_id, second_memory_id), + ), + ) + + assert payload["entity"]["entity_type"] == "project" + assert payload["entity"]["name"] == "AliceBot" + assert payload["entity"]["source_memory_ids"] == [str(first_memory_id), str(second_memory_id)] + + +def test_list_entity_records_returns_deterministic_shape() -> None: + store = EntityStoreStub() + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + first_entity = store.create_entity( + entity_type="person", + name="Samir", + source_memory_ids=[str(first_memory_id)], + ) + second_entity = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(second_memory_id)], + ) + + payload = list_entity_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + ) + + assert payload == { + "items": [ + { + "id": str(first_entity["id"]), + "entity_type": "person", + "name": "Samir", + "source_memory_ids": [str(first_memory_id)], + "created_at": first_entity["created_at"].isoformat(), + }, + { + "id": str(second_entity["id"]), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(second_memory_id)], + "created_at": second_entity["created_at"].isoformat(), + }, + ], + "summary": { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_get_entity_record_raises_not_found_for_inaccessible_entity() -> None: + with pytest.raises(EntityNotFoundError, match="entity .* was not found"): + get_entity_record( + EntityStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + entity_id=uuid4(), + ) diff --git a/tests/unit/test_entity_edge.py b/tests/unit/test_entity_edge.py new file mode 100644 index 0000000..d30f376 --- /dev/null +++ b/tests/unit/test_entity_edge.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import EntityEdgeCreateInput +from alicebot_api.entity import EntityNotFoundError +from alicebot_api.entity_edge import ( + EntityEdgeValidationError, + create_entity_edge_record, + list_entity_edge_records, +) + + +class EntityEdgeStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.memories: dict[UUID, dict[str, object]] = {} + self.entities: dict[UUID, dict[str, object]] = {} + self.created_edges: list[dict[str, object]] = [] + + def list_memories_by_ids(self, memory_ids: list[UUID]) -> list[dict[str, object]]: + return [self.memories[memory_id] for memory_id in memory_ids if memory_id in self.memories] + + def get_entity_optional(self, entity_id: UUID) -> dict[str, object] | None: + return self.entities.get(entity_id) + + def create_entity_edge( + self, + *, + from_entity_id: UUID, + to_entity_id: UUID, + relationship_type: str, + valid_from: datetime | None, + valid_to: datetime | None, + source_memory_ids: list[str], + ) -> dict[str, object]: + edge_id = uuid4() + edge = { + "id": edge_id, + "user_id": uuid4(), + "from_entity_id": from_entity_id, + "to_entity_id": to_entity_id, + "relationship_type": relationship_type, + "valid_from": valid_from, + "valid_to": valid_to, + "source_memory_ids": source_memory_ids, + "created_at": self.base_time + timedelta(minutes=len(self.created_edges)), + } + self.created_edges.append(edge) + return edge + + def list_entity_edges_for_entity(self, entity_id: UUID) -> list[dict[str, object]]: + return [ + edge + for edge in self.created_edges + if edge["from_entity_id"] == entity_id or edge["to_entity_id"] == entity_id + ] + + +def seed_memory(store: EntityEdgeStoreStub) -> UUID: + memory_id = uuid4() + store.memories[memory_id] = { + "id": memory_id, + "memory_key": "user.project.current", + } + return memory_id + + +def seed_entity(store: EntityEdgeStoreStub) -> UUID: + entity_id = uuid4() + store.entities[entity_id] = { + "id": entity_id, + "name": "entity", + } + return entity_id + + +def test_create_entity_edge_record_rejects_missing_entities() -> None: + store = EntityEdgeStoreStub() + memory_id = seed_memory(store) + + with pytest.raises( + EntityEdgeValidationError, + match="from_entity_id must reference an existing entity owned by the user", + ): + create_entity_edge_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + edge=EntityEdgeCreateInput( + from_entity_id=uuid4(), + to_entity_id=uuid4(), + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=(memory_id,), + ), + ) + + +def test_create_entity_edge_record_rejects_invalid_temporal_range() -> None: + store = EntityEdgeStoreStub() + from_entity_id = seed_entity(store) + to_entity_id = seed_entity(store) + memory_id = seed_memory(store) + + with pytest.raises( + EntityEdgeValidationError, + match="valid_to must be greater than or equal to valid_from", + ): + create_entity_edge_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + edge=EntityEdgeCreateInput( + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from=datetime(2026, 3, 12, 11, 0, tzinfo=UTC), + valid_to=datetime(2026, 3, 12, 10, 0, tzinfo=UTC), + source_memory_ids=(memory_id,), + ), + ) + + +def test_create_entity_edge_record_creates_edge_with_deduped_source_memories() -> None: + store = EntityEdgeStoreStub() + from_entity_id = seed_entity(store) + to_entity_id = seed_entity(store) + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + valid_from = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + valid_to = datetime(2026, 3, 12, 10, 0, tzinfo=UTC) + + payload = create_entity_edge_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + edge=EntityEdgeCreateInput( + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from=valid_from, + valid_to=valid_to, + source_memory_ids=(first_memory_id, first_memory_id, second_memory_id), + ), + ) + + assert payload == { + "edge": { + "id": payload["edge"]["id"], + "from_entity_id": str(from_entity_id), + "to_entity_id": str(to_entity_id), + "relationship_type": "works_on", + "valid_from": valid_from.isoformat(), + "valid_to": valid_to.isoformat(), + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": store.created_edges[0]["created_at"].isoformat(), + } + } + + +def test_list_entity_edge_records_returns_deterministic_shape() -> None: + store = EntityEdgeStoreStub() + primary_entity_id = seed_entity(store) + secondary_entity_id = seed_entity(store) + tertiary_entity_id = seed_entity(store) + first_memory_id = seed_memory(store) + second_memory_id = seed_memory(store) + + first_edge = store.create_entity_edge( + from_entity_id=primary_entity_id, + to_entity_id=secondary_entity_id, + relationship_type="works_on", + valid_from=None, + valid_to=None, + source_memory_ids=[str(first_memory_id)], + ) + second_edge = store.create_entity_edge( + from_entity_id=tertiary_entity_id, + to_entity_id=primary_entity_id, + relationship_type="references", + valid_from=None, + valid_to=None, + source_memory_ids=[str(second_memory_id)], + ) + + payload = list_entity_edge_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + entity_id=primary_entity_id, + ) + + assert payload == { + "items": [ + { + "id": str(first_edge["id"]), + "from_entity_id": str(primary_entity_id), + "to_entity_id": str(secondary_entity_id), + "relationship_type": "works_on", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(first_memory_id)], + "created_at": first_edge["created_at"].isoformat(), + }, + { + "id": str(second_edge["id"]), + "from_entity_id": str(tertiary_entity_id), + "to_entity_id": str(primary_entity_id), + "relationship_type": "references", + "valid_from": None, + "valid_to": None, + "source_memory_ids": [str(second_memory_id)], + "created_at": second_edge["created_at"].isoformat(), + }, + ], + "summary": { + "entity_id": str(primary_entity_id), + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_list_entity_edge_records_raises_not_found_for_inaccessible_entity() -> None: + with pytest.raises(EntityNotFoundError, match="entity .* was not found"): + list_entity_edge_records( + EntityEdgeStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + entity_id=uuid4(), + ) diff --git a/tests/unit/test_entity_store.py b/tests/unit/test_entity_store.py new file mode 100644 index 0000000..7b377ca --- /dev/null +++ b/tests/unit/test_entity_store.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__( + self, + fetchone_results: list[dict[str, Any]], + fetchall_results: list[list[dict[str, Any]]] | None = None, + ) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_results = list(fetchall_results or []) + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + if not self.fetchall_results: + return [] + return self.fetchall_results.pop(0) + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_entity_methods_use_expected_queries_and_deterministic_order() -> None: + entity_id = uuid4() + first_memory_id = uuid4() + second_memory_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": entity_id, + "user_id": uuid4(), + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": "ignored", + } + ], + fetchall_results=[ + [{"id": first_memory_id}, {"id": second_memory_id}], + [{"id": entity_id, "name": "AliceBot"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_entity( + entity_type="project", + name="AliceBot", + source_memory_ids=[str(first_memory_id), str(second_memory_id)], + ) + listed_memories = store.list_memories_by_ids([first_memory_id, second_memory_id]) + listed_entities = store.list_entities() + + assert created["id"] == entity_id + assert listed_memories == [{"id": first_memory_id}, {"id": second_memory_id}] + assert listed_entities == [{"id": entity_id, "name": "AliceBot"}] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO entities" in create_query + assert create_params is not None + assert create_params[0] == "project" + assert create_params[1] == "AliceBot" + assert isinstance(create_params[2], Jsonb) + assert create_params[2].obj == [str(first_memory_id), str(second_memory_id)] + + assert cursor.executed[1] == ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = ANY(%s) + ORDER BY created_at ASC, id ASC + """, + ([first_memory_id, second_memory_id],), + ) + assert cursor.executed[2] == ( + """ + SELECT id, user_id, entity_type, name, source_memory_ids, created_at + FROM entities + ORDER BY created_at ASC, id ASC + """, + None, + ) + + +def test_get_entity_optional_returns_none_when_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + assert store.get_entity_optional(uuid4()) is None + + +def test_entity_edge_methods_use_expected_queries_and_deterministic_order() -> None: + edge_id = uuid4() + from_entity_id = uuid4() + to_entity_id = uuid4() + related_entity_id = uuid4() + source_memory_id = uuid4() + valid_from = datetime(2026, 3, 12, 10, 0, tzinfo=UTC) + cursor = RecordingCursor( + fetchone_results=[ + { + "id": edge_id, + "user_id": uuid4(), + "from_entity_id": from_entity_id, + "to_entity_id": to_entity_id, + "relationship_type": "works_on", + "valid_from": valid_from, + "valid_to": None, + "source_memory_ids": [str(source_memory_id)], + "created_at": "ignored", + } + ], + fetchall_results=[ + [{"id": edge_id, "relationship_type": "works_on"}], + [{"id": edge_id, "relationship_type": "works_on"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_entity_edge( + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from=valid_from, + valid_to=None, + source_memory_ids=[str(source_memory_id)], + ) + listed_edges = store.list_entity_edges_for_entity(from_entity_id) + listed_edges_for_entities = store.list_entity_edges_for_entities([from_entity_id, related_entity_id]) + + assert created["id"] == edge_id + assert listed_edges == [{"id": edge_id, "relationship_type": "works_on"}] + assert listed_edges_for_entities == [{"id": edge_id, "relationship_type": "works_on"}] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO entity_edges" in create_query + assert create_params is not None + assert create_params[0] == from_entity_id + assert create_params[1] == to_entity_id + assert create_params[2] == "works_on" + assert create_params[3] == valid_from + assert create_params[4] is None + assert isinstance(create_params[5], Jsonb) + assert create_params[5].obj == [str(source_memory_id)] + + assert cursor.executed[1] == ( + """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = %s OR to_entity_id = %s + ORDER BY created_at ASC, id ASC + """, + (from_entity_id, from_entity_id), + ) + assert cursor.executed[2] == ( + """ + SELECT + id, + user_id, + from_entity_id, + to_entity_id, + relationship_type, + valid_from, + valid_to, + source_memory_ids, + created_at + FROM entity_edges + WHERE from_entity_id = ANY(%s) OR to_entity_id = ANY(%s) + ORDER BY created_at ASC, id ASC + """, + ([from_entity_id, related_entity_id], [from_entity_id, related_entity_id]), + ) diff --git a/tests/unit/test_env.py b/tests/unit/test_env.py new file mode 100644 index 0000000..b0fdb49 --- /dev/null +++ b/tests/unit/test_env.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from contextlib import contextmanager +import importlib +import sys +from typing import Any + + +MODULE_NAME = "apps.api.alembic.env" + + +class FakeAlembicConfig: + def __init__(self, sqlalchemy_url: str, section: dict[str, Any] | None = None) -> None: + self.config_file_name = "alembic.ini" + self.config_ini_section = "alembic" + self.sqlalchemy_url = sqlalchemy_url + self.section = section or {} + + def get_main_option(self, option: str) -> str: + assert option == "sqlalchemy.url" + return self.sqlalchemy_url + + def get_section(self, section_name: str, default: dict[str, Any] | None = None) -> dict[str, Any]: + assert section_name == self.config_ini_section + base = dict(default or {}) + base.update(self.section) + return base + + +class RecordingConnectable: + def __init__(self) -> None: + self.connection = object() + self.connected = False + + @contextmanager + def connect(self): + self.connected = True + yield self.connection + + +def load_env_module( + monkeypatch, + *, + offline_mode: bool, + admin_url: str | None = None, + app_url: str | None = None, + config_url: str = "postgresql://config-user:secret@localhost:5432/configdb", + config_section: dict[str, Any] | None = None, +) -> tuple[Any, dict[str, Any]]: + records: dict[str, Any] = { + "file_config_calls": [], + "configure_calls": [], + "run_migrations_calls": 0, + "begin_calls": 0, + "engine_calls": [], + } + fake_config = FakeAlembicConfig(config_url, config_section) + connectable = RecordingConnectable() + + if admin_url is None: + monkeypatch.delenv("DATABASE_ADMIN_URL", raising=False) + else: + monkeypatch.setenv("DATABASE_ADMIN_URL", admin_url) + if app_url is None: + monkeypatch.delenv("DATABASE_URL", raising=False) + else: + monkeypatch.setenv("DATABASE_URL", app_url) + + monkeypatch.setattr("logging.config.fileConfig", records["file_config_calls"].append) + monkeypatch.setattr("alembic.context.config", fake_config, raising=False) + monkeypatch.setattr("alembic.context.is_offline_mode", lambda: offline_mode, raising=False) + monkeypatch.setattr( + "alembic.context.configure", + lambda **kwargs: records["configure_calls"].append(kwargs), + raising=False, + ) + + @contextmanager + def begin_transaction(): + records["begin_calls"] += 1 + yield + + monkeypatch.setattr("alembic.context.begin_transaction", begin_transaction, raising=False) + monkeypatch.setattr( + "alembic.context.run_migrations", + lambda: records.__setitem__("run_migrations_calls", records["run_migrations_calls"] + 1), + raising=False, + ) + + def fake_engine_from_config(configuration: dict[str, Any], **kwargs: Any) -> RecordingConnectable: + records["engine_calls"].append((dict(configuration), kwargs)) + return connectable + + monkeypatch.setattr("sqlalchemy.engine_from_config", fake_engine_from_config) + + sys.modules.pop(MODULE_NAME, None) + module = importlib.import_module(MODULE_NAME) + records["connectable"] = connectable + return module, records + + +def test_normalize_sqlalchemy_url_rewrites_postgresql_scheme(monkeypatch) -> None: + module, _records = load_env_module(monkeypatch, offline_mode=True) + + assert module.normalize_sqlalchemy_url("postgresql://user:pw@localhost/db") == ( + "postgresql+psycopg://user:pw@localhost/db" + ) + assert module.normalize_sqlalchemy_url("sqlite:///tmp/test.db") == "sqlite:///tmp/test.db" + + +def test_get_url_prefers_admin_env_then_database_env_then_config(monkeypatch) -> None: + module, _records = load_env_module( + monkeypatch, + offline_mode=True, + admin_url="postgresql://admin-user:secret@localhost:5432/admin_db", + app_url="postgresql://app-user:secret@localhost:5432/app_db", + ) + + assert module.get_url() == "postgresql+psycopg://admin-user:secret@localhost:5432/admin_db" + + module, _records = load_env_module( + monkeypatch, + offline_mode=True, + admin_url=None, + app_url="postgresql://app-user:secret@localhost:5432/app_db", + ) + + assert module.get_url() == "postgresql+psycopg://app-user:secret@localhost:5432/app_db" + + module, _records = load_env_module(monkeypatch, offline_mode=True, admin_url=None, app_url=None) + + assert module.get_url() == "postgresql+psycopg://config-user:secret@localhost:5432/configdb" + + +def test_run_migrations_offline_configures_context_with_normalized_url(monkeypatch) -> None: + _module, records = load_env_module( + monkeypatch, + offline_mode=True, + admin_url="postgresql://admin-user:secret@localhost:5432/admin_db", + ) + + assert records["file_config_calls"] == ["alembic.ini"] + assert records["begin_calls"] == 1 + assert records["run_migrations_calls"] == 1 + assert records["configure_calls"] == [ + { + "url": "postgresql+psycopg://admin-user:secret@localhost:5432/admin_db", + "target_metadata": None, + "literal_binds": True, + "dialect_opts": {"paramstyle": "named"}, + } + ] + assert records["engine_calls"] == [] + + +def test_run_migrations_online_builds_engine_configuration(monkeypatch) -> None: + _module, records = load_env_module( + monkeypatch, + offline_mode=False, + app_url="postgresql://app-user:secret@localhost:5432/app_db", + config_section={"sqlalchemy.echo": "false"}, + ) + + configuration, engine_kwargs = records["engine_calls"][0] + + assert records["file_config_calls"] == ["alembic.ini"] + assert configuration == { + "sqlalchemy.echo": "false", + "sqlalchemy.url": "postgresql+psycopg://app-user:secret@localhost:5432/app_db", + } + assert engine_kwargs["prefix"] == "sqlalchemy." + assert engine_kwargs["poolclass"].__name__ == "NullPool" + assert records["connectable"].connected is True + assert records["configure_calls"] == [ + {"connection": records["connectable"].connection, "target_metadata": None} + ] + assert records["begin_calls"] == 1 + assert records["run_migrations_calls"] == 1 diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py new file mode 100644 index 0000000..7e64d9d --- /dev/null +++ b/tests/unit/test_events.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import pytest + +from alicebot_api.store import AppendOnlyViolation, ContinuityStore + + +def test_event_updates_are_rejected_by_contract(): + store = ContinuityStore(conn=None) # type: ignore[arg-type] + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.update_event("event-id", {"text": "mutated"}) + + +def test_event_deletes_are_rejected_by_contract(): + store = ContinuityStore(conn=None) # type: ignore[arg-type] + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.delete_event("event-id") + diff --git a/tests/unit/test_execution_budget_store.py b/tests/unit/test_execution_budget_store.py new file mode 100644 index 0000000..05e7b2e --- /dev/null +++ b/tests/unit/test_execution_budget_store.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_execution_budget_store_methods_use_expected_queries_and_parameters() -> None: + execution_budget_id = uuid4() + replacement_budget_id = uuid4() + row = { + "id": execution_budget_id, + "tool_key": "proxy.echo", + "domain_hint": "docs", + "max_completed_executions": 2, + "rolling_window_seconds": 3600, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + cursor = RecordingCursor( + fetchone_results=[ + row, + row, + {**row, "status": "inactive"}, + {**row, "status": "superseded", "superseded_by_budget_id": replacement_budget_id}, + ], + fetchall_result=[row], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + rolling_window_seconds=3600, + ) + fetched = store.get_execution_budget_optional(execution_budget_id) + listed = store.list_execution_budgets() + deactivated = store.deactivate_execution_budget_optional(execution_budget_id) + superseded = store.supersede_execution_budget_optional( + execution_budget_id=execution_budget_id, + superseded_by_budget_id=replacement_budget_id, + ) + + assert created["id"] == execution_budget_id + assert fetched is not None + assert fetched["id"] == execution_budget_id + assert listed[0]["id"] == execution_budget_id + assert deactivated is not None + assert deactivated["status"] == "inactive" + assert superseded is not None + assert superseded["superseded_by_budget_id"] == replacement_budget_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO execution_budgets" in create_query + assert create_params == (None, "proxy.echo", "docs", 2, 3600, None) + assert "FROM execution_budgets" in cursor.executed[1][0] + assert "ORDER BY created_at ASC, id ASC" in cursor.executed[2][0] + assert "UPDATE execution_budgets" in cursor.executed[3][0] + assert cursor.executed[4][1] == (replacement_budget_id, execution_budget_id) diff --git a/tests/unit/test_execution_budgets.py b/tests/unit/test_execution_budgets.py new file mode 100644 index 0000000..752f7d1 --- /dev/null +++ b/tests/unit/test_execution_budgets.py @@ -0,0 +1,709 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import ( + ExecutionBudgetCreateInput, + ExecutionBudgetDeactivateInput, + ExecutionBudgetSupersedeInput, +) +from alicebot_api.execution_budgets import ( + ExecutionBudgetLifecycleError, + ExecutionBudgetNotFoundError, + ExecutionBudgetValidationError, + create_execution_budget_record, + deactivate_execution_budget_record, + evaluate_execution_budget, + get_execution_budget_record, + list_execution_budget_records, + supersede_execution_budget_record, +) + + +class _SavepointConnection: + def __init__(self, store: "ExecutionBudgetStoreStub") -> None: + self.store = store + + def transaction(self) -> "_Savepoint": + return _Savepoint(self.store) + + +class _Savepoint: + def __init__(self, store: "ExecutionBudgetStoreStub") -> None: + self.store = store + self.snapshot: list[dict[str, object]] | None = None + + def __enter__(self) -> "_Savepoint": + self.snapshot = [dict(row) for row in self.store.budgets] + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + if exc_type is not None and self.snapshot is not None: + self.store.budgets = [dict(row) for row in self.snapshot] + return False + + +class ExecutionBudgetStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 11, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.budgets: list[dict[str, object]] = [] + self.executions: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + self.fail_next_supersede_update = False + self.conn = _SavepointConnection(self) + + def current_time(self) -> datetime: + return self.base_time + timedelta(minutes=len(self.executions)) + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Budget lifecycle thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + def create_execution_budget( + self, + *, + budget_id: UUID | None = None, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, + supersedes_budget_id: UUID | None = None, + ) -> dict[str, object]: + row = { + "id": uuid4() if budget_id is None else budget_id, + "user_id": self.user_id, + "tool_key": tool_key, + "domain_hint": domain_hint, + "max_completed_executions": max_completed_executions, + "rolling_window_seconds": rolling_window_seconds, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": supersedes_budget_id, + "created_at": self.base_time + timedelta(minutes=len(self.budgets)), + } + self.budgets.append(row) + self.budgets.sort(key=lambda item: (item["created_at"], item["id"])) + return row + + def deactivate_execution_budget_optional( + self, + execution_budget_id: UUID, + ) -> dict[str, object] | None: + row = self.get_execution_budget_optional(execution_budget_id) + if row is None or row["status"] != "active": + return None + row["status"] = "inactive" + row["deactivated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return row + + def supersede_execution_budget_optional( + self, + *, + execution_budget_id: UUID, + superseded_by_budget_id: UUID, + ) -> dict[str, object] | None: + if self.fail_next_supersede_update: + self.fail_next_supersede_update = False + return None + row = self.get_execution_budget_optional(execution_budget_id) + if row is None or row["status"] != "active": + return None + row["status"] = "superseded" + row["deactivated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + row["superseded_by_budget_id"] = superseded_by_budget_id + return row + + def get_execution_budget_optional(self, execution_budget_id: UUID) -> dict[str, object] | None: + return next((row for row in self.budgets if row["id"] == execution_budget_id), None) + + def list_execution_budgets(self) -> list[dict[str, object]]: + return list(self.budgets) + + def seed_execution( + self, + *, + tool_key: str, + domain_hint: str | None, + status: str, + offset_minutes: int, + ) -> None: + tool_id = uuid4() + self.executions.append( + { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": uuid4(), + "thread_id": self.thread_id, + "tool_id": tool_id, + "trace_id": uuid4(), + "request_event_id": None, + "result_event_id": None, + "status": status, + "handler_key": None if status == "blocked" else tool_key, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": domain_hint, + "risk_hint": None, + "attributes": {}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + }, + "result": { + "handler_key": None if status == "blocked" else tool_key, + "status": status, + "output": None, + "reason": None, + }, + "executed_at": self.base_time + timedelta(minutes=offset_minutes), + } + ) + + def list_tool_executions(self) -> list[dict[str, object]]: + return list(self.executions) + + +def test_create_execution_budget_requires_at_least_one_selector() -> None: + store = ExecutionBudgetStoreStub() + + with pytest.raises( + ExecutionBudgetValidationError, + match="execution budget requires at least one selector: tool_key or domain_hint", + ): + create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key=None, + domain_hint=None, + max_completed_executions=1, + ), + ) + + +def test_create_execution_budget_rejects_duplicate_active_scope() -> None: + store = ExecutionBudgetStoreStub() + create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ), + ) + + with pytest.raises( + ExecutionBudgetValidationError, + match="active execution budget already exists for selector scope", + ): + create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ), + ) + + +def test_create_execution_budget_includes_optional_rolling_window_seconds() -> None: + store = ExecutionBudgetStoreStub() + + payload = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ), + ) + + assert payload["execution_budget"]["rolling_window_seconds"] == 3600 + assert store.budgets[0]["rolling_window_seconds"] == 3600 + + +def test_create_list_and_get_execution_budget_records_are_deterministic() -> None: + store = ExecutionBudgetStoreStub() + second = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + ), + ) + first = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key=None, + domain_hint="docs", + max_completed_executions=1, + ), + ) + + listed = list_execution_budget_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_budget_id=UUID(second["execution_budget"]["id"]), + ) + + assert [item["id"] for item in listed["items"]] == [ + second["execution_budget"]["id"], + first["execution_budget"]["id"], + ] + assert listed["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail == {"execution_budget": second["execution_budget"]} + assert detail["execution_budget"]["status"] == "active" + assert detail["execution_budget"]["deactivated_at"] is None + assert detail["execution_budget"]["rolling_window_seconds"] is None + + +def test_deactivate_execution_budget_marks_row_inactive_and_records_trace() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ), + ) + + payload = deactivate_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + ), + ) + + assert payload["execution_budget"]["status"] == "inactive" + assert payload["execution_budget"]["deactivated_at"] == "2026-03-13T12:00:00+00:00" + assert payload["trace"]["trace_event_count"] == 3 + assert store.traces[0]["kind"] == "execution_budget.lifecycle" + assert store.traces[0]["compiler_version"] == "execution_budget_lifecycle_v0" + assert [event["kind"] for event in store.trace_events] == [ + "execution_budget.lifecycle.request", + "execution_budget.lifecycle.state", + "execution_budget.lifecycle.summary", + ] + assert store.trace_events[1]["payload"] == { + "execution_budget_id": created["execution_budget"]["id"], + "requested_action": "deactivate", + "previous_status": "active", + "current_status": "inactive", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "deactivated_at": "2026-03-13T12:00:00+00:00", + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "replacement_budget_id": None, + "replacement_status": None, + "replacement_max_completed_executions": None, + "replacement_rolling_window_seconds": None, + "rejection_reason": None, + } + + +def test_supersede_execution_budget_replaces_active_budget_and_records_trace() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ), + ) + + payload = supersede_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetSupersedeInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + max_completed_executions=3, + ), + ) + + assert payload["superseded_budget"]["status"] == "superseded" + assert payload["replacement_budget"]["status"] == "active" + assert payload["replacement_budget"]["max_completed_executions"] == 3 + assert payload["replacement_budget"]["tool_key"] == "proxy.echo" + assert payload["replacement_budget"]["domain_hint"] == "docs" + assert payload["replacement_budget"]["rolling_window_seconds"] is None + assert payload["replacement_budget"]["supersedes_budget_id"] == created["execution_budget"]["id"] + assert payload["superseded_budget"]["superseded_by_budget_id"] == payload["replacement_budget"]["id"] + assert payload["trace"]["trace_event_count"] == 3 + assert store.trace_events[1]["payload"]["replacement_budget_id"] == payload["replacement_budget"]["id"] + assert store.trace_events[2]["payload"] == { + "execution_budget_id": created["execution_budget"]["id"], + "requested_action": "supersede", + "outcome": "superseded", + "replacement_budget_id": payload["replacement_budget"]["id"], + "active_budget_id": payload["replacement_budget"]["id"], + } + + +def test_lifecycle_rejects_invalid_transition_and_records_trace() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ), + ) + deactivate_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + ), + ) + + with pytest.raises( + ExecutionBudgetLifecycleError, + match=f"execution budget {created['execution_budget']['id']} is inactive and cannot be deactivated", + ): + deactivate_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetDeactivateInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + ), + ) + + assert store.trace_events[-2]["payload"]["current_status"] == "inactive" + assert store.trace_events[-2]["payload"]["rejection_reason"] == ( + f"execution budget {created['execution_budget']['id']} is inactive and cannot be deactivated" + ) + assert store.trace_events[-1]["payload"]["outcome"] == "rejected" + + +def test_supersede_execution_budget_rolls_back_replacement_when_source_update_fails() -> None: + store = ExecutionBudgetStoreStub() + created = create_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetCreateInput( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ), + ) + store.fail_next_supersede_update = True + + with pytest.raises( + ExecutionBudgetLifecycleError, + match=f"execution budget {created['execution_budget']['id']} could not be superseded", + ): + supersede_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ExecutionBudgetSupersedeInput( + thread_id=store.thread_id, + execution_budget_id=UUID(created["execution_budget"]["id"]), + max_completed_executions=3, + ), + ) + + assert len(store.budgets) == 1 + assert store.budgets[0]["id"] == UUID(created["execution_budget"]["id"]) + assert store.budgets[0]["status"] == "active" + assert store.budgets[0]["superseded_by_budget_id"] is None + assert store.trace_events[-1]["payload"]["outcome"] == "rejected" + + +def test_get_execution_budget_record_raises_clear_error_when_missing() -> None: + store = ExecutionBudgetStoreStub() + + with pytest.raises(ExecutionBudgetNotFoundError, match="execution budget .* was not found"): + get_execution_budget_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_budget_id=uuid4(), + ) + + +def test_evaluate_execution_budget_prefers_more_specific_active_match_and_ignores_inactive_rows() -> None: + store = ExecutionBudgetStoreStub() + inactive = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=1, + ) + store.deactivate_execution_budget_optional(inactive["id"]) + store.create_execution_budget(tool_key=None, domain_hint="docs", max_completed_executions=1) + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint="docs", + max_completed_executions=2, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint="docs", status="completed", offset_minutes=0) + store.seed_execution(tool_key="proxy.echo", domain_hint="docs", status="blocked", offset_minutes=1) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": "docs", + "budget_tool_key": "proxy.echo", + "budget_domain_hint": "docs", + "max_completed_executions": 2, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "allow", + "reason": "within_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result is None + + +def test_evaluate_execution_budget_blocks_when_projected_completed_count_would_exceed_limit() -> None: + store = ExecutionBudgetStoreStub() + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=0) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {matched['id']} blocks execution: projected completed executions " + "2 would exceed limit 1" + ), + "budget_decision": decision.record, + } + + +def test_evaluate_execution_budget_uses_only_recent_completed_history_inside_window() -> None: + store = ExecutionBudgetStoreStub() + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=2, + rolling_window_seconds=3600, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=-120) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=-10) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 2, + "rolling_window_seconds": 3600, + "count_scope": "rolling_window", + "window_started_at": "2026-03-13T10:02:00+00:00", + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "allow", + "reason": "within_budget", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result is None + + +def test_evaluate_execution_budget_blocks_when_recent_window_history_exceeds_limit() -> None: + store = ExecutionBudgetStoreStub() + matched = store.create_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=900, + ) + store.seed_execution(tool_key="proxy.echo", domain_hint=None, status="completed", offset_minutes=-5) + + decision = evaluate_execution_budget( + store, # type: ignore[arg-type] + tool={"id": str(uuid4()), "tool_key": "proxy.echo"}, + request={ + "thread_id": str(store.thread_id), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + ) + + assert decision.record == { + "matched_budget_id": str(matched["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": 900, + "count_scope": "rolling_window", + "window_started_at": "2026-03-13T10:46:00+00:00", + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + assert decision.blocked_result == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {matched['id']} blocks execution: projected completed executions " + "2 within rolling window 900 seconds would exceed limit 1" + ), + "budget_decision": decision.record, + } diff --git a/tests/unit/test_execution_budgets_main.py b/tests/unit/test_execution_budgets_main.py new file mode 100644 index 0000000..bf7c1cf --- /dev/null +++ b/tests/unit/test_execution_budgets_main.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.execution_budgets import ( + ExecutionBudgetLifecycleError, + ExecutionBudgetNotFoundError, + ExecutionBudgetValidationError, +) + + +def test_create_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_execution_budget_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "execution_budget": { + "id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": 3600, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_execution_budget_record", fake_create_execution_budget_record) + + response = main_module.create_execution_budget( + main_module.CreateExecutionBudgetRequest( + user_id=user_id, + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + rolling_window_seconds=3600, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body)["execution_budget"]["id"] == "budget-123" + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].tool_key == "proxy.echo" + assert captured["request"].rolling_window_seconds == 3600 + + +def test_create_execution_budget_endpoint_maps_validation_error_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_execution_budget_record(*_args, **_kwargs): + raise ExecutionBudgetValidationError( + "execution budget requires at least one selector: tool_key or domain_hint" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_execution_budget_record", fake_create_execution_budget_record) + + response = main_module.create_execution_budget( + main_module.CreateExecutionBudgetRequest( + user_id=user_id, + tool_key=None, + domain_hint="docs", + max_completed_executions=1, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "execution budget requires at least one selector: tool_key or domain_hint" + } + + +def test_list_execution_budgets_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_execution_budget_records(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "items": [ + { + "id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + ], + "summary": {"total_count": 1, "order": ["created_at_asc", "id_asc"]}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_execution_budget_records", fake_list_execution_budget_records) + + response = main_module.list_execution_budgets(user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["summary"] == { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + } + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + } + + +def test_get_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_execution_budget_record(store, *, user_id, execution_budget_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["execution_budget_id"] = execution_budget_id + return { + "execution_budget": { + "id": str(execution_budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_execution_budget_record", fake_get_execution_budget_record) + + response = main_module.get_execution_budget(execution_budget_id, user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["execution_budget"]["id"] == str(execution_budget_id) + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + "execution_budget_id": execution_budget_id, + } + + +def test_get_execution_budget_endpoint_maps_missing_record_to_404(monkeypatch) -> None: + user_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_execution_budget_record(*_args, **_kwargs): + raise ExecutionBudgetNotFoundError(f"execution budget {execution_budget_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_execution_budget_record", fake_get_execution_budget_record) + + response = main_module.get_execution_budget(execution_budget_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"execution budget {execution_budget_id} was not found" + } + + +def test_deactivate_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_deactivate_execution_budget_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "execution_budget": { + "id": str(execution_budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "status": "inactive", + "deactivated_at": "2026-03-13T12:00:00+00:00", + "superseded_by_budget_id": None, + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "deactivate_execution_budget_record", fake_deactivate_execution_budget_record) + + response = main_module.deactivate_execution_budget( + execution_budget_id, + main_module.DeactivateExecutionBudgetRequest( + user_id=user_id, + thread_id=thread_id, + ), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["execution_budget"]["status"] == "inactive" + assert captured["request"].thread_id == thread_id + assert captured["request"].execution_budget_id == execution_budget_id + + +def test_deactivate_execution_budget_endpoint_maps_lifecycle_error_to_409(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_deactivate_execution_budget_record(*_args, **_kwargs): + raise ExecutionBudgetLifecycleError( + f"execution budget {execution_budget_id} is inactive and cannot be deactivated" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "deactivate_execution_budget_record", fake_deactivate_execution_budget_record) + + response = main_module.deactivate_execution_budget( + execution_budget_id, + main_module.DeactivateExecutionBudgetRequest( + user_id=user_id, + thread_id=thread_id, + ), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"execution budget {execution_budget_id} is inactive and cannot be deactivated" + } + + +def test_supersede_execution_budget_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + execution_budget_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_supersede_execution_budget_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "superseded_budget": { + "id": str(execution_budget_id), + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": 1800, + "status": "superseded", + "deactivated_at": "2026-03-13T12:00:00+00:00", + "superseded_by_budget_id": "budget-456", + "supersedes_budget_id": None, + "created_at": "2026-03-13T11:00:00+00:00", + }, + "replacement_budget": { + "id": "budget-456", + "tool_key": "proxy.echo", + "domain_hint": None, + "max_completed_executions": 3, + "rolling_window_seconds": 1800, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": str(execution_budget_id), + "created_at": "2026-03-13T11:01:00+00:00", + }, + "trace": {"trace_id": "trace-456", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "supersede_execution_budget_record", fake_supersede_execution_budget_record) + + response = main_module.supersede_execution_budget( + execution_budget_id, + main_module.SupersedeExecutionBudgetRequest( + user_id=user_id, + thread_id=thread_id, + max_completed_executions=3, + ), + ) + + assert response.status_code == 200 + body = json.loads(response.body) + assert body["superseded_budget"]["status"] == "superseded" + assert body["replacement_budget"]["status"] == "active" + assert captured["request"].thread_id == thread_id + assert captured["request"].execution_budget_id == execution_budget_id + assert captured["request"].max_completed_executions == 3 diff --git a/tests/unit/test_executions.py b/tests/unit/test_executions.py new file mode 100644 index 0000000..01dac78 --- /dev/null +++ b/tests/unit/test_executions.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.executions import ( + ToolExecutionNotFoundError, + get_tool_execution_record, + list_tool_execution_records, +) + + +class ToolExecutionStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.executions: list[dict[str, object]] = [] + + def seed_execution(self, *, tool_key: str, offset_minutes: int) -> dict[str, object]: + tool_id = uuid4() + execution = { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": uuid4(), + "task_step_id": uuid4(), + "thread_id": self.thread_id, + "tool_id": tool_id, + "trace_id": uuid4(), + "request_event_id": uuid4(), + "result_event_id": uuid4(), + "status": "completed", + "handler_key": tool_key, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": tool_key}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": (self.base_time + timedelta(minutes=offset_minutes)).isoformat(), + }, + "result": { + "handler_key": tool_key, + "status": "completed", + "output": {"mode": "no_side_effect", "tool_key": tool_key}, + "reason": None, + }, + "executed_at": self.base_time + timedelta(minutes=offset_minutes), + } + self.executions.append(execution) + self.executions.sort(key=lambda row: (row["executed_at"], row["id"])) + return execution + + def seed_blocked_execution(self, *, tool_key: str, offset_minutes: int) -> dict[str, object]: + tool_id = uuid4() + execution = { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": uuid4(), + "task_step_id": uuid4(), + "thread_id": self.thread_id, + "tool_id": tool_id, + "trace_id": uuid4(), + "request_event_id": None, + "result_event_id": None, + "status": "blocked", + "handler_key": None, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": tool_key}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + "name": "Missing Proxy", + "description": "Missing handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": (self.base_time + timedelta(minutes=offset_minutes)).isoformat(), + }, + "result": { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": f"tool '{tool_key}' has no registered proxy handler", + }, + "executed_at": self.base_time + timedelta(minutes=offset_minutes), + } + self.executions.append(execution) + self.executions.sort(key=lambda row: (row["executed_at"], row["id"])) + return execution + + def list_tool_executions(self) -> list[dict[str, object]]: + return list(self.executions) + + def get_tool_execution_optional(self, execution_id: UUID) -> dict[str, object] | None: + return next((row for row in self.executions if row["id"] == execution_id), None) + + +def test_list_tool_execution_records_uses_explicit_order_and_summary() -> None: + store = ToolExecutionStoreStub() + first = store.seed_execution(tool_key="proxy.echo", offset_minutes=0) + second = store.seed_execution(tool_key="proxy.echo", offset_minutes=5) + + payload = list_tool_execution_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + + assert [item["id"] for item in payload["items"]] == [str(first["id"]), str(second["id"])] + assert payload["summary"] == { + "total_count": 2, + "order": ["executed_at_asc", "id_asc"], + } + + +def test_get_tool_execution_record_returns_detail_shape() -> None: + store = ToolExecutionStoreStub() + execution = store.seed_execution(tool_key="proxy.echo", offset_minutes=0) + + payload = get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=execution["id"], + ) + + assert payload["execution"]["id"] == str(execution["id"]) + assert payload["execution"]["approval_id"] == str(execution["approval_id"]) + assert payload["execution"]["task_step_id"] == str(execution["task_step_id"]) + assert payload["execution"]["status"] == "completed" + assert payload["execution"]["tool"]["tool_key"] == "proxy.echo" + assert payload["execution"]["result"]["output"] == { + "mode": "no_side_effect", + "tool_key": "proxy.echo", + } + + +def test_get_tool_execution_record_preserves_blocked_attempt_shape() -> None: + store = ToolExecutionStoreStub() + execution = store.seed_blocked_execution(tool_key="proxy.missing", offset_minutes=0) + + payload = get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=execution["id"], + ) + + assert payload["execution"]["status"] == "blocked" + assert payload["execution"]["handler_key"] is None + assert payload["execution"]["request_event_id"] is None + assert payload["execution"]["result_event_id"] is None + assert payload["execution"]["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + } + + +def test_get_tool_execution_record_preserves_budget_blocked_attempt_shape() -> None: + store = ToolExecutionStoreStub() + execution = store.seed_blocked_execution(tool_key="proxy.echo", offset_minutes=0) + execution["result"] = { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "execution budget budget-123 blocks execution: projected completed executions 2 would exceed limit 1", + "budget_decision": { + "matched_budget_id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + } + + payload = get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=execution["id"], + ) + + assert payload["execution"]["result"]["budget_decision"] == { + "matched_budget_id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + } + + +def test_get_tool_execution_record_raises_clear_error_when_missing() -> None: + store = ToolExecutionStoreStub() + + with pytest.raises(ToolExecutionNotFoundError, match="tool execution .* was not found"): + get_tool_execution_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + execution_id=uuid4(), + ) diff --git a/tests/unit/test_executions_main.py b/tests/unit/test_executions_main.py new file mode 100644 index 0000000..9070c0e --- /dev/null +++ b/tests/unit/test_executions_main.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.executions import ToolExecutionNotFoundError + + +def test_list_tool_executions_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_tool_execution_records(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "items": [ + { + "id": "execution-123", + "approval_id": "approval-123", + "task_step_id": "task-step-123", + "thread_id": "thread-123", + "tool_id": "tool-123", + "trace_id": "trace-123", + "request_event_id": "event-1", + "result_event_id": "event-2", + "status": "completed", + "handler_key": "proxy.echo", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + "executed_at": "2026-03-13T10:00:00+00:00", + } + ], + "summary": {"total_count": 1, "order": ["executed_at_asc", "id_asc"]}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_tool_execution_records", fake_list_tool_execution_records) + + response = main_module.list_tool_executions(user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["summary"] == { + "total_count": 1, + "order": ["executed_at_asc", "id_asc"], + } + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + } + + +def test_get_tool_execution_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + execution_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_tool_execution_record(store, *, user_id, execution_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["execution_id"] = execution_id + return { + "execution": { + "id": str(execution_id), + "approval_id": "approval-123", + "task_step_id": "task-step-123", + "thread_id": "thread-123", + "tool_id": "tool-123", + "trace_id": "trace-123", + "request_event_id": "event-1", + "result_event_id": "event-2", + "status": "completed", + "handler_key": "proxy.echo", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + "executed_at": "2026-03-13T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_tool_execution_record", fake_get_tool_execution_record) + + response = main_module.get_tool_execution(execution_id, user_id) + + assert response.status_code == 200 + assert json.loads(response.body)["execution"]["id"] == str(execution_id) + assert captured == { + "database_url": "postgresql://app", + "current_user_id": user_id, + "store_type": "ContinuityStore", + "user_id": user_id, + "execution_id": execution_id, + } + + +def test_get_tool_execution_endpoint_maps_missing_record_to_404(monkeypatch) -> None: + user_id = uuid4() + execution_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_tool_execution_record(*_args, **_kwargs): + raise ToolExecutionNotFoundError(f"tool execution {execution_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_tool_execution_record", fake_get_tool_execution_record) + + response = main_module.get_tool_execution(execution_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"tool execution {execution_id} was not found" + } diff --git a/tests/unit/test_explicit_preferences.py b/tests/unit/test_explicit_preferences.py new file mode 100644 index 0000000..7fb5a31 --- /dev/null +++ b/tests/unit/test_explicit_preferences.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import AdmissionDecisionOutput, ExplicitPreferenceExtractionRequestInput +from alicebot_api.explicit_preferences import ( + ExplicitPreferenceExtractionValidationError, + _build_memory_key, + extract_and_admit_explicit_preferences, + extract_explicit_preference_candidates, +) + + +class ExplicitPreferenceStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.events: dict[UUID, dict[str, object]] = {} + + def list_events_by_ids(self, event_ids: list[UUID]) -> list[dict[str, object]]: + return [self.events[event_id] for event_id in event_ids if event_id in self.events] + + +def seed_event( + store: ExplicitPreferenceStoreStub, + *, + kind: str = "message.user", + text: str = "I like black coffee.", +) -> UUID: + event_id = uuid4() + store.events[event_id] = { + "id": event_id, + "sequence_no": 1, + "kind": kind, + "payload": {"text": text}, + "created_at": store.base_time, + } + return event_id + + +def test_extract_explicit_preference_candidates_returns_supported_candidate_shape() -> None: + event_id = UUID("11111111-1111-1111-1111-111111111111") + memory_key = _build_memory_key("black coffee") + + payload = extract_explicit_preference_candidates( + source_event_id=event_id, + text="I like black coffee.", + ) + + assert payload == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ] + + +def test_extract_explicit_preference_candidates_keeps_remember_pattern_deterministic() -> None: + event_id = UUID("22222222-2222-2222-2222-222222222222") + memory_key = _build_memory_key("oat milk") + + payload = extract_explicit_preference_candidates( + source_event_id=event_id, + text=" remember that I prefer oat milk!! ", + ) + + assert payload == [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "prefer", + "text": "oat milk", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + "pattern": "remember_that_i_prefer", + "subject_text": "oat milk", + } + ] + + +def test_extract_explicit_preference_candidates_returns_empty_for_unsupported_text() -> None: + assert extract_explicit_preference_candidates( + source_event_id=uuid4(), + text="I had coffee yesterday.", + ) == [] + + +def test_extract_explicit_preference_candidates_rejects_clause_style_text() -> None: + assert extract_explicit_preference_candidates( + source_event_id=uuid4(), + text="I prefer that we meet tomorrow.", + ) == [] + + +def test_build_memory_key_keeps_symbol_bearing_subjects_distinct() -> None: + c_plus_plus_key = _build_memory_key("C++") + c_hash_key = _build_memory_key("C#") + + assert c_plus_plus_key != c_hash_key + assert c_plus_plus_key.startswith("user.preference.c__") + assert c_hash_key.startswith("user.preference.c__") + + +def test_build_memory_key_is_case_insensitive_for_the_same_subject() -> None: + assert _build_memory_key("Black Coffee") == _build_memory_key("black coffee") + + +def test_extract_and_admit_explicit_preferences_rejects_invalid_source_event() -> None: + store = ExplicitPreferenceStoreStub() + event_id = seed_event(store, kind="message.assistant", text="I like black coffee.") + + with pytest.raises( + ExplicitPreferenceExtractionValidationError, + match="source_event_id must reference an existing message.user event owned by the user", + ): + extract_and_admit_explicit_preferences( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=ExplicitPreferenceExtractionRequestInput(source_event_id=event_id), + ) + + +def test_extract_and_admit_explicit_preferences_routes_candidate_through_memory_admission( + monkeypatch, +) -> None: + store = ExplicitPreferenceStoreStub() + user_id = uuid4() + event_id = seed_event(store, text="I don't like black coffee.") + memory_key = _build_memory_key("black coffee") + captured: dict[str, object] = {} + + def fake_admit_memory_candidate(store_arg, *, user_id, candidate): + captured["store"] = store_arg + captured["user_id"] = user_id + captured["candidate"] = candidate + return AdmissionDecisionOutput( + action="ADD", + reason="source_backed_add", + memory={ + "id": "memory-123", + "user_id": str(user_id), + "memory_key": candidate.memory_key, + "value": candidate.value, + "status": "active", + "source_event_ids": [str(event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + revision={ + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": candidate.memory_key, + "previous_value": None, + "new_value": candidate.value, + "source_event_ids": [str(event_id)], + "candidate": candidate.as_payload(), + "created_at": "2026-03-12T09:00:00+00:00", + }, + ) + + monkeypatch.setattr( + "alicebot_api.explicit_preferences.admit_memory_candidate", + fake_admit_memory_candidate, + ) + + payload = extract_and_admit_explicit_preferences( + store, # type: ignore[arg-type] + user_id=user_id, + request=ExplicitPreferenceExtractionRequestInput(source_event_id=event_id), + ) + + assert captured["store"] is store + assert captured["user_id"] == user_id + assert captured["candidate"].memory_key == memory_key + assert captured["candidate"].value == { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + } + assert payload == { + "candidates": [ + { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + "pattern": "i_dont_like", + "subject_text": "black coffee", + } + ], + "admissions": [ + { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "status": "active", + "source_event_ids": [str(event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": memory_key, + "previous_value": None, + "new_value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "candidate": { + "memory_key": memory_key, + "value": { + "kind": "explicit_preference", + "preference": "dislike", + "text": "black coffee", + }, + "source_event_ids": [str(event_id)], + "delete_requested": False, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + } + ], + "summary": { + "source_event_id": str(event_id), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + }, + } diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py new file mode 100644 index 0000000..dc1e5ca --- /dev/null +++ b/tests/unit/test_main.py @@ -0,0 +1,2378 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.compiler import CompiledTraceRun +from alicebot_api.contracts import AdmissionDecisionOutput +from alicebot_api.embedding import ( + EmbeddingConfigValidationError, + MemoryEmbeddingNotFoundError, + MemoryEmbeddingValidationError, +) +from alicebot_api.entity import EntityNotFoundError, EntityValidationError +from alicebot_api.entity_edge import EntityEdgeValidationError +from alicebot_api.memory import MemoryAdmissionValidationError, MemoryReviewNotFoundError +from alicebot_api.response_generation import ResponseFailure +from alicebot_api.semantic_retrieval import SemanticMemoryRetrievalValidationError +from alicebot_api.store import ContinuityStoreInvariantError + + +def test_healthcheck_reports_ok_when_database_is_reachable(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=7, + ) + ping_calls: list[tuple[str, int]] = [] + + def fake_get_settings() -> Settings: + return settings + + def fake_ping_database(database_url: str, timeout_seconds: int) -> bool: + ping_calls.append((database_url, timeout_seconds)) + return True + + monkeypatch.setattr(main_module, "get_settings", fake_get_settings) + monkeypatch.setattr(main_module, "ping_database", fake_ping_database) + + response = main_module.healthcheck() + + assert response.status_code == 200 + assert json.loads(response.body) == { + "status": "ok", + "environment": "test", + "services": { + "database": {"status": "ok"}, + "redis": {"status": "not_checked", "url": "redis://cache:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://object-store", + }, + }, + } + assert ping_calls == [("postgresql://db", 7)] + + +def test_healthcheck_reports_degraded_when_database_is_unreachable(monkeypatch) -> None: + settings = Settings( + app_env="test", + database_url="postgresql://db", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + healthcheck_timeout_seconds=4, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "ping_database", lambda *_args, **_kwargs: False) + + response = main_module.healthcheck() + + assert response.status_code == 503 + assert json.loads(response.body) == { + "status": "degraded", + "environment": "test", + "services": { + "database": {"status": "unreachable"}, + "redis": {"status": "not_checked", "url": "redis://cache:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://object-store", + }, + }, + } + + +def test_healthcheck_route_is_registered() -> None: + route_paths = {route.path for route in main_module.app.routes} + + assert "/healthz" in route_paths + assert "/v0/context/compile" in route_paths + assert "/v0/responses" in route_paths + assert "/v0/memories/admit" in route_paths + assert "/v0/consents" in route_paths + assert "/v0/policies" in route_paths + assert "/v0/policies/{policy_id}" in route_paths + assert "/v0/policies/evaluate" in route_paths + assert "/v0/memories/extract-explicit-preferences" in route_paths + assert "/v0/memories" in route_paths + assert "/v0/memories/review-queue" in route_paths + assert "/v0/memories/evaluation-summary" in route_paths + assert "/v0/memories/semantic-retrieval" in route_paths + assert "/v0/memories/{memory_id}" in route_paths + assert "/v0/memories/{memory_id}/revisions" in route_paths + assert "/v0/memories/{memory_id}/labels" in route_paths + assert "/v0/embedding-configs" in route_paths + assert "/v0/memory-embeddings" in route_paths + assert "/v0/memories/{memory_id}/embeddings" in route_paths + assert "/v0/memory-embeddings/{memory_embedding_id}" in route_paths + assert "/v0/entities" in route_paths + assert "/v0/entity-edges" in route_paths + assert "/v0/tools/route" in route_paths + assert "/v0/execution-budgets" in route_paths + assert "/v0/execution-budgets/{execution_budget_id}" in route_paths + assert "/v0/execution-budgets/{execution_budget_id}/deactivate" in route_paths + assert "/v0/execution-budgets/{execution_budget_id}/supersede" in route_paths + assert "/v0/tool-executions" in route_paths + assert "/v0/tool-executions/{execution_id}" in route_paths + assert "/v0/tasks" in route_paths + assert "/v0/tasks/{task_id}" in route_paths + assert "/v0/tasks/{task_id}/workspace" in route_paths + assert "/v0/tasks/{task_id}/steps" in route_paths + assert "/v0/task-workspaces" in route_paths + assert "/v0/task-workspaces/{task_workspace_id}" in route_paths + assert "/v0/task-steps/{task_step_id}" in route_paths + assert "/v0/task-steps/{task_step_id}/transition" in route_paths + assert "/v0/entities/{entity_id}" in route_paths + assert "/v0/entities/{entity_id}/edges" in route_paths + + +def test_redact_url_credentials_strips_embedded_secrets() -> None: + assert main_module.redact_url_credentials("redis://alicebot:supersecret@cache:6379/0") == ( + "redis://cache:6379/0" + ) + assert main_module.redact_url_credentials("redis://cache:6379/0") == "redis://cache:6379/0" + + +def test_build_healthcheck_payload_keeps_boundary_statuses_consistent() -> None: + settings = Settings( + app_env="test", + redis_url="redis://alicebot:supersecret@cache:6379/0", + s3_endpoint_url="http://object-store", + ) + + assert main_module.build_healthcheck_payload(settings, database_ok=True) == { + "status": "ok", + "environment": "test", + "services": { + "database": {"status": "ok"}, + "redis": {"status": "not_checked", "url": "redis://cache:6379/0"}, + "object_storage": { + "status": "not_checked", + "endpoint_url": "http://object-store", + }, + }, + } + assert main_module.build_healthcheck_payload(settings, database_ok=False)["services"][ + "database" + ] == {"status": "unreachable"} + + +def test_compile_context_returns_trace_and_context_pack(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semantic_retrieval): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["thread_id"] = thread_id + captured["limits"] = limits + captured["semantic_retrieval"] = semantic_retrieval + return CompiledTraceRun( + trace_id="trace-123", + trace_event_count=5, + context_pack={ + "compiler_version": "continuity_v0", + "scope": {"user_id": str(user_id), "thread_id": str(thread_id)}, + "limits": { + "max_sessions": 2, + "max_events": 4, + "max_memories": 3, + "max_entities": 2, + "max_entity_edges": 6, + }, + "user": { + "id": str(user_id), + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-11T09:00:00+00:00", + }, + "thread": { + "id": str(thread_id), + "title": "Thread", + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:01:00+00:00", + }, + "sessions": [], + "events": [], + "memories": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ], + "memory_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:03:00+00:00", + } + ], + "entity_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + "entity_edges": [ + { + "id": "edge-123", + "from_entity_id": "entity-123", + "to_entity_id": "entity-999", + "relationship_type": "depends_on", + "valid_from": "2026-03-11T09:04:00+00:00", + "valid_to": None, + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:04:00+00:00", + } + ], + "entity_edge_summary": { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + }, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "compile_and_persist_trace", fake_compile_and_persist_trace) + + response = main_module.compile_context( + main_module.CompileContextRequest( + user_id=user_id, + thread_id=thread_id, + max_sessions=2, + max_events=4, + max_memories=3, + max_entities=2, + max_entity_edges=6, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "trace_id": "trace-123", + "trace_event_count": 5, + "context_pack": { + "compiler_version": "continuity_v0", + "scope": {"user_id": str(user_id), "thread_id": str(thread_id)}, + "limits": { + "max_sessions": 2, + "max_events": 4, + "max_memories": 3, + "max_entities": 2, + "max_entity_edges": 6, + }, + "user": { + "id": str(user_id), + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-11T09:00:00+00:00", + }, + "thread": { + "id": str(thread_id), + "title": "Thread", + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:01:00+00:00", + }, + "sessions": [], + "events": [], + "memories": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ], + "memory_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_deleted_count": 1, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:03:00+00:00", + } + ], + "entity_summary": { + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + "entity_edges": [ + { + "id": "edge-123", + "from_entity_id": "entity-123", + "to_entity_id": "entity-999", + "relationship_type": "depends_on", + "valid_from": "2026-03-11T09:04:00+00:00", + "valid_to": None, + "source_memory_ids": ["memory-123"], + "created_at": "2026-03-11T09:04:00+00:00", + } + ], + "entity_edge_summary": { + "anchor_entity_count": 1, + "candidate_count": 2, + "included_count": 1, + "excluded_limit_count": 1, + }, + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["thread_id"] == thread_id + assert captured["limits"].max_sessions == 2 + assert captured["limits"].max_events == 4 + assert captured["limits"].max_memories == 3 + assert captured["limits"].max_entities == 2 + assert captured["limits"].max_entity_edges == 6 + assert captured["semantic_retrieval"] is None + + +def test_compile_context_returns_not_found_when_scope_row_is_missing(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "compile_and_persist_trace", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + ContinuityStoreInvariantError("get_thread did not return a row from the database") + ), + ) + + response = main_module.compile_context( + main_module.CompileContextRequest(user_id=uuid4(), thread_id=uuid4()) + ) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": "get_thread did not return a row from the database", + } + + +def test_compile_context_routes_semantic_inputs_and_validation_errors(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_compile_and_persist_trace(store, *, user_id, thread_id, limits, semantic_retrieval): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["thread_id"] = thread_id + captured["limits"] = limits + captured["semantic_retrieval"] = semantic_retrieval + return CompiledTraceRun( + trace_id="trace-semantic", + trace_event_count=7, + context_pack={ + "compiler_version": "continuity_v0", + "scope": {"user_id": str(user_id), "thread_id": str(thread_id)}, + "limits": { + "max_sessions": 3, + "max_events": 8, + "max_memories": 5, + "max_entities": 5, + "max_entity_edges": 10, + }, + "user": { + "id": str(user_id), + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "thread": { + "id": str(thread_id), + "title": "Thread", + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:01:00+00:00", + }, + "sessions": [], + "events": [], + "memories": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-123"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "source_provenance": { + "sources": ["symbolic", "semantic"], + "semantic_score": 0.99, + }, + } + ], + "memory_summary": { + "candidate_count": 1, + "included_count": 1, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": True, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "semantic_limit": 2, + "symbolic_selected_count": 1, + "semantic_selected_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 1, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [], + "entity_summary": { + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + "entity_edges": [], + "entity_edge_summary": { + "anchor_entity_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + }, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "compile_and_persist_trace", fake_compile_and_persist_trace) + + response = main_module.compile_context( + main_module.CompileContextRequest( + user_id=user_id, + thread_id=thread_id, + semantic=main_module.CompileContextSemanticRequest( + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=2, + ), + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["context_pack"]["memory_summary"]["hybrid_retrieval"] == { + "requested": True, + "embedding_config_id": str(config_id), + "query_vector_dimensions": 3, + "semantic_limit": 2, + "symbolic_selected_count": 1, + "semantic_selected_count": 1, + "merged_candidate_count": 1, + "deduplicated_count": 1, + "included_symbolic_only_count": 0, + "included_semantic_only_count": 0, + "included_dual_source_count": 1, + "similarity_metric": "cosine_similarity", + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["semantic_retrieval"].embedding_config_id == config_id + assert captured["semantic_retrieval"].query_vector == (0.1, 0.2, 0.3) + assert captured["semantic_retrieval"].limit == 2 + + monkeypatch.setattr( + main_module, + "compile_and_persist_trace", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + SemanticMemoryRetrievalValidationError( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + ), + ) + + error_response = main_module.compile_context( + main_module.CompileContextRequest( + user_id=user_id, + thread_id=thread_id, + semantic=main_module.CompileContextSemanticRequest( + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=2, + ), + ) + ) + + assert error_response.status_code == 400 + assert json.loads(error_response.body) == { + "detail": "embedding_config_id must reference an existing embedding config owned by the user" + } + + +def test_generate_assistant_response_returns_assistant_and_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings( + database_url="postgresql://app", + model_provider="openai_responses", + model_name="gpt-5-mini", + ) + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_generate_response(store, *, settings, user_id, thread_id, message_text, limits): + captured["store_type"] = type(store).__name__ + captured["settings"] = settings + captured["user_id"] = user_id + captured["thread_id"] = thread_id + captured["message_text"] = message_text + captured["limits"] = limits + return { + "assistant": { + "event_id": "assistant-event-123", + "sequence_no": 5, + "text": "Hello back.", + "model_provider": "openai_responses", + "model": "gpt-5-mini", + }, + "trace": { + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 11, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "generate_response", fake_generate_response) + + response = main_module.generate_assistant_response( + main_module.GenerateResponseRequest( + user_id=user_id, + thread_id=thread_id, + message="Hello?", + max_sessions=2, + max_events=4, + max_memories=3, + max_entities=2, + max_entity_edges=6, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "assistant": { + "event_id": "assistant-event-123", + "sequence_no": 5, + "text": "Hello back.", + "model_provider": "openai_responses", + "model": "gpt-5-mini", + }, + "trace": { + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 11, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["thread_id"] == thread_id + assert captured["message_text"] == "Hello?" + assert captured["limits"].max_sessions == 2 + assert captured["limits"].max_events == 4 + assert captured["limits"].max_memories == 3 + assert captured["limits"].max_entities == 2 + assert captured["limits"].max_entity_edges == 6 + + +def test_generate_assistant_response_returns_502_with_trace_when_model_invocation_fails( + monkeypatch, +) -> None: + user_id = uuid4() + thread_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "generate_response", + lambda *_args, **_kwargs: ResponseFailure( + detail="upstream timeout", + trace={ + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 9, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + ), + ) + + response = main_module.generate_assistant_response( + main_module.GenerateResponseRequest( + user_id=user_id, + thread_id=thread_id, + message="Hello?", + ) + ) + + assert response.status_code == 502 + assert json.loads(response.body) == { + "detail": "upstream timeout", + "trace": { + "compile_trace_id": "compile-trace-123", + "compile_trace_event_count": 9, + "response_trace_id": "response-trace-123", + "response_trace_event_count": 2, + }, + } + + +def test_admit_memory_returns_decision_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_admit_memory_candidate(store, *, user_id, candidate): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["candidate"] = candidate + return AdmissionDecisionOutput( + action="ADD", + reason="source_backed_add", + memory={ + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:00:00+00:00", + "deleted_at": None, + }, + revision={ + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "candidate": { + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "delete_requested": False, + }, + "created_at": "2026-03-11T09:00:00+00:00", + }, + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "admit_memory_candidate", fake_admit_memory_candidate) + + response = main_module.admit_memory( + main_module.AdmitMemoryRequest( + user_id=user_id, + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=[uuid4()], + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "candidate": { + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": ["event-1"], + "delete_requested": False, + }, + "created_at": "2026-03-11T09:00:00+00:00", + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["candidate"].memory_key == "user.preference.coffee" + + +def test_admit_memory_returns_bad_request_when_source_validation_fails(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "admit_memory_candidate", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + MemoryAdmissionValidationError("source_event_ids must all reference existing events owned by the user") + ), + ) + + response = main_module.admit_memory( + main_module.AdmitMemoryRequest( + user_id=uuid4(), + memory_key="user.preference.coffee", + value={"likes": "black"}, + source_event_ids=[uuid4()], + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "source_event_ids must all reference existing events owned by the user", + } + + +def test_extract_explicit_preferences_returns_payload(monkeypatch) -> None: + user_id = uuid4() + source_event_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_extract_and_admit_explicit_preferences(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "candidates": [ + { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ], + "admissions": [ + { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "status": "active", + "source_event_ids": [str(source_event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.black_coffee", + "previous_value": None, + "new_value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "candidate": { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + } + ], + "summary": { + "source_event_id": str(source_event_id), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "extract_and_admit_explicit_preferences", + fake_extract_and_admit_explicit_preferences, + ) + + response = main_module.extract_explicit_preferences( + main_module.ExtractExplicitPreferencesRequest( + user_id=user_id, + source_event_id=source_event_id, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "candidates": [ + { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + "pattern": "i_like", + "subject_text": "black coffee", + } + ], + "admissions": [ + { + "decision": "ADD", + "reason": "source_backed_add", + "memory": { + "id": "memory-123", + "user_id": str(user_id), + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "status": "active", + "source_event_ids": [str(source_event_id)], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + "deleted_at": None, + }, + "revision": { + "id": "revision-123", + "user_id": str(user_id), + "memory_id": "memory-123", + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.black_coffee", + "previous_value": None, + "new_value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "candidate": { + "memory_key": "user.preference.black_coffee", + "value": { + "kind": "explicit_preference", + "preference": "like", + "text": "black coffee", + }, + "source_event_ids": [str(source_event_id)], + "delete_requested": False, + }, + "created_at": "2026-03-12T09:00:00+00:00", + }, + } + ], + "summary": { + "source_event_id": str(source_event_id), + "source_event_kind": "message.user", + "candidate_count": 1, + "admission_count": 1, + "persisted_change_count": 1, + "noop_count": 0, + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["request"].source_event_id == source_event_id + + +def test_extract_explicit_preferences_returns_bad_request_when_source_event_is_invalid( + monkeypatch, +) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "extract_and_admit_explicit_preferences", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + main_module.ExplicitPreferenceExtractionValidationError( + "source_event_id must reference an existing message.user event owned by the user" + ) + ), + ) + + response = main_module.extract_explicit_preferences( + main_module.ExtractExplicitPreferencesRequest( + user_id=uuid4(), + source_event_id=uuid4(), + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "source_event_id must reference an existing message.user event owned by the user", + } + + +def test_list_memories_returns_review_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_review_records(store, *, user_id, status, limit): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["status"] = status + captured["limit"] = limit + return { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "deleted_at": None, + } + ], + "summary": { + "status": "active", + "limit": 10, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_memory_review_records", fake_list_memory_review_records) + + response = main_module.list_memories(user_id=user_id, status="active", limit=10) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + "updated_at": "2026-03-11T09:02:00+00:00", + "deleted_at": None, + } + ], + "summary": { + "status": "active", + "limit": 10, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["status"] == "active" + assert captured["limit"] == 10 + + +def test_get_memory_returns_not_found_when_memory_is_inaccessible(monkeypatch) -> None: + memory_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "get_memory_review_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + main_module.MemoryReviewNotFoundError(f"memory {memory_id} was not found") + ), + ) + + response = main_module.get_memory(memory_id=memory_id, user_id=uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"memory {memory_id} was not found", + } + + +def test_list_memory_review_queue_returns_unlabeled_active_queue_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_review_queue_records(store, *, user_id, limit): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["limit"] = limit + return { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:02:00+00:00", + } + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 7, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_memory_review_queue_records", fake_list_memory_review_queue_records) + + response = main_module.list_memory_review_queue(user_id=user_id, limit=7) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:02:00+00:00", + } + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 7, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["limit"] == 7 + + +def test_get_memories_evaluation_summary_returns_aggregate_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_memory_evaluation_summary(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "summary": { + "total_memory_count": 4, + "active_memory_count": 3, + "deleted_memory_count": 1, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_memory_evaluation_summary", fake_get_memory_evaluation_summary) + + response = main_module.get_memories_evaluation_summary(user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "summary": { + "total_memory_count": 4, + "active_memory_count": 3, + "deleted_memory_count": 1, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + + +def test_list_memory_revisions_returns_review_payload(monkeypatch) -> None: + user_id = uuid4() + memory_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_revision_review_records(store, *, user_id, memory_id, limit): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["memory_id"] = memory_id + captured["limit"] = limit + return { + "items": [ + { + "id": "revision-123", + "memory_id": str(memory_id), + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + } + ], + "summary": { + "memory_id": str(memory_id), + "limit": 5, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["sequence_no_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_revision_review_records", + fake_list_memory_revision_review_records, + ) + + response = main_module.list_memory_revisions(memory_id=memory_id, user_id=user_id, limit=5) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "revision-123", + "memory_id": str(memory_id), + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T09:00:00+00:00", + } + ], + "summary": { + "memory_id": str(memory_id), + "limit": 5, + "returned_count": 1, + "total_count": 1, + "has_more": False, + "order": ["sequence_no_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["memory_id"] == memory_id + assert captured["limit"] == 5 + + +def test_create_memory_review_label_returns_created_payload(monkeypatch) -> None: + memory_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_memory_review_label_record(store, *, user_id, memory_id, label, note): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["memory_id"] = memory_id + captured["label"] = label + captured["note"] = note + return { + "label": { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "correct", + "note": "Backed by the latest source.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_memory_review_label_record", + fake_create_memory_review_label_record, + ) + + response = main_module.create_memory_review_label( + memory_id, + main_module.CreateMemoryReviewLabelRequest( + user_id=user_id, + label="correct", + note="Backed by the latest source.", + ), + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "label": { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "correct", + "note": "Backed by the latest source.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 0, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["memory_id"] == memory_id + assert captured["label"] == "correct" + assert captured["note"] == "Backed by the latest source." + + +def test_create_memory_review_label_returns_not_found_for_inaccessible_memory(monkeypatch) -> None: + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_memory_review_label_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw(MemoryReviewNotFoundError("memory missing")), + ) + + response = main_module.create_memory_review_label( + uuid4(), + main_module.CreateMemoryReviewLabelRequest(user_id=uuid4(), label="incorrect"), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": "memory missing"} + + +def test_list_memory_review_labels_returns_deterministic_items_and_summary(monkeypatch) -> None: + memory_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_memory_review_label_records(store, *, user_id, memory_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["memory_id"] = memory_id + return { + "items": [ + { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "incorrect", + "note": "Conflicts with the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + { + "id": "label-124", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "outdated", + "note": None, + "created_at": "2026-03-12T09:01:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_review_label_records", + fake_list_memory_review_label_records, + ) + + response = main_module.list_memory_review_labels(memory_id=memory_id, user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "label-123", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "incorrect", + "note": "Conflicts with the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + }, + { + "id": "label-124", + "memory_id": str(memory_id), + "reviewer_user_id": str(user_id), + "label": "outdated", + "note": None, + "created_at": "2026-03-12T09:01:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["memory_id"] == memory_id + + +def test_list_memory_review_labels_returns_not_found_for_inaccessible_memory(monkeypatch) -> None: + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_review_label_records", + lambda *_args, **_kwargs: (_ for _ in ()).throw(MemoryReviewNotFoundError("memory hidden")), + ) + + response = main_module.list_memory_review_labels(uuid4(), uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": "memory hidden"} + + +def test_create_embedding_config_returns_created_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_embedding_config_record(store, *, user_id, config): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["config"] = config + return { + "embedding_config": { + "id": "config-123", + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + "created_at": "2026-03-12T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_embedding_config_record", fake_create_embedding_config_record) + + response = main_module.create_embedding_config( + main_module.CreateEmbeddingConfigRequest( + user_id=user_id, + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + status="active", + metadata={"task": "memory_retrieval"}, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "embedding_config": { + "id": "config-123", + "provider": "openai", + "model": "text-embedding-3-large", + "version": "2026-03-12", + "dimensions": 3, + "status": "active", + "metadata": {"task": "memory_retrieval"}, + "created_at": "2026-03-12T10:00:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["config"].provider == "openai" + + +def test_create_embedding_config_returns_bad_request_for_validation_failure(monkeypatch) -> None: + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_embedding_config_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EmbeddingConfigValidationError( + "embedding config already exists for provider/model/version under the user scope: " + "openai/text-embedding-3-large/2026-03-12" + ) + ), + ) + + response = main_module.create_embedding_config( + main_module.CreateEmbeddingConfigRequest( + user_id=uuid4(), + provider="openai", + model="text-embedding-3-large", + version="2026-03-12", + dimensions=3, + status="active", + metadata={"task": "memory_retrieval"}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": ( + "embedding config already exists for provider/model/version under the user scope: " + "openai/text-embedding-3-large/2026-03-12" + ) + } + + +def test_upsert_memory_embedding_routes_success_and_validation_errors(monkeypatch) -> None: + user_id = uuid4() + memory_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_upsert_memory_embedding_record(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "embedding": { + "id": "embedding-123", + "memory_id": str(memory_id), + "embedding_config_id": str(config_id), + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + }, + "write_mode": "created", + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "upsert_memory_embedding_record", fake_upsert_memory_embedding_record) + + response = main_module.upsert_memory_embedding( + main_module.UpsertMemoryEmbeddingRequest( + user_id=user_id, + memory_id=memory_id, + embedding_config_id=config_id, + vector=[0.1, 0.2, 0.3], + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body)["write_mode"] == "created" + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["request"].memory_id == memory_id + + monkeypatch.setattr( + main_module, + "upsert_memory_embedding_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + MemoryEmbeddingValidationError( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + ), + ) + + error_response = main_module.upsert_memory_embedding( + main_module.UpsertMemoryEmbeddingRequest( + user_id=user_id, + memory_id=memory_id, + embedding_config_id=config_id, + vector=[0.1, 0.2, 0.3], + ) + ) + + assert error_response.status_code == 400 + assert json.loads(error_response.body) == { + "detail": "embedding_config_id must reference an existing embedding config owned by the user" + } + + +def test_retrieve_semantic_memories_routes_success_and_validation_errors(monkeypatch) -> None: + user_id = uuid4() + config_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_retrieve_semantic_memory_records(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "items": [ + { + "memory_id": "memory-123", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "source_event_ids": ["event-123"], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + "score": 0.99, + } + ], + "summary": { + "embedding_config_id": str(config_id), + "limit": 5, + "returned_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "retrieve_semantic_memory_records", + fake_retrieve_semantic_memory_records, + ) + + response = main_module.retrieve_semantic_memories( + main_module.RetrieveSemanticMemoriesRequest( + user_id=user_id, + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=5, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["summary"] == { + "embedding_config_id": str(config_id), + "limit": 5, + "returned_count": 1, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["request"].embedding_config_id == config_id + assert captured["request"].query_vector == (0.1, 0.2, 0.3) + + monkeypatch.setattr( + main_module, + "retrieve_semantic_memory_records", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + SemanticMemoryRetrievalValidationError( + "embedding_config_id must reference an existing embedding config owned by the user" + ) + ), + ) + + error_response = main_module.retrieve_semantic_memories( + main_module.RetrieveSemanticMemoriesRequest( + user_id=user_id, + embedding_config_id=config_id, + query_vector=[0.1, 0.2, 0.3], + limit=5, + ) + ) + + assert error_response.status_code == 400 + assert json.loads(error_response.body) == { + "detail": "embedding_config_id must reference an existing embedding config owned by the user" + } + + +def test_memory_embedding_read_routes_return_payload_and_not_found(monkeypatch) -> None: + user_id = uuid4() + memory_id = uuid4() + embedding_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_memory_embedding_records", + lambda *_args, **_kwargs: { + "items": [ + { + "id": str(embedding_id), + "memory_id": str(memory_id), + "embedding_config_id": "config-123", + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + }, + ) + monkeypatch.setattr( + main_module, + "get_memory_embedding_record", + lambda *_args, **_kwargs: { + "embedding": { + "id": str(embedding_id), + "memory_id": str(memory_id), + "embedding_config_id": "config-123", + "dimensions": 3, + "vector": [0.1, 0.2, 0.3], + "created_at": "2026-03-12T10:00:00+00:00", + "updated_at": "2026-03-12T10:00:00+00:00", + } + }, + ) + + list_response = main_module.list_memory_embeddings(memory_id=memory_id, user_id=user_id) + detail_response = main_module.get_memory_embedding(memory_embedding_id=embedding_id, user_id=user_id) + + assert list_response.status_code == 200 + assert json.loads(list_response.body)["summary"]["memory_id"] == str(memory_id) + assert detail_response.status_code == 200 + assert json.loads(detail_response.body)["embedding"]["id"] == str(embedding_id) + + monkeypatch.setattr( + main_module, + "get_memory_embedding_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + MemoryEmbeddingNotFoundError(f"memory embedding {embedding_id} was not found") + ), + ) + + not_found_response = main_module.get_memory_embedding( + memory_embedding_id=embedding_id, + user_id=user_id, + ) + + assert not_found_response.status_code == 404 + assert json.loads(not_found_response.body) == { + "detail": f"memory embedding {embedding_id} was not found" + } + + +def test_create_entity_returns_created_payload(monkeypatch) -> None: + user_id = uuid4() + first_memory_id = uuid4() + second_memory_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_entity_record(store, *, user_id, entity): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["entity"] = entity + return { + "entity": { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_entity_record", fake_create_entity_record) + + response = main_module.create_entity( + main_module.CreateEntityRequest( + user_id=user_id, + entity_type="project", + name="AliceBot", + source_memory_ids=[first_memory_id, second_memory_id], + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "entity": { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": [str(first_memory_id), str(second_memory_id)], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["entity"].entity_type == "project" + assert captured["entity"].name == "AliceBot" + + +def test_create_entity_returns_bad_request_when_source_memory_validation_fails(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_entity_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityValidationError("source_memory_ids must all reference existing memories owned by the user") + ), + ) + + response = main_module.create_entity( + main_module.CreateEntityRequest( + user_id=uuid4(), + entity_type="person", + name="Samir", + source_memory_ids=[uuid4()], + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "source_memory_ids must all reference existing memories owned by the user", + } + + +def test_create_entity_edge_returns_created_payload(monkeypatch) -> None: + user_id = uuid4() + from_entity_id = uuid4() + to_entity_id = uuid4() + source_memory_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_entity_edge_record(store, *, user_id, edge): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["edge"] = edge + return { + "edge": { + "id": "edge-123", + "from_entity_id": str(from_entity_id), + "to_entity_id": str(to_entity_id), + "relationship_type": "works_on", + "valid_from": "2026-03-12T10:00:00+00:00", + "valid_to": None, + "source_memory_ids": [str(source_memory_id)], + "created_at": "2026-03-12T10:01:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_entity_edge_record", fake_create_entity_edge_record) + + response = main_module.create_entity_edge( + main_module.CreateEntityEdgeRequest( + user_id=user_id, + from_entity_id=from_entity_id, + to_entity_id=to_entity_id, + relationship_type="works_on", + valid_from="2026-03-12T10:00:00+00:00", + source_memory_ids=[source_memory_id], + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "edge": { + "id": "edge-123", + "from_entity_id": str(from_entity_id), + "to_entity_id": str(to_entity_id), + "relationship_type": "works_on", + "valid_from": "2026-03-12T10:00:00+00:00", + "valid_to": None, + "source_memory_ids": [str(source_memory_id)], + "created_at": "2026-03-12T10:01:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["edge"].from_entity_id == from_entity_id + assert captured["edge"].to_entity_id == to_entity_id + + +def test_create_entity_edge_returns_bad_request_for_validation_failure(monkeypatch) -> None: + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "create_entity_edge_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityEdgeValidationError("valid_to must be greater than or equal to valid_from") + ), + ) + + response = main_module.create_entity_edge( + main_module.CreateEntityEdgeRequest( + user_id=uuid4(), + from_entity_id=uuid4(), + to_entity_id=uuid4(), + relationship_type="works_on", + valid_from="2026-03-12T11:00:00+00:00", + valid_to="2026-03-12T10:00:00+00:00", + source_memory_ids=[uuid4()], + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "valid_to must be greater than or equal to valid_from", + } + + +def test_list_entities_returns_deterministic_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_entity_records(store, *, user_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + return { + "items": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_entity_records", fake_list_entity_records) + + response = main_module.list_entities(user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "entity-123", + "entity_type": "project", + "name": "AliceBot", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + + +def test_list_entity_edges_returns_deterministic_payload(monkeypatch) -> None: + user_id = uuid4() + entity_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_list_entity_edge_records(store, *, user_id, entity_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["entity_id"] = entity_id + return { + "items": [ + { + "id": "edge-123", + "from_entity_id": str(entity_id), + "to_entity_id": "entity-456", + "relationship_type": "works_on", + "valid_from": None, + "valid_to": None, + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "entity_id": str(entity_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_entity_edge_records", fake_list_entity_edge_records) + + response = main_module.list_entity_edges(entity_id=entity_id, user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [ + { + "id": "edge-123", + "from_entity_id": str(entity_id), + "to_entity_id": "entity-456", + "relationship_type": "works_on", + "valid_from": None, + "valid_to": None, + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + ], + "summary": { + "entity_id": str(entity_id), + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["entity_id"] == entity_id + + +def test_list_entity_edges_returns_not_found_for_inaccessible_entity(monkeypatch) -> None: + entity_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_entity_edge_records", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityNotFoundError(f"entity {entity_id} was not found") + ), + ) + + response = main_module.list_entity_edges(entity_id=entity_id, user_id=uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"entity {entity_id} was not found", + } + + +def test_get_entity_returns_detail_payload(monkeypatch) -> None: + user_id = uuid4() + entity_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_get_entity_record(store, *, user_id, entity_id): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["entity_id"] = entity_id + return { + "entity": { + "id": str(entity_id), + "entity_type": "person", + "name": "Samir", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_entity_record", fake_get_entity_record) + + response = main_module.get_entity(entity_id=entity_id, user_id=user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "entity": { + "id": str(entity_id), + "entity_type": "person", + "name": "Samir", + "source_memory_ids": ["memory-1"], + "created_at": "2026-03-12T10:00:00+00:00", + } + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["user_id"] == user_id + assert captured["entity_id"] == entity_id + + +def test_get_entity_returns_not_found_for_inaccessible_entity(monkeypatch) -> None: + entity_id = uuid4() + + @contextmanager + def fake_user_connection(_database_url: str, _current_user_id): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: Settings(database_url="postgresql://app")) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "get_entity_record", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + EntityNotFoundError(f"entity {entity_id} was not found") + ), + ) + + response = main_module.get_entity(entity_id=entity_id, user_id=uuid4()) + + assert response.status_code == 404 + assert json.loads(response.body) == { + "detail": f"entity {entity_id} was not found", + } diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py new file mode 100644 index 0000000..1ce8211 --- /dev/null +++ b/tests/unit/test_memory.py @@ -0,0 +1,897 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import MemoryCandidateInput +from alicebot_api.memory import ( + MemoryAdmissionValidationError, + MemoryReviewNotFoundError, + admit_memory_candidate, + create_memory_review_label_record, + get_memory_evaluation_summary, + get_memory_review_record, + list_memory_review_queue_records, + list_memory_review_label_records, + list_memory_review_records, + list_memory_revision_review_records, +) + + +class MemoryStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + self.events: dict[UUID, dict[str, object]] = {} + self.memory: dict[str, object] | None = None + self.revisions: list[dict[str, object]] = [] + + def list_events_by_ids(self, event_ids: list[UUID]) -> list[dict[str, object]]: + return [self.events[event_id] for event_id in event_ids if event_id in self.events] + + def get_memory_by_key(self, memory_key: str) -> dict[str, object] | None: + if self.memory is None or self.memory["memory_key"] != memory_key: + return None + return self.memory + + def create_memory( + self, + *, + memory_key: str, + value, + status: str, + source_event_ids: list[str], + ) -> dict[str, object]: + self.memory = { + "id": uuid4(), + "user_id": uuid4(), + "memory_key": memory_key, + "value": value, + "status": status, + "source_event_ids": source_event_ids, + "created_at": self.base_time, + "updated_at": self.base_time, + "deleted_at": None, + } + return self.memory + + def update_memory( + self, + *, + memory_id: UUID, + value, + status: str, + source_event_ids: list[str], + ) -> dict[str, object]: + assert self.memory is not None + assert self.memory["id"] == memory_id + updated_at = self.base_time + timedelta(minutes=len(self.revisions) + 1) + self.memory = { + **self.memory, + "value": value, + "status": status, + "source_event_ids": source_event_ids, + "updated_at": updated_at, + "deleted_at": updated_at if status == "deleted" else None, + } + return self.memory + + def append_memory_revision( + self, + *, + memory_id: UUID, + action: str, + memory_key: str, + previous_value, + new_value, + source_event_ids: list[str], + candidate: dict[str, object], + ) -> dict[str, object]: + revision = { + "id": uuid4(), + "user_id": self.memory["user_id"] if self.memory is not None else uuid4(), + "memory_id": memory_id, + "sequence_no": len(self.revisions) + 1, + "action": action, + "memory_key": memory_key, + "previous_value": previous_value, + "new_value": new_value, + "source_event_ids": source_event_ids, + "candidate": candidate, + "created_at": self.base_time + timedelta(minutes=len(self.revisions) + 1), + } + self.revisions.append(revision) + return revision + + +def seed_event(store: MemoryStoreStub) -> UUID: + event_id = uuid4() + store.events[event_id] = { + "id": event_id, + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "evidence"}, + "created_at": store.base_time, + } + return event_id + + +def test_admit_memory_candidate_defaults_to_noop_when_value_is_missing() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value=None, + source_event_ids=(event_id,), + ), + ) + + assert decision.action == "NOOP" + assert decision.reason == "candidate_value_missing" + assert decision.memory is None + assert decision.revision is None + + +def test_admit_memory_candidate_rejects_missing_source_events() -> None: + store = MemoryStoreStub() + + with pytest.raises( + MemoryAdmissionValidationError, + match="source_event_ids must all reference existing events owned by the user", + ): + admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.tea", + value={"likes": True}, + source_event_ids=(uuid4(),), + ), + ) + + +def test_admit_memory_candidate_rejects_empty_source_event_ids() -> None: + store = MemoryStoreStub() + + with pytest.raises( + MemoryAdmissionValidationError, + match="source_event_ids must include at least one existing event owned by the user", + ): + admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.tea", + value={"likes": True}, + source_event_ids=(), + ), + ) + + +def test_admit_memory_candidate_adds_new_memory_with_first_revision() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_id,), + ), + ) + + assert decision.action == "ADD" + assert decision.reason == "source_backed_add" + assert decision.memory is not None + assert decision.memory["memory_key"] == "user.preference.coffee" + assert decision.memory["status"] == "active" + assert decision.revision is not None + assert decision.revision["sequence_no"] == 1 + assert decision.revision["action"] == "ADD" + assert decision.revision["new_value"] == {"likes": "oat milk"} + + +def test_admit_memory_candidate_updates_existing_memory_and_appends_revision() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + created = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_id)], + ) + store.append_memory_revision( + memory_id=created["id"], + action="ADD", + memory_key="user.preference.coffee", + previous_value=None, + new_value={"likes": "black"}, + source_event_ids=[str(event_id)], + candidate={ + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "source_event_ids": [str(event_id)], + "delete_requested": False, + }, + ) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value={"likes": "oat milk"}, + source_event_ids=(event_id,), + ), + ) + + assert decision.action == "UPDATE" + assert decision.reason == "source_backed_update" + assert decision.memory is not None + assert decision.memory["value"] == {"likes": "oat milk"} + assert decision.revision is not None + assert decision.revision["sequence_no"] == 2 + assert decision.revision["previous_value"] == {"likes": "black"} + assert decision.revision["new_value"] == {"likes": "oat milk"} + + +def test_admit_memory_candidate_marks_memory_deleted_and_appends_revision() -> None: + store = MemoryStoreStub() + event_id = seed_event(store) + created = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_id)], + ) + + decision = admit_memory_candidate( + store, # type: ignore[arg-type] + user_id=uuid4(), + candidate=MemoryCandidateInput( + memory_key="user.preference.coffee", + value=None, + source_event_ids=(event_id,), + delete_requested=True, + ), + ) + + assert decision.action == "DELETE" + assert decision.reason == "source_backed_delete" + assert decision.memory is not None + assert UUID(decision.memory["id"]) == created["id"] + assert decision.memory["status"] == "deleted" + assert decision.revision is not None + assert decision.revision["sequence_no"] == 1 + assert decision.revision["action"] == "DELETE" + assert decision.revision["new_value"] is None + + +class MemoryReviewStoreStub: + def __init__(self) -> None: + self.memories: list[dict[str, object]] = [] + self.revisions: dict[UUID, list[dict[str, object]]] = {} + self.labels: dict[UUID, list[dict[str, object]]] = {} + + def count_memories(self, *, status: str | None = None) -> int: + return len(self._filtered_memories(status)) + + def list_review_memories(self, *, status: str | None = None, limit: int) -> list[dict[str, object]]: + return self._review_sorted_memories(self._filtered_memories(status))[:limit] + + def count_unlabeled_review_memories(self) -> int: + return len( + [memory for memory in self.memories if memory["status"] == "active" and not self.labels.get(memory["id"])] + ) + + def list_unlabeled_review_memories(self, *, limit: int) -> list[dict[str, object]]: + return self._review_sorted_memories( + [ + memory + for memory in self.memories + if memory["status"] == "active" and not self.labels.get(memory["id"]) + ] + )[:limit] + + def get_memory_optional(self, memory_id: UUID) -> dict[str, object] | None: + for memory in self.memories: + if memory["id"] == memory_id: + return memory + return None + + def count_memory_revisions(self, memory_id: UUID) -> int: + return len(self.revisions.get(memory_id, [])) + + def list_memory_revisions(self, memory_id: UUID, *, limit: int | None = None) -> list[dict[str, object]]: + revisions = self.revisions.get(memory_id, []) + if limit is None: + return revisions + return revisions[:limit] + + def create_memory_review_label( + self, + *, + memory_id: UUID, + label: str, + note: str | None, + ) -> dict[str, object]: + memory = self.get_memory_optional(memory_id) + created = { + "id": uuid4(), + "user_id": uuid4() if memory is None else memory["user_id"], + "memory_id": memory_id, + "label": label, + "note": note, + "created_at": datetime(2026, 3, 11, 13, len(self.labels.get(memory_id, [])), tzinfo=UTC), + } + self.labels.setdefault(memory_id, []).append(created) + return created + + def list_memory_review_labels(self, memory_id: UUID) -> list[dict[str, object]]: + return list(self.labels.get(memory_id, [])) + + def list_memory_review_label_counts(self, memory_id: UUID) -> list[dict[str, object]]: + counts: dict[str, int] = {} + for label in self.labels.get(memory_id, []): + label_name = label["label"] + counts[label_name] = counts.get(label_name, 0) + 1 + return [{"label": label, "count": count} for label, count in sorted(counts.items())] + + def count_labeled_memories(self) -> int: + return len([memory for memory in self.memories if self.labels.get(memory["id"])]) + + def count_unlabeled_memories(self) -> int: + return len([memory for memory in self.memories if not self.labels.get(memory["id"])]) + + def list_all_memory_review_label_counts(self) -> list[dict[str, object]]: + counts: dict[str, int] = {} + for labels in self.labels.values(): + for label in labels: + label_name = label["label"] + counts[label_name] = counts.get(label_name, 0) + 1 + return [{"label": label, "count": count} for label, count in sorted(counts.items())] + + def _filtered_memories(self, status: str | None) -> list[dict[str, object]]: + if status is None: + return list(self.memories) + return [memory for memory in self.memories if memory["status"] == status] + + def _review_sorted_memories(self, memories: list[dict[str, object]]) -> list[dict[str, object]]: + return sorted( + memories, + key=lambda memory: (memory["updated_at"], memory["created_at"], memory["id"]), + reverse=True, + ) + + +def test_list_memory_review_records_returns_summary_and_stable_shape() -> None: + store = MemoryReviewStoreStub() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + deleted_time = base_time + timedelta(minutes=1) + active_time = base_time + timedelta(minutes=2) + deleted_id = uuid4() + active_id = uuid4() + store.memories = [ + { + "id": active_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": active_time, + "deleted_at": None, + }, + { + "id": deleted_id, + "user_id": uuid4(), + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "deleted", + "source_event_ids": ["event-1"], + "created_at": base_time, + "updated_at": deleted_time, + "deleted_at": deleted_time, + }, + ] + + payload = list_memory_review_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + status="all", + limit=1, + ) + + assert payload == { + "items": [ + { + "id": str(active_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": "2026-03-11T12:00:00+00:00", + "updated_at": "2026-03-11T12:02:00+00:00", + "deleted_at": None, + } + ], + "summary": { + "status": "all", + "limit": 1, + "returned_count": 1, + "total_count": 2, + "has_more": True, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + +def test_get_memory_review_record_raises_not_found_for_inaccessible_memory() -> None: + store = MemoryReviewStoreStub() + + with pytest.raises(MemoryReviewNotFoundError, match="was not found"): + get_memory_review_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=uuid4(), + ) + + +def test_list_memory_review_queue_records_returns_only_active_unlabeled_memories_in_stable_order() -> None: + store = MemoryReviewStoreStub() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + deleted_id = uuid4() + labeled_id = uuid4() + newest_unlabeled_id = uuid4() + older_unlabeled_id = uuid4() + store.memories = [ + { + "id": newest_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-4"], + "created_at": base_time + timedelta(minutes=3), + "updated_at": base_time + timedelta(minutes=6), + "deleted_at": None, + }, + { + "id": labeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.snack", + "value": {"likes": "chips"}, + "status": "active", + "source_event_ids": ["event-3"], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=5), + "deleted_at": None, + }, + { + "id": older_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=4), + "deleted_at": None, + }, + { + "id": deleted_id, + "user_id": uuid4(), + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "deleted", + "source_event_ids": ["event-1"], + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=7), + "deleted_at": base_time + timedelta(minutes=7), + }, + ] + store.labels[labeled_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": labeled_id, + "label": "correct", + "note": "Already reviewed.", + "created_at": base_time + timedelta(minutes=8), + } + ] + + payload = list_memory_review_queue_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + limit=2, + ) + + assert payload == { + "items": [ + { + "id": str(newest_unlabeled_id), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-4"], + "created_at": "2026-03-11T12:03:00+00:00", + "updated_at": "2026-03-11T12:06:00+00:00", + }, + { + "id": str(older_unlabeled_id), + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": "2026-03-11T12:01:00+00:00", + "updated_at": "2026-03-11T12:04:00+00:00", + }, + ], + "summary": { + "memory_status": "active", + "review_state": "unlabeled", + "limit": 2, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["updated_at_desc", "created_at_desc", "id_desc"], + }, + } + + +def test_list_memory_revision_review_records_returns_deterministic_revision_order() -> None: + store = MemoryReviewStoreStub() + memory_id = uuid4() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + store.memories = [ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": base_time + timedelta(minutes=2), + "deleted_at": None, + } + ] + store.revisions[memory_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": memory_id, + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "candidate": {"memory_key": "user.preference.coffee"}, + "created_at": base_time, + }, + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": memory_id, + "sequence_no": 2, + "action": "UPDATE", + "memory_key": "user.preference.coffee", + "previous_value": {"likes": "black"}, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-2"], + "candidate": {"memory_key": "user.preference.coffee"}, + "created_at": base_time + timedelta(minutes=1), + }, + ] + + payload = list_memory_revision_review_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=memory_id, + limit=10, + ) + + assert payload == { + "items": [ + { + "id": str(store.revisions[memory_id][0]["id"]), + "memory_id": str(memory_id), + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": ["event-1"], + "created_at": "2026-03-11T12:00:00+00:00", + }, + { + "id": str(store.revisions[memory_id][1]["id"]), + "memory_id": str(memory_id), + "sequence_no": 2, + "action": "UPDATE", + "memory_key": "user.preference.coffee", + "previous_value": {"likes": "black"}, + "new_value": {"likes": "oat milk"}, + "source_event_ids": ["event-2"], + "created_at": "2026-03-11T12:01:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "limit": 10, + "returned_count": 2, + "total_count": 2, + "has_more": False, + "order": ["sequence_no_asc"], + }, + } + + +def test_create_memory_review_label_record_returns_created_label_and_summary_counts() -> None: + store = MemoryReviewStoreStub() + memory_id = uuid4() + reviewer_user_id = uuid4() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + store.memories = [ + { + "id": memory_id, + "user_id": reviewer_user_id, + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + } + ] + store.labels[memory_id] = [ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "correct", + "note": "Matches the latest cited event.", + "created_at": datetime(2026, 3, 11, 12, 30, tzinfo=UTC), + } + ] + + payload = create_memory_review_label_record( + store, # type: ignore[arg-type] + user_id=reviewer_user_id, + memory_id=memory_id, + label="outdated", + note="Superseded by the newer milk preference.", + ) + + assert payload == { + "label": { + "id": payload["label"]["id"], + "memory_id": str(memory_id), + "reviewer_user_id": payload["label"]["reviewer_user_id"], + "label": "outdated", + "note": "Superseded by the newer milk preference.", + "created_at": "2026-03-11T13:01:00+00:00", + }, + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 0, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_create_memory_review_label_record_raises_not_found_for_inaccessible_memory() -> None: + store = MemoryReviewStoreStub() + + with pytest.raises(MemoryReviewNotFoundError, match="was not found"): + create_memory_review_label_record( + store, # type: ignore[arg-type] + user_id=uuid4(), + memory_id=uuid4(), + label="correct", + note=None, + ) + + +def test_list_memory_review_label_records_returns_deterministic_order_and_zero_filled_counts() -> None: + store = MemoryReviewStoreStub() + memory_id = uuid4() + reviewer_user_id = uuid4() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + store.memories = [ + { + "id": memory_id, + "user_id": reviewer_user_id, + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + } + ] + store.labels[memory_id] = [ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "incorrect", + "note": "The source event only mentions tea.", + "created_at": datetime(2026, 3, 11, 12, 15, tzinfo=UTC), + }, + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "insufficient_evidence", + "note": None, + "created_at": datetime(2026, 3, 11, 12, 16, tzinfo=UTC), + }, + ] + + payload = list_memory_review_label_records( + store, # type: ignore[arg-type] + user_id=reviewer_user_id, + memory_id=memory_id, + ) + + assert payload == { + "items": [ + { + "id": str(store.labels[memory_id][0]["id"]), + "memory_id": str(memory_id), + "reviewer_user_id": str(reviewer_user_id), + "label": "incorrect", + "note": "The source event only mentions tea.", + "created_at": "2026-03-11T12:15:00+00:00", + }, + { + "id": str(store.labels[memory_id][1]["id"]), + "memory_id": str(memory_id), + "reviewer_user_id": str(reviewer_user_id), + "label": "insufficient_evidence", + "note": None, + "created_at": "2026-03-11T12:16:00+00:00", + }, + ], + "summary": { + "memory_id": str(memory_id), + "total_count": 2, + "counts_by_label": { + "correct": 0, + "incorrect": 1, + "outdated": 0, + "insufficient_evidence": 1, + }, + "order": ["created_at_asc", "id_asc"], + }, + } + + +def test_get_memory_evaluation_summary_returns_explicit_consistent_counts() -> None: + store = MemoryReviewStoreStub() + base_time = datetime(2026, 3, 11, 12, 0, tzinfo=UTC) + active_labeled_id = uuid4() + active_unlabeled_id = uuid4() + deleted_labeled_id = uuid4() + deleted_unlabeled_id = uuid4() + store.memories = [ + { + "id": active_labeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["event-1"], + "created_at": base_time, + "updated_at": base_time, + "deleted_at": None, + }, + { + "id": active_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.book", + "value": {"genre": "science fiction"}, + "status": "active", + "source_event_ids": ["event-2"], + "created_at": base_time + timedelta(minutes=1), + "updated_at": base_time + timedelta(minutes=1), + "deleted_at": None, + }, + { + "id": deleted_labeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.snack", + "value": {"likes": "chips"}, + "status": "deleted", + "source_event_ids": ["event-3"], + "created_at": base_time + timedelta(minutes=2), + "updated_at": base_time + timedelta(minutes=2), + "deleted_at": base_time + timedelta(minutes=2), + }, + { + "id": deleted_unlabeled_id, + "user_id": uuid4(), + "memory_key": "user.preference.tea", + "value": {"likes": "green"}, + "status": "deleted", + "source_event_ids": ["event-4"], + "created_at": base_time + timedelta(minutes=3), + "updated_at": base_time + timedelta(minutes=3), + "deleted_at": base_time + timedelta(minutes=3), + }, + ] + store.labels[active_labeled_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": active_labeled_id, + "label": "correct", + "note": "Looks right.", + "created_at": base_time + timedelta(minutes=4), + }, + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": active_labeled_id, + "label": "insufficient_evidence", + "note": "Needs another source.", + "created_at": base_time + timedelta(minutes=5), + }, + ] + store.labels[deleted_labeled_id] = [ + { + "id": uuid4(), + "user_id": uuid4(), + "memory_id": deleted_labeled_id, + "label": "outdated", + "note": None, + "created_at": base_time + timedelta(minutes=6), + } + ] + + payload = get_memory_evaluation_summary( + store, # type: ignore[arg-type] + user_id=uuid4(), + ) + + assert payload == { + "summary": { + "total_memory_count": 4, + "active_memory_count": 2, + "deleted_memory_count": 2, + "labeled_memory_count": 2, + "unlabeled_memory_count": 2, + "total_label_row_count": 3, + "label_row_counts_by_value": { + "correct": 1, + "incorrect": 0, + "outdated": 1, + "insufficient_evidence": 1, + }, + "label_value_order": [ + "correct", + "incorrect", + "outdated", + "insufficient_evidence", + ], + } + } diff --git a/tests/unit/test_memory_store.py b/tests/unit/test_memory_store.py new file mode 100644 index 0000000..9b09755 --- /dev/null +++ b/tests/unit/test_memory_store.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb +import pytest + +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError + + +class RecordingCursor: + def __init__( + self, + fetchone_results: list[dict[str, Any]], + fetchall_results: list[list[dict[str, Any]]] | None = None, + ) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_results = list(fetchall_results or []) + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + if not self.fetchall_results: + return [] + return self.fetchall_results.pop(0) + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_memory_methods_use_expected_queries_and_payload_serialization() -> None: + memory_id = uuid4() + event_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "status": "active", + "source_event_ids": [str(event_id)], + }, + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": [str(event_id)], + }, + { + "id": uuid4(), + "memory_id": memory_id, + "sequence_no": 1, + "action": "ADD", + "memory_key": "user.preference.coffee", + "previous_value": None, + "new_value": {"likes": "black"}, + "source_event_ids": [str(event_id)], + "candidate": {"memory_key": "user.preference.coffee"}, + }, + ], + fetchall_results=[ + [{"id": event_id, "sequence_no": 1}], + [{"sequence_no": 1, "action": "ADD"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_memory( + memory_key="user.preference.coffee", + value={"likes": "black"}, + status="active", + source_event_ids=[str(event_id)], + ) + updated = store.update_memory( + memory_id=memory_id, + value={"likes": "oat milk"}, + status="active", + source_event_ids=[str(event_id)], + ) + revision = store.append_memory_revision( + memory_id=memory_id, + action="ADD", + memory_key="user.preference.coffee", + previous_value=None, + new_value={"likes": "black"}, + source_event_ids=[str(event_id)], + candidate={"memory_key": "user.preference.coffee"}, + ) + listed_events = store.list_events_by_ids([event_id]) + listed_revisions = store.list_memory_revisions(memory_id) + listed_context_memories = store.list_context_memories() + + assert created["id"] == memory_id + assert updated["value"] == {"likes": "oat milk"} + assert revision["sequence_no"] == 1 + assert listed_events == [{"id": event_id, "sequence_no": 1}] + assert listed_revisions == [{"sequence_no": 1, "action": "ADD"}] + assert listed_context_memories == [] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO memories" in create_query + assert "clock_timestamp()" in create_query + assert create_params is not None + assert create_params[0] == "user.preference.coffee" + assert isinstance(create_params[1], Jsonb) + assert create_params[1].obj == {"likes": "black"} + assert create_params[2] == "active" + assert isinstance(create_params[3], Jsonb) + assert create_params[3].obj == [str(event_id)] + + update_query, update_params = cursor.executed[1] + assert "UPDATE memories" in update_query + assert "updated_at = clock_timestamp()" in update_query + assert "THEN clock_timestamp()" in update_query + assert update_params is not None + assert isinstance(update_params[0], Jsonb) + assert update_params[0].obj == {"likes": "oat milk"} + assert update_params[1] == "active" + assert isinstance(update_params[2], Jsonb) + assert update_params[2].obj == [str(event_id)] + assert update_params[3] == "active" + assert update_params[4] == memory_id + + assert cursor.executed[2] == ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 1))", + (str(memory_id),), + ) + append_revision_query, append_revision_params = cursor.executed[3] + assert "INSERT INTO memory_revisions" in append_revision_query + assert append_revision_params is not None + assert append_revision_params[:4] == ( + memory_id, + memory_id, + "ADD", + "user.preference.coffee", + ) + assert isinstance(append_revision_params[4], Jsonb) + assert append_revision_params[4].obj is None + assert isinstance(append_revision_params[5], Jsonb) + assert append_revision_params[5].obj == {"likes": "black"} + assert isinstance(append_revision_params[6], Jsonb) + assert append_revision_params[6].obj == [str(event_id)] + assert isinstance(append_revision_params[7], Jsonb) + assert append_revision_params[7].obj == {"memory_key": "user.preference.coffee"} + assert cursor.executed[6] == ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + ORDER BY updated_at ASC, created_at ASC, id ASC + """, + None, + ) + + +def test_get_memory_by_key_returns_none_when_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + assert store.get_memory_by_key("user.preference.coffee") is None + + +def test_append_memory_revision_raises_clear_error_when_returning_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + with pytest.raises( + ContinuityStoreInvariantError, + match="append_memory_revision did not return a row", + ): + store.append_memory_revision( + memory_id=uuid4(), + action="ADD", + memory_key="user.preference.coffee", + previous_value=None, + new_value={"likes": "black"}, + source_event_ids=["event-1"], + candidate={"memory_key": "user.preference.coffee"}, + ) + + +def test_memory_review_read_methods_use_explicit_order_filter_and_limit() -> None: + memory_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": memory_id, + "user_id": uuid4(), + "memory_key": "user.preference.coffee", + "value": {"likes": "black"}, + "status": "active", + "source_event_ids": ["event-1"], + }, + {"count": 2}, + {"count": 3}, + ], + fetchall_results=[ + [{"id": memory_id, "memory_key": "user.preference.coffee"}], + [{"sequence_no": 1, "action": "ADD"}], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + memory = store.get_memory_optional(memory_id) + memory_count = store.count_memories(status="active") + listed_memories = store.list_review_memories(status="active", limit=5) + revision_count = store.count_memory_revisions(memory_id) + listed_revisions = store.list_memory_revisions(memory_id, limit=2) + + assert memory is not None + assert memory["id"] == memory_id + assert memory_count == 2 + assert listed_memories == [{"id": memory_id, "memory_key": "user.preference.coffee"}] + assert revision_count == 3 + assert listed_revisions == [{"sequence_no": 1, "action": "ADD"}] + assert cursor.executed == [ + ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE id = %s + """, + (memory_id,), + ), + ( + """ + SELECT COUNT(*) AS count + FROM memories + WHERE status = %s + """, + ("active",), + ), + ( + """ + SELECT id, user_id, memory_key, value, status, source_event_ids, created_at, updated_at, deleted_at + FROM memories + WHERE status = %s + ORDER BY updated_at DESC, created_at DESC, id DESC + LIMIT %s + """, + ("active", 5), + ), + ( + """ + SELECT COUNT(*) AS count + FROM memory_revisions + WHERE memory_id = %s + """, + (memory_id,), + ), + ( + """ + SELECT id, user_id, memory_id, sequence_no, action, memory_key, previous_value, new_value, source_event_ids, candidate, created_at + FROM memory_revisions + WHERE memory_id = %s + ORDER BY sequence_no ASC + LIMIT %s + """, + (memory_id, 2), + ), + ] + + +def test_memory_review_label_methods_use_append_only_queries_and_deterministic_order() -> None: + memory_id = uuid4() + reviewer_user_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "correct", + "note": "Supported by the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + } + ], + fetchall_results=[ + [ + { + "id": uuid4(), + "user_id": reviewer_user_id, + "memory_id": memory_id, + "label": "correct", + "note": "Supported by the latest event.", + "created_at": "2026-03-12T09:00:00+00:00", + } + ], + [ + {"label": "correct", "count": 1}, + {"label": "outdated", "count": 2}, + ], + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_memory_review_label( + memory_id=memory_id, + label="correct", + note="Supported by the latest event.", + ) + listed = store.list_memory_review_labels(memory_id) + counts = store.list_memory_review_label_counts(memory_id) + + assert created["memory_id"] == memory_id + assert listed[0]["label"] == "correct" + assert counts == [{"label": "correct", "count": 1}, {"label": "outdated", "count": 2}] + assert cursor.executed == [ + ( + """ + INSERT INTO memory_review_labels (user_id, memory_id, label, note) + VALUES (app.current_user_id(), %s, %s, %s) + RETURNING id, user_id, memory_id, label, note, created_at + """, + (memory_id, "correct", "Supported by the latest event."), + ), + ( + """ + SELECT id, user_id, memory_id, label, note, created_at + FROM memory_review_labels + WHERE memory_id = %s + ORDER BY created_at ASC, id ASC + """, + (memory_id,), + ), + ( + """ + SELECT label, COUNT(*) AS count + FROM memory_review_labels + WHERE memory_id = %s + GROUP BY label + ORDER BY label ASC + """, + (memory_id,), + ), + ] diff --git a/tests/unit/test_ops_assets.py b/tests/unit/test_ops_assets.py new file mode 100644 index 0000000..ddabcbb --- /dev/null +++ b/tests/unit/test_ops_assets.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def test_dev_up_waits_for_postgres_and_role_bootstrap() -> None: + script = (REPO_ROOT / "scripts" / "dev_up.sh").read_text() + + assert "Timed out waiting for Postgres readiness and alicebot_app bootstrap" in script + assert "SELECT 1 FROM pg_roles WHERE rolname = %s" in script + + +def test_runtime_role_init_only_grants_connect_on_alicebot_database() -> None: + init_sql = (REPO_ROOT / "infra" / "postgres" / "init" / "001_roles.sql").read_text() + + assert "GRANT CONNECT ON DATABASE alicebot TO alicebot_app;" in init_sql + assert "GRANT CONNECT ON DATABASE postgres TO alicebot_app;" not in init_sql diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py new file mode 100644 index 0000000..9f5c40e --- /dev/null +++ b/tests/unit/test_policy.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import ConsentUpsertInput, PolicyCreateInput, PolicyEvaluationRequestInput +from alicebot_api.policy import ( + PolicyEvaluationValidationError, + PolicyNotFoundError, + create_policy_record, + evaluate_policy_request, + get_policy_record, + list_consent_records, + list_policy_records, + upsert_consent_record, +) + + +class PolicyStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.consents: dict[str, dict[str, object]] = {} + self.policies: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def create_consent(self, *, consent_key: str, status: str, metadata: dict[str, object]) -> dict[str, object]: + consent = { + "id": uuid4(), + "user_id": self.user_id, + "consent_key": consent_key, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.consents)), + "updated_at": self.base_time + timedelta(minutes=len(self.consents)), + } + self.consents[consent_key] = consent + return consent + + def get_consent_by_key_optional(self, consent_key: str) -> dict[str, object] | None: + return self.consents.get(consent_key) + + def list_consents(self) -> list[dict[str, object]]: + return sorted( + self.consents.values(), + key=lambda consent: (consent["consent_key"], consent["created_at"], consent["id"]), + ) + + def update_consent(self, *, consent_id: UUID, status: str, metadata: dict[str, object]) -> dict[str, object]: + for consent in self.consents.values(): + if consent["id"] != consent_id: + continue + consent["status"] = status + consent["metadata"] = metadata + consent["updated_at"] = consent["updated_at"] + timedelta(minutes=5) + return consent + raise AssertionError("missing consent") + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: dict[str, object], + required_consents: list[str], + ) -> dict[str, object]: + policy = { + "id": uuid4(), + "user_id": self.user_id, + "name": name, + "action": action, + "scope": scope, + "effect": effect, + "priority": priority, + "active": active, + "conditions": conditions, + "required_consents": required_consents, + "created_at": self.base_time + timedelta(minutes=len(self.policies)), + "updated_at": self.base_time + timedelta(minutes=len(self.policies)), + } + self.policies.append(policy) + return policy + + def list_policies(self) -> list[dict[str, object]]: + return sorted( + self.policies, + key=lambda policy: (policy["priority"], policy["created_at"], policy["id"]), + ) + + def get_policy_optional(self, policy_id: UUID) -> dict[str, object] | None: + return next((policy for policy in self.policies if policy["id"] == policy_id), None) + + def list_active_policies(self) -> list[dict[str, object]]: + return [policy for policy in self.list_policies() if policy["active"] is True] + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Policy thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time, + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time, + } + self.trace_events.append(event) + return event + + +def test_upsert_consent_record_creates_and_updates_in_place() -> None: + store = PolicyStoreStub() + + created = upsert_consent_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + consent=ConsentUpsertInput( + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ), + ) + updated = upsert_consent_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + consent=ConsentUpsertInput( + consent_key="email_marketing", + status="revoked", + metadata={"source": "banner"}, + ), + ) + + assert created["write_mode"] == "created" + assert updated["write_mode"] == "updated" + assert updated["consent"]["id"] == created["consent"]["id"] + assert updated["consent"]["status"] == "revoked" + assert updated["consent"]["metadata"] == {"source": "banner"} + + +def test_list_consent_records_returns_deterministic_shape() -> None: + store = PolicyStoreStub() + zeta = store.create_consent(consent_key="zeta", status="granted", metadata={}) + alpha = store.create_consent(consent_key="alpha", status="revoked", metadata={"reason": "user"}) + + payload = list_consent_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + + assert payload == { + "items": [ + { + "id": str(alpha["id"]), + "consent_key": "alpha", + "status": "revoked", + "metadata": {"reason": "user"}, + "created_at": alpha["created_at"].isoformat(), + "updated_at": alpha["updated_at"].isoformat(), + }, + { + "id": str(zeta["id"]), + "consent_key": "zeta", + "status": "granted", + "metadata": {}, + "created_at": zeta["created_at"].isoformat(), + "updated_at": zeta["updated_at"].isoformat(), + }, + ], + "summary": { + "total_count": 2, + "order": ["consent_key_asc", "created_at_asc", "id_asc"], + }, + } + + +def test_create_and_list_policy_records_preserve_priority_order_and_shape() -> None: + store = PolicyStoreStub() + first = create_policy_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + policy=PolicyCreateInput( + name="Require approval for exports", + action="memory.export", + scope="profile", + effect="require_approval", + priority=20, + active=True, + conditions={"channel": "email"}, + required_consents=("email_marketing", "email_marketing"), + ), + ) + second_policy = store.create_policy( + name="Allow low risk read", + action="memory.read", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=[], + ) + + list_payload = list_policy_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail_payload = get_policy_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + policy_id=UUID(first["policy"]["id"]), + ) + + assert first["policy"]["required_consents"] == ["email_marketing"] + assert [item["id"] for item in list_payload["items"]] == [ + str(second_policy["id"]), + first["policy"]["id"], + ] + assert list_payload["summary"] == { + "total_count": 2, + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert detail_payload == {"policy": first["policy"]} + + +def test_get_policy_record_raises_not_found_for_inaccessible_policy() -> None: + with pytest.raises(PolicyNotFoundError, match="policy .* was not found"): + get_policy_record( + PolicyStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + policy_id=uuid4(), + ) + + +def test_evaluate_policy_request_uses_first_matching_policy_and_emits_trace() -> None: + store = PolicyStoreStub() + store.create_consent(consent_key="email_marketing", status="granted", metadata={"source": "settings"}) + higher_priority_match = store.create_policy( + name="Allow email export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={"channel": "email"}, + required_consents=["email_marketing"], + ) + store.create_policy( + name="Deny fallback export", + action="memory.export", + scope="profile", + effect="deny", + priority=20, + active=True, + conditions={"channel": "email"}, + required_consents=[], + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={"channel": "email"}, + ), + ) + + assert payload["decision"] == "allow" + assert payload["matched_policy"]["id"] == str(higher_priority_match["id"]) + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "policy_effect_allow", + ] + assert payload["evaluation"] == { + "action": "memory.export", + "scope": "profile", + "evaluated_policy_count": 2, + "matched_policy_id": str(higher_priority_match["id"]), + "order": ["priority_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 3 + assert [event["kind"] for event in store.trace_events] == [ + "policy.evaluate.request", + "policy.evaluate.order", + "policy.evaluate.decision", + ] + + +def test_evaluate_policy_request_denies_when_required_consent_is_missing() -> None: + store = PolicyStoreStub() + matched_policy = store.create_policy( + name="Allow export with consent", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={}, + ), + ) + + assert payload["decision"] == "deny" + assert payload["matched_policy"]["id"] == str(matched_policy["id"]) + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "consent_missing", + ] + + +def test_evaluate_policy_request_denies_when_required_consent_is_revoked() -> None: + store = PolicyStoreStub() + matched_policy = store.create_policy( + name="Allow export with consent", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={}, + required_consents=["email_marketing"], + ) + store.create_consent( + consent_key="email_marketing", + status="revoked", + metadata={"source": "settings"}, + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={}, + ), + ) + + assert payload["decision"] == "deny" + assert payload["matched_policy"]["id"] == str(matched_policy["id"]) + assert [reason["code"] for reason in payload["reasons"]] == [ + "matched_policy", + "consent_revoked", + ] + + +def test_evaluate_policy_request_returns_require_approval_and_validates_thread_scope() -> None: + store = PolicyStoreStub() + matched_policy = store.create_policy( + name="Escalate export", + action="memory.export", + scope="profile", + effect="require_approval", + priority=10, + active=True, + conditions={}, + required_consents=[], + ) + + payload = evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=store.thread_id, + action="memory.export", + scope="profile", + attributes={}, + ), + ) + + assert payload["decision"] == "require_approval" + assert payload["matched_policy"]["id"] == str(matched_policy["id"]) + assert payload["reasons"][-1]["code"] == "policy_effect_require_approval" + + with pytest.raises( + PolicyEvaluationValidationError, + match="thread_id must reference an existing thread owned by the user", + ): + evaluate_policy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=PolicyEvaluationRequestInput( + thread_id=uuid4(), + action="memory.export", + scope="profile", + attributes={}, + ), + ) diff --git a/tests/unit/test_policy_main.py b/tests/unit/test_policy_main.py new file mode 100644 index 0000000..fa3e4e5 --- /dev/null +++ b/tests/unit/test_policy_main.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.policy import PolicyEvaluationValidationError, PolicyNotFoundError + + +def test_upsert_consent_endpoint_translates_request_and_returns_created_status(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_upsert_consent_record(store, *, user_id, consent): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["consent"] = consent + return { + "consent": { + "id": "consent-123", + "consent_key": "email_marketing", + "status": "granted", + "metadata": {"source": "settings"}, + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + }, + "write_mode": "created", + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "upsert_consent_record", fake_upsert_consent_record) + + response = main_module.upsert_consent( + main_module.UpsertConsentRequest( + user_id=user_id, + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body) == { + "consent": { + "id": "consent-123", + "consent_key": "email_marketing", + "status": "granted", + "metadata": {"source": "settings"}, + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + }, + "write_mode": "created", + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["consent"].consent_key == "email_marketing" + assert captured["consent"].status == "granted" + assert captured["consent"].metadata == {"source": "settings"} + + +def test_get_policy_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + policy_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_policy_record(*_args, **_kwargs): + raise PolicyNotFoundError(f"policy {policy_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_policy_record", fake_get_policy_record) + + response = main_module.get_policy(policy_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"policy {policy_id} was not found"} + + +def test_evaluate_policy_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_evaluate_policy_request(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "decision": "allow", + "matched_policy": { + "id": "policy-123", + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:00:00+00:00", + }, + "reasons": [ + { + "code": "matched_policy", + "source": "policy", + "message": "Matched policy 'Allow export' at priority 10.", + "policy_id": "policy-123", + "consent_key": None, + }, + { + "code": "policy_effect_allow", + "source": "policy", + "message": "Policy effect resolved the decision to 'allow'.", + "policy_id": "policy-123", + "consent_key": None, + }, + ], + "evaluation": { + "action": "memory.export", + "scope": "profile", + "evaluated_policy_count": 1, + "matched_policy_id": "policy-123", + "order": ["priority_asc", "created_at_asc", "id_asc"], + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_policy_request", fake_evaluate_policy_request) + + response = main_module.evaluate_policy( + main_module.EvaluatePolicyRequest( + user_id=user_id, + thread_id=thread_id, + action="memory.export", + scope="profile", + attributes={"channel": "email"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == {"trace_id": "trace-123", "trace_event_count": 3} + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].action == "memory.export" + assert captured["request"].scope == "profile" + assert captured["request"].attributes == {"channel": "email"} + + +def test_evaluate_policy_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_evaluate_policy_request(*_args, **_kwargs): + raise PolicyEvaluationValidationError("thread_id must reference an existing thread owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_policy_request", fake_evaluate_policy_request) + + response = main_module.evaluate_policy( + main_module.EvaluatePolicyRequest( + user_id=user_id, + thread_id=uuid4(), + action="memory.export", + scope="profile", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "thread_id must reference an existing thread owned by the user" + } diff --git a/tests/unit/test_policy_store.py b/tests/unit/test_policy_store.py new file mode 100644 index 0000000..6d33734 --- /dev/null +++ b/tests/unit/test_policy_store.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_consent_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + consent_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + {"id": consent_id, "consent_key": "email_marketing", "status": "granted", "metadata": {}}, + {"id": consent_id, "consent_key": "email_marketing", "status": "revoked", "metadata": {"source": "banner"}}, + ], + fetchall_result=[{"id": consent_id, "consent_key": "email_marketing", "status": "revoked", "metadata": {}}], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_consent( + consent_key="email_marketing", + status="granted", + metadata={"source": "settings"}, + ) + updated = store.update_consent( + consent_id=consent_id, + status="revoked", + metadata={"source": "banner"}, + ) + listed = store.list_consents() + + assert created["id"] == consent_id + assert updated["status"] == "revoked" + assert listed == [{"id": consent_id, "consent_key": "email_marketing", "status": "revoked", "metadata": {}}] + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO consents" in create_query + assert create_params is not None + assert create_params[:2] == ("email_marketing", "granted") + assert isinstance(create_params[2], Jsonb) + assert create_params[2].obj == {"source": "settings"} + + update_query, update_params = cursor.executed[1] + assert "UPDATE consents" in update_query + assert update_params is not None + assert update_params[0] == "revoked" + assert isinstance(update_params[1], Jsonb) + assert update_params[1].obj == {"source": "banner"} + assert update_params[2] == consent_id + + assert cursor.executed[2] == ( + """ + SELECT id, user_id, consent_key, status, metadata, created_at, updated_at + FROM consents + ORDER BY consent_key ASC, created_at ASC, id ASC + """, + None, + ) + + +def test_policy_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + policy_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": policy_id, + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + }, + { + "id": policy_id, + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + }, + ], + fetchall_result=[ + { + "id": policy_id, + "name": "Allow export", + "action": "memory.export", + "scope": "profile", + "effect": "allow", + "priority": 10, + "active": True, + "conditions": {"channel": "email"}, + "required_consents": ["email_marketing"], + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_policy( + name="Allow export", + action="memory.export", + scope="profile", + effect="allow", + priority=10, + active=True, + conditions={"channel": "email"}, + required_consents=["email_marketing"], + ) + fetched = store.get_policy_optional(policy_id) + listed = store.list_active_policies() + + assert created["id"] == policy_id + assert fetched is not None + assert listed[0]["id"] == policy_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO policies" in create_query + assert create_params is not None + assert create_params[:6] == ("Allow export", "memory.export", "profile", "allow", 10, True) + assert isinstance(create_params[6], Jsonb) + assert create_params[6].obj == {"channel": "email"} + assert isinstance(create_params[7], Jsonb) + assert create_params[7].obj == ["email_marketing"] + + assert cursor.executed[1] == ( + """ + SELECT + id, + user_id, + name, + action, + scope, + effect, + priority, + active, + conditions, + required_consents, + created_at, + updated_at + FROM policies + WHERE id = %s + """, + (policy_id,), + ) + assert "WHERE active = TRUE" in cursor.executed[2][0] diff --git a/tests/unit/test_proxy_execution.py b/tests/unit/test_proxy_execution.py new file mode 100644 index 0000000..d1f4a61 --- /dev/null +++ b/tests/unit/test_proxy_execution.py @@ -0,0 +1,783 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.approvals import ApprovalNotFoundError +from alicebot_api.contracts import ProxyExecutionRequestInput +from alicebot_api.proxy_execution import ( + PROXY_EXECUTION_REQUEST_EVENT_KIND, + PROXY_EXECUTION_RESULT_EVENT_KIND, + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, + execute_approved_proxy_request, + registered_proxy_handler_keys, +) + + +class ProxyExecutionStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.locked_task_ids: list[UUID] = [] + self.approvals: dict[UUID, dict[str, object]] = {} + self.tasks: list[dict[str, object]] = [] + self.task_steps: list[dict[str, object]] = [] + self.events: list[dict[str, object]] = [] + self.tool_executions: list[dict[str, object]] = [] + self.execution_budgets: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def current_time(self) -> datetime: + return self.base_time + timedelta(minutes=len(self.tool_executions)) + + def seed_approval(self, *, status: str, tool_key: str) -> dict[str, object]: + approval_id = uuid4() + tool_id = uuid4() + created_at = self.base_time + timedelta(minutes=len(self.approvals)) + approval = { + "id": approval_id, + "user_id": self.user_id, + "thread_id": self.thread_id, + "tool_id": tool_id, + "task_step_id": None, + "status": status, + "request": { + "thread_id": str(self.thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + "tool": { + "id": str(tool_id), + "tool_key": tool_key, + "name": "Proxy Echo" if tool_key == "proxy.echo" else "Unregistered Proxy", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": created_at.isoformat(), + }, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": str(uuid4()), "trace_event_count": 3}, + }, + "routing_trace_id": uuid4(), + "created_at": created_at, + "resolved_at": None if status == "pending" else created_at + timedelta(minutes=30), + "resolved_by_user_id": None if status == "pending" else self.user_id, + } + self.approvals[approval_id] = approval + task = self.create_task( + thread_id=self.thread_id, + tool_id=tool_id, + status={ + "pending": "pending_approval", + "approved": "approved", + "rejected": "denied", + }[status], + request=approval["request"], + tool=approval["tool"], + latest_approval_id=approval_id, + latest_execution_id=None, + ) + task_step = self.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status={ + "pending": "created", + "approved": "approved", + "rejected": "denied", + }[status], + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval_id), + "approval_status": status, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="approval.request" if status == "pending" else "approval.resolve", + ) + approval["task_step_id"] = task_step["id"] + return approval + + def seed_execution_budget( + self, + *, + tool_key: str | None, + domain_hint: str | None, + max_completed_executions: int, + rolling_window_seconds: int | None = None, + supersedes_budget_id: UUID | None = None, + ) -> dict[str, object]: + row = { + "id": uuid4(), + "user_id": self.user_id, + "tool_key": tool_key, + "domain_hint": domain_hint, + "max_completed_executions": max_completed_executions, + "rolling_window_seconds": rolling_window_seconds, + "status": "active", + "deactivated_at": None, + "superseded_by_budget_id": None, + "supersedes_budget_id": supersedes_budget_id, + "created_at": self.base_time + timedelta(minutes=len(self.execution_budgets)), + } + self.execution_budgets.append(row) + self.execution_budgets.sort(key=lambda item: (item["created_at"], item["id"])) + return row + + def get_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return self.approvals.get(approval_id) + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + def create_task( + self, + *, + thread_id: UUID, + tool_id: UUID, + status: str, + request: dict[str, object], + tool: dict[str, object], + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object]: + task = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "tool_id": tool_id, + "status": status, + "request": request, + "tool": tool, + "latest_approval_id": latest_approval_id, + "latest_execution_id": latest_execution_id, + "created_at": self.base_time + timedelta(minutes=len(self.tasks)), + "updated_at": self.base_time + timedelta(minutes=len(self.tasks)), + } + self.tasks.append(task) + return task + + def get_task_by_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["latest_approval_id"] == approval_id), None) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def get_task_step_optional(self, task_step_id: UUID) -> dict[str, object] | None: + return next((task_step for task_step in self.task_steps if task_step["id"] == task_step_id), None) + + def lock_task_steps(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def update_task_execution_by_approval_optional( + self, + *, + approval_id: UUID, + latest_execution_id: UUID, + status: str, + ) -> dict[str, object] | None: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + return None + task["status"] = status + task["latest_execution_id"] = latest_execution_id + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def append_event( + self, + thread_id: UUID, + session_id: UUID | None, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": thread_id, + "session_id": session_id, + "sequence_no": len(self.events) + 1, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.events)), + } + self.events.append(event) + return event + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: dict[str, object], + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object]: + task_step = { + "id": uuid4(), + "user_id": self.user_id, + "task_id": task_id, + "sequence_no": sequence_no, + "parent_step_id": parent_step_id, + "source_approval_id": source_approval_id, + "source_execution_id": source_execution_id, + "kind": kind, + "status": status, + "request": request, + "outcome": outcome, + "trace_id": trace_id, + "trace_kind": trace_kind, + "created_at": self.base_time + timedelta(minutes=len(self.task_steps)), + "updated_at": self.base_time + timedelta(minutes=len(self.task_steps)), + } + self.task_steps.append(task_step) + return task_step + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> dict[str, object] | None: + return next( + ( + task_step + for task_step in self.task_steps + if task_step["task_id"] == task_id and task_step["sequence_no"] == sequence_no + ), + None, + ) + + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return sorted( + [task_step for task_step in self.task_steps if task_step["task_id"] == task_id], + key=lambda task_step: (task_step["sequence_no"], task_step["created_at"], task_step["id"]), + ) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_for_task_sequence_optional(task_id=task_id, sequence_no=sequence_no) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_optional(task_step_id) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task_step + + def create_tool_execution( + self, + *, + approval_id: UUID, + task_step_id: UUID, + thread_id: UUID, + tool_id: UUID, + trace_id: UUID, + request_event_id: UUID | None, + result_event_id: UUID | None, + status: str, + handler_key: str | None, + request: dict[str, object], + tool: dict[str, object], + result: dict[str, object], + ) -> dict[str, object]: + execution = { + "id": uuid4(), + "user_id": self.user_id, + "approval_id": approval_id, + "task_step_id": task_step_id, + "thread_id": thread_id, + "tool_id": tool_id, + "trace_id": trace_id, + "request_event_id": request_event_id, + "result_event_id": result_event_id, + "status": status, + "handler_key": handler_key, + "request": request, + "tool": tool, + "result": result, + "executed_at": self.base_time + timedelta(minutes=len(self.tool_executions)), + } + self.tool_executions.append(execution) + return execution + + def list_execution_budgets(self) -> list[dict[str, object]]: + return list(self.execution_budgets) + + def list_tool_executions(self) -> list[dict[str, object]]: + return list(self.tool_executions) + + +def test_execute_approved_proxy_request_returns_result_and_persists_events() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert list(payload) == ["request", "approval", "tool", "result", "events", "trace"] + assert payload["request"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(approval["task_step_id"]), + } + assert payload["approval"]["status"] == "approved" + assert payload["tool"]["tool_key"] == "proxy.echo" + assert payload["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": { + "mode": "no_side_effect", + "tool_key": "proxy.echo", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello", "count": 2}, + }, + } + assert payload["events"]["request_sequence_no"] == 1 + assert payload["events"]["result_sequence_no"] == 2 + assert payload["trace"]["trace_event_count"] == 9 + assert len(store.tool_executions) == 1 + assert store.tool_executions[0]["approval_id"] == approval["id"] + assert store.tool_executions[0]["task_step_id"] == approval["task_step_id"] + assert store.tool_executions[0]["trace_id"] == UUID(payload["trace"]["trace_id"]) + assert store.tool_executions[0]["handler_key"] == "proxy.echo" + assert store.tasks[0]["status"] == "executed" + assert store.task_steps[0]["status"] == "executed" + assert store.tasks[0]["latest_execution_id"] == store.tool_executions[0]["id"] + assert store.tool_executions[0]["result"] == { + "handler_key": "proxy.echo", + "status": "completed", + "output": payload["result"]["output"], + "reason": None, + } + assert [event["kind"] for event in store.events] == [ + PROXY_EXECUTION_REQUEST_EVENT_KIND, + PROXY_EXECUTION_RESULT_EVENT_KIND, + ] + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + + +def test_execute_approved_proxy_request_locks_task_steps_before_persisting_execution_state() -> None: + class LockingProxyExecutionStoreStub(ProxyExecutionStoreStub): + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + if task_id not in self.locked_task_ids: + raise AssertionError("task-step boundary was checked before the task-step lock was taken") + return super().list_task_steps_for_task(task_id) + + def create_tool_execution( + self, + *, + approval_id: UUID, + task_step_id: UUID, + thread_id: UUID, + tool_id: UUID, + trace_id: UUID, + request_event_id: UUID | None, + result_event_id: UUID | None, + status: str, + handler_key: str | None, + request: dict[str, object], + tool: dict[str, object], + result: dict[str, object], + ) -> dict[str, object]: + task = self.get_task_by_approval_optional(approval_id) + if task is None: + raise AssertionError("expected task for approval before execution persistence") + if task["id"] not in self.locked_task_ids: + raise AssertionError("tool execution persisted before the task-step lock was taken") + return super().create_tool_execution( + approval_id=approval_id, + task_step_id=task_step_id, + thread_id=thread_id, + tool_id=tool_id, + trace_id=trace_id, + request_event_id=request_event_id, + result_event_id=result_event_id, + status=status, + handler_key=handler_key, + request=request, + tool=tool, + result=result, + ) + + store = LockingProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert payload["result"]["status"] == "completed" + assert store.tasks[0]["id"] in store.locked_task_ids + + +def test_execute_approved_proxy_request_updates_the_linked_later_step_without_mutating_the_original_step() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + task = store.tasks[0] + first_step = store.task_steps[0] + initial_execution_id = uuid4() + task["status"] = "pending_approval" + task["latest_execution_id"] = None + first_step["status"] = "executed" + first_step["outcome"] = { + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "approved", + "execution_id": str(initial_execution_id), + "execution_status": "completed", + "blocked_reason": None, + } + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval["id"], + source_execution_id=initial_execution_id, + kind="governed_request", + status="created", + request=approval["request"], + outcome={ + "routing_decision": "approval_required", + "approval_id": str(approval["id"]), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + + original_first_trace_id = first_step["trace_id"] + original_first_outcome = dict(first_step["outcome"]) + original_later_trace_id = later_step["trace_id"] + approval["task_step_id"] = later_step["id"] + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert payload["result"]["status"] == "completed" + assert task["status"] == "executed" + assert task["latest_execution_id"] == store.tool_executions[0]["id"] + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "executed" + assert later_step["trace_id"] == UUID(payload["trace"]["trace_id"]) + assert later_step["trace_id"] != original_later_trace_id + assert later_step["outcome"]["execution_id"] == str(store.tool_executions[0]["id"]) + assert later_step["outcome"]["execution_status"] == "completed" + assert store.tool_executions[0]["task_step_id"] == later_step["id"] + assert store.events[0]["payload"]["task_step_id"] == str(later_step["id"]) + assert store.events[1]["payload"]["task_step_id"] == str(later_step["id"]) + assert store.trace_events[0]["payload"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(later_step["id"]), + } + assert store.trace_events[3]["payload"]["task_step_id"] == str(later_step["id"]) + assert store.trace_events[4]["payload"]["task_step_id"] == str(later_step["id"]) + + +@pytest.mark.parametrize("status", ["pending", "rejected"]) +def test_execute_approved_proxy_request_rejects_non_approved_statuses(status: str) -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status=status, tool_key="proxy.echo") + + with pytest.raises( + ProxyExecutionApprovalStateError, + match=rf"approval {approval['id']} is {status} and cannot be executed", + ): + execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert store.events == [] + assert store.tool_executions == [] + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + ] + assert store.trace_events[2]["payload"]["dispatch_status"] == "blocked" + assert store.trace_events[3]["payload"]["execution_status"] == "blocked" + + +def test_execute_approved_proxy_request_rejects_missing_handlers() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.missing") + + with pytest.raises( + ProxyExecutionHandlerNotFoundError, + match="tool 'proxy.missing' has no registered proxy handler", + ): + execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert store.events == [] + assert len(store.tool_executions) == 1 + assert store.tool_executions[0]["status"] == "blocked" + assert store.tool_executions[0]["task_step_id"] == approval["task_step_id"] + assert store.tool_executions[0]["handler_key"] is None + assert store.tool_executions[0]["request_event_id"] is None + assert store.tool_executions[0]["result_event_id"] is None + assert store.tasks[0]["status"] == "blocked" + assert store.task_steps[0]["status"] == "blocked" + assert store.tasks[0]["latest_execution_id"] == store.tool_executions[0]["id"] + assert store.tool_executions[0]["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + } + assert store.trace_events[2]["payload"]["decision"] == "allow" + assert store.trace_events[3]["payload"] == { + "approval_id": str(approval["id"]), + "task_step_id": str(approval["task_step_id"]), + "tool_id": approval["tool"]["id"], + "tool_key": "proxy.missing", + "handler_key": None, + "dispatch_status": "blocked", + "reason": "tool 'proxy.missing' has no registered proxy handler", + "result_status": "blocked", + "output": None, + } + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + + +def test_execute_approved_proxy_request_returns_blocked_budget_response_and_persists_review_record() -> None: + store = ProxyExecutionStoreStub() + approval = store.seed_approval(status="approved", tool_key="proxy.echo") + budget = store.seed_execution_budget( + tool_key="proxy.echo", + domain_hint=None, + max_completed_executions=1, + ) + store.create_tool_execution( + approval_id=uuid4(), + task_step_id=uuid4(), + thread_id=store.thread_id, + tool_id=UUID(approval["tool"]["id"]), + trace_id=uuid4(), + request_event_id=uuid4(), + result_event_id=uuid4(), + status="completed", + handler_key="proxy.echo", + request={ + "thread_id": str(store.thread_id), + "tool_id": approval["tool"]["id"], + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "seed"}, + }, + tool=approval["tool"], + result={ + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + ) + + payload = execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=approval["id"]), + ) + + assert payload["events"] is None + assert payload["trace"]["trace_event_count"] == 9 + assert payload["result"] == { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": ( + f"execution budget {budget['id']} blocks execution: projected completed executions " + "2 would exceed limit 1" + ), + "budget_decision": { + "matched_budget_id": str(budget["id"]), + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + } + assert len(store.events) == 0 + assert len(store.tool_executions) == 2 + assert store.tool_executions[-1]["status"] == "blocked" + assert store.tool_executions[-1]["request_event_id"] is None + assert store.tool_executions[-1]["result_event_id"] is None + assert store.tasks[0]["status"] == "blocked" + assert store.task_steps[0]["status"] == "blocked" + assert store.tasks[0]["latest_execution_id"] == store.tool_executions[-1]["id"] + assert store.tool_executions[-1]["result"] == payload["result"] + assert [event["kind"] for event in store.trace_events] == [ + "tool.proxy.execute.request", + "tool.proxy.execute.approval", + "tool.proxy.execute.budget", + "tool.proxy.execute.dispatch", + "tool.proxy.execute.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[2]["payload"] == payload["result"]["budget_decision"] + + +def test_execute_approved_proxy_request_rejects_missing_visible_approval() -> None: + store = ProxyExecutionStoreStub() + + with pytest.raises(ApprovalNotFoundError, match="was not found"): + execute_approved_proxy_request( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ProxyExecutionRequestInput(approval_id=uuid4()), + ) + + +def test_registered_proxy_handler_keys_are_sorted_and_explicit() -> None: + assert registered_proxy_handler_keys() == ("proxy.echo",) diff --git a/tests/unit/test_proxy_execution_main.py b/tests/unit/test_proxy_execution_main.py new file mode 100644 index 0000000..b98b7bb --- /dev/null +++ b/tests/unit/test_proxy_execution_main.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.approvals import ApprovalNotFoundError +from alicebot_api.proxy_execution import ( + ProxyExecutionApprovalStateError, + ProxyExecutionHandlerNotFoundError, +) +from alicebot_api.tasks import TaskStepApprovalLinkageError + + +def test_execute_approved_proxy_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_execute_approved_proxy_request(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "request": {"approval_id": str(approval_id), "task_step_id": "task-step-123"}, + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-123", + "status": "approved", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-13T09:00:00+00:00", + "resolution": { + "resolved_at": "2026-03-13T09:30:00+00:00", + "resolved_by_user_id": str(user_id), + }, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + }, + "events": { + "request_event_id": "event-request-123", + "request_sequence_no": 1, + "result_event_id": "event-result-123", + "result_sequence_no": 2, + }, + "trace": {"trace_id": "proxy-trace-123", "trace_event_count": 5}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["request"] == { + "approval_id": str(approval_id), + "task_step_id": "task-step-123", + } + assert json.loads(response.body)["trace"] == { + "trace_id": "proxy-trace-123", + "trace_event_count": 5, + } + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].approval_id == approval_id + + +def test_execute_approved_proxy_endpoint_maps_missing_approval_to_404(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise ApprovalNotFoundError(f"approval {approval_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"approval {approval_id} was not found"} + + +def test_execute_approved_proxy_endpoint_maps_blocked_approval_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise ProxyExecutionApprovalStateError( + f"approval {approval_id} is pending and cannot be executed" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"approval {approval_id} is pending and cannot be executed" + } + + +def test_execute_approved_proxy_endpoint_maps_missing_handler_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise ProxyExecutionHandlerNotFoundError( + "tool 'proxy.missing' has no registered proxy handler" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": "tool 'proxy.missing' has no registered proxy handler" + } + + +def test_execute_approved_proxy_endpoint_maps_linkage_error_to_409(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + raise TaskStepApprovalLinkageError( + f"approval {approval_id} is missing linked task_step_id" + ) + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"approval {approval_id} is missing linked task_step_id" + } + + +def test_execute_approved_proxy_endpoint_returns_budget_blocked_payload(monkeypatch) -> None: + user_id = uuid4() + approval_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_execute_approved_proxy_request(*_args, **_kwargs): + return { + "request": {"approval_id": str(approval_id), "task_step_id": "task-step-123"}, + "approval": { + "id": str(approval_id), + "thread_id": "thread-123", + "task_step_id": "task-step-123", + "status": "approved", + "request": { + "thread_id": "thread-123", + "tool_id": "tool-123", + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {"message": "hello"}, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "routing": { + "decision": "approval_required", + "reasons": [], + "trace": {"trace_id": "routing-trace-123", "trace_event_count": 3}, + }, + "created_at": "2026-03-13T09:00:00+00:00", + "resolution": { + "resolved_at": "2026-03-13T09:30:00+00:00", + "resolved_by_user_id": str(user_id), + }, + }, + "tool": {"id": "tool-123", "tool_key": "proxy.echo"}, + "result": { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "execution budget budget-123 blocks execution: projected completed executions 2 would exceed limit 1", + "budget_decision": { + "matched_budget_id": "budget-123", + "tool_key": "proxy.echo", + "domain_hint": None, + "budget_tool_key": "proxy.echo", + "budget_domain_hint": None, + "max_completed_executions": 1, + "rolling_window_seconds": None, + "count_scope": "lifetime", + "window_started_at": None, + "completed_execution_count": 1, + "projected_completed_execution_count": 2, + "decision": "block", + "reason": "budget_exceeded", + "order": ["specificity_desc", "created_at_asc", "id_asc"], + "history_order": ["executed_at_asc", "id_asc"], + }, + }, + "events": None, + "trace": {"trace_id": "proxy-trace-456", "trace_event_count": 5}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "execute_approved_proxy_request", fake_execute_approved_proxy_request) + + response = main_module.execute_approved_proxy( + approval_id, + main_module.ExecuteApprovedProxyRequest(user_id=user_id), + ) + + assert response.status_code == 200 + assert json.loads(response.body)["events"] is None diff --git a/tests/unit/test_response_generation.py b/tests/unit/test_response_generation.py new file mode 100644 index 0000000..f91c051 --- /dev/null +++ b/tests/unit/test_response_generation.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +import json + +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.contracts import ( + ModelInvocationRequest, + ModelInvocationResponse, + PROMPT_ASSEMBLY_VERSION_V0, + PromptAssemblyInput, +) +from alicebot_api.response_generation import ( + assemble_prompt, + build_assistant_response_payload, + invoke_model, +) + + +def make_context_pack() -> dict[str, object]: + return { + "compiler_version": "continuity_v0", + "scope": { + "user_id": "11111111-1111-1111-8111-111111111111", + "thread_id": "22222222-2222-2222-8222-222222222222", + }, + "limits": { + "max_sessions": 3, + "max_events": 8, + "max_memories": 5, + "max_entities": 5, + "max_entity_edges": 10, + }, + "user": { + "id": "11111111-1111-1111-8111-111111111111", + "email": "owner@example.com", + "display_name": "Owner", + "created_at": "2026-03-12T09:00:00+00:00", + }, + "thread": { + "id": "22222222-2222-2222-8222-222222222222", + "title": "Thread", + "created_at": "2026-03-12T09:00:00+00:00", + "updated_at": "2026-03-12T09:05:00+00:00", + }, + "sessions": [], + "events": [ + { + "id": "33333333-3333-3333-8333-333333333333", + "session_id": None, + "sequence_no": 1, + "kind": "message.user", + "payload": {"text": "Hello"}, + "created_at": "2026-03-12T09:06:00+00:00", + } + ], + "memories": [ + { + "id": "44444444-4444-4444-8444-444444444444", + "memory_key": "user.preference.coffee", + "value": {"likes": "oat milk"}, + "status": "active", + "source_event_ids": ["33333333-3333-3333-8333-333333333333"], + "created_at": "2026-03-12T09:04:00+00:00", + "updated_at": "2026-03-12T09:05:00+00:00", + "source_provenance": {"sources": ["symbolic"], "semantic_score": None}, + } + ], + "memory_summary": { + "candidate_count": 1, + "included_count": 1, + "excluded_deleted_count": 0, + "excluded_limit_count": 0, + "hybrid_retrieval": { + "requested": False, + "embedding_config_id": None, + "query_vector_dimensions": 0, + "semantic_limit": 0, + "symbolic_selected_count": 1, + "semantic_selected_count": 0, + "merged_candidate_count": 1, + "deduplicated_count": 0, + "included_symbolic_only_count": 1, + "included_semantic_only_count": 0, + "included_dual_source_count": 0, + "similarity_metric": None, + "source_precedence": ["symbolic", "semantic"], + "symbolic_order": ["updated_at_asc", "created_at_asc", "id_asc"], + "semantic_order": ["score_desc", "created_at_asc", "id_asc"], + }, + }, + "entities": [], + "entity_summary": { + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + "entity_edges": [], + "entity_edge_summary": { + "anchor_entity_count": 0, + "candidate_count": 0, + "included_count": 0, + "excluded_limit_count": 0, + }, + } + + +def test_assemble_prompt_is_deterministic_and_explicit() -> None: + first = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + second = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + + assert first.prompt_text == second.prompt_text + assert first.prompt_sha256 == second.prompt_sha256 + assert first.trace_payload == second.trace_payload + assert [section.name for section in first.sections] == [ + "system", + "developer", + "context", + "conversation", + ] + assert "[SYSTEM]\nSystem instruction" in first.prompt_text + assert "[DEVELOPER]\nDeveloper instruction" in first.prompt_text + assert '"memory_key":"user.preference.coffee"' in first.prompt_text + assert first.trace_payload["version"] == PROMPT_ASSEMBLY_VERSION_V0 + assert first.trace_payload["compile_trace_id"] == "compile-trace-123" + assert first.trace_payload["included_event_count"] == 1 + assert first.trace_payload["included_memory_count"] == 1 + + +class FakeHTTPResponse: + def __init__(self, body: bytes) -> None: + self.body = body + + def __enter__(self) -> "FakeHTTPResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def read(self) -> bytes: + return self.body + + +def test_invoke_model_sends_tools_disabled_request_and_parses_response(monkeypatch) -> None: + captured: dict[str, object] = {} + + def fake_urlopen(request, timeout): + captured["url"] = request.full_url + captured["timeout"] = timeout + captured["headers"] = dict(request.header_items()) + captured["body"] = json.loads(request.data.decode("utf-8")) + return FakeHTTPResponse( + json.dumps( + { + "id": "resp_123", + "status": "completed", + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "Assistant reply"}], + } + ], + "usage": { + "input_tokens": 12, + "output_tokens": 4, + "total_tokens": 16, + }, + } + ).encode("utf-8") + ) + + monkeypatch.setattr("alicebot_api.response_generation.urlopen", fake_urlopen) + + prompt = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + response = invoke_model( + settings=Settings( + model_provider="openai_responses", + model_base_url="https://example.test/v1", + model_name="gpt-5-mini", + model_api_key="secret-key", + model_timeout_seconds=17, + ), + request=ModelInvocationRequest( + provider="openai_responses", + model="gpt-5-mini", + prompt=prompt, + ), + ) + + assert captured["url"] == "https://example.test/v1/responses" + assert captured["timeout"] == 17 + assert captured["headers"]["Authorization"] == "Bearer secret-key" + assert captured["body"]["tool_choice"] == "none" + assert captured["body"]["tools"] == [] + assert captured["body"]["store"] is False + assert [item["role"] for item in captured["body"]["input"]] == [ + "system", + "developer", + "user", + "user", + ] + assert response == ModelInvocationResponse( + provider="openai_responses", + model="gpt-5-mini", + response_id="resp_123", + finish_reason="completed", + output_text="Assistant reply", + usage={"input_tokens": 12, "output_tokens": 4, "total_tokens": 16}, + ) + + +def test_build_assistant_response_payload_captures_model_and_prompt_metadata() -> None: + prompt = assemble_prompt( + request=PromptAssemblyInput( + context_pack=make_context_pack(), + system_instruction="System instruction", + developer_instruction="Developer instruction", + ), + compile_trace_id="compile-trace-123", + ) + payload = build_assistant_response_payload( + prompt=prompt, + model_response=ModelInvocationResponse( + provider="openai_responses", + model="gpt-5-mini", + response_id="resp_123", + finish_reason="completed", + output_text="Assistant reply", + usage={"input_tokens": 12, "output_tokens": 4, "total_tokens": 16}, + ), + ) + + assert payload == { + "text": "Assistant reply", + "model": { + "provider": "openai_responses", + "model": "gpt-5-mini", + "response_id": "resp_123", + "finish_reason": "completed", + "usage": {"input_tokens": 12, "output_tokens": 4, "total_tokens": 16}, + }, + "prompt": { + "assembly_version": "prompt_assembly_v0", + "prompt_sha256": prompt.prompt_sha256, + "section_order": ["system", "developer", "context", "conversation"], + }, + } diff --git a/tests/unit/test_semantic_retrieval.py b/tests/unit/test_semantic_retrieval.py new file mode 100644 index 0000000..780b4e4 --- /dev/null +++ b/tests/unit/test_semantic_retrieval.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import SemanticMemoryRetrievalRequestInput +from alicebot_api.semantic_retrieval import ( + SemanticMemoryRetrievalValidationError, + retrieve_semantic_memory_records, +) + + +class SemanticRetrievalStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.config_by_id: dict[UUID, dict[str, object]] = {} + self.retrieval_rows: list[dict[str, object]] = [] + self.last_query: dict[str, object] | None = None + + def get_embedding_config_optional(self, embedding_config_id: UUID) -> dict[str, object] | None: + return self.config_by_id.get(embedding_config_id) + + def retrieve_semantic_memory_matches( + self, + *, + embedding_config_id: UUID, + query_vector: list[float], + limit: int, + ) -> list[dict[str, object]]: + self.last_query = { + "embedding_config_id": embedding_config_id, + "query_vector": query_vector, + "limit": limit, + } + return list(self.retrieval_rows[:limit]) + + +def seed_config(store: SemanticRetrievalStoreStub, *, dimensions: int = 3) -> UUID: + config_id = uuid4() + store.config_by_id[config_id] = { + "id": config_id, + "dimensions": dimensions, + } + return config_id + + +def active_row( + store: SemanticRetrievalStoreStub, + *, + memory_key: str, + score: float, + minute_offset: int, +) -> dict[str, object]: + return { + "id": uuid4(), + "user_id": uuid4(), + "memory_key": memory_key, + "value": {"memory_key": memory_key}, + "status": "active", + "source_event_ids": [str(uuid4())], + "created_at": store.base_time + timedelta(minutes=minute_offset), + "updated_at": store.base_time + timedelta(minutes=minute_offset + 1), + "deleted_at": None, + "score": score, + } + + +def test_retrieve_semantic_memory_records_returns_stable_shape_and_summary() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + first_row = active_row(store, memory_key="user.preference.coffee", score=1.0, minute_offset=0) + second_row = active_row(store, memory_key="user.preference.tea", score=0.75, minute_offset=1) + store.retrieval_rows = [first_row, second_row] + + payload = retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=config_id, + query_vector=(0.1, 0.2, 0.3), + limit=2, + ), + ) + + assert payload == { + "items": [ + { + "memory_id": str(first_row["id"]), + "memory_key": "user.preference.coffee", + "value": {"memory_key": "user.preference.coffee"}, + "source_event_ids": first_row["source_event_ids"], + "created_at": first_row["created_at"].isoformat(), + "updated_at": first_row["updated_at"].isoformat(), + "score": 1.0, + }, + { + "memory_id": str(second_row["id"]), + "memory_key": "user.preference.tea", + "value": {"memory_key": "user.preference.tea"}, + "source_event_ids": second_row["source_event_ids"], + "created_at": second_row["created_at"].isoformat(), + "updated_at": second_row["updated_at"].isoformat(), + "score": 0.75, + }, + ], + "summary": { + "embedding_config_id": str(config_id), + "limit": 2, + "returned_count": 2, + "similarity_metric": "cosine_similarity", + "order": ["score_desc", "created_at_asc", "id_asc"], + }, + } + assert store.last_query == { + "embedding_config_id": config_id, + "query_vector": [0.1, 0.2, 0.3], + "limit": 2, + } + + +def test_retrieve_semantic_memory_records_rejects_missing_config() -> None: + store = SemanticRetrievalStoreStub() + + with pytest.raises( + SemanticMemoryRetrievalValidationError, + match="embedding_config_id must reference an existing embedding config owned by the user", + ): + retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=uuid4(), + query_vector=(0.1, 0.2, 0.3), + ), + ) + + +def test_retrieve_semantic_memory_records_rejects_dimension_mismatch() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + + with pytest.raises( + SemanticMemoryRetrievalValidationError, + match="query_vector length must match embedding config dimensions \\(3\\): 2", + ): + retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=config_id, + query_vector=(0.1, 0.2), + ), + ) + + +def test_retrieve_semantic_memory_records_rejects_non_active_memory_rows() -> None: + store = SemanticRetrievalStoreStub() + config_id = seed_config(store, dimensions=3) + invalid_row = active_row(store, memory_key="user.preference.music", score=0.5, minute_offset=0) + invalid_row["status"] = "deleted" + store.retrieval_rows = [invalid_row] + + with pytest.raises( + SemanticMemoryRetrievalValidationError, + match="semantic retrieval only supports active memories", + ): + retrieve_semantic_memory_records( + store, # type: ignore[arg-type] + user_id=uuid4(), + request=SemanticMemoryRetrievalRequestInput( + embedding_config_id=config_id, + query_vector=(0.1, 0.2, 0.3), + ), + ) diff --git a/tests/unit/test_store.py b/tests/unit/test_store.py new file mode 100644 index 0000000..15c1f1e --- /dev/null +++ b/tests/unit/test_store.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb +import pytest + +from alicebot_api.store import ContinuityStore, ContinuityStoreInvariantError + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_create_methods_return_cursor_rows_and_use_expected_parameters() -> None: + user_id = uuid4() + thread_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + {"id": user_id, "email": "owner@example.com", "display_name": "Owner"}, + {"id": thread_id, "title": "Starter thread"}, + {"id": uuid4(), "thread_id": thread_id, "status": "active"}, + ] + ) + store = ContinuityStore(RecordingConnection(cursor)) + + user = store.create_user(user_id, "owner@example.com", "Owner") + thread = store.create_thread("Starter thread") + session = store.create_session(thread_id) + + assert user["id"] == user_id + assert thread["id"] == thread_id + assert session["thread_id"] == thread_id + assert cursor.executed == [ + ( + """ + INSERT INTO users (id, email, display_name) + VALUES (%s, %s, %s) + RETURNING id, email, display_name, created_at + """, + (user_id, "owner@example.com", "Owner"), + ), + ( + """ + INSERT INTO threads (user_id, title) + VALUES (app.current_user_id(), %s) + RETURNING id, user_id, title, created_at, updated_at + """, + ("Starter thread",), + ), + ( + """ + INSERT INTO sessions (user_id, thread_id, status) + VALUES (app.current_user_id(), %s, %s) + RETURNING id, user_id, thread_id, status, started_at, ended_at, created_at + """, + (thread_id, "active"), + ), + ] + + +def test_append_event_locks_thread_and_serializes_payload() -> None: + thread_id = uuid4() + session_id = uuid4() + payload = {"text": "hello"} + cursor = RecordingCursor( + fetchone_results=[ + { + "id": uuid4(), + "thread_id": thread_id, + "session_id": session_id, + "sequence_no": 1, + "kind": "message.user", + "payload": payload, + } + ] + ) + store = ContinuityStore(RecordingConnection(cursor)) + + event = store.append_event(thread_id, session_id, "message.user", payload) + + assert event["sequence_no"] == 1 + assert cursor.executed[0] == ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 0))", + (str(thread_id),), + ) + insert_query, insert_params = cursor.executed[1] + assert "WITH next_sequence AS" in insert_query + assert insert_params is not None + assert insert_params[:4] == (thread_id, thread_id, session_id, "message.user") + assert isinstance(insert_params[4], Jsonb) + assert insert_params[4].obj == payload + + +def test_list_thread_events_returns_all_rows_in_order() -> None: + thread_id = uuid4() + events = [ + {"sequence_no": 1, "kind": "message.user"}, + {"sequence_no": 2, "kind": "message.assistant"}, + ] + cursor = RecordingCursor(fetchone_results=[], fetchall_result=events) + store = ContinuityStore(RecordingConnection(cursor)) + + result = store.list_thread_events(thread_id) + + assert result == events + assert cursor.executed == [ + ( + """ + SELECT id, user_id, thread_id, session_id, sequence_no, kind, payload, created_at + FROM events + WHERE thread_id = %s + ORDER BY sequence_no ASC + """, + (thread_id,), + ), + ] + + +def test_create_user_raises_clear_error_when_returning_row_is_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + with pytest.raises( + ContinuityStoreInvariantError, + match="create_user did not return a row", + ): + store.create_user(uuid4(), "owner@example.com") diff --git a/tests/unit/test_task_step_store.py b/tests/unit/test_task_step_store.py new file mode 100644 index 0000000..af764b4 --- /dev/null +++ b/tests/unit/test_task_step_store.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_task_step_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + task_step_id = uuid4() + task_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + trace_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "created", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.request", + }, + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "created", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.request", + }, + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "approved", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": str(uuid4()), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.resolve", + }, + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "approved", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": str(uuid4()), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.resolve", + }, + { + "id": task_id, + "user_id": uuid4(), + "thread_id": thread_id, + "tool_id": tool_id, + "status": "approved", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "tool": {"id": str(tool_id), "tool_key": "proxy.echo"}, + "latest_approval_id": None, + "latest_execution_id": None, + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:05:00+00:00", + }, + ], + fetchall_result=[ + { + "id": task_step_id, + "user_id": uuid4(), + "task_id": task_id, + "sequence_no": 1, + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + "kind": "governed_request", + "status": "created", + "request": {"thread_id": str(uuid4()), "tool_id": str(uuid4())}, + "outcome": { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + "trace_id": trace_id, + "trace_kind": "approval.request", + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_task_step( + task_id=task_id, + sequence_no=1, + kind="governed_request", + status="created", + request={"thread_id": "thread-123", "tool_id": "tool-123"}, + outcome={ + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=trace_id, + trace_kind="approval.request", + ) + fetched = store.get_task_step_optional(task_step_id) + listed = store.list_task_steps_for_task(task_id) + updated = store.update_task_step_for_task_sequence_optional( + task_id=task_id, + sequence_no=1, + status="approved", + outcome={ + "routing_decision": "approval_required", + "approval_id": "approval-123", + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=trace_id, + trace_kind="approval.resolve", + ) + updated_by_id = store.update_task_step_optional( + task_step_id=task_step_id, + status="approved", + outcome={ + "routing_decision": "approval_required", + "approval_id": "approval-123", + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + }, + trace_id=trace_id, + trace_kind="approval.resolve", + ) + updated_task = store.update_task_status_optional( + task_id=task_id, + status="approved", + latest_approval_id=None, + latest_execution_id=None, + ) + + assert created["id"] == task_step_id + assert fetched is not None + assert listed[0]["id"] == task_step_id + assert updated is not None + assert updated["status"] == "approved" + assert updated_by_id is not None + assert updated_by_id["status"] == "approved" + assert updated_task is not None + assert updated_task["status"] == "approved" + + lock_query, lock_params = cursor.executed[0] + assert "pg_advisory_xact_lock" in lock_query + assert lock_params == (str(task_id),) + + create_query, create_params = cursor.executed[1] + assert "INSERT INTO task_steps" in create_query + assert create_params is not None + assert create_params[:7] == (task_id, 1, None, None, None, "governed_request", "created") + assert isinstance(create_params[7], Jsonb) + assert create_params[7].obj == {"thread_id": "thread-123", "tool_id": "tool-123"} + assert isinstance(create_params[8], Jsonb) + assert create_params[8].obj == { + "routing_decision": "approval_required", + "approval_id": None, + "approval_status": None, + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + } + assert create_params[9] == trace_id + assert create_params[10] == "approval.request" + assert "FROM task_steps" in cursor.executed[2][0] + assert "ORDER BY sequence_no ASC, created_at ASC, id ASC" in cursor.executed[3][0] + + update_query, update_params = cursor.executed[4] + assert "UPDATE task_steps" in update_query + assert "WHERE task_id = %s" in update_query + assert update_params is not None + assert update_params[0] == "approved" + assert isinstance(update_params[1], Jsonb) + assert update_params[1].obj["approval_status"] == "approved" + assert update_params[2] == trace_id + assert update_params[3] == "approval.resolve" + assert update_params[4:] == (task_id, 1) + + update_by_id_query, update_by_id_params = cursor.executed[5] + assert "UPDATE task_steps" in update_by_id_query + assert "WHERE id = %s" in update_by_id_query + assert update_by_id_params is not None + assert update_by_id_params[0] == "approved" + assert isinstance(update_by_id_params[1], Jsonb) + assert update_by_id_params[1].obj["approval_status"] == "approved" + assert update_by_id_params[2] == trace_id + assert update_by_id_params[3] == "approval.resolve" + assert update_by_id_params[4] == task_step_id + + task_update_query, task_update_params = cursor.executed[6] + assert "UPDATE tasks" in task_update_query + assert task_update_params == ("approved", None, None, task_id) diff --git a/tests/unit/test_task_workspace_store.py b/tests/unit/test_task_workspace_store.py new file mode 100644 index 0000000..16bd0ae --- /dev/null +++ b/tests/unit/test_task_workspace_store.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_task_workspace_store_methods_use_expected_queries() -> None: + task_workspace_id = uuid4() + task_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + }, + ], + fetchall_result=[ + { + "id": task_workspace_id, + "user_id": uuid4(), + "task_id": task_id, + "status": "active", + "local_path": "/tmp/alicebot/task-workspaces/user/task", + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_task_workspace( + task_id=task_id, + status="active", + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + fetched = store.get_task_workspace_optional(task_workspace_id) + active = store.get_active_task_workspace_for_task_optional(task_id) + listed = store.list_task_workspaces() + store.lock_task_workspaces(task_id) + + assert created["id"] == task_workspace_id + assert fetched is not None + assert active is not None + assert listed[0]["id"] == task_workspace_id + assert cursor.executed == [ + ( + """ + INSERT INTO task_workspaces ( + user_id, + task_id, + status, + local_path, + created_at, + updated_at + ) + VALUES ( + app.current_user_id(), + %s, + %s, + %s, + clock_timestamp(), + clock_timestamp() + ) + RETURNING + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + """, + (task_id, "active", "/tmp/alicebot/task-workspaces/user/task"), + ), + ( + """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE id = %s + """, + (task_workspace_id,), + ), + ( + """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + WHERE task_id = %s + AND status = 'active' + ORDER BY created_at ASC, id ASC + LIMIT 1 + """, + (task_id,), + ), + ( + """ + SELECT + id, + user_id, + task_id, + status, + local_path, + created_at, + updated_at + FROM task_workspaces + ORDER BY created_at ASC, id ASC + """, + None, + ), + ( + "SELECT pg_advisory_xact_lock(hashtextextended(%s::text, 3))", + (str(task_id),), + ), + ] diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py new file mode 100644 index 0000000..142f048 --- /dev/null +++ b/tests/unit/test_tasks.py @@ -0,0 +1,1663 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +from alicebot_api.tasks import ( + TaskNotFoundError, + TaskStepApprovalLinkageError, + TaskStepExecutionLinkageError, + TaskStepNotFoundError, + TaskStepSequenceError, + TaskStepTransitionError, + allowed_task_step_transitions, + create_next_task_step_record, + create_task_step_for_governed_request, + get_task_step_record, + get_task_record, + list_task_records, + list_task_step_records, + sync_task_with_task_step_status, + sync_task_step_with_approval, + sync_task_step_with_execution, + task_status_for_step_status, + next_task_status_for_approval, + task_lifecycle_trace_events, + task_step_lifecycle_trace_events, + task_step_outcome_snapshot, + task_step_status_for_approval_status, + task_step_status_for_execution_status, + task_step_status_for_routing_decision, + task_status_for_approval_status, + task_status_for_execution_status, + task_status_for_routing_decision, + transition_task_step_record, +) +from alicebot_api.contracts import ( + TaskStepCreateInput, + TaskStepLineageInput, + TaskStepNextCreateInput, + TaskStepTransitionInput, +) + + +class TaskStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.user_id = uuid4() + self.tasks: list[dict[str, object]] = [] + self.task_steps: list[dict[str, object]] = [] + self.approvals: list[dict[str, object]] = [] + self.tool_executions: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + self.locked_task_ids: list[UUID] = [] + + def create_task( + self, + *, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object]: + task = { + "id": uuid4(), + "user_id": self.user_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + "status": status, + "request": { + "thread_id": str(uuid4()), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + "tool": { + "id": str(uuid4()), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": self.base_time.isoformat(), + }, + "latest_approval_id": latest_approval_id, + "latest_execution_id": latest_execution_id, + "created_at": self.base_time + timedelta(minutes=len(self.tasks)), + "updated_at": self.base_time + timedelta(minutes=len(self.tasks)), + } + self.tasks.append(task) + return task + + def list_tasks(self) -> list[dict[str, object]]: + return sorted(self.tasks, key=lambda task: (task["created_at"], task["id"])) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def get_task_by_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["latest_approval_id"] == approval_id), None) + + def update_task_status_optional( + self, + *, + task_id: UUID, + status: str, + latest_approval_id: UUID | None, + latest_execution_id: UUID | None, + ) -> dict[str, object] | None: + task = self.get_task_optional(task_id) + if task is None: + return None + task["status"] = status + task["latest_approval_id"] = latest_approval_id + task["latest_execution_id"] = latest_execution_id + task["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.trace_events)) + return task + + def lock_task_steps(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def create_task_step( + self, + *, + task_id: UUID, + sequence_no: int, + parent_step_id: UUID | None = None, + source_approval_id: UUID | None = None, + source_execution_id: UUID | None = None, + kind: str, + status: str, + request: dict[str, object], + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object]: + task_step = { + "id": uuid4(), + "user_id": self.user_id, + "task_id": task_id, + "sequence_no": sequence_no, + "parent_step_id": parent_step_id, + "source_approval_id": source_approval_id, + "source_execution_id": source_execution_id, + "kind": kind, + "status": status, + "request": request, + "outcome": outcome, + "trace_id": trace_id, + "trace_kind": trace_kind, + "created_at": self.base_time + timedelta(minutes=len(self.task_steps)), + "updated_at": self.base_time + timedelta(minutes=len(self.task_steps)), + } + self.task_steps.append(task_step) + return task_step + + def get_task_step_optional(self, task_step_id: UUID) -> dict[str, object] | None: + return next((task_step for task_step in self.task_steps if task_step["id"] == task_step_id), None) + + def get_approval_optional(self, approval_id: UUID) -> dict[str, object] | None: + return next((approval for approval in self.approvals if approval["id"] == approval_id), None) + + def get_tool_execution_optional(self, execution_id: UUID) -> dict[str, object] | None: + return next((execution for execution in self.tool_executions if execution["id"] == execution_id), None) + + def get_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + ) -> dict[str, object] | None: + return next( + ( + task_step + for task_step in self.task_steps + if task_step["task_id"] == task_id and task_step["sequence_no"] == sequence_no + ), + None, + ) + + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + return sorted( + [task_step for task_step in self.task_steps if task_step["task_id"] == task_id], + key=lambda task_step: (task_step["sequence_no"], task_step["created_at"], task_step["id"]), + ) + + def update_task_step_for_task_sequence_optional( + self, + *, + task_id: UUID, + sequence_no: int, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_for_task_sequence_optional(task_id=task_id, sequence_no=sequence_no) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.task_steps)) + return task_step + + def update_task_step_optional( + self, + *, + task_step_id: UUID, + status: str, + outcome: dict[str, object], + trace_id: UUID, + trace_kind: str, + ) -> dict[str, object] | None: + task_step = self.get_task_step_optional(task_step_id) + if task_step is None: + return None + task_step["status"] = status + task_step["outcome"] = outcome + task_step["trace_id"] = trace_id + task_step["trace_kind"] = trace_kind + task_step["updated_at"] = self.base_time + timedelta(hours=1, minutes=len(self.task_steps)) + return task_step + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time + timedelta(minutes=len(self.traces)), + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time + timedelta(minutes=len(self.trace_events)), + } + self.trace_events.append(event) + return event + + +def test_list_and_get_task_records_are_deterministic() -> None: + store = TaskStoreStub() + first = store.create_task( + status="approved", + latest_approval_id=None, + latest_execution_id=None, + ) + second = store.create_task( + status="blocked", + latest_approval_id=uuid4(), + latest_execution_id=uuid4(), + ) + + listed = list_task_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_task_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_id=second["id"], + ) + + assert [item["id"] for item in listed["items"]] == [str(first["id"]), str(second["id"])] + assert [item["status"] for item in listed["items"]] == ["approved", "blocked"] + assert listed["summary"] == { + "total_count": 2, + "order": ["created_at_asc", "id_asc"], + } + assert detail["task"]["id"] == str(second["id"]) + assert detail["task"]["status"] == "blocked" + assert detail["task"]["latest_approval_id"] == str(second["latest_approval_id"]) + assert detail["task"]["latest_execution_id"] == str(second["latest_execution_id"]) + + +def test_task_lifecycle_helpers_return_deterministic_statuses_and_trace_payloads() -> None: + assert task_status_for_routing_decision("approval_required") == "pending_approval" + assert task_status_for_routing_decision("ready") == "approved" + assert task_status_for_routing_decision("denied") == "denied" + assert task_status_for_approval_status("approved") == "approved" + assert task_status_for_approval_status("rejected") == "denied" + assert next_task_status_for_approval(current_status="pending_approval", approval_status="approved") == "approved" + assert next_task_status_for_approval(current_status="executed", approval_status="approved") == "executed" + assert task_status_for_execution_status("completed") == "executed" + assert task_status_for_execution_status("blocked") == "blocked" + assert task_step_status_for_routing_decision("approval_required") == "created" + assert task_step_status_for_routing_decision("ready") == "approved" + assert task_step_status_for_routing_decision("denied") == "denied" + assert task_step_status_for_approval_status("approved") == "approved" + assert task_step_status_for_approval_status("rejected") == "denied" + assert task_step_status_for_execution_status("completed") == "executed" + assert task_step_status_for_execution_status("blocked") == "blocked" + assert task_status_for_step_status("created") == "pending_approval" + assert task_status_for_step_status("approved") == "approved" + assert task_status_for_step_status("executed") == "executed" + assert allowed_task_step_transitions("created") == ["approved", "denied"] + assert allowed_task_step_transitions("approved") == ["executed", "blocked"] + assert allowed_task_step_transitions("executed") == [] + + task = { + "id": str(uuid4()), + "thread_id": str(uuid4()), + "tool_id": str(uuid4()), + "status": "executed", + "request": { + "thread_id": str(uuid4()), + "tool_id": str(uuid4()), + "action": "tool.run", + "scope": "workspace", + "domain_hint": None, + "risk_hint": None, + "attributes": {}, + }, + "tool": { + "id": str(uuid4()), + "tool_key": "proxy.echo", + "name": "Proxy Echo", + "description": "Deterministic proxy handler.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["proxy"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": [], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": "2026-03-13T10:00:00+00:00", + }, + "latest_approval_id": str(uuid4()), + "latest_execution_id": str(uuid4()), + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:05:00+00:00", + } + + events = task_lifecycle_trace_events( + task=task, + previous_status="approved", + source="proxy_execution", + ) + + assert events == [ + ( + "task.lifecycle.state", + { + "task_id": task["id"], + "source": "proxy_execution", + "previous_status": "approved", + "current_status": "executed", + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + }, + ), + ( + "task.lifecycle.summary", + { + "task_id": task["id"], + "source": "proxy_execution", + "final_status": "executed", + "latest_approval_id": task["latest_approval_id"], + "latest_execution_id": task["latest_execution_id"], + }, + ), + ] + + task_step = { + "id": str(uuid4()), + "task_id": task["id"], + "sequence_no": 1, + "lineage": { + "parent_step_id": None, + "source_approval_id": None, + "source_execution_id": None, + }, + "kind": "governed_request", + "status": "executed", + "request": task["request"], + "outcome": task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=task["latest_approval_id"], + approval_status="approved", + execution_id=task["latest_execution_id"], + execution_status="completed", + blocked_reason=None, + ), + "trace": { + "trace_id": str(uuid4()), + "trace_kind": "tool.proxy.execute", + }, + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:05:00+00:00", + } + + task_step_events = task_step_lifecycle_trace_events( + task_step=task_step, + previous_status="approved", + source="proxy_execution", + ) + + assert task_step_events == [ + ( + "task.step.lifecycle.state", + { + "task_id": task["id"], + "task_step_id": task_step["id"], + "source": "proxy_execution", + "sequence_no": 1, + "kind": "governed_request", + "previous_status": "approved", + "current_status": "executed", + "trace": task_step["trace"], + }, + ), + ( + "task.step.lifecycle.summary", + { + "task_id": task["id"], + "task_step_id": task_step["id"], + "source": "proxy_execution", + "sequence_no": 1, + "kind": "governed_request", + "final_status": "executed", + "trace": task_step["trace"], + }, + ), + ] + + +def test_get_task_record_raises_not_found_when_missing() -> None: + store = TaskStoreStub() + + try: + get_task_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_id=uuid4(), + ) + except TaskNotFoundError as exc: + assert "task" in str(exc) + else: + raise AssertionError("expected TaskNotFoundError") + + +def test_task_step_list_get_and_lifecycle_updates_are_deterministic() -> None: + store = TaskStoreStub() + task = store.create_task( + status="pending_approval", + latest_approval_id=uuid4(), + latest_execution_id=None, + ) + first_trace_id = uuid4() + create_payload = create_task_step_for_governed_request( + store, # type: ignore[arg-type] + request=TaskStepCreateInput( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=first_trace_id, + trace_kind="approval.request", + ), + ) + second_trace_id = uuid4() + approval_transition = sync_task_step_with_approval( + store, # type: ignore[arg-type] + approval_id=UUID(str(task["latest_approval_id"])), + task_step_id=UUID(create_payload["task_step"]["id"]), + approval_status="approved", + trace_id=second_trace_id, + trace_kind="approval.resolve", + ) + execution = { + "id": uuid4(), + "approval_id": task["latest_approval_id"], + "task_step_id": UUID(create_payload["task_step"]["id"]), + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + } + third_trace_id = uuid4() + execution_transition = sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution=execution, # type: ignore[arg-type] + trace_id=third_trace_id, + trace_kind="tool.proxy.execute", + ) + store.create_task_step( + task_id=task["id"], + sequence_no=2, + kind="governed_request", + status="denied", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="denied", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.request", + ) + + listed = list_task_step_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_id=task["id"], + ) + detail = get_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_step_id=UUID(create_payload["task_step"]["id"]), + ) + + assert [item["sequence_no"] for item in listed["items"]] == [1, 2] + assert listed["summary"] == { + "task_id": str(task["id"]), + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "denied", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert detail["task_step"]["id"] == create_payload["task_step"]["id"] + assert detail["task_step"]["status"] == "executed" + assert detail["task_step"]["trace"] == { + "trace_id": str(third_trace_id), + "trace_kind": "tool.proxy.execute", + } + assert detail["task_step"]["outcome"] == { + "routing_decision": "approval_required", + "approval_id": str(task["latest_approval_id"]), + "approval_status": "approved", + "execution_id": str(execution["id"]), + "execution_status": "completed", + "blocked_reason": None, + } + + +def test_sync_task_step_with_approval_updates_explicitly_linked_later_step_only() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="pending_approval", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(initial_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + + original_first_trace_id = first_step["trace_id"] + original_first_trace_kind = first_step["trace_kind"] + original_first_outcome = dict(first_step["outcome"]) + later_trace_id = uuid4() + + transition = sync_task_step_with_approval( + store, # type: ignore[arg-type] + approval_id=approval_id, + task_step_id=later_step["id"], + approval_status="approved", + trace_id=later_trace_id, + trace_kind="approval.resolve", + ) + + assert transition.previous_status == "created" + assert transition.task_step["id"] == str(later_step["id"]) + assert transition.task_step["status"] == "approved" + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["trace_kind"] == original_first_trace_kind + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "approved" + assert later_step["trace_id"] == later_trace_id + assert later_step["trace_kind"] == "approval.resolve" + assert later_step["outcome"] == { + "routing_decision": "approval_required", + "approval_id": str(approval_id), + "approval_status": "approved", + "execution_id": None, + "execution_status": None, + "blocked_reason": None, + } + assert task["status"] == "pending_approval" + assert task["latest_execution_id"] is None + + +def test_sync_task_step_with_approval_rejects_inconsistent_linkage_without_mutating_steps() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="pending_approval", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(initial_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.continuation", + ) + + original_first_outcome = dict(first_step["outcome"]) + original_later_trace_id = later_step["trace_id"] + + try: + sync_task_step_with_approval( + store, # type: ignore[arg-type] + approval_id=approval_id, + task_step_id=later_step["id"], + approval_status="approved", + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + except TaskStepApprovalLinkageError as exc: + assert str(exc) == ( + f"approval {approval_id} is inconsistent with linked task step {later_step['id']}" + ) + else: + raise AssertionError("expected TaskStepApprovalLinkageError") + + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "created" + assert later_step["trace_id"] == original_later_trace_id + assert later_step["trace_kind"] == "task.step.continuation" + + +def test_sync_task_step_with_execution_updates_the_linked_later_step_without_mutating_initial_step() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(initial_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + later_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + parent_step_id=first_step["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.transition", + ) + + original_first_trace_id = first_step["trace_id"] + original_first_trace_kind = first_step["trace_kind"] + original_first_outcome = dict(first_step["outcome"]) + execution = { + "id": uuid4(), + "approval_id": approval_id, + "task_step_id": later_step["id"], + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + } + + transition = sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution=execution, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + assert transition.previous_status == "approved" + assert transition.task_step["id"] == str(later_step["id"]) + assert transition.task_step["status"] == "executed" + assert first_step["status"] == "executed" + assert first_step["trace_id"] == original_first_trace_id + assert first_step["trace_kind"] == original_first_trace_kind + assert first_step["outcome"] == original_first_outcome + assert later_step["status"] == "executed" + assert later_step["trace_kind"] == "tool.proxy.execute" + assert later_step["outcome"] == { + "routing_decision": "approval_required", + "approval_id": str(approval_id), + "approval_status": "approved", + "execution_id": str(execution["id"]), + "execution_status": "completed", + "blocked_reason": None, + } + assert task["status"] == "approved" + assert task["latest_execution_id"] is None + + +def test_sync_task_step_with_execution_rejects_missing_linkage_without_mutating_steps() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + execution_id = uuid4() + + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": execution_id, + "approval_id": approval_id, + "task_step_id": None, + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == f"tool execution {execution_id} is missing linked task_step_id" + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + assert first_step["status"] == "approved" + assert first_step["outcome"]["execution_id"] is None + + +def test_sync_task_step_with_execution_rejects_unknown_or_out_of_task_linkage() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + other_task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + other_step = store.create_task_step( + task_id=other_task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=other_task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + + missing_execution_id = uuid4() + missing_task_step_id = uuid4() + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": missing_execution_id, + "approval_id": approval_id, + "task_step_id": missing_task_step_id, + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == ( + f"tool execution {missing_execution_id} references linked task step " + f"{missing_task_step_id} that was not found" + ) + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + outside_execution_id = uuid4() + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": outside_execution_id, + "approval_id": approval_id, + "task_step_id": other_step["id"], + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == ( + f"tool execution {outside_execution_id} links task step {other_step['id']} " + f"outside task {task['id']}" + ) + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + +def test_sync_task_step_with_execution_rejects_inconsistent_linkage_without_mutating_steps() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.resolve", + ) + inconsistent_execution_id = uuid4() + + try: + sync_task_step_with_execution( + store, # type: ignore[arg-type] + task_id=task["id"], + execution={ + "id": inconsistent_execution_id, + "approval_id": uuid4(), + "task_step_id": step["id"], + "status": "completed", + "result": { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + }, + }, # type: ignore[arg-type] + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + except TaskStepExecutionLinkageError as exc: + assert str(exc) == ( + f"tool execution {inconsistent_execution_id} is inconsistent with linked task step {step['id']}" + ) + else: + raise AssertionError("expected TaskStepExecutionLinkageError") + + assert step["status"] == "approved" + assert step["outcome"]["execution_id"] is None + + +def test_sync_task_with_task_step_status_updates_parent_through_task_seam() -> None: + store = TaskStoreStub() + task = store.create_task( + status="executed", + latest_approval_id=uuid4(), + latest_execution_id=uuid4(), + ) + + transition = sync_task_with_task_step_status( + store, # type: ignore[arg-type] + task_id=task["id"], + task_step_status="created", + linked_approval_id=task["latest_approval_id"], + linked_execution_id=None, + ) + + assert transition.previous_status == "executed" + assert transition.task["status"] == "pending_approval" + assert transition.task["latest_execution_id"] is None + assert store.tasks[0]["status"] == "pending_approval" + assert store.tasks[0]["latest_execution_id"] is None + + +def test_create_next_task_step_assigns_deterministic_sequence_updates_parent_and_records_trace() -> None: + store = TaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="executed", + latest_approval_id=approval_id, + latest_execution_id=initial_execution_id, + ) + store.approvals.append({"id": approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}) + store.tool_executions.append( + { + "id": task["latest_execution_id"], + "thread_id": task["thread_id"], + "tool_id": task["tool_id"], + "approval_id": approval_id, + } + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(task["latest_execution_id"]), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + payload = create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput( + parent_step_id=store.task_steps[0]["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + ), + ), + ) + + assert payload["task"]["status"] == "pending_approval" + assert payload["task"]["latest_approval_id"] == str(approval_id) + assert payload["task"]["latest_execution_id"] is None + assert payload["task_step"]["sequence_no"] == 2 + assert payload["task_step"]["status"] == "created" + assert payload["task_step"]["lineage"] == { + "parent_step_id": str(store.task_steps[0]["id"]), + "source_approval_id": str(approval_id), + "source_execution_id": str(initial_execution_id), + } + assert payload["task_step"]["trace"]["trace_kind"] == "task.step.continuation" + assert payload["sequencing"] == { + "task_id": str(task["id"]), + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "created", + "next_sequence_no": 3, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 7 + assert [event["kind"] for event in store.trace_events] == [ + "task.step.continuation.request", + "task.step.continuation.lineage", + "task.step.continuation.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[1]["payload"] == { + "task_id": str(task["id"]), + "parent_task_step_id": str(store.task_steps[0]["id"]), + "parent_sequence_no": 1, + "parent_status": "executed", + "source_approval_id": str(approval_id), + "source_execution_id": str(initial_execution_id), + } + + +def test_create_next_task_step_rejects_when_latest_step_is_not_terminal() -> None: + store = TaskStoreStub() + task = store.create_task( + status="pending_approval", + latest_approval_id=uuid4(), + latest_execution_id=None, + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.request", + ) + + try: + create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput(parent_step_id=store.task_steps[0]["id"]), + ), + ) + except TaskStepSequenceError as exc: + assert str(exc) == ( + f"task {task['id']} latest step {store.task_steps[0]['id']} is created and cannot append a next step" + ) + else: + raise AssertionError("expected TaskStepSequenceError") + + +def test_transition_task_step_updates_latest_step_parent_and_trace() -> None: + store = TaskStoreStub() + first_approval_id = uuid4() + first_execution_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=first_approval_id, + latest_execution_id=first_execution_id, + ) + store.approvals.extend( + [ + {"id": first_approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}, + ] + ) + store.tool_executions.extend( + [ + { + "id": first_execution_id, + "thread_id": task["thread_id"], + "tool_id": task["tool_id"], + "approval_id": first_approval_id, + }, + ] + ) + first_step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(first_approval_id), + approval_status="approved", + execution_id=str(first_execution_id), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + second_step = store.create_task_step( + task_id=task["id"], + sequence_no=2, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.sequence", + ) + + payload = transition_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepTransitionInput( + task_step_id=second_step["id"], + status="executed", + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=str(first_approval_id), + approval_status="approved", + execution_id=str(first_execution_id), + execution_status="completed", + blocked_reason=None, + ), + ), + ) + + assert first_step["status"] == "executed" + assert payload["task"]["status"] == "executed" + assert payload["task"]["latest_approval_id"] == str(first_approval_id) + assert payload["task"]["latest_execution_id"] == str(first_execution_id) + assert payload["task_step"]["id"] == str(second_step["id"]) + assert payload["task_step"]["status"] == "executed" + assert payload["task_step"]["trace"]["trace_kind"] == "task.step.transition" + assert payload["sequencing"] == { + "task_id": str(task["id"]), + "total_count": 2, + "latest_sequence_no": 2, + "latest_status": "executed", + "next_sequence_no": 3, + "append_allowed": True, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + } + assert [event["kind"] for event in store.trace_events] == [ + "task.step.transition.request", + "task.step.transition.state", + "task.step.transition.summary", + "task.lifecycle.state", + "task.lifecycle.summary", + "task.step.lifecycle.state", + "task.step.lifecycle.summary", + ] + assert store.trace_events[1]["payload"]["allowed_next_statuses"] == ["executed", "blocked"] + + +def test_create_next_task_step_locks_before_listing_existing_steps() -> None: + class LockingTaskStoreStub(TaskStoreStub): + def list_task_steps_for_task(self, task_id: UUID) -> list[dict[str, object]]: + if task_id not in self.locked_task_ids: + raise AssertionError("task steps were listed before the advisory lock was taken") + return super().list_task_steps_for_task(task_id) + + store = LockingTaskStoreStub() + approval_id = uuid4() + initial_execution_id = uuid4() + task = store.create_task( + status="executed", + latest_approval_id=approval_id, + latest_execution_id=initial_execution_id, + ) + store.approvals.append({"id": approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}) + store.tool_executions.append( + { + "id": task["latest_execution_id"], + "thread_id": task["thread_id"], + "tool_id": task["tool_id"], + "approval_id": approval_id, + } + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(task["latest_execution_id"]), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + payload = create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput( + parent_step_id=store.task_steps[0]["id"], + source_approval_id=approval_id, + source_execution_id=initial_execution_id, + ), + ), + ) + + assert payload["task_step"]["sequence_no"] == 2 + + +def test_create_next_task_step_rejects_visible_approval_from_unrelated_task_lineage() -> None: + store = TaskStoreStub() + task = store.create_task( + status="executed", + latest_approval_id=uuid4(), + latest_execution_id=uuid4(), + ) + unrelated_approval_id = uuid4() + store.approvals.append( + { + "id": unrelated_approval_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + } + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="approved", + execution_id=str(task["latest_execution_id"]), + execution_status="completed", + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + try: + create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput( + parent_step_id=store.task_steps[0]["id"], + source_approval_id=unrelated_approval_id, + ), + ), + ) + except TaskStepSequenceError as exc: + assert str(exc) == f"approval {unrelated_approval_id} does not belong to task {task['id']}" + else: + raise AssertionError("expected TaskStepSequenceError") + + +def test_create_next_task_step_rejects_parent_step_from_unrelated_task() -> None: + store = TaskStoreStub() + task = store.create_task( + status="executed", + latest_approval_id=None, + latest_execution_id=None, + ) + unrelated_task = store.create_task( + status="executed", + latest_approval_id=None, + latest_execution_id=None, + ) + store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + unrelated_step = store.create_task_step( + task_id=unrelated_task["id"], + sequence_no=1, + kind="governed_request", + status="executed", + request=unrelated_task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="tool.proxy.execute", + ) + + try: + create_next_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepNextCreateInput( + task_id=task["id"], + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="ready", + approval_id=None, + approval_status=None, + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + lineage=TaskStepLineageInput(parent_step_id=unrelated_step["id"]), + ), + ) + except TaskStepSequenceError as exc: + assert str(exc) == f"task step {unrelated_step['id']} does not belong to task {task['id']}" + else: + raise AssertionError("expected TaskStepSequenceError") + + +def test_transition_task_step_rejects_invalid_status_graph_edge() -> None: + store = TaskStoreStub() + task = store.create_task( + status="pending_approval", + latest_approval_id=uuid4(), + latest_execution_id=None, + ) + step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="created", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="pending", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="approval.request", + ) + + try: + transition_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepTransitionInput( + task_step_id=step["id"], + status="executed", + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(task["latest_approval_id"]), + approval_status="approved", + execution_id=str(uuid4()), + execution_status="completed", + blocked_reason=None, + ), + ), + ) + except TaskStepTransitionError as exc: + assert str(exc) == ( + f"task step {step['id']} is created and cannot transition to executed; allowed: approved, denied" + ) + else: + raise AssertionError("expected TaskStepTransitionError") + + +def test_transition_task_step_rejects_visible_execution_from_unrelated_task_lineage() -> None: + store = TaskStoreStub() + approval_id = uuid4() + task = store.create_task( + status="approved", + latest_approval_id=approval_id, + latest_execution_id=None, + ) + step = store.create_task_step( + task_id=task["id"], + sequence_no=1, + kind="governed_request", + status="approved", + request=task["request"], + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=None, + execution_status=None, + blocked_reason=None, + ), + trace_id=uuid4(), + trace_kind="task.step.sequence", + ) + store.approvals.append({"id": approval_id, "thread_id": task["thread_id"], "tool_id": task["tool_id"]}) + unrelated_execution_id = uuid4() + store.tool_executions.append( + { + "id": unrelated_execution_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + "approval_id": approval_id, + } + ) + + try: + transition_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=TaskStepTransitionInput( + task_step_id=step["id"], + status="executed", + outcome=task_step_outcome_snapshot( + routing_decision="approval_required", + approval_id=str(approval_id), + approval_status="approved", + execution_id=str(unrelated_execution_id), + execution_status="completed", + blocked_reason=None, + ), + ), + ) + except TaskStepTransitionError as exc: + assert str(exc) == f"tool execution {unrelated_execution_id} does not belong to task {task['id']}" + else: + raise AssertionError("expected TaskStepTransitionError") + + +def test_get_task_step_record_raises_not_found_when_missing() -> None: + store = TaskStoreStub() + + try: + get_task_step_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + task_step_id=uuid4(), + ) + except TaskStepNotFoundError as exc: + assert "task step" in str(exc) + else: + raise AssertionError("expected TaskStepNotFoundError") diff --git a/tests/unit/test_tasks_main.py b/tests/unit/test_tasks_main.py new file mode 100644 index 0000000..0be9960 --- /dev/null +++ b/tests/unit/test_tasks_main.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.tasks import ( + TaskNotFoundError, + TaskStepNotFoundError, + TaskStepSequenceError, + TaskStepTransitionError, +) + + +def test_list_task_steps_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_task_step_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": { + "task_id": str(task_id), + "total_count": 0, + "latest_sequence_no": None, + "latest_status": None, + "next_sequence_no": 1, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + }, + }, + ) + + response = main_module.list_task_steps(task_id, user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": { + "task_id": str(task_id), + "total_count": 0, + "latest_sequence_no": None, + "latest_status": None, + "next_sequence_no": 1, + "append_allowed": False, + "order": ["sequence_no_asc", "created_at_asc", "id_asc"], + }, + } + + +def test_list_task_steps_endpoint_maps_task_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_list_task_step_records(*_args, **_kwargs): + raise TaskNotFoundError(f"task {task_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "list_task_step_records", fake_list_task_step_records) + + response = main_module.list_task_steps(task_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task {task_id} was not found"} + + +def test_get_task_step_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_step_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_task_step_record(*_args, **_kwargs): + raise TaskStepNotFoundError(f"task step {task_step_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_task_step_record", fake_get_task_step_record) + + response = main_module.get_task_step(task_step_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task step {task_step_id} was not found"} + + +def test_create_next_task_step_endpoint_maps_sequence_conflict_to_409(monkeypatch) -> None: + task_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_next_task_step_record(*_args, **_kwargs): + raise TaskStepSequenceError(f"task {task_id} latest step blocked append") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_next_task_step_record", fake_create_next_task_step_record) + + response = main_module.create_next_task_step( + task_id, + main_module.CreateNextTaskStepRequest( + user_id=user_id, + kind="governed_request", + status="created", + request=main_module.TaskStepRequestSnapshot( + thread_id=uuid4(), + tool_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ), + outcome=main_module.TaskStepOutcomeRequest( + routing_decision="approval_required", + approval_status="pending", + ), + lineage=main_module.TaskStepLineageRequest(parent_step_id=uuid4()), + ), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == {"detail": f"task {task_id} latest step blocked append"} + + +def test_transition_task_step_endpoint_maps_transition_conflict_to_409(monkeypatch) -> None: + task_step_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_transition_task_step_record(*_args, **_kwargs): + raise TaskStepTransitionError(f"task step {task_step_id} is created and cannot transition") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "transition_task_step_record", fake_transition_task_step_record) + + response = main_module.transition_task_step( + task_step_id, + main_module.TransitionTaskStepRequest( + user_id=user_id, + status="approved", + outcome=main_module.TaskStepOutcomeRequest( + routing_decision="approval_required", + approval_status="approved", + ), + ), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"task step {task_step_id} is created and cannot transition" + } diff --git a/tests/unit/test_tool_execution_store.py b/tests/unit/test_tool_execution_store.py new file mode 100644 index 0000000..f0715fc --- /dev/null +++ b/tests/unit/test_tool_execution_store.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_tool_execution_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + execution_id = uuid4() + approval_id = uuid4() + task_step_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + trace_id = uuid4() + request_event_id = uuid4() + result_event_id = uuid4() + row = { + "id": execution_id, + "approval_id": approval_id, + "task_step_id": task_step_id, + "thread_id": thread_id, + "tool_id": tool_id, + "trace_id": trace_id, + "request_event_id": request_event_id, + "result_event_id": result_event_id, + "status": "completed", + "handler_key": "proxy.echo", + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "proxy.echo"}, + "result": {"handler_key": "proxy.echo", "status": "completed", "output": {"mode": "no_side_effect"}, "reason": None}, + "executed_at": "2026-03-13T10:00:00+00:00", + } + cursor = RecordingCursor( + fetchone_results=[row, row], + fetchall_result=[row], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_tool_execution( + approval_id=approval_id, + task_step_id=task_step_id, + thread_id=thread_id, + tool_id=tool_id, + trace_id=trace_id, + request_event_id=request_event_id, + result_event_id=result_event_id, + status="completed", + handler_key="proxy.echo", + request={"thread_id": str(thread_id), "tool_id": str(tool_id)}, + tool={"id": str(tool_id), "tool_key": "proxy.echo"}, + result={"handler_key": "proxy.echo", "status": "completed", "output": {"mode": "no_side_effect"}, "reason": None}, + ) + fetched = store.get_tool_execution_optional(execution_id) + listed = store.list_tool_executions() + + assert created["id"] == execution_id + assert fetched is not None + assert fetched["id"] == execution_id + assert listed[0]["id"] == execution_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO tool_executions" in create_query + assert create_params is not None + assert create_params[:9] == ( + approval_id, + task_step_id, + thread_id, + tool_id, + trace_id, + request_event_id, + result_event_id, + "completed", + "proxy.echo", + ) + assert isinstance(create_params[9], Jsonb) + assert create_params[9].obj == {"thread_id": str(thread_id), "tool_id": str(tool_id)} + assert isinstance(create_params[10], Jsonb) + assert create_params[10].obj == {"id": str(tool_id), "tool_key": "proxy.echo"} + assert isinstance(create_params[11], Jsonb) + assert create_params[11].obj == { + "handler_key": "proxy.echo", + "status": "completed", + "output": {"mode": "no_side_effect"}, + "reason": None, + } + assert "FROM tool_executions" in cursor.executed[1][0] + assert "ORDER BY executed_at ASC, id ASC" in cursor.executed[2][0] + + +def test_create_tool_execution_accepts_blocked_attempt_without_event_ids() -> None: + execution_id = uuid4() + approval_id = uuid4() + task_step_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + trace_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": execution_id, + "approval_id": approval_id, + "task_step_id": task_step_id, + "thread_id": thread_id, + "tool_id": tool_id, + "trace_id": trace_id, + "request_event_id": None, + "result_event_id": None, + "status": "blocked", + "handler_key": None, + "request": {"thread_id": str(thread_id), "tool_id": str(tool_id)}, + "tool": {"id": str(tool_id), "tool_key": "proxy.missing"}, + "result": { + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + }, + "executed_at": "2026-03-13T10:05:00+00:00", + } + ] + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_tool_execution( + approval_id=approval_id, + task_step_id=task_step_id, + thread_id=thread_id, + tool_id=tool_id, + trace_id=trace_id, + request_event_id=None, + result_event_id=None, + status="blocked", + handler_key=None, + request={"thread_id": str(thread_id), "tool_id": str(tool_id)}, + tool={"id": str(tool_id), "tool_key": "proxy.missing"}, + result={ + "handler_key": None, + "status": "blocked", + "output": None, + "reason": "tool 'proxy.missing' has no registered proxy handler", + }, + ) + + assert created["status"] == "blocked" + create_query, create_params = cursor.executed[0] + assert "INSERT INTO tool_executions" in create_query + assert create_params is not None + assert create_params[5] is None + assert create_params[6] is None + assert create_params[8] is None diff --git a/tests/unit/test_tool_store.py b/tests/unit/test_tool_store.py new file mode 100644 index 0000000..6f9fc09 --- /dev/null +++ b/tests/unit/test_tool_store.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb + +from alicebot_api.store import ContinuityStore + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_tool_store_methods_use_expected_queries_and_jsonb_parameters() -> None: + tool_id = uuid4() + cursor = RecordingCursor( + fetchone_results=[ + { + "id": tool_id, + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + { + "id": tool_id, + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + }, + ], + fetchall_result=[ + { + "id": tool_id, + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + } + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + created = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + fetched = store.get_tool_optional(tool_id) + listed = store.list_active_tools() + + assert created["id"] == tool_id + assert fetched is not None + assert listed[0]["id"] == tool_id + + create_query, create_params = cursor.executed[0] + assert "INSERT INTO tools" in create_query + assert create_params is not None + assert create_params[:6] == ( + "browser.open", + "Browser Open", + "Open documentation pages.", + "1.0.0", + "tool_metadata_v0", + True, + ) + for index, expected in ( + (6, ["browser"]), + (7, ["tool.run"]), + (8, ["workspace"]), + (9, ["docs"]), + (10, []), + ): + assert isinstance(create_params[index], Jsonb) + assert create_params[index].obj == expected + assert isinstance(create_params[11], Jsonb) + assert create_params[11].obj == {"transport": "proxy"} + + assert cursor.executed[1] == ( + """ + SELECT + id, + user_id, + tool_key, + name, + description, + version, + metadata_version, + active, + tags, + action_hints, + scope_hints, + domain_hints, + risk_hints, + metadata, + created_at + FROM tools + WHERE id = %s + """, + (tool_id,), + ) + assert "WHERE active = TRUE" in cursor.executed[2][0] diff --git a/tests/unit/test_tools.py b/tests/unit/test_tools.py new file mode 100644 index 0000000..169e24e --- /dev/null +++ b/tests/unit/test_tools.py @@ -0,0 +1,688 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.contracts import ( + ToolAllowlistEvaluationRequestInput, + ToolCreateInput, + ToolRoutingRequestInput, +) +from alicebot_api.tools import ( + ToolAllowlistValidationError, + create_tool_record, + evaluate_tool_allowlist, + get_tool_record, + list_tool_records, + route_tool_invocation, + ToolRoutingValidationError, +) + + +class ToolStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 12, 9, 0, tzinfo=UTC) + self.user_id = uuid4() + self.thread_id = uuid4() + self.consents: dict[str, dict[str, object]] = {} + self.policies: list[dict[str, object]] = [] + self.tools: list[dict[str, object]] = [] + self.traces: list[dict[str, object]] = [] + self.trace_events: list[dict[str, object]] = [] + + def create_consent(self, *, consent_key: str, status: str, metadata: dict[str, object]) -> dict[str, object]: + consent = { + "id": uuid4(), + "user_id": self.user_id, + "consent_key": consent_key, + "status": status, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.consents)), + "updated_at": self.base_time + timedelta(minutes=len(self.consents)), + } + self.consents[consent_key] = consent + return consent + + def list_consents(self) -> list[dict[str, object]]: + return sorted( + self.consents.values(), + key=lambda consent: (consent["consent_key"], consent["created_at"], consent["id"]), + ) + + def create_policy( + self, + *, + name: str, + action: str, + scope: str, + effect: str, + priority: int, + active: bool, + conditions: dict[str, object], + required_consents: list[str], + ) -> dict[str, object]: + policy = { + "id": uuid4(), + "user_id": self.user_id, + "name": name, + "action": action, + "scope": scope, + "effect": effect, + "priority": priority, + "active": active, + "conditions": conditions, + "required_consents": required_consents, + "created_at": self.base_time + timedelta(minutes=len(self.policies)), + "updated_at": self.base_time + timedelta(minutes=len(self.policies)), + } + self.policies.append(policy) + return policy + + def list_active_policies(self) -> list[dict[str, object]]: + return sorted( + [policy for policy in self.policies if policy["active"] is True], + key=lambda policy: (policy["priority"], policy["created_at"], policy["id"]), + ) + + def create_tool( + self, + *, + tool_key: str, + name: str, + description: str, + version: str, + metadata_version: str, + active: bool, + tags: list[str], + action_hints: list[str], + scope_hints: list[str], + domain_hints: list[str], + risk_hints: list[str], + metadata: dict[str, object], + ) -> dict[str, object]: + tool = { + "id": uuid4(), + "user_id": self.user_id, + "tool_key": tool_key, + "name": name, + "description": description, + "version": version, + "metadata_version": metadata_version, + "active": active, + "tags": tags, + "action_hints": action_hints, + "scope_hints": scope_hints, + "domain_hints": domain_hints, + "risk_hints": risk_hints, + "metadata": metadata, + "created_at": self.base_time + timedelta(minutes=len(self.tools)), + } + self.tools.append(tool) + return tool + + def get_tool_optional(self, tool_id: UUID) -> dict[str, object] | None: + return next((tool for tool in self.tools if tool["id"] == tool_id), None) + + def list_tools(self) -> list[dict[str, object]]: + return sorted( + self.tools, + key=lambda tool: (tool["tool_key"], tool["version"], tool["created_at"], tool["id"]), + ) + + def list_active_tools(self) -> list[dict[str, object]]: + return [tool for tool in self.list_tools() if tool["active"] is True] + + def get_thread_optional(self, thread_id: UUID) -> dict[str, object] | None: + if thread_id != self.thread_id: + return None + return { + "id": self.thread_id, + "user_id": self.user_id, + "title": "Tool thread", + "created_at": self.base_time, + "updated_at": self.base_time, + } + + def create_trace( + self, + *, + user_id: UUID, + thread_id: UUID, + kind: str, + compiler_version: str, + status: str, + limits: dict[str, object], + ) -> dict[str, object]: + trace = { + "id": uuid4(), + "user_id": user_id, + "thread_id": thread_id, + "kind": kind, + "compiler_version": compiler_version, + "status": status, + "limits": limits, + "created_at": self.base_time, + } + self.traces.append(trace) + return trace + + def append_trace_event( + self, + *, + trace_id: UUID, + sequence_no: int, + kind: str, + payload: dict[str, object], + ) -> dict[str, object]: + event = { + "id": uuid4(), + "trace_id": trace_id, + "sequence_no": sequence_no, + "kind": kind, + "payload": payload, + "created_at": self.base_time, + } + self.trace_events.append(event) + return event + + +def test_create_list_and_get_tool_records_preserve_deterministic_order() -> None: + store = ToolStoreStub() + later = create_tool_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + tool=ToolCreateInput( + tool_key="zeta.fetch", + name="Zeta Fetch", + description="Fetch zeta records.", + version="2.0.0", + action_hints=("tool.run",), + scope_hints=("workspace",), + ), + ) + earlier = store.create_tool( + tool_key="alpha.open", + name="Alpha Open", + description="Open alpha pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + + listed = list_tool_records( + store, # type: ignore[arg-type] + user_id=store.user_id, + ) + detail = get_tool_record( + store, # type: ignore[arg-type] + user_id=store.user_id, + tool_id=UUID(later["tool"]["id"]), + ) + + assert [item["tool_key"] for item in listed["items"]] == ["alpha.open", "zeta.fetch"] + assert listed["summary"] == { + "total_count": 2, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert detail == {"tool": later["tool"]} + assert listed["items"][0]["id"] == str(earlier["id"]) + + +def test_evaluate_tool_allowlist_splits_allowed_denied_and_approval_required() -> None: + store = ToolStoreStub() + store.create_consent(consent_key="web_access", status="granted", metadata={"source": "settings"}) + allowed_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + denied_tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read a calendar.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + approval_tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + + store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=20, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + payload = evaluate_tool_allowlist( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolAllowlistEvaluationRequestInput( + thread_id=store.thread_id, + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={}, + ), + ) + + assert payload["allowed"] == [ + { + "decision": "allowed", + "tool": { + "id": str(allowed_tool["id"]), + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": allowed_tool["created_at"].isoformat(), + }, + "reasons": [ + { + "code": "tool_metadata_matched", + "source": "tool", + "message": "Tool metadata matched the requested action, scope, and optional hints.", + "tool_id": str(allowed_tool["id"]), + "policy_id": None, + "consent_key": None, + }, + { + "code": "matched_policy", + "source": "policy", + "message": "Matched policy 'Allow docs browser' at priority 10.", + "tool_id": str(allowed_tool["id"]), + "policy_id": str(store.policies[0]["id"]), + "consent_key": None, + }, + { + "code": "policy_effect_allow", + "source": "policy", + "message": "Policy effect resolved the decision to 'allow'.", + "tool_id": str(allowed_tool["id"]), + "policy_id": str(store.policies[0]["id"]), + "consent_key": None, + }, + ], + } + ] + assert [item["tool"]["id"] for item in payload["approval_required"]] == [str(approval_tool["id"])] + assert payload["approval_required"][0]["reasons"][-1]["code"] == "policy_effect_require_approval" + assert [item["tool"]["id"] for item in payload["denied"]] == [str(denied_tool["id"])] + assert [reason["code"] for reason in payload["denied"][0]["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert payload["summary"] == { + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "evaluated_tool_count": 3, + "allowed_count": 1, + "denied_count": 1, + "approval_required_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + } + assert payload["trace"]["trace_event_count"] == 6 + assert [event["kind"] for event in store.trace_events] == [ + "tool.allowlist.request", + "tool.allowlist.order", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.decision", + "tool.allowlist.summary", + ] + + +def test_evaluate_tool_allowlist_validates_thread_scope() -> None: + with pytest.raises( + ToolAllowlistValidationError, + match="thread_id must reference an existing thread owned by the user", + ): + evaluate_tool_allowlist( + ToolStoreStub(), # type: ignore[arg-type] + user_id=uuid4(), + request=ToolAllowlistEvaluationRequestInput( + thread_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + +def test_route_tool_invocation_returns_ready_with_trace() -> None: + store = ToolStoreStub() + store.create_consent(consent_key="web_access", status="granted", metadata={"source": "settings"}) + tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + policy = store.create_policy( + name="Allow docs browser", + action="tool.run", + scope="workspace", + effect="allow", + priority=10, + active=True, + conditions={"tool_key": "browser.open", "domain_hint": "docs"}, + required_consents=["web_access"], + ) + + payload = route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={"channel": "chat"}, + ), + ) + + assert payload == { + "request": { + "thread_id": str(store.thread_id), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {"channel": "chat"}, + }, + "decision": "ready", + "tool": { + "id": str(tool["id"]), + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": tool["created_at"].isoformat(), + }, + "reasons": [ + { + "code": "tool_metadata_matched", + "source": "tool", + "message": "Tool metadata matched the requested action, scope, and optional hints.", + "tool_id": str(tool["id"]), + "policy_id": None, + "consent_key": None, + }, + { + "code": "matched_policy", + "source": "policy", + "message": "Matched policy 'Allow docs browser' at priority 10.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + }, + { + "code": "policy_effect_allow", + "source": "policy", + "message": "Policy effect resolved the decision to 'allow'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + }, + ], + "summary": { + "thread_id": str(store.thread_id), + "tool_id": str(tool["id"]), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 1, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + "trace": { + "trace_id": str(store.traces[0]["id"]), + "trace_event_count": 3, + }, + } + assert store.traces[0]["kind"] == "tool.route" + assert store.traces[0]["compiler_version"] == "tool_routing_v0" + assert [event["kind"] for event in store.trace_events] == [ + "tool.route.request", + "tool.route.decision", + "tool.route.summary", + ] + assert store.trace_events[1]["payload"]["allowlist_decision"] == "allowed" + assert store.trace_events[1]["payload"]["routing_decision"] == "ready" + + +def test_route_tool_invocation_returns_denied_for_metadata_or_policy_denial() -> None: + store = ToolStoreStub() + tool = store.create_tool( + tool_key="calendar.read", + name="Calendar Read", + description="Read calendars.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["calendar"], + action_hints=["calendar.read"], + scope_hints=["calendar"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + + payload = route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + assert payload["decision"] == "denied" + assert [reason["code"] for reason in payload["reasons"]] == [ + "tool_action_unsupported", + "tool_scope_unsupported", + ] + assert payload["summary"]["decision"] == "denied" + assert payload["trace"]["trace_event_count"] == 3 + + +def test_route_tool_invocation_returns_approval_required() -> None: + store = ToolStoreStub() + tool = store.create_tool( + tool_key="shell.exec", + name="Shell Exec", + description="Run shell commands.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["shell"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={"transport": "local"}, + ) + policy = store.create_policy( + name="Require shell approval", + action="tool.run", + scope="workspace", + effect="require_approval", + priority=10, + active=True, + conditions={"tool_key": "shell.exec"}, + required_consents=[], + ) + + payload = route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + assert payload["decision"] == "approval_required" + assert payload["summary"]["decision"] == "approval_required" + assert payload["reasons"][-1] == { + "code": "policy_effect_require_approval", + "source": "policy", + "message": "Policy effect resolved the decision to 'require_approval'.", + "tool_id": str(tool["id"]), + "policy_id": str(policy["id"]), + "consent_key": None, + } + + +def test_route_tool_invocation_validates_thread_scope() -> None: + store = ToolStoreStub() + tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=True, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + + with pytest.raises( + ToolRoutingValidationError, + match="thread_id must reference an existing thread owned by the user", + ): + route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=uuid4(), + tool_id=tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) + + +def test_route_tool_invocation_validates_active_tool_scope() -> None: + store = ToolStoreStub() + inactive_tool = store.create_tool( + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + metadata_version="tool_metadata_v0", + active=False, + tags=["browser"], + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=[], + risk_hints=[], + metadata={}, + ) + + with pytest.raises( + ToolRoutingValidationError, + match="tool_id must reference an existing active tool owned by the user", + ): + route_tool_invocation( + store, # type: ignore[arg-type] + user_id=store.user_id, + request=ToolRoutingRequestInput( + thread_id=store.thread_id, + tool_id=inactive_tool["id"], + action="tool.run", + scope="workspace", + attributes={}, + ), + ) diff --git a/tests/unit/test_tools_main.py b/tests/unit/test_tools_main.py new file mode 100644 index 0000000..fc86a9c --- /dev/null +++ b/tests/unit/test_tools_main.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.tools import ( + ToolAllowlistValidationError, + ToolNotFoundError, + ToolRoutingValidationError, +) + + +def test_create_tool_endpoint_translates_request_and_returns_created_status(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_create_tool_record(store, *, user_id, tool): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["tool"] = tool + return { + "tool": { + "id": "tool-123", + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": "2026-03-12T09:00:00+00:00", + } + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_tool_record", fake_create_tool_record) + + response = main_module.create_tool( + main_module.CreateToolRequest( + user_id=user_id, + tool_key="browser.open", + name="Browser Open", + description="Open documentation pages.", + version="1.0.0", + action_hints=["tool.run"], + scope_hints=["workspace"], + domain_hints=["docs"], + risk_hints=[], + metadata={"transport": "proxy"}, + ) + ) + + assert response.status_code == 201 + assert json.loads(response.body)["tool"]["tool_key"] == "browser.open" + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["tool"].tool_key == "browser.open" + assert captured["tool"].action_hints == ("tool.run",) + assert captured["tool"].scope_hints == ("workspace",) + + +def test_get_tool_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + tool_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_tool_record(*_args, **_kwargs): + raise ToolNotFoundError(f"tool {tool_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_tool_record", fake_get_tool_record) + + response = main_module.get_tool(tool_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"tool {tool_id} was not found"} + + +def test_evaluate_tool_allowlist_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_evaluate_tool_allowlist(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "allowed": [], + "denied": [], + "approval_required": [], + "summary": { + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "evaluated_tool_count": 0, + "allowed_count": 0, + "denied_count": 0, + "approval_required_count": 0, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_tool_allowlist", fake_evaluate_tool_allowlist) + + response = main_module.evaluate_tools_allowlist( + main_module.EvaluateToolAllowlistRequest( + user_id=user_id, + thread_id=thread_id, + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={"channel": "chat"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == {"trace_id": "trace-123", "trace_event_count": 3} + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].action == "tool.run" + assert captured["request"].scope == "workspace" + assert captured["request"].domain_hint == "docs" + assert captured["request"].attributes == {"channel": "chat"} + + +def test_evaluate_tool_allowlist_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_evaluate_tool_allowlist(*_args, **_kwargs): + raise ToolAllowlistValidationError("thread_id must reference an existing thread owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "evaluate_tool_allowlist", fake_evaluate_tool_allowlist) + + response = main_module.evaluate_tools_allowlist( + main_module.EvaluateToolAllowlistRequest( + user_id=user_id, + thread_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "thread_id must reference an existing thread owned by the user" + } + + +def test_route_tool_endpoint_translates_request_and_returns_trace_payload(monkeypatch) -> None: + user_id = uuid4() + thread_id = uuid4() + tool_id = uuid4() + settings = Settings(database_url="postgresql://app") + captured: dict[str, object] = {} + + @contextmanager + def fake_user_connection(database_url: str, current_user_id): + captured["database_url"] = database_url + captured["current_user_id"] = current_user_id + yield object() + + def fake_route_tool_invocation(store, *, user_id, request): + captured["store_type"] = type(store).__name__ + captured["user_id"] = user_id + captured["request"] = request + return { + "request": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "attributes": {"channel": "chat"}, + }, + "decision": "ready", + "tool": { + "id": str(tool_id), + "tool_key": "browser.open", + "name": "Browser Open", + "description": "Open documentation pages.", + "version": "1.0.0", + "metadata_version": "tool_metadata_v0", + "active": True, + "tags": ["browser"], + "action_hints": ["tool.run"], + "scope_hints": ["workspace"], + "domain_hints": ["docs"], + "risk_hints": [], + "metadata": {"transport": "proxy"}, + "created_at": "2026-03-12T09:00:00+00:00", + }, + "reasons": [], + "summary": { + "thread_id": str(thread_id), + "tool_id": str(tool_id), + "action": "tool.run", + "scope": "workspace", + "domain_hint": "docs", + "risk_hint": None, + "decision": "ready", + "evaluated_tool_count": 1, + "active_policy_count": 1, + "consent_count": 1, + "order": ["tool_key_asc", "version_asc", "created_at_asc", "id_asc"], + }, + "trace": {"trace_id": "trace-123", "trace_event_count": 3}, + } + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "route_tool_invocation", fake_route_tool_invocation) + + response = main_module.route_tool( + main_module.RouteToolRequest( + user_id=user_id, + thread_id=thread_id, + tool_id=tool_id, + action="tool.run", + scope="workspace", + domain_hint="docs", + attributes={"channel": "chat"}, + ) + ) + + assert response.status_code == 200 + assert json.loads(response.body)["trace"] == {"trace_id": "trace-123", "trace_event_count": 3} + assert captured["database_url"] == "postgresql://app" + assert captured["current_user_id"] == user_id + assert captured["store_type"] == "ContinuityStore" + assert captured["user_id"] == user_id + assert captured["request"].thread_id == thread_id + assert captured["request"].tool_id == tool_id + assert captured["request"].action == "tool.run" + assert captured["request"].scope == "workspace" + assert captured["request"].domain_hint == "docs" + assert captured["request"].attributes == {"channel": "chat"} + + +def test_route_tool_endpoint_maps_validation_errors_to_400(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_route_tool_invocation(*_args, **_kwargs): + raise ToolRoutingValidationError("tool_id must reference an existing active tool owned by the user") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "route_tool_invocation", fake_route_tool_invocation) + + response = main_module.route_tool( + main_module.RouteToolRequest( + user_id=user_id, + thread_id=uuid4(), + tool_id=uuid4(), + action="tool.run", + scope="workspace", + attributes={}, + ) + ) + + assert response.status_code == 400 + assert json.loads(response.body) == { + "detail": "tool_id must reference an existing active tool owned by the user" + } diff --git a/tests/unit/test_trace_store.py b/tests/unit/test_trace_store.py new file mode 100644 index 0000000..28bbac2 --- /dev/null +++ b/tests/unit/test_trace_store.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import Any +from uuid import uuid4 + +from psycopg.types.json import Jsonb +import pytest + +from alicebot_api.store import AppendOnlyViolation, ContinuityStore, ContinuityStoreInvariantError + + +class RecordingCursor: + def __init__(self, fetchone_results: list[dict[str, Any]], fetchall_result: list[dict[str, Any]] | None = None) -> None: + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.fetchone_results = list(fetchone_results) + self.fetchall_result = fetchall_result or [] + + def __enter__(self) -> "RecordingCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execute(self, query: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((query, params)) + + def fetchone(self) -> dict[str, Any] | None: + if not self.fetchone_results: + return None + return self.fetchone_results.pop(0) + + def fetchall(self) -> list[dict[str, Any]]: + return self.fetchall_result + + +class RecordingConnection: + def __init__(self, cursor: RecordingCursor) -> None: + self.cursor_instance = cursor + + def cursor(self) -> RecordingCursor: + return self.cursor_instance + + +def test_trace_methods_use_expected_queries_and_payload_serialization() -> None: + user_id = uuid4() + thread_id = uuid4() + trace_id = uuid4() + payload = {"reason": "within_event_limit"} + cursor = RecordingCursor( + fetchone_results=[ + {"id": user_id, "email": "owner@example.com", "display_name": "Owner"}, + {"id": thread_id, "user_id": user_id, "title": "Thread"}, + {"id": trace_id, "user_id": user_id, "thread_id": thread_id, "kind": "context.compile"}, + { + "id": uuid4(), + "user_id": user_id, + "trace_id": trace_id, + "sequence_no": 1, + "kind": "context.include", + "payload": payload, + }, + ], + fetchall_result=[ + {"sequence_no": 1, "kind": "context.include", "payload": payload}, + ], + ) + store = ContinuityStore(RecordingConnection(cursor)) + + user = store.get_user(user_id) + thread = store.get_thread(thread_id) + trace = store.create_trace( + user_id=user_id, + thread_id=thread_id, + kind="context.compile", + compiler_version="continuity_v0", + status="completed", + limits={"max_sessions": 3, "max_events": 8}, + ) + trace_event = store.append_trace_event( + trace_id=trace_id, + sequence_no=1, + kind="context.include", + payload=payload, + ) + listed_trace_events = store.list_trace_events(trace_id) + + assert user["id"] == user_id + assert thread["id"] == thread_id + assert trace["id"] == trace_id + assert trace_event["sequence_no"] == 1 + assert listed_trace_events == [{"sequence_no": 1, "kind": "context.include", "payload": payload}] + + assert cursor.executed[0] == ( + """ + SELECT id, email, display_name, created_at + FROM users + WHERE id = %s + """, + (user_id,), + ) + assert cursor.executed[1] == ( + """ + SELECT id, user_id, title, created_at, updated_at + FROM threads + WHERE id = %s + """, + (thread_id,), + ) + create_trace_query, create_trace_params = cursor.executed[2] + assert "INSERT INTO traces" in create_trace_query + assert create_trace_params is not None + assert create_trace_params[:5] == ( + user_id, + thread_id, + "context.compile", + "continuity_v0", + "completed", + ) + assert isinstance(create_trace_params[5], Jsonb) + assert create_trace_params[5].obj == {"max_sessions": 3, "max_events": 8} + + append_trace_query, append_trace_params = cursor.executed[3] + assert "INSERT INTO trace_events" in append_trace_query + assert append_trace_params is not None + assert append_trace_params[:3] == (trace_id, 1, "context.include") + assert isinstance(append_trace_params[3], Jsonb) + assert append_trace_params[3].obj == payload + + +def test_trace_event_updates_and_deletes_are_rejected_by_contract() -> None: + store = ContinuityStore(conn=None) # type: ignore[arg-type] + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.update_trace_event("trace-event-id", {"text": "mutated"}) + + with pytest.raises(AppendOnlyViolation, match="append-only"): + store.delete_trace_event("trace-event-id") + + +def test_get_trace_raises_clear_error_when_missing() -> None: + cursor = RecordingCursor(fetchone_results=[]) + store = ContinuityStore(RecordingConnection(cursor)) + + with pytest.raises( + ContinuityStoreInvariantError, + match="get_trace did not return a row", + ): + store.get_trace(uuid4()) diff --git a/tests/unit/test_worker_main.py b/tests/unit/test_worker_main.py new file mode 100644 index 0000000..b6391d2 --- /dev/null +++ b/tests/unit/test_worker_main.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import logging +import os +from pathlib import Path +import subprocess +import sys + +from workers.alicebot_worker.main import run + + +def test_run_logs_scaffold_message(caplog) -> None: + with caplog.at_level(logging.INFO, logger="alicebot.worker"): + run() + + assert caplog.messages == [ + "Worker scaffold initialized; no background jobs are in scope for this sprint." + ] + + +def test_module_entrypoint_logs_scaffold_message() -> None: + repo_root = Path(__file__).resolve().parents[2] + env = os.environ.copy() + pythonpath_entries = [str(repo_root / "apps" / "api" / "src"), str(repo_root / "workers")] + existing_pythonpath = env.get("PYTHONPATH") + if existing_pythonpath: + pythonpath_entries.append(existing_pythonpath) + env["PYTHONPATH"] = os.pathsep.join(pythonpath_entries) + + result = subprocess.run( + [sys.executable, "-m", "alicebot_worker.main"], + cwd=repo_root, + env=env, + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0 + assert "Worker scaffold initialized; no background jobs are in scope for this sprint." in result.stderr diff --git a/tests/unit/test_workspaces.py b/tests/unit/test_workspaces.py new file mode 100644 index 0000000..67cb1bc --- /dev/null +++ b/tests/unit/test_workspaces.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from pathlib import Path +from uuid import UUID, uuid4 + +import pytest + +from alicebot_api.config import Settings +from alicebot_api.contracts import TaskWorkspaceCreateInput +from alicebot_api.tasks import TaskNotFoundError +from alicebot_api.workspaces import ( + TaskWorkspaceAlreadyExistsError, + TaskWorkspaceNotFoundError, + TaskWorkspaceProvisioningError, + build_task_workspace_path, + create_task_workspace_record, + ensure_workspace_path_is_rooted, + get_task_workspace_record, + list_task_workspace_records, + serialize_task_workspace_row, +) + + +class WorkspaceStoreStub: + def __init__(self) -> None: + self.base_time = datetime(2026, 3, 13, 10, 0, tzinfo=UTC) + self.tasks: list[dict[str, object]] = [] + self.workspaces: list[dict[str, object]] = [] + self.locked_task_ids: list[UUID] = [] + + def create_task(self, *, task_id: UUID, user_id: UUID) -> None: + self.tasks.append( + { + "id": task_id, + "user_id": user_id, + "thread_id": uuid4(), + "tool_id": uuid4(), + "status": "approved", + "request": {}, + "tool": {}, + "latest_approval_id": None, + "latest_execution_id": None, + "created_at": self.base_time, + "updated_at": self.base_time, + } + ) + + def get_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next((task for task in self.tasks if task["id"] == task_id), None) + + def lock_task_workspaces(self, task_id: UUID) -> None: + self.locked_task_ids.append(task_id) + + def get_active_task_workspace_for_task_optional(self, task_id: UUID) -> dict[str, object] | None: + return next( + ( + workspace + for workspace in self.workspaces + if workspace["task_id"] == task_id and workspace["status"] == "active" + ), + None, + ) + + def create_task_workspace( + self, + *, + task_id: UUID, + status: str, + local_path: str, + ) -> dict[str, object]: + workspace = { + "id": uuid4(), + "user_id": self.tasks[0]["user_id"], + "task_id": task_id, + "status": status, + "local_path": local_path, + "created_at": self.base_time + timedelta(minutes=len(self.workspaces)), + "updated_at": self.base_time + timedelta(minutes=len(self.workspaces)), + } + self.workspaces.append(workspace) + return workspace + + def list_task_workspaces(self) -> list[dict[str, object]]: + return sorted(self.workspaces, key=lambda workspace: (workspace["created_at"], workspace["id"])) + + def get_task_workspace_optional(self, task_workspace_id: UUID) -> dict[str, object] | None: + return next((workspace for workspace in self.workspaces if workspace["id"] == task_workspace_id), None) + + +def test_build_task_workspace_path_is_deterministic() -> None: + user_id = UUID("00000000-0000-0000-0000-000000000111") + task_id = UUID("00000000-0000-0000-0000-000000000222") + root = Path("/tmp/alicebot/task-workspaces") + + path = build_task_workspace_path( + workspace_root=root, + user_id=user_id, + task_id=task_id, + ) + + assert path == Path("/tmp/alicebot/task-workspaces") / str(user_id) / str(task_id) + + +def test_ensure_workspace_path_is_rooted_rejects_escape() -> None: + with pytest.raises(TaskWorkspaceProvisioningError, match="escapes configured root"): + ensure_workspace_path_is_rooted( + workspace_root=Path("/tmp/alicebot/task-workspaces"), + workspace_path=Path("/tmp/alicebot/task-workspaces/../escape"), + ) + + +def test_create_task_workspace_record_provisions_directory_and_returns_record(tmp_path) -> None: + store = WorkspaceStoreStub() + user_id = uuid4() + task_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + settings = Settings(task_workspace_root=str(tmp_path / "workspaces")) + + response = create_task_workspace_record( + store, + settings=settings, + user_id=user_id, + request=TaskWorkspaceCreateInput(task_id=task_id, status="active"), + ) + + expected_path = tmp_path / "workspaces" / str(user_id) / str(task_id) + assert response == { + "workspace": { + "id": response["workspace"]["id"], + "task_id": str(task_id), + "status": "active", + "local_path": str(expected_path.resolve()), + "created_at": "2026-03-13T10:00:00+00:00", + "updated_at": "2026-03-13T10:00:00+00:00", + } + } + assert expected_path.is_dir() + assert store.locked_task_ids == [task_id] + + +def test_create_task_workspace_record_rejects_duplicate_active_workspace(tmp_path) -> None: + store = WorkspaceStoreStub() + user_id = uuid4() + task_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + settings = Settings(task_workspace_root=str(tmp_path / "workspaces")) + create_task_workspace_record( + store, + settings=settings, + user_id=user_id, + request=TaskWorkspaceCreateInput(task_id=task_id, status="active"), + ) + + with pytest.raises(TaskWorkspaceAlreadyExistsError, match=f"task {task_id} already has active workspace"): + create_task_workspace_record( + store, + settings=settings, + user_id=user_id, + request=TaskWorkspaceCreateInput(task_id=task_id, status="active"), + ) + + +def test_create_task_workspace_record_requires_visible_task(tmp_path) -> None: + store = WorkspaceStoreStub() + + with pytest.raises(TaskNotFoundError, match="was not found"): + create_task_workspace_record( + store, + settings=Settings(task_workspace_root=str(tmp_path / "workspaces")), + user_id=uuid4(), + request=TaskWorkspaceCreateInput(task_id=uuid4(), status="active"), + ) + + +def test_list_and_get_task_workspace_records_are_deterministic() -> None: + store = WorkspaceStoreStub() + user_id = uuid4() + task_id = uuid4() + store.create_task(task_id=task_id, user_id=user_id) + workspace = store.create_task_workspace( + task_id=task_id, + status="active", + local_path="/tmp/alicebot/task-workspaces/user/task", + ) + + assert list_task_workspace_records(store, user_id=user_id) == { + "items": [serialize_task_workspace_row(workspace)], + "summary": { + "total_count": 1, + "order": ["created_at_asc", "id_asc"], + }, + } + assert get_task_workspace_record( + store, + user_id=user_id, + task_workspace_id=workspace["id"], + ) == {"workspace": serialize_task_workspace_row(workspace)} + + +def test_get_task_workspace_record_raises_when_workspace_is_missing() -> None: + with pytest.raises(TaskWorkspaceNotFoundError, match="was not found"): + get_task_workspace_record( + WorkspaceStoreStub(), + user_id=uuid4(), + task_workspace_id=uuid4(), + ) diff --git a/tests/unit/test_workspaces_main.py b/tests/unit/test_workspaces_main.py new file mode 100644 index 0000000..b5f19ff --- /dev/null +++ b/tests/unit/test_workspaces_main.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import json +from contextlib import contextmanager +from uuid import uuid4 + +import apps.api.src.alicebot_api.main as main_module +from apps.api.src.alicebot_api.config import Settings +from alicebot_api.tasks import TaskNotFoundError +from alicebot_api.workspaces import TaskWorkspaceAlreadyExistsError, TaskWorkspaceNotFoundError + + +def test_list_task_workspaces_endpoint_returns_payload(monkeypatch) -> None: + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr( + main_module, + "list_task_workspace_records", + lambda *_args, **_kwargs: { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + }, + ) + + response = main_module.list_task_workspaces(user_id) + + assert response.status_code == 200 + assert json.loads(response.body) == { + "items": [], + "summary": {"total_count": 0, "order": ["created_at_asc", "id_asc"]}, + } + + +def test_get_task_workspace_endpoint_maps_not_found_to_404(monkeypatch) -> None: + user_id = uuid4() + task_workspace_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_get_task_workspace_record(*_args, **_kwargs): + raise TaskWorkspaceNotFoundError(f"task workspace {task_workspace_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "get_task_workspace_record", fake_get_task_workspace_record) + + response = main_module.get_task_workspace(task_workspace_id, user_id) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task workspace {task_workspace_id} was not found"} + + +def test_create_task_workspace_endpoint_maps_task_not_found_to_404(monkeypatch) -> None: + task_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_task_workspace_record(*_args, **_kwargs): + raise TaskNotFoundError(f"task {task_id} was not found") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_task_workspace_record", fake_create_task_workspace_record) + + response = main_module.create_task_workspace( + task_id, + main_module.CreateTaskWorkspaceRequest(user_id=user_id), + ) + + assert response.status_code == 404 + assert json.loads(response.body) == {"detail": f"task {task_id} was not found"} + + +def test_create_task_workspace_endpoint_maps_duplicate_to_409(monkeypatch) -> None: + task_id = uuid4() + user_id = uuid4() + settings = Settings(database_url="postgresql://app") + + @contextmanager + def fake_user_connection(*_args, **_kwargs): + yield object() + + def fake_create_task_workspace_record(*_args, **_kwargs): + raise TaskWorkspaceAlreadyExistsError(f"task {task_id} already has active workspace workspace-123") + + monkeypatch.setattr(main_module, "get_settings", lambda: settings) + monkeypatch.setattr(main_module, "user_connection", fake_user_connection) + monkeypatch.setattr(main_module, "create_task_workspace_record", fake_create_task_workspace_record) + + response = main_module.create_task_workspace( + task_id, + main_module.CreateTaskWorkspaceRequest(user_id=user_id), + ) + + assert response.status_code == 409 + assert json.loads(response.body) == { + "detail": f"task {task_id} already has active workspace workspace-123" + } diff --git a/workers/.gitkeep b/workers/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/workers/.gitkeep @@ -0,0 +1 @@ + diff --git a/workers/alicebot_worker/__init__.py b/workers/alicebot_worker/__init__.py new file mode 100644 index 0000000..462d476 --- /dev/null +++ b/workers/alicebot_worker/__init__.py @@ -0,0 +1,2 @@ +"""Worker scaffold for future asynchronous jobs.""" + diff --git a/workers/alicebot_worker/main.py b/workers/alicebot_worker/main.py new file mode 100644 index 0000000..21d01ff --- /dev/null +++ b/workers/alicebot_worker/main.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import logging + + +LOGGER = logging.getLogger("alicebot.worker") + + +def run() -> None: + LOGGER.info("Worker scaffold initialized; no background jobs are in scope for this sprint.") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + run()