diff --git a/agent-langgraph-advanced/agent_server/agent.py b/agent-langgraph-advanced/agent_server/agent.py index 22d0e8bc..49316d76 100644 --- a/agent-langgraph-advanced/agent_server/agent.py +++ b/agent-langgraph-advanced/agent_server/agent.py @@ -23,6 +23,7 @@ from agent_server.prompts import SYSTEM_PROMPT from agent_server.utils import ( _get_or_create_thread_id, + deduplicate_input, get_user_workspace_client, init_mcp_client, process_agent_astream_events, @@ -110,10 +111,7 @@ async def stream_handler( if user_id: config["configurable"]["user_id"] = user_id - input_state: dict[str, Any] = { - "messages": to_chat_completions_input([i.model_dump() for i in request.input]), - "custom_inputs": dict(request.custom_inputs or {}), - } + incoming_messages = to_chat_completions_input([i.model_dump() for i in request.input]) try: async with lakebase_context(LAKEBASE_CONFIG) as (checkpointer, store): @@ -123,6 +121,11 @@ async def stream_handler( # For on-behalf-of user authentication, pass get_user_workspace_client() to init_agent. agent = await init_agent(store=store, checkpointer=checkpointer) + input_state: dict[str, Any] = { + "messages": await deduplicate_input(agent, config, incoming_messages), + "custom_inputs": dict(request.custom_inputs or {}), + } + async for event in process_agent_astream_events( agent.astream(input_state, config, stream_mode=["updates", "messages"]) ): diff --git a/agent-langgraph-advanced/agent_server/start_server.py b/agent-langgraph-advanced/agent_server/start_server.py index a82b4004..d21ce9d8 100644 --- a/agent-langgraph-advanced/agent_server/start_server.py +++ b/agent-langgraph-advanced/agent_server/start_server.py @@ -11,12 +11,28 @@ load_dotenv(dotenv_path=Path(__file__).parent.parent / ".env", override=True) import logging +import sys from databricks_ai_bridge.long_running import LongRunningAgentServer from mlflow.genai.agent_server import setup_mlflow_git_based_version_tracking logger = logging.getLogger(__name__) +# Surface databricks_ai_bridge INFO logs (durable-execution lifecycle: +# task spawn, resume, prose-recovery, terminal status, stale-scan claims) +# in app stdout. Uvicorn's default logging config drops INFO from +# non-uvicorn loggers, so attach a stream handler explicitly. +_bridge_logger = logging.getLogger("databricks_ai_bridge") +if _bridge_logger.level == logging.NOTSET or _bridge_logger.level > logging.INFO: + _bridge_logger.setLevel(logging.INFO) +if not any(isinstance(h, logging.StreamHandler) for h in _bridge_logger.handlers): + _bridge_handler = logging.StreamHandler(sys.stdout) + _bridge_handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") + ) + _bridge_logger.addHandler(_bridge_handler) + _bridge_logger.propagate = False + # Need to import the agent to register the functions with the server import agent_server.agent # noqa: F401 diff --git a/agent-langgraph-advanced/agent_server/utils.py b/agent-langgraph-advanced/agent_server/utils.py index 75b92de2..2d04b39f 100644 --- a/agent-langgraph-advanced/agent_server/utils.py +++ b/agent-langgraph-advanced/agent_server/utils.py @@ -40,6 +40,27 @@ def _is_databricks_app_env() -> bool: return bool(os.getenv("DATABRICKS_APP_NAME")) +async def deduplicate_input( + agent: Any, config: dict[str, Any], messages: list[dict[str, Any]] +) -> list[dict[str, Any]]: + """Drop UI-echoed history when the checkpointer already holds the thread. + + The chatbot UI replays the full conversation on each turn, but LangGraph's + checkpointer already has the prior messages keyed by ``thread_id``. Sending + them again duplicates everything in the agent's view. When we detect an + existing checkpoint for this thread, keep only the latest user message. + """ + if not messages: + return messages + try: + state = await agent.aget_state(config) + except Exception: + return messages + if state and state.values.get("messages"): + return messages[-1:] + return messages + + def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerMCPClient: host_name = get_databricks_host_from_env() return DatabricksMultiServerMCPClient( diff --git a/agent-openai-advanced/agent_server/start_server.py b/agent-openai-advanced/agent_server/start_server.py index 802ac13a..4d902dc6 100644 --- a/agent-openai-advanced/agent_server/start_server.py +++ b/agent-openai-advanced/agent_server/start_server.py @@ -11,6 +11,7 @@ load_dotenv(dotenv_path=Path(__file__).parent.parent / ".env", override=True) import logging +import sys from databricks_ai_bridge.long_running import LongRunningAgentServer from databricks_openai.agents import AsyncDatabricksSession @@ -23,6 +24,21 @@ logger = logging.getLogger(__name__) +# Surface databricks_ai_bridge INFO logs (durable-execution lifecycle: +# task spawn, resume, prose-recovery, terminal status, stale-scan claims) +# in app stdout. Uvicorn's default logging config drops INFO from +# non-uvicorn loggers, so attach a stream handler explicitly. +_bridge_logger = logging.getLogger("databricks_ai_bridge") +if _bridge_logger.level == logging.NOTSET or _bridge_logger.level > logging.INFO: + _bridge_logger.setLevel(logging.INFO) +if not any(isinstance(h, logging.StreamHandler) for h in _bridge_logger.handlers): + _bridge_handler = logging.StreamHandler(sys.stdout) + _bridge_handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s") + ) + _bridge_logger.addHandler(_bridge_handler) + _bridge_logger.propagate = False + async def run_lakebase_session_setup() -> None: """Create session tables at startup so per-request _ensure_tables is a no-op.""" diff --git a/agent-openai-advanced/agent_server/utils.py b/agent-openai-advanced/agent_server/utils.py index 7cd07e8c..30a6c73c 100644 --- a/agent-openai-advanced/agent_server/utils.py +++ b/agent-openai-advanced/agent_server/utils.py @@ -206,8 +206,13 @@ async def deduplicate_input(request: ResponsesAgentRequest, session: AsyncDatabr and isinstance(msg.get("content"), str) ): msg["content"] = [{"type": "output_text", "text": msg["content"], "annotations": []}] + # Session is authoritative for cross-turn history when non-empty; only + # forward the latest message. Count-based heuristics break under + # prose-recovery (rotated session is fresh while the UI echo accumulated + # events from both attempts), forwarding duplicates that Anthropic-backed + # models reject as malformed tool_use/tool_result pairs. session_items = await session.get_items() - if len(session_items) >= len(messages) - 1: + if session_items and len(messages) > 1: return [messages[-1]] return messages diff --git a/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts b/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts index f19dad35..39b5761f 100644 --- a/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts +++ b/e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts @@ -73,6 +73,47 @@ const LOG_SSE_EVENTS = process.env.LOG_SSE_EVENTS === 'true'; const API_PROXY = process.env.API_PROXY; +// Durable-execution support: when talking to a `LongRunningAgentServer` agent +// (the case when API_PROXY is set in our advanced templates), we +// 1. inject `background: true` so the server persists every SSE frame to its +// durable store and our retrieve endpoint can resume mid-stream; +// 2. capture the rotated `conversation_id` from the `response.resumed` +// sentinel and replay it on the next user turn — without this, the next +// turn lands on the orphan-poisoned session; +// 3. on connection close without `[DONE]`, transparently re-stream from the +// retrieve endpoint using the last seen sequence number. +// +// All three live here because `databricksFetch` is the single boundary the +// Vercel AI SDK pipes every agent request through. +const MAX_RESUME_ATTEMPTS = 5; +const conversationAliasMap = new Map(); + +function captureRotation( + json: Record | null, + originalChatId: string | null, +): void { + if (!json || !originalChatId) return; + if (json.type !== 'response.resumed') return; + const rotated = json.conversation_id; + if (typeof rotated === 'string' && rotated.length > 0) { + conversationAliasMap.set(originalChatId, rotated); + } +} + +function extractResponseId(json: Record | null): string | null { + if (!json) return null; + if (typeof json.response_id === 'string') return json.response_id; + const resp = json.response as { id?: unknown } | undefined; + if (resp && typeof resp.id === 'string') return resp.id; + return null; +} + +function buildRetrieveUrl(invocationsUrl: string, responseId: string): string { + // The bridge mounts GET /responses/{id} on the same origin as POST /invocations. + const base = invocationsUrl.replace(/\/invocations\/?$/, ''); + return `${base}/responses/${encodeURIComponent(responseId)}`; +} + // Cache for endpoint details to check task type and OBO scopes const endpointDetailsCache = new Map< string, @@ -110,28 +151,53 @@ export const databricksFetch: typeof fetch = async (input, init) => { headers.delete(CONTEXT_HEADER_USER_ID); requestInit = { ...requestInit, headers }; - // Inject context into request body if appropriate - if ( - conversationId && - userId && - requestInit?.body && - typeof requestInit.body === 'string' - ) { - if (shouldInjectContext()) { - try { - const body = JSON.parse(requestInit.body); - const enhancedBody = { - ...body, - context: { - ...body.context, - conversation_id: conversationId, - user_id: userId, - }, + // Mutate the request body for durable execution (when we have a body to + // mutate). Three things happen here, all conditional: + // - Inject context.conversation_id / context.user_id from headers when the + // endpoint expects it (existing behavior). + // - Substitute conversation_id with any previously-captured rotated alias + // so subsequent turns land on the right (post-resume) session. + // - Set body.background = true on streaming requests when API_PROXY is + // set, so the long-running server persists the stream to its store. + let originalChatId: string | null = null; + if (requestInit?.body && typeof requestInit.body === 'string') { + try { + const body = JSON.parse(requestInit.body); + let mutated = false; + + if (conversationId && userId && shouldInjectContext()) { + body.context = { + ...(body.context ?? {}), + conversation_id: conversationId, + user_id: userId, }; - requestInit = { ...requestInit, body: JSON.stringify(enhancedBody) }; - } catch { - // If JSON parsing fails, pass through unchanged + mutated = true; } + + const ctx = body.context as { conversation_id?: unknown } | undefined; + const ctxConvId = + ctx && typeof ctx.conversation_id === 'string' + ? ctx.conversation_id + : null; + if (ctxConvId) { + originalChatId = ctxConvId; + const aliased = conversationAliasMap.get(ctxConvId); + if (aliased && aliased !== ctxConvId) { + body.context = { ...body.context, conversation_id: aliased }; + mutated = true; + } + } + + if (API_PROXY && body.stream === true && body.background !== true) { + body.background = true; + mutated = true; + } + + if (mutated) { + requestInit = { ...requestInit, body: JSON.stringify(body) }; + } + } catch { + // If JSON parsing fails, pass through unchanged } } @@ -161,58 +227,20 @@ export const databricksFetch: typeof fetch = async (input, init) => { const response = await fetch(url, requestInit); - // If SSE logging is enabled and this is a streaming response, wrap the body to log events - if (LOG_SSE_EVENTS && response.body) { + if (response.body) { const contentType = response.headers.get('content-type') || ''; const isSSE = contentType.includes('text/event-stream') || contentType.includes('application/x-ndjson'); if (isSSE) { - const originalBody = response.body; - const reader = originalBody.getReader(); - const decoder = new TextDecoder(); - let eventCounter = 0; - - const loggingStream = new ReadableStream({ - async pull(controller) { - const { done, value } = await reader.read(); - - if (done) { - console.log('[SSE] Stream ended'); - controller.close(); - return; - } - - // Decode and log the chunk - const text = decoder.decode(value, { stream: true }); - const lines = text.split('\n').filter((line) => line.trim()); - - for (const line of lines) { - eventCounter++; - if (line.startsWith('data:')) { - const data = line.slice(5).trim(); - try { - const parsed = JSON.parse(data); - console.log(`[SSE #${eventCounter}]`, JSON.stringify(parsed)); - } catch { - console.log(`[SSE #${eventCounter}] (raw)`, data); - } - } else if (line.trim()) { - console.log(`[SSE #${eventCounter}] (line)`, line); - } - } - - // Pass the original data through - controller.enqueue(value); - }, - cancel() { - reader.cancel(); - }, - }); - - // Create a new response with the logging stream - return new Response(loggingStream, { + const wrapped = wrapDurableSseStream( + response.body, + url, + requestInit?.headers, + originalChatId, + ); + return new Response(wrapped, { status: response.status, statusText: response.statusText, headers: response.headers, @@ -223,6 +251,116 @@ export const databricksFetch: typeof fetch = async (input, init) => { return response; }; +/** + * Wrap a long-running-server SSE response so we can: + * - sniff `response.resumed` frames and update the conversation alias map, + * - track the last sequence number and response_id we observed, + * - if the upstream stream closes before `[DONE]`, transparently re-stream + * from `GET /responses/{id}?stream=true&starting_after=`. + * + * Bytes are passed through untouched; we only sniff data frames. + */ +function wrapDurableSseStream( + initialBody: ReadableStream, + invocationsUrl: string, + reqHeaders: HeadersInit | undefined, + originalChatId: string | null, +): ReadableStream { + const decoder = new TextDecoder(); + let buffer = ''; + let eventCounter = 0; + let responseId: string | null = null; + let lastSeq = -1; + let sawDone = false; + let attemptsLeft = MAX_RESUME_ATTEMPTS; + + function processChunk(value: Uint8Array): void { + buffer += decoder.decode(value, { stream: true }); + while (true) { + const nl = buffer.indexOf('\n'); + if (nl === -1) break; + const line = buffer.slice(0, nl); + buffer = buffer.slice(nl + 1); + const trimmed = line.trim(); + if (!trimmed) continue; + if (!trimmed.startsWith('data:')) continue; + const data = trimmed.slice(5).trim(); + if (data === '[DONE]') { + sawDone = true; + if (LOG_SSE_EVENTS) console.log(`[SSE #${++eventCounter}] [DONE]`); + continue; + } + let json: Record | null = null; + try { + json = JSON.parse(data) as Record; + } catch { + if (LOG_SSE_EVENTS) console.log(`[SSE #${++eventCounter}] (raw)`, data); + continue; + } + if (LOG_SSE_EVENTS) { + console.log(`[SSE #${++eventCounter}]`, JSON.stringify(json)); + } + captureRotation(json, originalChatId); + const rid = extractResponseId(json); + if (rid) responseId = rid; + const seq = json.sequence_number; + if (typeof seq === 'number' && seq > lastSeq) lastSeq = seq; + } + } + + return new ReadableStream({ + async start(controller) { + let currentBody: ReadableStream | null = initialBody; + + while (currentBody) { + const reader = currentBody.getReader(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + controller.enqueue(value); + processChunk(value); + } + } catch (err) { + if (LOG_SSE_EVENTS) console.warn('[SSE] read error', err); + } finally { + reader.releaseLock(); + } + + if (sawDone) break; + if (!responseId || attemptsLeft <= 0) break; + + attemptsLeft -= 1; + const startingAfter = Math.max(lastSeq, 0); + const resumeUrl = + `${buildRetrieveUrl(invocationsUrl, responseId)}` + + `?stream=true&starting_after=${startingAfter}`; + console.log( + `[SSE] upstream closed without [DONE], resuming response_id=${responseId} from seq=${startingAfter}`, + ); + try { + const resp = await fetch(resumeUrl, { + method: 'GET', + headers: reqHeaders, + }); + if (!resp.ok || !resp.body) { + console.warn( + `[SSE] resume request failed status=${resp.status}, giving up`, + ); + break; + } + currentBody = resp.body; + } catch (err) { + console.warn('[SSE] resume fetch threw, giving up', err); + break; + } + } + + controller.close(); + }, + }); +} + type CachedProvider = ReturnType; let oauthProviderCache: CachedProvider | null = null; let oauthProviderCacheTime = 0;