Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions agent-langgraph-advanced/agent_server/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from agent_server.prompts import SYSTEM_PROMPT
from agent_server.utils import (
_get_or_create_thread_id,
deduplicate_input,
get_user_workspace_client,
init_mcp_client,
process_agent_astream_events,
Expand Down Expand Up @@ -110,10 +111,7 @@ async def stream_handler(
if user_id:
config["configurable"]["user_id"] = user_id

input_state: dict[str, Any] = {
"messages": to_chat_completions_input([i.model_dump() for i in request.input]),
"custom_inputs": dict(request.custom_inputs or {}),
}
incoming_messages = to_chat_completions_input([i.model_dump() for i in request.input])

try:
async with lakebase_context(LAKEBASE_CONFIG) as (checkpointer, store):
Expand All @@ -123,6 +121,11 @@ async def stream_handler(
# For on-behalf-of user authentication, pass get_user_workspace_client() to init_agent.
agent = await init_agent(store=store, checkpointer=checkpointer)

input_state: dict[str, Any] = {
"messages": await deduplicate_input(agent, config, incoming_messages),
"custom_inputs": dict(request.custom_inputs or {}),
}

async for event in process_agent_astream_events(
agent.astream(input_state, config, stream_mode=["updates", "messages"])
):
Expand Down
16 changes: 16 additions & 0 deletions agent-langgraph-advanced/agent_server/start_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@
load_dotenv(dotenv_path=Path(__file__).parent.parent / ".env", override=True)

import logging
import sys

from databricks_ai_bridge.long_running import LongRunningAgentServer
from mlflow.genai.agent_server import setup_mlflow_git_based_version_tracking

logger = logging.getLogger(__name__)

# Surface databricks_ai_bridge INFO logs (durable-execution lifecycle:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bbqiu do we want to maybe gate this behind some debug env var? (my pr desc mentions the env var used to enable the debug kill endpoint, maybe we can bar the logging behind that too?)

# task spawn, resume, prose-recovery, terminal status, stale-scan claims)
# in app stdout. Uvicorn's default logging config drops INFO from
# non-uvicorn loggers, so attach a stream handler explicitly.
_bridge_logger = logging.getLogger("databricks_ai_bridge")
if _bridge_logger.level == logging.NOTSET or _bridge_logger.level > logging.INFO:
_bridge_logger.setLevel(logging.INFO)
if not any(isinstance(h, logging.StreamHandler) for h in _bridge_logger.handlers):
_bridge_handler = logging.StreamHandler(sys.stdout)
_bridge_handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
)
_bridge_logger.addHandler(_bridge_handler)
_bridge_logger.propagate = False

# Need to import the agent to register the functions with the server
import agent_server.agent # noqa: F401

Expand Down
21 changes: 21 additions & 0 deletions agent-langgraph-advanced/agent_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def _is_databricks_app_env() -> bool:
return bool(os.getenv("DATABRICKS_APP_NAME"))


async def deduplicate_input(
agent: Any, config: dict[str, Any], messages: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Drop UI-echoed history when the checkpointer already holds the thread.

The chatbot UI replays the full conversation on each turn, but LangGraph's
checkpointer already has the prior messages keyed by ``thread_id``. Sending
them again duplicates everything in the agent's view. When we detect an
existing checkpoint for this thread, keep only the latest user message.
"""
if not messages:
return messages
try:
state = await agent.aget_state(config)
except Exception:
return messages
if state and state.values.get("messages"):
return messages[-1:]
return messages


def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerMCPClient:
host_name = get_databricks_host_from_env()
return DatabricksMultiServerMCPClient(
Expand Down
16 changes: 16 additions & 0 deletions agent-openai-advanced/agent_server/start_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
load_dotenv(dotenv_path=Path(__file__).parent.parent / ".env", override=True)

import logging
import sys

from databricks_ai_bridge.long_running import LongRunningAgentServer
from databricks_openai.agents import AsyncDatabricksSession
Expand All @@ -23,6 +24,21 @@

logger = logging.getLogger(__name__)

# Surface databricks_ai_bridge INFO logs (durable-execution lifecycle:
# task spawn, resume, prose-recovery, terminal status, stale-scan claims)
# in app stdout. Uvicorn's default logging config drops INFO from
# non-uvicorn loggers, so attach a stream handler explicitly.
_bridge_logger = logging.getLogger("databricks_ai_bridge")
if _bridge_logger.level == logging.NOTSET or _bridge_logger.level > logging.INFO:
_bridge_logger.setLevel(logging.INFO)
if not any(isinstance(h, logging.StreamHandler) for h in _bridge_logger.handlers):
_bridge_handler = logging.StreamHandler(sys.stdout)
_bridge_handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
)
_bridge_logger.addHandler(_bridge_handler)
_bridge_logger.propagate = False


async def run_lakebase_session_setup() -> None:
"""Create session tables at startup so per-request _ensure_tables is a no-op."""
Expand Down
7 changes: 6 additions & 1 deletion agent-openai-advanced/agent_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,13 @@ async def deduplicate_input(request: ResponsesAgentRequest, session: AsyncDatabr
and isinstance(msg.get("content"), str)
):
msg["content"] = [{"type": "output_text", "text": msg["content"], "annotations": []}]
# Session is authoritative for cross-turn history when non-empty; only
# forward the latest message. Count-based heuristics break under
# prose-recovery (rotated session is fresh while the UI echo accumulated
# events from both attempts), forwarding duplicates that Anthropic-backed
# models reject as malformed tool_use/tool_result pairs.
session_items = await session.get_items()
if len(session_items) >= len(messages) - 1:
if session_items and len(messages) > 1:
return [messages[-1]]
return messages

Expand Down
Loading