diff --git a/dash-chatbot-app/model_serving_utils.py b/dash-chatbot-app/model_serving_utils.py index 5b6beef0..684c2761 100644 --- a/dash-chatbot-app/model_serving_utils.py +++ b/dash-chatbot-app/model_serving_utils.py @@ -11,7 +11,7 @@ def _get_endpoint_task_type(endpoint_name: str) -> str: def is_endpoint_supported(endpoint_name: str) -> bool: """Check if the endpoint has a supported task type.""" task_type = _get_endpoint_task_type(endpoint_name) - supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat"] + supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat", "agent/v1/responses"] return task_type in supported_task_types def _validate_endpoint_task_type(endpoint_name: str) -> None: diff --git a/e2e-chatbot-app-next/client/src/components/message.tsx b/e2e-chatbot-app-next/client/src/components/message.tsx index 2a296d7f..198fa8ae 100644 --- a/e2e-chatbot-app-next/client/src/components/message.tsx +++ b/e2e-chatbot-app-next/client/src/components/message.tsx @@ -105,6 +105,11 @@ const PurePreviewMessage = ({ [message.parts], ); + const renderBlocks = React.useMemo( + () => groupConsecutiveToolSegments(partSegments), + [partSegments], + ); + // Check if message only contains non-OAuth errors (no other content) const hasOnlyErrors = React.useMemo(() => { const nonErrorParts = message.parts.filter( @@ -158,7 +163,22 @@ const PurePreviewMessage = ({ )} - {partSegments?.map((parts, index) => { + {renderBlocks.map((block) => { + if (block.kind === 'tool-group') { + return ( + + ); + } + + const parts = block.parts; + const index = block.index; const [part] = parts; const { type } = part; const key = `message-${message.id}-part-${index}`; @@ -223,119 +243,7 @@ const PurePreviewMessage = ({ } } - // Render Databricks tool calls and results - if (part.type === `dynamic-tool`) { - const { toolCallId, input, state, errorText, output, toolName } = - part; - - // Check if this is an MCP tool call by looking for approvalRequestId in metadata - // This works across all states (approval-requested, approval-denied, output-available) - const isMcpApproval = - part.callProviderMetadata?.databricks?.approvalRequestId != - null; - const mcpServerName = - part.callProviderMetadata?.databricks?.mcpServerName?.toString(); - - // Extract approval outcome for 'approval-responded' state - // When addToolApprovalResponse is called, AI SDK sets the `approval` property - // on the tool-call part and changes state to 'approval-responded' - const approved: boolean | undefined = - 'approval' in part ? part.approval?.approved : undefined; - - // When approved but only have approval status (not actual output), show as input-available - const effectiveState: ToolState = (() => { - if ( - part.providerExecuted && - !isLoading && - state === 'input-available' - ) { - return 'output-available'; - } - return state; - })(); - - // Render MCP tool calls with special styling - if (isMcpApproval) { - return ( - - - - - {state === 'approval-requested' && ( - - submitApproval({ - approvalRequestId: toolCallId, - approve: true, - }) - } - onDeny={() => - submitApproval({ - approvalRequestId: toolCallId, - approve: false, - }) - } - isSubmitting={ - isSubmitting && pendingApprovalId === toolCallId - } - /> - )} - {state === 'output-available' && output != null && ( - - Error: {errorText} - - ) : ( - - {typeof output === 'string' - ? output - : JSON.stringify(output, null, 2)} - - ) - } - errorText={undefined} - /> - )} - - - ); - } - - // Render regular tool calls - return ( - - - - - {state === 'output-available' && ( - - Error: {errorText} - - ) : ( - - {typeof output === 'string' - ? output - : JSON.stringify(output, null, 2)} - - ) - } - errorText={undefined} - /> - )} - - - ); - } + // dynamic-tool parts are rendered by MessageToolGroup above. // Support for citations/annotations if (type === 'source-url') { @@ -417,6 +325,182 @@ export const PreviewMessage = memo( }, ); +type ChatPart = ChatMessage['parts'][number]; +type ToolPart = Extract; + +type RenderBlock = + | { kind: 'segment'; parts: ChatPart[]; index: number } + | { kind: 'tool-group'; tools: ToolPart[]; startIndex: number }; + +const groupConsecutiveToolSegments = ( + partSegments: ChatPart[][], +): RenderBlock[] => { + const blocks: RenderBlock[] = []; + let i = 0; + while (i < partSegments.length) { + const segment = partSegments[i]; + const firstPart = segment[0]; + if (firstPart?.type === 'dynamic-tool') { + const startIndex = i; + const tools: ToolPart[] = [firstPart as ToolPart]; + i++; + while ( + i < partSegments.length && + partSegments[i][0]?.type === 'dynamic-tool' + ) { + tools.push(partSegments[i][0] as ToolPart); + i++; + } + blocks.push({ kind: 'tool-group', tools, startIndex }); + } else { + blocks.push({ kind: 'segment', parts: segment, index: i }); + i++; + } + } + return blocks; +}; + +const MessageToolGroup = ({ + tools, + isLoading, + submitApproval, + isSubmitting, + pendingApprovalId, +}: { + tools: ToolPart[]; + isLoading: boolean; + submitApproval: ReturnType['submitApproval']; + isSubmitting: boolean; + pendingApprovalId: string | null; +}) => { + const isMultiple = tools.length > 1; + return ( + + {tools.map((tool) => ( + + ))} + + ); +}; + +const ToolPartRenderer = ({ + part, + isLoading, + submitApproval, + isSubmitting, + pendingApprovalId, +}: { + part: ToolPart; + isLoading: boolean; + submitApproval: ReturnType['submitApproval']; + isSubmitting: boolean; + pendingApprovalId: string | null; +}) => { + const { toolCallId, input, state, errorText, output, toolName } = part; + + const isMcpApproval = + part.callProviderMetadata?.databricks?.approvalRequestId != null; + const mcpServerName = + part.callProviderMetadata?.databricks?.mcpServerName?.toString(); + + const approved: boolean | undefined = + 'approval' in part ? part.approval?.approved : undefined; + + const effectiveState: ToolState = (() => { + if (part.providerExecuted && !isLoading && state === 'input-available') { + return 'output-available'; + } + return state; + })(); + + if (isMcpApproval) { + return ( + + + + + {state === 'approval-requested' && ( + + submitApproval({ approvalRequestId: toolCallId, approve: true }) + } + onDeny={() => + submitApproval({ + approvalRequestId: toolCallId, + approve: false, + }) + } + isSubmitting={isSubmitting && pendingApprovalId === toolCallId} + /> + )} + {state === 'output-available' && output != null && ( + + Error: {errorText} + + ) : ( + + {typeof output === 'string' + ? output + : JSON.stringify(output, null, 2)} + + ) + } + errorText={undefined} + /> + )} + + + ); + } + + return ( + + + + + {state === 'output-available' && ( + + Error: {errorText} + + ) : ( + + {typeof output === 'string' + ? output + : JSON.stringify(output, null, 2)} + + ) + } + errorText={undefined} + /> + )} + + + ); +}; + export const AwaitingResponseMessage = () => { const role = 'assistant'; diff --git a/e2e-chatbot-app-next/packages/ai-sdk-providers/src/request-context.ts b/e2e-chatbot-app-next/packages/ai-sdk-providers/src/request-context.ts index 4f08882a..2ec7b939 100644 --- a/e2e-chatbot-app-next/packages/ai-sdk-providers/src/request-context.ts +++ b/e2e-chatbot-app-next/packages/ai-sdk-providers/src/request-context.ts @@ -7,7 +7,7 @@ * * Context is injected when: * 1. Using API_PROXY environment variable, OR - * 2. Endpoint task type is 'agent/v2/chat' or 'agent/v1/responses' + * 2. Endpoint task type is 'agent/v2/chat' * * @param endpointTask - The task type of the serving endpoint (optional) * @returns Whether to inject context into requests @@ -21,7 +21,5 @@ export function shouldInjectContextForEndpoint( return true; } - return ( - endpointTask === 'agent/v2/chat' || endpointTask === 'agent/v1/responses' - ); + return endpointTask === 'agent/v2/chat'; } diff --git a/e2e-chatbot-app-next/tests/ai-sdk-provider/request-context.test.ts b/e2e-chatbot-app-next/tests/ai-sdk-provider/request-context.test.ts index 9f8e549f..c7531ba6 100644 --- a/e2e-chatbot-app-next/tests/ai-sdk-provider/request-context.test.ts +++ b/e2e-chatbot-app-next/tests/ai-sdk-provider/request-context.test.ts @@ -26,10 +26,6 @@ test.describe("Request Context Utils", () => { expect(shouldInjectContextForEndpoint("agent/v2/chat")).toBe(true); }); - test("returns true for agent/v1/responses endpoint task", () => { - expect(shouldInjectContextForEndpoint("agent/v1/responses")).toBe(true); - }); - test("returns false for llm/v1/chat endpoint task", () => { expect(shouldInjectContextForEndpoint("llm/v1/chat")).toBe(false); }); diff --git a/e2e-chatbot-app-next/tests/api-mocking/api-mock-handlers.ts b/e2e-chatbot-app-next/tests/api-mocking/api-mock-handlers.ts index f318dbc8..fd899747 100644 --- a/e2e-chatbot-app-next/tests/api-mocking/api-mock-handlers.ts +++ b/e2e-chatbot-app-next/tests/api-mocking/api-mock-handlers.ts @@ -323,13 +323,13 @@ export const handlers = [ }), // Mock fetching endpoint details - // Returns agent/v1/responses to enable context injection testing + // Returns agent/v2/chat to enable context injection testing // Includes auth_policy to simulate an OBO-enabled endpoint http.get(/\/api\/2\.0\/serving-endpoints\/([^/]+)$/, ({ params }) => { const endpointName = (params as Record)[0] ?? ''; return HttpResponse.json({ name: endpointName || 'test-endpoint', - task: 'agent/v1/responses', + task: 'agent/v2/chat', auth_policy: { user_auth_policy: { api_scopes: ['serving.serving-endpoints'], diff --git a/e2e-chatbot-app-next/tests/routes/context-injection.test.ts b/e2e-chatbot-app-next/tests/routes/context-injection.test.ts index 57876f3a..bc98da1f 100644 --- a/e2e-chatbot-app-next/tests/routes/context-injection.test.ts +++ b/e2e-chatbot-app-next/tests/routes/context-injection.test.ts @@ -7,9 +7,9 @@ import { TEST_PROMPTS } from '../prompts/routes'; * * Context (conversation_id and user_id) should be injected when: * 1. API_PROXY environment variable is set, OR - * 2. Endpoint task type is 'agent/v2/chat' or 'agent/v1/responses' + * 2. Endpoint task type is 'agent/v2/chat' * - * The default mock returns 'agent/v1/responses', so context should be injected + * The default mock returns 'agent/v2/chat', so context should be injected * in all tests by default. */ @@ -31,7 +31,7 @@ test.describe await adaContext.request.post('/api/test/reset-captured-requests'); }); - test.describe('agent/v1/responses endpoints', () => { + test.describe('agent/v2/chat endpoints', () => { test('injects context with conversation_id and user_id', async ({ adaContext, }) => { diff --git a/gradio-chatbot-app/model_serving_utils.py b/gradio-chatbot-app/model_serving_utils.py index 5b6beef0..684c2761 100644 --- a/gradio-chatbot-app/model_serving_utils.py +++ b/gradio-chatbot-app/model_serving_utils.py @@ -11,7 +11,7 @@ def _get_endpoint_task_type(endpoint_name: str) -> str: def is_endpoint_supported(endpoint_name: str) -> bool: """Check if the endpoint has a supported task type.""" task_type = _get_endpoint_task_type(endpoint_name) - supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat"] + supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat", "agent/v1/responses"] return task_type in supported_task_types def _validate_endpoint_task_type(endpoint_name: str) -> None: diff --git a/shiny-chatbot-app/model_serving_utils.py b/shiny-chatbot-app/model_serving_utils.py index c8ccbe6b..0c5793b1 100644 --- a/shiny-chatbot-app/model_serving_utils.py +++ b/shiny-chatbot-app/model_serving_utils.py @@ -11,7 +11,7 @@ def _get_endpoint_task_type(endpoint_name: str) -> str: def is_endpoint_supported(endpoint_name: str) -> bool: """Check if the endpoint has a supported task type.""" task_type = _get_endpoint_task_type(endpoint_name) - supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat"] + supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat", "agent/v1/responses"] return task_type in supported_task_types def _validate_endpoint_task_type(endpoint_name: str) -> None: diff --git a/streamlit-chatbot-app/model_serving_utils.py b/streamlit-chatbot-app/model_serving_utils.py index 18dddf09..acf256b7 100644 --- a/streamlit-chatbot-app/model_serving_utils.py +++ b/streamlit-chatbot-app/model_serving_utils.py @@ -11,7 +11,7 @@ def _get_endpoint_task_type(endpoint_name: str) -> str: def is_endpoint_supported(endpoint_name: str) -> bool: """Check if the endpoint has a supported task type.""" task_type = _get_endpoint_task_type(endpoint_name) - supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat"] + supported_task_types = ["agent/v1/chat", "agent/v2/chat", "llm/v1/chat", "agent/v1/responses"] return task_type in supported_task_types def _validate_endpoint_task_type(endpoint_name: str) -> None: