Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openviking/client/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ async def commit_session(

async def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
"""Query background task status."""
return await self._service.sessions.get_commit_task(task_id)
return await self._service.sessions.get_commit_task(task_id, self._ctx)

async def add_message(
self,
Expand Down
14 changes: 12 additions & 2 deletions openviking/server/routers/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,12 @@ async def reindex(

if request.wait:
# Synchronous path: block until reindex completes
if tracker.has_running(REINDEX_TASK_TYPE, uri):
if tracker.has_running(
REINDEX_TASK_TYPE,
uri,
owner_account_id=_ctx.account_id,
owner_user_id=_ctx.user.user_id,
):
return Response(
status="error",
error=ErrorInfo(
Expand All @@ -190,7 +195,12 @@ async def reindex(
return Response(status="ok", result=result)
else:
# Async path: run in background, return task_id for polling
task = tracker.create_if_no_running(REINDEX_TASK_TYPE, uri)
task = tracker.create_if_no_running(
REINDEX_TASK_TYPE,
uri,
owner_account_id=_ctx.account_id,
owner_user_id=_ctx.user.user_id,
)
if task is None:
return Response(
status="error",
Expand Down
18 changes: 15 additions & 3 deletions openviking/server/routers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,28 @@

from typing import Optional

from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query

from openviking.server.auth import get_request_context
from openviking.server.identity import RequestContext
from openviking.server.models import Response
from openviking.service.task_tracker import get_task_tracker

router = APIRouter(prefix="/api/v1", tags=["tasks"])


@router.get("/tasks/{task_id}")
async def get_task(task_id: str):
async def get_task(
task_id: str,
_ctx: RequestContext = Depends(get_request_context),
):
"""Get the status of a single background task."""
tracker = get_task_tracker()
task = tracker.get(task_id)
task = tracker.get(
task_id,
owner_account_id=_ctx.account_id,
owner_user_id=_ctx.user.user_id,
)
if not task:
raise HTTPException(status_code=404, detail="Task not found or expired")
return Response(status="ok", result=task.to_dict())
Expand All @@ -35,6 +44,7 @@ async def list_tasks(
),
resource_id: Optional[str] = Query(None, description="Filter by resource ID (e.g. session_id)"),
limit: int = Query(50, le=200, description="Max results"),
_ctx: RequestContext = Depends(get_request_context),
):
"""List background tasks with optional filters."""
tracker = get_task_tracker()
Expand All @@ -43,5 +53,7 @@ async def list_tasks(
status=status,
resource_id=resource_id,
limit=limit,
owner_account_id=_ctx.account_id,
owner_user_id=_ctx.user.user_id,
)
return Response(status="ok", result=[t.to_dict() for t in tasks])
10 changes: 7 additions & 3 deletions openviking/service/session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,13 @@ async def commit_async(self, session_id: str, ctx: RequestContext) -> Dict[str,
session = await self.get(session_id, ctx)
return await session.commit_async()

async def get_commit_task(self, task_id: str) -> Optional[Dict[str, Any]]:
"""Query background commit task status by task_id."""
task = get_task_tracker().get(task_id)
async def get_commit_task(self, task_id: str, ctx: RequestContext) -> Optional[Dict[str, Any]]:
"""Query background commit task status by task_id for the calling owner."""
task = get_task_tracker().get(
task_id,
owner_account_id=ctx.account_id,
owner_user_id=ctx.user.user_id,
)
return task.to_dict() if task else None

async def extract(self, session_id: str, ctx: RequestContext) -> List[Any]:
Expand Down
76 changes: 70 additions & 6 deletions openviking/service/task_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,17 @@ class TaskRecord:
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
resource_id: Optional[str] = None # e.g. session_id
owner_account_id: Optional[str] = None
owner_user_id: Optional[str] = None
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None

def to_dict(self) -> Dict[str, Any]:
"""Serialize for JSON response."""
d = asdict(self)
d["status"] = self.status.value
d.pop("owner_account_id", None)
d.pop("owner_user_id", None)
return d


Expand Down Expand Up @@ -170,14 +174,43 @@ def _evict_expired(self) -> None:
if to_delete:
logger.debug("[TaskTracker] Evicted %d expired tasks", len(to_delete))

@staticmethod
def _matches_owner(
task: TaskRecord,
owner_account_id: Optional[str] = None,
owner_user_id: Optional[str] = None,
) -> bool:
"""Return True when a task belongs to the requested owner filter."""
if owner_account_id is not None and task.owner_account_id != owner_account_id:
return False
if owner_user_id is not None and task.owner_user_id != owner_user_id:
return False
return True

@staticmethod
def _validate_owner(owner_account_id: str, owner_user_id: str) -> None:
"""Reject ownerless task creation for user-originated background work."""
if not owner_account_id or not owner_user_id:
raise ValueError("Task ownership requires non-empty owner_account_id and owner_user_id")

# ── CRUD ──

def create(self, task_type: str, resource_id: Optional[str] = None) -> TaskRecord:
def create(
self,
task_type: str,
resource_id: Optional[str] = None,
*,
owner_account_id: str,
owner_user_id: str,
) -> TaskRecord:
"""Register a new pending task. Returns a snapshot copy."""
self._validate_owner(owner_account_id, owner_user_id)
task = TaskRecord(
task_id=str(uuid4()),
task_type=task_type,
resource_id=resource_id,
owner_account_id=owner_account_id,
owner_user_id=owner_user_id,
)
with self._lock:
self._tasks[task.task_id] = task
Expand All @@ -189,17 +222,26 @@ def create(self, task_type: str, resource_id: Optional[str] = None) -> TaskRecor
)
return self._copy(task)

def create_if_no_running(self, task_type: str, resource_id: str) -> Optional[TaskRecord]:
def create_if_no_running(
self,
task_type: str,
resource_id: str,
*,
owner_account_id: str,
owner_user_id: str,
) -> Optional[TaskRecord]:
"""Atomically check for running tasks and create a new one if none exist.

Returns TaskRecord on success, None if a running task already exists.
This eliminates the race condition between has_running() and create().
"""
self._validate_owner(owner_account_id, owner_user_id)
with self._lock:
# Check for existing running tasks
has_active = any(
t.task_type == task_type
and t.resource_id == resource_id
and self._matches_owner(t, owner_account_id, owner_user_id)
and t.status in (TaskStatus.PENDING, TaskStatus.RUNNING)
for t in self._tasks.values()
)
Expand All @@ -210,6 +252,8 @@ def create_if_no_running(self, task_type: str, resource_id: str) -> Optional[Tas
task_id=str(uuid4()),
task_type=task_type,
resource_id=resource_id,
owner_account_id=owner_account_id,
owner_user_id=owner_user_id,
)
self._tasks[task.task_id] = task
logger.debug(
Expand Down Expand Up @@ -248,22 +292,35 @@ def fail(self, task_id: str, error: str) -> None:
task.updated_at = time.time()
logger.warning("[TaskTracker] Task %s failed: %s", task_id, _sanitize_error(error))

def get(self, task_id: str) -> Optional[TaskRecord]:
def get(
self,
task_id: str,
owner_account_id: Optional[str] = None,
owner_user_id: Optional[str] = None,
) -> Optional[TaskRecord]:
"""Look up a single task. Returns a snapshot copy (None if not found)."""
with self._lock:
task = self._tasks.get(task_id)
return self._copy(task) if task else None
if task is None or not self._matches_owner(task, owner_account_id, owner_user_id):
return None
return self._copy(task)

def list_tasks(
self,
task_type: Optional[str] = None,
status: Optional[str] = None,
resource_id: Optional[str] = None,
limit: int = 50,
owner_account_id: Optional[str] = None,
owner_user_id: Optional[str] = None,
) -> List[TaskRecord]:
"""List tasks with optional filters. Most-recent first. Returns snapshot copies."""
with self._lock:
tasks = [self._copy(t) for t in self._tasks.values()]
tasks = [
self._copy(t)
for t in self._tasks.values()
if self._matches_owner(t, owner_account_id, owner_user_id)
]
if task_type:
tasks = [t for t in tasks if t.task_type == task_type]
if status:
Expand All @@ -273,12 +330,19 @@ def list_tasks(
tasks.sort(key=lambda t: t.created_at, reverse=True)
return tasks[:limit]

def has_running(self, task_type: str, resource_id: str) -> bool:
def has_running(
self,
task_type: str,
resource_id: str,
owner_account_id: Optional[str] = None,
owner_user_id: Optional[str] = None,
) -> bool:
"""Check if there is already a running task for the given type+resource."""
with self._lock:
return any(
t.task_type == task_type
and t.resource_id == resource_id
and self._matches_owner(t, owner_account_id, owner_user_id)
and t.status in (TaskStatus.PENDING, TaskStatus.RUNNING)
for t in self._tasks.values()
)
Expand Down
7 changes: 6 additions & 1 deletion openviking/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,12 @@ async def commit_async(self) -> Dict[str, Any]:

# Create TaskRecord for tracking Phase 2
tracker = get_task_tracker()
task = tracker.create("session_commit", resource_id=self.session_id)
task = tracker.create(
"session_commit",
resource_id=self.session_id,
owner_account_id=self.ctx.account_id,
owner_user_id=self.ctx.user.user_id,
)

asyncio.create_task(
self._run_memory_extraction(
Expand Down
2 changes: 1 addition & 1 deletion tests/server/test_api_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def exists(self, uri, ctx=None):
return True

class FakeTracker:
def has_running(self, task_type, uri):
def has_running(self, task_type, uri, owner_account_id=None, owner_user_id=None):
return False

async def fake_do_reindex(service, uri, regenerate, ctx):
Expand Down
73 changes: 73 additions & 0 deletions tests/server/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from openviking.server.identity import ResolvedIdentity, Role
from openviking.server.models import ERROR_CODE_TO_HTTP_STATUS, ErrorInfo, Response
from openviking.service.core import OpenVikingService
from openviking.service.task_tracker import get_task_tracker, reset_task_tracker
from openviking_cli.exceptions import InvalidArgumentError, OpenVikingError
from openviking_cli.session.user_id import UserIdentifier

Expand Down Expand Up @@ -128,6 +129,15 @@ async def debug_vector_scroll(ctx=Depends(get_request_context)):
return app


def _build_task_http_test_app(identity: ResolvedIdentity | None) -> FastAPI:
"""Build a lightweight app that mounts the real task router."""
from openviking.server.routers import tasks as tasks_router

app = _build_auth_http_test_app(identity=identity, auth_enabled=True, root_api_key=ROOT_KEY)
app.include_router(tasks_router.router)
return app


@pytest_asyncio.fixture(scope="function")
async def auth_service(temp_dir):
"""Service for auth tests."""
Expand Down Expand Up @@ -276,6 +286,69 @@ async def test_auth_on_multiple_endpoints(auth_client: httpx.AsyncClient):
assert tenant_resp.status_code == 200


async def test_task_endpoints_require_auth():
"""Task endpoints must reject unauthenticated callers before lookup/filtering."""
reset_task_tracker()
app = _build_task_http_test_app(identity=None)
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
for url in ("/api/v1/tasks", "/api/v1/tasks/nonexistent-id"):
resp = await client.get(url)
assert resp.status_code == 401
reset_task_tracker()


async def test_task_endpoints_are_user_scoped():
"""Authenticated callers must not see another user's background tasks."""
reset_task_tracker()
account_id = _uid()
tracker = get_task_tracker()
alice_task = tracker.create(
"session_commit",
resource_id="alice-session",
owner_account_id=account_id,
owner_user_id="alice",
)
bob_task = tracker.create(
"session_commit",
resource_id="bob-session",
owner_account_id=account_id,
owner_user_id="bob",
)

alice_app = _build_task_http_test_app(
ResolvedIdentity(role=Role.ADMIN, account_id=account_id, user_id="alice")
)
bob_app = _build_task_http_test_app(
ResolvedIdentity(role=Role.ADMIN, account_id=account_id, user_id="bob")
)
alice_transport = httpx.ASGITransport(app=alice_app)
bob_transport = httpx.ASGITransport(app=bob_app)

async with httpx.AsyncClient(
transport=alice_transport, base_url="http://testserver"
) as alice_client:
alice_get = await alice_client.get(f"/api/v1/tasks/{alice_task.task_id}")
assert alice_get.status_code == 200
assert alice_get.json()["result"]["resource_id"] == "alice-session"

alice_list = await alice_client.get("/api/v1/tasks")
assert alice_list.status_code == 200
assert {task["task_id"] for task in alice_list.json()["result"]} == {alice_task.task_id}

async with httpx.AsyncClient(
transport=bob_transport, base_url="http://testserver"
) as bob_client:
bob_get_other = await bob_client.get(f"/api/v1/tasks/{alice_task.task_id}")
assert bob_get_other.status_code == 404

bob_list = await bob_client.get("/api/v1/tasks")
assert bob_list.status_code == 200
assert {task["task_id"] for task in bob_list.json()["result"]} == {bob_task.task_id}

reset_task_tracker()


# ---- Role-based access tests ----


Expand Down
7 changes: 6 additions & 1 deletion tests/test_session_task_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def _make_tracked_commit(behavior="instant", result_overrides=None, gate=None, s

async def mock_commit(_sid, _ctx):
tracker = get_task_tracker()
task = tracker.create("session_commit", resource_id=_sid)
task = tracker.create(
"session_commit",
resource_id=_sid,
owner_account_id=_ctx.account_id,
owner_user_id=_ctx.user.user_id,
)
archive_uri = f"viking://session/test/{_sid}/history/archive_001"

async def _background():
Expand Down
Loading
Loading