diff --git a/.github/workflows/claude-live-test.yml b/.github/workflows/claude-live-test.yml index 7b6f3cad7..cb30491ef 100644 --- a/.github/workflows/claude-live-test.yml +++ b/.github/workflows/claude-live-test.yml @@ -216,6 +216,8 @@ jobs: - name: Run Claude with Playwright MCP uses: anthropics/claude-code-action@v1 + continue-on-error: true + id: claude_test env: TEST_TOKEN: ${{ steps.token.outputs.token }} DEBUG: "*" diff --git a/.github/workflows/components-build-deploy.yml b/.github/workflows/components-build-deploy.yml index 04baca254..3681bc3b1 100644 --- a/.github/workflows/components-build-deploy.yml +++ b/.github/workflows/components-build-deploy.yml @@ -265,7 +265,8 @@ jobs: run: | oc set env deployment/frontend -n ambient-code -c frontend \ GITHUB_APP_SLUG="ambient-code-stage" \ - VTEAM_VERSION="${{ github.sha }}" + VTEAM_VERSION="${{ github.sha }}" \ + FEEDBACK_URL="https://forms.gle/7XiWrvo6No922DUz6" - name: Update backend environment variables if: needs.detect-changes.outputs.backend == 'true' @@ -328,7 +329,8 @@ jobs: run: | oc set env deployment/frontend -n ambient-code -c frontend \ GITHUB_APP_SLUG="ambient-code-stage" \ - VTEAM_VERSION="${{ github.sha }}" + VTEAM_VERSION="${{ github.sha }}" \ + FEEDBACK_URL="https://forms.gle/7XiWrvo6No922DUz6" - name: Update backend environment variables run: | diff --git a/.github/workflows/prod-release-deploy.yaml b/.github/workflows/prod-release-deploy.yaml index e2e3235d6..694a755f4 100644 --- a/.github/workflows/prod-release-deploy.yaml +++ b/.github/workflows/prod-release-deploy.yaml @@ -265,7 +265,8 @@ jobs: run: | oc set env deployment/frontend -n ambient-code -c frontend \ GITHUB_APP_SLUG="ambient-code" \ - VTEAM_VERSION="${{ needs.release.outputs.new_tag }}" + VTEAM_VERSION="${{ needs.release.outputs.new_tag }}" \ + FEEDBACK_URL="https://forms.gle/7XiWrvo6No922DUz6" - name: Update backend environment variables run: | diff --git a/components/backend/handlers/content.go b/components/backend/handlers/content.go index e0cd4aad5..ea7cd6cfb 100644 --- a/components/backend/handlers/content.go +++ b/components/backend/handlers/content.go @@ -663,15 +663,20 @@ func ContentWorkflowMetadata(c *gin.Context) { log.Printf("ContentWorkflowMetadata: agents directory not found or unreadable: %v", err) } + configResponse := gin.H{ + "name": ambientConfig.Name, + "description": ambientConfig.Description, + "systemPrompt": ambientConfig.SystemPrompt, + "artifactsDir": ambientConfig.ArtifactsDir, + } + if ambientConfig.Rubric != nil { + configResponse["rubric"] = ambientConfig.Rubric + } + c.JSON(http.StatusOK, gin.H{ "commands": commands, "agents": agents, - "config": gin.H{ - "name": ambientConfig.Name, - "description": ambientConfig.Description, - "systemPrompt": ambientConfig.SystemPrompt, - "artifactsDir": ambientConfig.ArtifactsDir, - }, + "config": configResponse, }) } @@ -713,12 +718,21 @@ func parseFrontmatter(filePath string) map[string]string { return result } +// RubricConfig represents the rubric evaluation configuration in ambient.json. +// Schema is a JSON Schema object that defines the tool's input_schema for +// additional metadata fields beyond final_score and reasoning. +type RubricConfig struct { + ActivationPrompt string `json:"activationPrompt,omitempty"` + Schema map[string]interface{} `json:"schema,omitempty"` +} + // AmbientConfig represents the ambient.json configuration type AmbientConfig struct { - Name string `json:"name"` - Description string `json:"description"` - SystemPrompt string `json:"systemPrompt"` - ArtifactsDir string `json:"artifactsDir"` + Name string `json:"name"` + Description string `json:"description"` + SystemPrompt string `json:"systemPrompt"` + ArtifactsDir string `json:"artifactsDir"` + Rubric *RubricConfig `json:"rubric,omitempty"` } // parseAmbientConfig reads and parses ambient.json from workflow directory diff --git a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/accordions/mcp-integrations-accordion.tsx b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/accordions/mcp-integrations-accordion.tsx index d6da85ae8..604cf2b5a 100644 --- a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/accordions/mcp-integrations-accordion.tsx +++ b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/accordions/mcp-integrations-accordion.tsx @@ -1,9 +1,8 @@ 'use client' import { useState, useEffect } from 'react' -import type { ReactNode } from 'react' import Link from 'next/link' -import { Plug, CheckCircle2, XCircle, AlertCircle, AlertTriangle } from 'lucide-react' +import { Plug, Link2, CheckCircle2, XCircle, AlertCircle, AlertTriangle, Info, Check, X } from 'lucide-react' import { AccordionItem, AccordionTrigger, @@ -16,87 +15,89 @@ import { TooltipProvider, TooltipTrigger, } from '@/components/ui/tooltip' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/components/ui/popover' import { Skeleton } from '@/components/ui/skeleton' import { useMcpStatus } from '@/services/queries/use-mcp' -import { useProjectIntegrationStatus } from '@/services/queries/use-projects' import { useIntegrationsStatus } from '@/services/queries/use-integrations' -import type { McpServer } from '@/services/api/sessions' +import type { McpServer, McpTool } from '@/services/api/sessions' -type McpIntegrationsAccordionProps = { +// ─── MCP Servers Accordion ─────────────────────────────────────────────────── + +type McpServersAccordionProps = { projectName: string sessionName: string + sessionPhase?: string } -export function McpIntegrationsAccordion({ +export function McpServersAccordion({ projectName, sessionName, -}: McpIntegrationsAccordionProps) { + sessionPhase, +}: McpServersAccordionProps) { const [placeholderTimedOut, setPlaceholderTimedOut] = useState(false) - // Fetch real MCP status from runner - const { data: mcpStatus, isPending: mcpPending } = useMcpStatus(projectName, sessionName) + // Only fetch MCP status once the session is actually running (runner pod ready) + const isRunning = sessionPhase === 'Running' + const { data: mcpStatus, isPending: mcpPending } = useMcpStatus(projectName, sessionName, isRunning) const mcpServers = mcpStatus?.servers || [] - const { data: integrationStatus, isPending: integrationStatusPending } = - useProjectIntegrationStatus(projectName) - const githubConfigured = integrationStatus?.github ?? false - - const { data: integrationsStatus } = useIntegrationsStatus() - const gitlabConfigured = integrationsStatus?.gitlab?.connected ?? false - - // Show skeleton cards until we have MCP servers or 2 min elapsed (backend returns empty when runner not ready) const showPlaceholders = - mcpPending || (mcpServers.length === 0 && !placeholderTimedOut) + !isRunning || mcpPending || (mcpServers.length === 0 && !placeholderTimedOut) useEffect(() => { if (mcpServers.length > 0) { setPlaceholderTimedOut(false) return } - if (!mcpStatus) return - const t = setTimeout(() => setPlaceholderTimedOut(true), 15 * 1000) // 15 seconds + if (!isRunning || !mcpStatus) return + const t = setTimeout(() => setPlaceholderTimedOut(true), 15 * 1000) return () => clearTimeout(t) - }, [mcpStatus, mcpServers.length]) + }, [mcpStatus, mcpServers.length, isRunning]) - // Collect all MCP servers - const allServers = [...mcpServers] - - // Ensure core integrations always appear (even if not in API response) - if (!showPlaceholders) { - // Webfetch - always available - const hasWebfetch = allServers.some((s) => s.name === 'webfetch') - if (!hasWebfetch) { - allServers.push({ - name: 'webfetch', - displayName: 'Webfetch', - status: 'disconnected', - authenticated: undefined, - authMessage: 'Fetches web content for the session.', - } as McpServer) - } - - // Google Workspace - show as not configured if missing - const hasGoogleWorkspace = allServers.some((s) => s.name === 'google-workspace') - if (!hasGoogleWorkspace) { - allServers.push({ - name: 'google-workspace', - displayName: 'Google Workspace', - status: 'disconnected', - authenticated: false, - authMessage: undefined, - } as McpServer) + const getStatusIcon = (server: McpServer) => { + switch (server.status) { + case 'configured': + case 'connected': + return + case 'error': + return + case 'disconnected': + default: + return } + } - // Jira - workspace-level integration - const hasJira = allServers.some((s) => s.name === 'mcp-atlassian') - if (!hasJira) { - allServers.push({ - name: 'mcp-atlassian', - displayName: 'Jira', - status: 'disconnected', - authenticated: false, - authMessage: undefined, - } as McpServer) + const getStatusBadge = (server: McpServer) => { + switch (server.status) { + case 'configured': + return ( + + Configured + + ) + case 'connected': + return ( + + Connected + + ) + case 'error': + return ( + + Error + + ) + case 'disconnected': + default: + return ( + + Disconnected + + ) } } @@ -115,65 +116,188 @@ export function McpIntegrationsAccordion({ ) - const renderGitHubCard = () => - integrationStatusPending ? ( - renderCardSkeleton() - ) : ( -
( + -
-
-
- {githubConfigured ? ( - - ) : ( - - - - - - - - -

not configured

-
-
-
+ {value ? : } + {key} + + ) + + const renderToolRow = (tool: McpTool) => { + const annotations = Object.entries(tool.annotations).filter( + ([, v]) => typeof v === 'boolean' + ) + return ( +
+ {tool.name} + {annotations.length > 0 && ( +
+ {annotations.map(([k, v]) => renderAnnotationBadge(k, v as boolean))} +
+ )} +
+ ) + } + + const renderServerCard = (server: McpServer) => { + const tools = server.tools ?? [] + const toolCount = tools.length + + return ( +
+
+
+
+ {getStatusIcon(server)} +
+

{server.displayName}

+
+
+ {server.version && ( + v{server.version} + )} + {toolCount > 0 && ( + + + + + +
+

+ {server.displayName} — {toolCount} {toolCount === 1 ? 'tool' : 'tools'} +

+
+
+ {tools.map((tool) => renderToolRow(tool))} +
+
+
)}
-

GitHub

-

- {githubConfigured ? ( - 'MCP access to GitHub repositories.' - ) : ( +

+ {getStatusBadge(server)} +
+
+ ) + } + + return ( + + +
+ + MCP Servers + {!showPlaceholders && mcpServers.length > 0 && ( + + {mcpServers.length} + + )} +
+
+ +
+ {showPlaceholders ? ( <> - Session started without GitHub MCP. Configure{' '} - - Integrations - {' '} - and start a new session. + {renderCardSkeleton()} + {renderCardSkeleton()} + ) : mcpServers.length > 0 ? ( + mcpServers.map((server) => renderServerCard(server)) + ) : ( +

+ No MCP servers available for this session. +

)} -

+
+
+
+ ) +} + +// ─── Integrations Accordion ────────────────────────────────────────────────── + +export function IntegrationsAccordion() { + const { data: integrationsStatus, isPending } = useIntegrationsStatus() + + const githubConfigured = integrationsStatus?.github?.active != null + const gitlabConfigured = integrationsStatus?.gitlab?.connected ?? false + const jiraConfigured = integrationsStatus?.jira?.connected ?? false + const googleConfigured = integrationsStatus?.google?.connected ?? false + + const integrations = [ + { + key: 'github', + name: 'GitHub', + configured: githubConfigured, + configuredMessage: 'Authenticated. Git push and repository access enabled.', + }, + { + key: 'gitlab', + name: 'GitLab', + configured: gitlabConfigured, + configuredMessage: 'Authenticated. Git push and repository access enabled.', + }, + { + key: 'google', + name: 'Google Workspace', + configured: googleConfigured, + configuredMessage: 'Authenticated. Drive, Calendar, and Gmail access enabled.', + }, + { + key: 'jira', + name: 'Jira', + configured: jiraConfigured, + configuredMessage: 'Authenticated. Issue and project access enabled.', + }, + ].sort((a, b) => a.name.localeCompare(b.name)) + + const configuredCount = integrations.filter((i) => i.configured).length + + const renderCardSkeleton = () => ( +
+
+
+ + +
+
- ) + ) - const renderGitLabCard = () => - integrationStatusPending ? ( - renderCardSkeleton() - ) : ( + const renderIntegrationCard = (integration: (typeof integrations)[number]) => (
- {gitlabConfigured ? ( + {integration.configured ? ( ) : ( @@ -184,244 +308,77 @@ export function McpIntegrationsAccordion({ -

not configured

+

Not configured

)}
-

GitLab

+

{integration.name}

- {gitlabConfigured ? ( - 'MCP access to GitLab repositories.' + {integration.configured ? ( + integration.configuredMessage ) : ( <> - Session started without GitLab MCP. Configure{' '} + Not connected.{' '} - Integrations + Set up {' '} - and start a new session. + to enable {integration.name} access. )}

- ) - - const renderServerCard = (server: McpServer) => ( -
-
-
-
- {server.authenticated === false ? ( - - - - {getStatusIcon(server)} - - -

not configured

-
-
-
- ) : ( - getStatusIcon(server) - )} -
-

{getDisplayName(server)}

-{server.name === 'mcp-atlassian' && server.authenticated === true && ( - - read only - - )} -
- {getDescription(server) && ( -

- {getDescription(server)} -

- )} -
-
- {getRightContent(server)} -
-
) - const getDisplayName = (server: McpServer) => - server.name === 'mcp-atlassian' ? 'Jira' : server.displayName - - const getDescription = (server: McpServer): ReactNode => { - if (server.name === 'webfetch') return 'Fetches web content for the session.' - if (server.name === 'mcp-atlassian') { - if (server.authenticated === false) { - return ( - <> - Session started without Jira MCP. Configure{' '} - - Integrations - {' '} - and start a new session. - - ) - } - return 'MCP access to Jira issues and projects.' - } - if (server.name === 'google-workspace') { - if (server.authenticated === false) { - return ( - <> - Session started without Google Workspace MCP. Configure{' '} - - Integrations - {' '} - and start a new session. - - ) - } - return 'MCP access to Google Drive files.' - } - return server.authMessage ?? null - } - - const getStatusIcon = (server: McpServer) => { - // If we have auth info, use that for the icon - if (server.authenticated !== undefined) { - if (server.authenticated === true) { - return - } else if (server.authenticated === null) { - // Null = needs refresh/uncertain state - return - } else { - // False = not authenticated/not configured - return - } - } - - // Fall back to status-based icons - switch (server.status) { - case 'configured': - case 'connected': - return - case 'error': - return - case 'disconnected': - default: - return - } - } - - const getRightContent = (server: McpServer) => { - // Webfetch: no badge - if (server.name === 'webfetch') return null - - // Jira not authenticated: no link (description explains to configure and start new session) - - // Google Workspace not authenticated: no link (description explains to configure and start new session) - - // Jira connected: no badge - if (server.name === 'mcp-atlassian' && server.authenticated === true) return null - - // Authenticated: show badge (with optional tooltip) - if (server.authenticated === true) { - const badge = ( - - - Authenticated - - ) - if (server.authMessage) { - return ( - - - {badge} - -

{server.authMessage}

-
-
-
- ) - } - return badge - } - - // Other servers with auth status but not authenticated: no badge (only Jira/Google get links above) - if (server.authenticated === false) return null - - // Fall back to status-based badges (for servers without auth info; webfetch already returns null) - switch (server.status) { - case 'configured': - return ( - - Configured - - ) - case 'connected': - return ( - - Connected - - ) - case 'error': - return ( - - Error - - ) - case 'disconnected': - default: - return ( - - Disconnected - - ) - } - } - - // Combine all integrations (GitHub + GitLab + all MCP servers) - type IntegrationItem = - | { type: 'github'; displayName: string } - | { type: 'gitlab'; displayName: string } - | { type: 'server'; displayName: string; server: McpServer } - const allIntegrations: IntegrationItem[] = [ - { type: 'github' as const, displayName: 'GitHub' }, - { type: 'gitlab' as const, displayName: 'GitLab' }, - ...allServers.map((server) => ({ type: 'server' as const, displayName: getDisplayName(server), server })), - ].sort((a, b) => a.displayName.localeCompare(b.displayName)) - return ( - <> - +
- + Integrations + {!isPending && ( + + {configuredCount}/{integrations.length} + + )}
- {showPlaceholders ? ( + {isPending ? ( <> {renderCardSkeleton()} {renderCardSkeleton()} + {renderCardSkeleton()} ) : ( - allIntegrations.map((item) => { - if (item.type === 'github') { - return
{renderGitHubCard()}
- } else if (item.type === 'gitlab') { - return
{renderGitLabCard()}
- } else { - return renderServerCard(item.server) - } - }) + integrations.map((integration) => renderIntegrationCard(integration)) )}
+ ) +} + +// ─── Legacy export (renders both) ──────────────────────────────────────────── + +type McpIntegrationsAccordionProps = { + projectName: string + sessionName: string +} + +/** @deprecated Use McpServersAccordion + IntegrationsAccordion separately */ +export function McpIntegrationsAccordion({ + projectName, + sessionName, +}: McpIntegrationsAccordionProps) { + return ( + <> + + ) } diff --git a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx index 169ec1697..b3fc6dcf5 100644 --- a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx +++ b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx @@ -60,7 +60,7 @@ import { ManageRemoteDialog } from "./components/modals/manage-remote-dialog"; import { WorkflowsAccordion } from "./components/accordions/workflows-accordion"; import { RepositoriesAccordion } from "./components/accordions/repositories-accordion"; import { ArtifactsAccordion } from "./components/accordions/artifacts-accordion"; -import { McpIntegrationsAccordion } from "./components/accordions/mcp-integrations-accordion"; +import { McpServersAccordion, IntegrationsAccordion } from "./components/accordions/mcp-integrations-accordion"; import { WelcomeExperience } from "./components/welcome-experience"; // Extracted hooks and utilities import { useGitOperations } from "./hooks/use-git-operations"; @@ -92,7 +92,7 @@ import { useOOTBWorkflows, useWorkflowMetadata, } from "@/services/queries/use-workflows"; -import { useProjectIntegrationStatus } from "@/services/queries/use-projects"; +import { useIntegrationsStatus } from "@/services/queries/use-integrations"; import { useMutation } from "@tanstack/react-query"; import { FeedbackProvider } from "@/contexts/FeedbackContext"; @@ -187,8 +187,8 @@ export default function ProjectSessionDetailPage({ const continueMutation = useContinueSession(); // Check integration status - const { data: integrationStatus } = useProjectIntegrationStatus(projectName); - const githubConfigured = integrationStatus?.github ?? false; + const { data: integrationsStatus } = useIntegrationsStatus(); + const githubConfigured = integrationsStatus?.github?.active != null; // Get current user for feedback context const { data: currentUser } = useCurrentUser(); @@ -1636,11 +1636,14 @@ export default function ProjectSessionDetailPage({ onNavigateBack={artifactsOps.navigateBack} /> - + + {/* File Explorer */} [s.key, s.value])) - : {}; - const atlassianConfigured = - !!(byKey.JIRA_URL?.trim() && byKey.JIRA_PROJECT?.trim() && byKey.JIRA_EMAIL?.trim() && byKey.JIRA_API_TOKEN?.trim()); + const atlassianConfigured = integrationsStatus?.jira?.connected ?? false; + const googleConfigured = integrationsStatus?.google?.connected ?? false; const form = useForm({ resolver: zodResolver(formSchema), @@ -201,24 +193,21 @@ export function CreateSessionDialog({ )} /> - {/* Integration status (same visual style as integrations accordion), alphabetical: Jira, GitHub, Google Workspace */} + {/* Integration auth status */}
Integrations - {/* Jira card */} - {atlassianConfigured ? ( + {/* GitHub card */} + {githubConfigured ? (
-

Jira

- - read only - +

GitHub

- MCP access to Jira issues and projects. + Authenticated. Git push and repository access enabled.

@@ -228,32 +217,29 @@ export function CreateSessionDialog({
-

Jira

+

GitHub

- Configure{" "} - - Integrations + Not connected.{" "} + + Set up {" "} - to access Jira MCP in this session. + to enable repository access.

)} - {/* GitHub card */} - {githubConfigured ? ( + {/* GitLab card */} + {gitlabConfigured ? (
-

GitHub

+

GitLab

- MCP access to GitHub repositories. + Authenticated. Git push and repository access enabled.

@@ -263,29 +249,29 @@ export function CreateSessionDialog({
-

GitHub

+

GitLab

- Configure{" "} + Not connected.{" "} - Integrations + Set up {" "} - to access GitHub MCP in this session. + to enable repository access.

)} - {/* GitLab card */} - {gitlabConfigured ? ( + {/* Google Workspace card */} + {googleConfigured ? (
-

GitLab

+

Google Workspace

- MCP access to GitLab repositories. + Authenticated. Drive, Calendar, and Gmail access enabled.

@@ -295,33 +281,52 @@ export function CreateSessionDialog({
-

GitLab

+

Google Workspace

- Configure{" "} + Not connected.{" "} - Integrations + Set up {" "} - to access GitLab MCP in this session. + to enable Drive, Calendar, and Gmail access.

)} - {/* Google Workspace card */} -
-
- + {/* Jira card */} + {atlassianConfigured ? ( +
+
+
+
+ +
+

Jira

+
+

+ Authenticated. Issue and project access enabled. +

+
-
-

Google Workspace

-

- Configure{" "} - - Integrations - {" "} - to access Google Workspace MCP in this session. -

+ ) : ( +
+
+ +
+
+

Jira

+

+ Not connected.{" "} + + Set up + {" "} + to enable issue and project access. +

+
-
+ )}
diff --git a/components/frontend/src/services/api/integrations.ts b/components/frontend/src/services/api/integrations.ts index 978b07cbc..58bd24317 100644 --- a/components/frontend/src/services/api/integrations.ts +++ b/components/frontend/src/services/api/integrations.ts @@ -10,6 +10,7 @@ export type IntegrationsStatus = { pat: { configured: boolean updatedAt?: string + valid?: boolean } active?: 'app' | 'pat' } diff --git a/components/frontend/src/services/api/sessions.ts b/components/frontend/src/services/api/sessions.ts index 2da194054..bd55d9d28 100644 --- a/components/frontend/src/services/api/sessions.ts +++ b/components/frontend/src/services/api/sessions.ts @@ -17,14 +17,25 @@ import type { PaginationParams, } from '@/types/api'; +export type McpToolAnnotations = { + readOnly?: boolean; + destructive?: boolean; + idempotent?: boolean; + openWorld?: boolean; + [key: string]: boolean | undefined; +}; + +export type McpTool = { + name: string; + annotations: McpToolAnnotations; +}; + export type McpServer = { name: string; displayName: string; - status: 'configured' | 'connected' | 'disconnected' | 'error'; - authenticated?: boolean | null; // true = valid, false = invalid, null = needs refresh/uncertain, undefined = not checked - authMessage?: string; - source?: string; - command?: string; + status: string; + version?: string; + tools?: McpTool[]; }; export type McpStatusResponse = { diff --git a/components/frontend/tsconfig.json b/components/frontend/tsconfig.json new file mode 100644 index 000000000..c1bd93ee4 --- /dev/null +++ b/components/frontend/tsconfig.json @@ -0,0 +1,30 @@ +{ + "compilerOptions": { + "target": "ES2017", + "lib": ["dom", "dom.iterable", "esnext"], + "allowJs": true, + "skipLibCheck": true, + "strict": true, + "noImplicitAny": true, + "strictNullChecks": true, + "noEmit": true, + "esModuleInterop": true, + "module": "esnext", + "moduleResolution": "bundler", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve", + "incremental": true, + "plugins": [ + { + "name": "next" + } + ], + "baseUrl": ".", + "paths": { + "@/*": ["./src/*"] + } + }, + "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], + "exclude": ["node_modules"] +} diff --git a/components/runners/claude-code-runner/adapter.py b/components/runners/claude-code-runner/adapter.py index 3502abf0f..dbd7414ea 100644 --- a/components/runners/claude-code-runner/adapter.py +++ b/components/runners/claude-code-runner/adapter.py @@ -2,49 +2,59 @@ """ Claude Code Adapter for AG-UI Server. -Refactored from wrapper.py to use async generators that yield AG-UI events -instead of WebSocket messaging. This is the core adapter that wraps the -Claude Code SDK and produces a stream of AG-UI protocol events. +Core adapter that wraps the Claude Code SDK and produces a stream of +AG-UI protocol events. Business logic is delegated to focused modules: + +- ``auth`` — credential fetching and authentication setup +- ``config`` — ambient.json, MCP, and repos configuration +- ``workspace`` — path setup, validation, prerequisites +- ``prompts`` — system prompt construction and constants +- ``tools`` — MCP tool definitions (session, rubric) +- ``utils`` — general utilities (redaction, URL parsing, subprocesses) """ -import asyncio import json as _json import logging import os -import re -import shutil -import sys import uuid -from datetime import datetime, timezone from pathlib import Path from typing import Any, AsyncIterator, Optional -from urllib import error as _urllib_error -from urllib import request as _urllib_request -from urllib.parse import urlparse, urlunparse # Set umask to make files readable by content service container os.umask(0o022) # AG-UI Protocol Events -from ag_ui.core import (BaseEvent, EventType, RawEvent, RunAgentInput, - RunErrorEvent, RunFinishedEvent, RunStartedEvent, - StateDeltaEvent, StateSnapshotEvent, StepFinishedEvent, - StepStartedEvent, TextMessageContentEvent, - TextMessageEndEvent, TextMessageStartEvent, - ToolCallArgsEvent, ToolCallEndEvent, - ToolCallStartEvent) - +from ag_ui.core import ( + BaseEvent, + EventType, + RawEvent, + RunAgentInput, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + StateDeltaEvent, + StepFinishedEvent, + StepStartedEvent, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallStartEvent, +) + +import auth +import config as runner_config +import prompts +import workspace from context import RunnerContext +from tools import create_restart_session_tool, create_rubric_mcp_tool, load_rubric_content +from utils import redact_secrets, run_cmd, url_with_token, parse_owner_repo +from workspace import PrerequisiteError logger = logging.getLogger(__name__) -class PrerequisiteError(RuntimeError): - """Raised when slash-command prerequisites are missing.""" - - pass - - class ClaudeCodeAdapter: """ Adapter that wraps the Claude Code SDK for AG-UI server. @@ -61,8 +71,6 @@ def __init__(self): self._turn_count = 0 # AG-UI streaming state (per-run, not instance state) - # NOTE: _current_message_id and _current_tool_id are now local variables - # in _run_claude_agent_sdk to avoid race conditions with concurrent runs self._current_run_id: Optional[str] = None self._current_thread_id: Optional[str] = None @@ -72,22 +80,19 @@ def __init__(self): async def initialize(self, context: RunnerContext): """Initialize the adapter with context.""" self.context = context - logger.info(f"Initialized Claude Code adapter for session {context.session_id}") + logger.info( + f"Initialized Claude Code adapter for session {context.session_id}" + ) - # NOTE: Credentials are now fetched at runtime from backend API - # No longer copying from mounted volumes or reading from env vars - # This ensures tokens are always fresh for long-running sessions + # Credentials are fetched on-demand from backend API logger.info("Credentials will be fetched on-demand from backend API") # Workspace is already prepared by init container (hydrate.sh) - # - Repos cloned to /workspace/repos/ - # - Workflows cloned to /workspace/workflows/ - # - State hydrated from S3 to .claude/, artifacts/, file-uploads/ logger.info("Workspace prepared by init container, validating...") - # Validate prerequisite files exist for phase-based commands + # Validate prerequisite files for phase-based commands try: - await self._validate_prerequisites() + await workspace.validate_prerequisites(self.context) except PrerequisiteError as exc: self.last_exit_code = 2 logger.error( @@ -95,22 +100,12 @@ async def initialize(self, context: RunnerContext): ) raise - def _timestamp(self) -> str: - """Return current UTC timestamp in ISO format.""" - return datetime.now(timezone.utc).isoformat() - - async def process_run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEvent]: - """ - Process a run and yield AG-UI events. + async def process_run( + self, input_data: RunAgentInput + ) -> AsyncIterator[BaseEvent]: + """Process a run and yield AG-UI events. This is the main entry point called by the FastAPI server. - - Args: - input_data: RunAgentInput with thread_id, run_id, messages, tools - app_state: Optional FastAPI app.state for persistent client storage/reuse - - Yields: - AG-UI events (RunStartedEvent, TextMessageContentEvent, etc.) """ thread_id = input_data.thread_id or self.context.session_id run_id = input_data.run_id or str(uuid.uuid4()) @@ -118,9 +113,6 @@ async def process_run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEven self._current_thread_id = thread_id self._current_run_id = run_id - # NOTE: Credentials are now fetched on-demand at runtime, no need to pre-fetch - # Each tool call will get fresh credentials from the backend API - try: # Emit RUN_STARTED yield RunStartedEvent( @@ -134,7 +126,9 @@ async def process_run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEven msg_dict = ( msg if isinstance(msg, dict) - else (msg.model_dump() if hasattr(msg, "model_dump") else {}) + else ( + msg.model_dump() if hasattr(msg, "model_dump") else {} + ) ) role = msg_dict.get("role", "") @@ -143,18 +137,14 @@ async def process_run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEven content = msg_dict.get("content", "") msg_metadata = msg_dict.get("metadata", {}) - # Check if message should be hidden from UI - is_hidden = isinstance(msg_metadata, dict) and msg_metadata.get( - "hidden", False - ) + is_hidden = isinstance( + msg_metadata, dict + ) and msg_metadata.get("hidden", False) if is_hidden: logger.info( - f"Message {msg_id[:8]} marked as hidden (auto-sent initial/workflow prompt)" + f"Message {msg_id[:8]} marked as hidden " + "(auto-sent initial/workflow prompt)" ) - - # Emit user message as TEXT_MESSAGE events - # Include metadata in RAW event for frontend filtering - if is_hidden: yield RawEvent( type=EventType.RAW, thread_id=thread_id, @@ -193,11 +183,13 @@ async def process_run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEven # Extract user message from input logger.info( - f"Extracting user message from {len(input_data.messages)} messages" + f"Extracting user message from " + f"{len(input_data.messages)} messages" ) user_message = self._extract_user_message(input_data) logger.info( - f"Extracted user message: '{user_message[:100] if user_message else '(empty)'}...'" + f"Extracted user message: " + f"'{user_message[:100] if user_message else '(empty)'}...'" ) if not user_message: @@ -206,7 +198,10 @@ async def process_run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEven type=EventType.RAW, thread_id=thread_id, run_id=run_id, - event={"type": "system_log", "message": "No user message provided"}, + event={ + "type": "system_log", + "message": "No user message provided", + }, ) yield RunFinishedEvent( type=EventType.RUN_FINISHED, @@ -216,7 +211,9 @@ async def process_run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEven return # Run Claude SDK and yield events - logger.info(f"Starting Claude SDK with prompt: '{user_message[:50]}...'") + logger.info( + f"Starting Claude SDK with prompt: '{user_message[:50]}...'" + ) async for event in self._run_claude_agent_sdk( user_message, thread_id, run_id ): @@ -255,170 +252,173 @@ def _extract_user_message(self, input_data: RunAgentInput) -> str: """Extract user message text from RunAgentInput.""" messages = input_data.messages or [] logger.info( - f"Extracting from {len(messages)} messages, types: {[type(m).__name__ for m in messages]}" + f"Extracting from {len(messages)} messages, " + f"types: {[type(m).__name__ for m in messages]}" ) - # Find the last user message for msg in reversed(messages): - logger.debug( - f"Checking message: type={type(msg).__name__}, hasattr(role)={hasattr(msg, 'role')}" - ) - if hasattr(msg, "role") and msg.role == "user": - # Handle different content formats content = getattr(msg, "content", "") if isinstance(content, str): - logger.info( - f"Found user message (object format): '{content[:50]}...'" - ) return content elif isinstance(content, list): - # Content blocks format for block in content: if hasattr(block, "text"): return block.text elif isinstance(block, dict) and "text" in block: return block["text"] elif isinstance(msg, dict): - logger.debug( - f"Dict message: role={msg.get('role')}, content={msg.get('content', '')[:30]}..." - ) if msg.get("role") == "user": content = msg.get("content", "") if isinstance(content, str): - logger.info( - f"Found user message (dict format): '{content[:50]}...'" - ) return content logger.warning("No user message found!") return "" + # ------------------------------------------------------------------ + # SDK orchestration + # ------------------------------------------------------------------ + async def _run_claude_agent_sdk( self, prompt: str, thread_id: str, run_id: str ) -> AsyncIterator[BaseEvent]: - """Execute the Claude Code SDK with the given prompt and yield AG-UI events. - - Creates a fresh client for each run - simpler and more reliable than client reuse. - - Args: - prompt: The user prompt to send to Claude - thread_id: AG-UI thread identifier - run_id: AG-UI run identifier - """ - # Per-run state - NOT instance variables to avoid race conditions with concurrent runs + """Execute the Claude Code SDK with the given prompt and yield AG-UI events.""" current_message_id: Optional[str] = None logger.info( - f"_run_claude_agent_sdk called with prompt length={len(prompt)}, will create fresh client" + f"_run_claude_agent_sdk called with prompt length={len(prompt)}, " + "will create fresh client" ) try: - # NOTE: Credentials are now fetched at runtime via _populate_runtime_credentials() - # No need for manual refresh - backend API always returns fresh tokens - - # Check for authentication method + # --- Authentication --- logger.info("Checking authentication configuration...") api_key = self.context.get_env("ANTHROPIC_API_KEY", "") use_vertex = ( - self.context.get_env("CLAUDE_CODE_USE_VERTEX", "").strip() == "1" + self.context.get_env("CLAUDE_CODE_USE_VERTEX", "").strip() + == "1" ) logger.info( - f"Auth config: api_key={'set' if api_key else 'not set'}, use_vertex={use_vertex}" + f"Auth config: api_key={'set' if api_key else 'not set'}, " + f"use_vertex={use_vertex}" ) if not api_key and not use_vertex: raise RuntimeError( - "Either ANTHROPIC_API_KEY or CLAUDE_CODE_USE_VERTEX=1 must be set" + "Either ANTHROPIC_API_KEY or CLAUDE_CODE_USE_VERTEX=1 " + "must be set" ) - # Set environment variables BEFORE importing SDK if api_key: os.environ["ANTHROPIC_API_KEY"] = api_key logger.info("Using Anthropic API key authentication") - # Configure Vertex AI if requested if use_vertex: - vertex_credentials = await self._setup_vertex_credentials() + vertex_credentials = await auth.setup_vertex_credentials( + self.context + ) if "ANTHROPIC_API_KEY" in os.environ: - logger.info("Clearing ANTHROPIC_API_KEY to force Vertex AI mode") + logger.info( + "Clearing ANTHROPIC_API_KEY to force Vertex AI mode" + ) del os.environ["ANTHROPIC_API_KEY"] os.environ["CLAUDE_CODE_USE_VERTEX"] = "1" - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = vertex_credentials.get( - "credentials_path", "" + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ( + vertex_credentials.get("credentials_path", "") + ) + os.environ["ANTHROPIC_VERTEX_PROJECT_ID"] = ( + vertex_credentials.get("project_id", "") ) - os.environ["ANTHROPIC_VERTEX_PROJECT_ID"] = vertex_credentials.get( - "project_id", "" + os.environ["CLOUD_ML_REGION"] = vertex_credentials.get( + "region", "" ) - os.environ["CLOUD_ML_REGION"] = vertex_credentials.get("region", "") - - # NOW we can safely import the SDK - from claude_agent_sdk import (AssistantMessage, ClaudeAgentOptions, - ClaudeSDKClient, ResultMessage, - SystemMessage, TextBlock, - ThinkingBlock, ToolResultBlock, - ToolUseBlock, UserMessage, - create_sdk_mcp_server) + + # --- SDK imports (after env vars are set) --- + from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ClaudeSDKClient, + ResultMessage, + SystemMessage, + TextBlock, + ThinkingBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, + create_sdk_mcp_server, + ) from claude_agent_sdk import tool as sdk_tool from claude_agent_sdk.types import StreamEvent from observability import ObservabilityManager - # Extract and sanitize user context for observability + # --- Observability --- raw_user_id = os.getenv("USER_ID", "").strip() raw_user_name = os.getenv("USER_NAME", "").strip() - user_id, user_name = self._sanitize_user_context(raw_user_id, raw_user_name) + user_id, user_name = auth.sanitize_user_context( + raw_user_id, raw_user_name + ) - # Get model configuration model = self.context.get_env("LLM_MODEL") configured_model = model or "claude-sonnet-4-5@20250929" if use_vertex and model: - configured_model = self._map_to_vertex_model(model) + configured_model = auth.map_to_vertex_model(model) - # Initialize observability obs = ObservabilityManager( - session_id=self.context.session_id, user_id=user_id, user_name=user_name + session_id=self.context.session_id, + user_id=user_id, + user_name=user_name, ) await obs.initialize( prompt=prompt, - namespace=self.context.get_env("AGENTIC_SESSION_NAMESPACE", "unknown"), + namespace=self.context.get_env( + "AGENTIC_SESSION_NAMESPACE", "unknown" + ), model=configured_model, ) obs._pending_initial_prompt = prompt - # Check if this is a resume session via IS_RESUME env var - # This is set by the operator when restarting a stopped/completed/failed session + # --- Workspace paths --- is_continuation = ( - self.context.get_env("IS_RESUME", "").strip().lower() == "true" + self.context.get_env("IS_RESUME", "").strip().lower() + == "true" ) if is_continuation: logger.info("IS_RESUME=true - treating as continuation") - # Determine cwd and additional dirs - repos_cfg = self._get_repos_config() + repos_cfg = runner_config.get_repos_config() cwd_path = self.context.workspace_path add_dirs = [] derived_name = None - # Check for active workflow first - active_workflow_url = (os.getenv("ACTIVE_WORKFLOW_GIT_URL") or "").strip() + active_workflow_url = ( + os.getenv("ACTIVE_WORKFLOW_GIT_URL") or "" + ).strip() if active_workflow_url: - cwd_path, add_dirs, derived_name = self._setup_workflow_paths( - active_workflow_url, repos_cfg + cwd_path, add_dirs, derived_name = ( + workspace.setup_workflow_paths( + self.context, active_workflow_url, repos_cfg + ) ) elif repos_cfg: - cwd_path, add_dirs = self._setup_multi_repo_paths(repos_cfg) + cwd_path, add_dirs = workspace.setup_multi_repo_paths( + self.context, repos_cfg + ) else: - cwd_path = str(Path(self.context.workspace_path) / "artifacts") + cwd_path = str( + Path(self.context.workspace_path) / "artifacts" + ) - # Load ambient.json configuration + # --- Config --- ambient_config = ( - self._load_ambient_config(cwd_path) if active_workflow_url else {} + runner_config.load_ambient_config(cwd_path) + if active_workflow_url + else {} ) - # Ensure working directory exists cwd_path_obj = Path(cwd_path) if not cwd_path_obj.exists(): logger.warning( @@ -433,40 +433,38 @@ async def _run_claude_agent_sdk( logger.info(f"Claude SDK CWD: {cwd_path}") logger.info(f"Claude SDK additional directories: {add_dirs}") - # Fetch fresh credentials from backend and populate environment - # This ensures MCP servers get fresh tokens for long-running sessions - await self._populate_runtime_credentials() + # --- Credentials --- + await auth.populate_runtime_credentials(self.context) - # Load MCP server configuration (webfetch is included in static .mcp.json) - mcp_servers = self._load_mcp_config(cwd_path) or {} + # --- MCP servers --- + mcp_servers = ( + runner_config.load_mcp_config(self.context, cwd_path) or {} + ) - # Pre-flight check: Validate MCP server authentication status - # Import here to avoid circular dependency + # Pre-flight check: Validate MCP server authentication from main import _check_mcp_authentication mcp_auth_warnings = [] if mcp_servers: for server_name in mcp_servers.keys(): is_auth, msg = _check_mcp_authentication(server_name) - if is_auth is False: - # Authentication definitely failed - mcp_auth_warnings.append(f"⚠️ {server_name}: {msg}") + mcp_auth_warnings.append( + f"⚠️ {server_name}: {msg}" + ) elif is_auth is None: - # Authentication needs refresh or uncertain - mcp_auth_warnings.append(f"ℹ️ {server_name}: {msg}") + mcp_auth_warnings.append( + f"ℹ️ {server_name}: {msg}" + ) if mcp_auth_warnings: - warning_msg = "**MCP Server Authentication Issues:**\n\n" + "\n".join( - mcp_auth_warnings - ) - warning_msg += ( - "\n\nThese servers may not work correctly until re-authenticated." + warning_msg = ( + "**MCP Server Authentication Issues:**\n\n" + + "\n".join(mcp_auth_warnings) + + "\n\nThese servers may not work correctly " + "until re-authenticated." ) logger.warning(warning_msg) - - # Send as RAW event (not chat message) so UI can display as banner/notification - # Don't send as TextMessage - that shows up in chat history yield RawEvent( type=EventType.RAW, thread_id=thread_id, @@ -481,36 +479,41 @@ async def _run_claude_agent_sdk( }, ) - # Create custom session control tools - # Capture self reference for the restart tool closure - adapter_ref = self - - @sdk_tool( - "restart_session", - "Restart the Claude session to recover from issues, clear state, or get a fresh connection. Use this if you detect you're in a broken state or need to reset.", - {}, - ) - async def restart_session_tool(args: dict) -> dict: - """Tool that allows Claude to request a session restart.""" - adapter_ref._restart_requested = True - logger.info("🔄 Session restart requested by Claude via MCP tool") - return { - "content": [ - { - "type": "text", - "text": "Session restart has been requested. The current run will complete and a fresh session will be established. Your conversation context will be preserved on disk.", - } - ] - } - - # Create SDK MCP server for session tools + # --- MCP tools --- + # Session control tool + restart_tool = create_restart_session_tool(self, sdk_tool) session_tools_server = create_sdk_mcp_server( - name="session", version="1.0.0", tools=[restart_session_tool] + name="session", version="1.0.0", tools=[restart_tool] ) mcp_servers["session"] = session_tools_server - logger.info("Added custom session control MCP tools (restart_session)") + logger.info( + "Added custom session control MCP tools (restart_session)" + ) + + # Dynamic rubric evaluation tool + rubric_content, rubric_config = load_rubric_content(cwd_path) + if rubric_content or rubric_config: + rubric_tool = create_rubric_mcp_tool( + rubric_content=rubric_content or "", + rubric_config=rubric_config, + obs=obs, + session_id=self.context.session_id, + sdk_tool_decorator=sdk_tool, + ) + if rubric_tool: + rubric_server = create_sdk_mcp_server( + name="rubric", + version="1.0.0", + tools=[rubric_tool], + ) + mcp_servers["rubric"] = rubric_server + logger.info( + "Added dynamic rubric evaluation MCP tool " + f"(categories: " + f"{list(rubric_config.get('schema', {}).keys())})" + ) - # Disable built-in WebFetch in favor of WebFetch.MCP from config + # Tool permissions allowed_tools = [ "Read", "Write", @@ -525,28 +528,31 @@ async def restart_session_tool(args: dict) -> dict: for server_name in mcp_servers.keys(): allowed_tools.append(f"mcp__{server_name}") logger.info( - f"MCP tool permissions granted for servers: {list(mcp_servers.keys())}" + f"MCP tool permissions granted for servers: " + f"{list(mcp_servers.keys())}" ) - # Build workspace context system prompt - workspace_prompt = self._build_workspace_context_prompt( + # --- System prompt --- + workspace_prompt = prompts.build_workspace_context_prompt( repos_cfg=repos_cfg, - workflow_name=derived_name if active_workflow_url else None, + workflow_name=( + derived_name if active_workflow_url else None + ), artifacts_path="artifacts", ambient_config=ambient_config, + workspace_path=self.context.workspace_path, ) - # SystemPromptPreset format: uses claude_code preset with appended workspace context system_prompt_config = { "type": "preset", "preset": "claude_code", "append": workspace_prompt, } - # Capture stderr from the SDK to diagnose MCP server failures + # Capture stderr from the SDK def sdk_stderr_handler(line: str): logger.warning(f"[SDK stderr] {line.rstrip()}") - # Configure SDK options + # --- SDK options --- options = ClaudeAgentOptions( cwd=cwd_path, permission_mode="acceptEdits", @@ -561,7 +567,6 @@ def sdk_stderr_handler(line: str): if self._skip_resume_on_restart: self._skip_resume_on_restart = False - # Set additional options try: if add_dirs: options.add_dirs = add_dirs @@ -592,6 +597,7 @@ def sdk_stderr_handler(line: str): except Exception: pass + # --- Client creation --- result_payload = None current_message = None sdk_session_id = None @@ -601,16 +607,14 @@ def create_sdk_client(opts, disable_continue=False): opts.continue_conversation = False return ClaudeSDKClient(options=opts) - # Create fresh client for each run - # (Python SDK has issues with client reuse despite docs suggesting it should work) logger.info("Creating new ClaudeSDKClient for this run...") - # Enable continue_conversation to resume from disk state if not self._first_run or is_continuation: try: options.continue_conversation = True logger.info( - "Enabled continue_conversation (will resume from disk state)" + "Enabled continue_conversation " + "(will resume from disk state)" ) yield RawEvent( type=EventType.RAW, @@ -622,18 +626,27 @@ def create_sdk_client(opts, disable_continue=False): }, ) except Exception as e: - logger.warning(f"Failed to set continue_conversation: {e}") + logger.warning( + f"Failed to set continue_conversation: {e}" + ) try: logger.info("Creating ClaudeSDKClient...") client = create_sdk_client(options) - logger.info("Connecting ClaudeSDKClient (initializing subprocess)...") + logger.info( + "Connecting ClaudeSDKClient (initializing subprocess)..." + ) await client.connect() logger.info("ClaudeSDKClient connected successfully!") except Exception as resume_error: error_str = str(resume_error).lower() - if "no conversation found" in error_str or "session" in error_str: - logger.warning(f"Conversation continuation failed: {resume_error}") + if ( + "no conversation found" in error_str + or "session" in error_str + ): + logger.warning( + f"Conversation continuation failed: {resume_error}" + ) yield RawEvent( type=EventType.RAW, thread_id=thread_id, @@ -649,7 +662,6 @@ def create_sdk_client(opts, disable_continue=False): raise try: - # Store client reference for interrupt support self._active_client = client # Process the prompt @@ -662,18 +674,23 @@ def create_sdk_client(opts, disable_continue=False): step_name="processing_prompt", ) - logger.info(f"Sending query to Claude SDK: '{prompt[:100]}...'") + logger.info( + f"Sending query to Claude SDK: '{prompt[:100]}...'" + ) await client.query(prompt) logger.info("Query sent, waiting for response stream...") - # Process response stream - logger.info("Starting to consume receive_response() iterator...") + # --- Process response stream --- + logger.info( + "Starting to consume receive_response() iterator..." + ) message_count = 0 async for message in client.receive_response(): message_count += 1 logger.info( - f"[ClaudeSDKClient Message #{message_count}]: {message}" + f"[ClaudeSDKClient Message #{message_count}]: " + f"{message}" ) # Handle StreamEvent for real-time streaming chunks @@ -705,19 +722,23 @@ def create_sdk_client(opts, disable_continue=False): ) continue - # Capture SDK session ID from init message + # Capture SDK session ID if isinstance(message, SystemMessage): - if message.subtype == "init" and message.data.get("session_id"): + if message.subtype == "init" and message.data.get( + "session_id" + ): sdk_session_id = message.data.get("session_id") - logger.info(f"Captured SDK session ID: {sdk_session_id}") + logger.info( + f"Captured SDK session ID: {sdk_session_id}" + ) if isinstance(message, (AssistantMessage, UserMessage)): if isinstance(message, AssistantMessage): current_message = message - obs.start_turn(configured_model, user_input=prompt) + obs.start_turn( + configured_model, user_input=prompt + ) - # Emit trace_id for feedback association - # Frontend can use this to link feedback to specific Langfuse traces trace_id = obs.get_current_trace_id() if trace_id: yield RawEvent( @@ -731,26 +752,34 @@ def create_sdk_client(opts, disable_continue=False): ) # Process all blocks in the message - for block in getattr(message, "content", []) or []: + for block in ( + getattr(message, "content", []) or [] + ): if isinstance(block, TextBlock): text_piece = getattr(block, "text", None) if text_piece: logger.info( - f"TextBlock received (complete), text length={len(text_piece)}" + f"TextBlock received (complete), " + f"text length={len(text_piece)}" ) elif isinstance(block, ToolUseBlock): - tool_name = getattr(block, "name", "") or "unknown" - tool_input = getattr(block, "input", {}) or {} - tool_id = getattr(block, "id", None) or str( - uuid.uuid4() + tool_name = ( + getattr(block, "name", "") or "unknown" + ) + tool_input = ( + getattr(block, "input", {}) or {} ) + tool_id = getattr( + block, "id", None + ) or str(uuid.uuid4()) parent_tool_use_id = getattr( message, "parent_tool_use_id", None ) logger.info( - f"ToolUseBlock detected: {tool_name} (id={tool_id[:12]})" + f"ToolUseBlock detected: {tool_name} " + f"(id={tool_id[:12]})" ) yield ToolCallStartEvent( @@ -772,20 +801,28 @@ def create_sdk_client(opts, disable_continue=False): delta=args_json, ) - obs.track_tool_use(tool_name, tool_id, tool_input) + obs.track_tool_use( + tool_name, tool_id, tool_input + ) elif isinstance(block, ToolResultBlock): - tool_use_id = getattr(block, "tool_use_id", None) + tool_use_id = getattr( + block, "tool_use_id", None + ) content = getattr(block, "content", None) is_error = getattr(block, "is_error", None) result_text = getattr(block, "text", None) result_content = ( - content if content is not None else result_text + content + if content is not None + else result_text ) if result_content is not None: try: - result_str = _json.dumps(result_content) + result_str = _json.dumps( + result_content + ) except (TypeError, ValueError): result_str = str(result_content) else: @@ -797,16 +834,28 @@ def create_sdk_client(opts, disable_continue=False): thread_id=thread_id, run_id=run_id, tool_call_id=tool_use_id, - result=result_str if not is_error else None, - error=result_str if is_error else None, + result=( + result_str + if not is_error + else None + ), + error=( + result_str + if is_error + else None + ), ) obs.track_tool_result( - tool_use_id, result_content, is_error or False + tool_use_id, + result_content, + is_error or False, ) elif isinstance(block, ThinkingBlock): - thinking_text = getattr(block, "thinking", "") + thinking_text = getattr( + block, "thinking", "" + ) signature = getattr(block, "signature", "") yield RawEvent( type=EventType.RAW, @@ -820,7 +869,10 @@ def create_sdk_client(opts, disable_continue=False): ) # End text message after processing all blocks - if getattr(message, "content", []) and current_message_id: + if ( + getattr(message, "content", []) + and current_message_id + ): yield TextMessageEndEvent( type=EventType.TEXT_MESSAGE_END, thread_id=thread_id, @@ -848,11 +900,14 @@ def create_sdk_client(opts, disable_continue=False): sdk_num_turns = getattr(message, "num_turns", None) logger.info( - f"ResultMessage: num_turns={sdk_num_turns}, usage={usage_raw}" + f"ResultMessage: num_turns={sdk_num_turns}, " + f"usage={usage_raw}" ) # Convert usage object to dict if needed - if usage_raw is not None and not isinstance(usage_raw, dict): + if usage_raw is not None and not isinstance( + usage_raw, dict + ): try: if hasattr(usage_raw, "__dict__"): usage_raw = usage_raw.__dict__ @@ -860,36 +915,42 @@ def create_sdk_client(opts, disable_continue=False): usage_raw = usage_raw.model_dump() except Exception as e: logger.warning( - f"Could not convert usage object to dict: {e}" + "Could not convert usage object " + f"to dict: {e}" ) - # Update turn count if ( sdk_num_turns is not None and sdk_num_turns > self._turn_count ): self._turn_count = sdk_num_turns - # Complete turn tracking if current_message: obs.end_turn( self._turn_count, current_message, - usage_raw if isinstance(usage_raw, dict) else None, + ( + usage_raw + if isinstance(usage_raw, dict) + else None + ), ) current_message = None result_payload = { "subtype": getattr(message, "subtype", None), - "duration_ms": getattr(message, "duration_ms", None), + "duration_ms": getattr( + message, "duration_ms", None + ), "is_error": getattr(message, "is_error", None), "num_turns": getattr(message, "num_turns", None), - "total_cost_usd": getattr(message, "total_cost_usd", None), + "total_cost_usd": getattr( + message, "total_cost_usd", None + ), "usage": usage_raw, "result": getattr(message, "result", None), } - # Emit state delta with result yield StateDeltaEvent( type=EventType.STATE_DELTA, thread_id=thread_id, @@ -913,31 +974,31 @@ def create_sdk_client(opts, disable_continue=False): ) logger.info( - f"Response iterator fully consumed ({message_count} messages total)" + f"Response iterator fully consumed " + f"({message_count} messages total)" ) - # Mark first run complete self._first_run = False - # Check if restart was requested by Claude + # Check if restart was requested if self._restart_requested: - logger.info("🔄 Restart was requested, emitting restart event") - self._restart_requested = False # Reset flag + logger.info( + "🔄 Restart was requested, emitting restart event" + ) + self._restart_requested = False yield RawEvent( type=EventType.RAW, thread_id=thread_id, run_id=run_id, event={ "type": "session_restart_requested", - "message": "Claude requested a session restart. Reconnecting...", + "message": "Claude requested a session restart. " + "Reconnecting...", }, ) finally: - # Clear active client reference self._active_client = None - - # Always disconnect client at end of run if client is not None: logger.info("Disconnecting client (end of run)") await client.disconnect() @@ -952,9 +1013,7 @@ def create_sdk_client(opts, disable_continue=False): raise async def interrupt(self) -> None: - """ - Interrupt the active Claude SDK execution. - """ + """Interrupt the active Claude SDK execution.""" if self._active_client is None: logger.warning("Interrupt requested but no active client") return @@ -965,834 +1024,3 @@ async def interrupt(self) -> None: logger.info("Interrupt signal sent successfully") except Exception as e: logger.error(f"Failed to interrupt Claude SDK: {e}") - - def _setup_workflow_paths( - self, active_workflow_url: str, repos_cfg: list - ) -> tuple[str, list, str]: - """Setup paths for workflow mode.""" - add_dirs = [] - derived_name = None - cwd_path = self.context.workspace_path - - try: - owner, repo, _ = self._parse_owner_repo(active_workflow_url) - derived_name = repo or "" - if not derived_name: - p = urlparse(active_workflow_url) - parts = [pt for pt in (p.path or "").split("/") if pt] - if parts: - derived_name = parts[-1] - derived_name = (derived_name or "").removesuffix(".git").strip() - - if derived_name: - workflow_path = str( - Path(self.context.workspace_path) / "workflows" / derived_name - ) - if Path(workflow_path).exists(): - cwd_path = workflow_path - logger.info(f"Using workflow as CWD: {derived_name}") - else: - logger.warning( - f"Workflow directory not found: {workflow_path}, using default" - ) - cwd_path = str( - Path(self.context.workspace_path) / "workflows" / "default" - ) - else: - cwd_path = str( - Path(self.context.workspace_path) / "workflows" / "default" - ) - except Exception as e: - logger.warning(f"Failed to derive workflow name: {e}, using default") - cwd_path = str(Path(self.context.workspace_path) / "workflows" / "default") - - # Add all repos as additional directories (repos are in /workspace/repos/{name}) - repos_base = Path(self.context.workspace_path) / "repos" - for r in repos_cfg: - name = (r.get("name") or "").strip() - if name: - repo_path = str(repos_base / name) - if repo_path not in add_dirs: - add_dirs.append(repo_path) - - # Add artifacts and file-uploads directories - artifacts_path = str(Path(self.context.workspace_path) / "artifacts") - if artifacts_path not in add_dirs: - add_dirs.append(artifacts_path) - - file_uploads_path = str(Path(self.context.workspace_path) / "file-uploads") - if file_uploads_path not in add_dirs: - add_dirs.append(file_uploads_path) - - return cwd_path, add_dirs, derived_name - - def _setup_multi_repo_paths(self, repos_cfg: list) -> tuple[str, list]: - """Setup paths for multi-repo mode. - - Repos are cloned to /workspace/repos/{name} by both: - - hydrate.sh (init container) - - clone_repo_at_runtime() (runtime addition) - """ - add_dirs = [] - repos_base = Path(self.context.workspace_path) / "repos" - - main_name = (os.getenv("MAIN_REPO_NAME") or "").strip() - if not main_name: - idx_raw = (os.getenv("MAIN_REPO_INDEX") or "").strip() - try: - idx_val = int(idx_raw) if idx_raw else 0 - except Exception: - idx_val = 0 - if idx_val < 0 or idx_val >= len(repos_cfg): - idx_val = 0 - main_name = (repos_cfg[idx_val].get("name") or "").strip() - - # Main repo path is /workspace/repos/{name} - cwd_path = ( - str(repos_base / main_name) if main_name else self.context.workspace_path - ) - - for r in repos_cfg: - name = (r.get("name") or "").strip() - if not name: - continue - # All repos are in /workspace/repos/{name} - p = str(repos_base / name) - if p != cwd_path: - add_dirs.append(p) - - # Add artifacts and file-uploads directories - artifacts_path = str(Path(self.context.workspace_path) / "artifacts") - if artifacts_path not in add_dirs: - add_dirs.append(artifacts_path) - - file_uploads_path = str(Path(self.context.workspace_path) / "file-uploads") - if file_uploads_path not in add_dirs: - add_dirs.append(file_uploads_path) - - return cwd_path, add_dirs - - @staticmethod - def _sanitize_user_context(user_id: str, user_name: str) -> tuple[str, str]: - """Validate and sanitize user context fields to prevent injection attacks.""" - if user_id: - user_id = str(user_id).strip() - if len(user_id) > 255: - user_id = user_id[:255] - sanitized_id = re.sub(r"[^a-zA-Z0-9@._-]", "", user_id) - user_id = sanitized_id - - if user_name: - user_name = str(user_name).strip() - if len(user_name) > 255: - user_name = user_name[:255] - sanitized_name = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", user_name) - user_name = sanitized_name - - return user_id, user_name - - def _map_to_vertex_model(self, model: str) -> str: - """Map Anthropic API model names to Vertex AI model names.""" - model_map = { - "claude-opus-4-6": "claude-opus-4-6", - "claude-opus-4-5": "claude-opus-4-5@20251101", - "claude-opus-4-1": "claude-opus-4-1@20250805", - "claude-sonnet-4-5": "claude-sonnet-4-5@20250929", - "claude-haiku-4-5": "claude-haiku-4-5@20251001", - } - return model_map.get(model, model) - - async def _setup_vertex_credentials(self) -> dict: - """Set up Google Cloud Vertex AI credentials from service account.""" - service_account_path = self.context.get_env( - "GOOGLE_APPLICATION_CREDENTIALS", "" - ).strip() - project_id = self.context.get_env("ANTHROPIC_VERTEX_PROJECT_ID", "").strip() - region = self.context.get_env("CLOUD_ML_REGION", "").strip() - - if not service_account_path: - raise RuntimeError( - "GOOGLE_APPLICATION_CREDENTIALS must be set when CLAUDE_CODE_USE_VERTEX=1" - ) - if not project_id: - raise RuntimeError( - "ANTHROPIC_VERTEX_PROJECT_ID must be set when CLAUDE_CODE_USE_VERTEX=1" - ) - if not region: - raise RuntimeError( - "CLOUD_ML_REGION must be set when CLAUDE_CODE_USE_VERTEX=1" - ) - - if not Path(service_account_path).exists(): - raise RuntimeError( - f"Service account key file not found at {service_account_path}" - ) - - logger.info(f"Vertex AI configured: project={project_id}, region={region}") - return { - "credentials_path": service_account_path, - "project_id": project_id, - "region": region, - } - - async def _prepare_workspace(self) -> AsyncIterator[BaseEvent]: - """Validate workspace prepared by init container. - - The init-hydrate container now handles: - - Downloading state from S3 (.claude/, artifacts/, file-uploads/) - - Cloning repos to /workspace/repos/ - - Cloning workflows to /workspace/workflows/ - - Runner just validates and logs what's ready. - """ - workspace = Path(self.context.workspace_path) - logger.info(f"Validating workspace at {workspace}") - - # Check what was hydrated - hydrated_paths = [] - for path_name in [".claude", "artifacts", "file-uploads"]: - path_dir = workspace / path_name - if path_dir.exists(): - file_count = len([f for f in path_dir.rglob("*") if f.is_file()]) - if file_count > 0: - hydrated_paths.append(f"{path_name} ({file_count} files)") - - if hydrated_paths: - logger.info(f"Hydrated from S3: {', '.join(hydrated_paths)}") - else: - logger.info("No state hydrated (fresh session)") - - # No further preparation needed - init container did the work - - async def _validate_prerequisites(self): - """Validate prerequisite files exist for phase-based slash commands.""" - prompt = self.context.get_env("INITIAL_PROMPT", "") - if not prompt: - return - - prompt_lower = prompt.strip().lower() - - prerequisites = { - "/speckit.plan": ( - "spec.md", - "Specification file (spec.md) not found. Please run /speckit.specify first.", - ), - "/speckit.tasks": ( - "plan.md", - "Planning file (plan.md) not found. Please run /speckit.plan first.", - ), - "/speckit.implement": ( - "tasks.md", - "Tasks file (tasks.md) not found. Please run /speckit.tasks first.", - ), - } - - for cmd, (required_file, error_msg) in prerequisites.items(): - if prompt_lower.startswith(cmd): - workspace = Path(self.context.workspace_path) - found = False - - if (workspace / required_file).exists(): - found = True - break - - for subdir in workspace.rglob("specs/*/"): - if (subdir / required_file).exists(): - found = True - break - - if not found: - raise PrerequisiteError(error_msg) - break - - async def _initialize_workflow_if_set(self) -> AsyncIterator[BaseEvent]: - """Validate workflow was cloned by init container.""" - active_workflow_url = (os.getenv("ACTIVE_WORKFLOW_GIT_URL") or "").strip() - if not active_workflow_url: - return - - try: - owner, repo, _ = self._parse_owner_repo(active_workflow_url) - derived_name = repo or "" - if not derived_name: - p = urlparse(active_workflow_url) - parts = [pt for pt in (p.path or "").split("/") if pt] - if parts: - derived_name = parts[-1] - derived_name = (derived_name or "").removesuffix(".git").strip() - - if not derived_name: - logger.warning("Could not derive workflow name from URL") - return - - # Check for cloned workflow (init container uses -clone-temp suffix) - workspace = Path(self.context.workspace_path) - workflow_temp_dir = workspace / "workflows" / f"{derived_name}-clone-temp" - workflow_dir = workspace / "workflows" / derived_name - - if workflow_temp_dir.exists(): - logger.info( - f"Workflow {derived_name} cloned by init container at {workflow_temp_dir.name}" - ) - elif workflow_dir.exists(): - logger.info(f"Workflow {derived_name} available at {workflow_dir.name}") - else: - logger.warning( - f"Workflow {derived_name} not found (init container may have failed to clone)" - ) - - except Exception as e: - logger.error(f"Failed to validate workflow: {e}") - - async def _run_cmd(self, cmd, cwd=None, capture_stdout=False, ignore_errors=False): - """Run a subprocess command asynchronously.""" - cmd_safe = [self._redact_secrets(str(arg)) for arg in cmd] - logger.info(f"Running command: {' '.join(cmd_safe)}") - - proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - cwd=cwd or self.context.workspace_path, - ) - stdout_data, stderr_data = await proc.communicate() - stdout_text = stdout_data.decode("utf-8", errors="replace") - stderr_text = stderr_data.decode("utf-8", errors="replace") - - if stdout_text.strip(): - logger.info(f"Command stdout: {self._redact_secrets(stdout_text.strip())}") - if stderr_text.strip(): - logger.info(f"Command stderr: {self._redact_secrets(stderr_text.strip())}") - - if proc.returncode != 0 and not ignore_errors: - raise RuntimeError(stderr_text or f"Command failed: {' '.join(cmd_safe)}") - - if capture_stdout: - return stdout_text - return "" - - def _url_with_token(self, url: str, token: str) -> str: - """Add authentication token to URL.""" - if not token or not url.lower().startswith("http"): - return url - try: - parsed = urlparse(url) - netloc = parsed.netloc - if "@" in netloc: - netloc = netloc.split("@", 1)[1] - - hostname = parsed.hostname or "" - if "gitlab" in hostname.lower(): - auth = f"oauth2:{token}@" - else: - auth = f"x-access-token:{token}@" - - new_netloc = auth + netloc - return urlunparse( - ( - parsed.scheme, - new_netloc, - parsed.path, - parsed.params, - parsed.query, - parsed.fragment, - ) - ) - except Exception: - return url - - def _redact_secrets(self, text: str) -> str: - """Redact tokens and secrets from text for safe logging.""" - if not text: - return text - - text = re.sub(r"gh[pousr]_[a-zA-Z0-9]{36,255}", "gh*_***REDACTED***", text) - text = re.sub(r"sk-ant-[a-zA-Z0-9\-_]{30,200}", "sk-ant-***REDACTED***", text) - text = re.sub(r"pk-lf-[a-zA-Z0-9\-_]{10,100}", "pk-lf-***REDACTED***", text) - text = re.sub(r"sk-lf-[a-zA-Z0-9\-_]{10,100}", "sk-lf-***REDACTED***", text) - text = re.sub( - r"x-access-token:[^@\s]+@", "x-access-token:***REDACTED***@", text - ) - text = re.sub(r"oauth2:[^@\s]+@", "oauth2:***REDACTED***@", text) - text = re.sub(r"://[^:@\s]+:[^@\s]+@", "://***REDACTED***@", text) - text = re.sub( - r'(ANTHROPIC_API_KEY|LANGFUSE_SECRET_KEY|LANGFUSE_PUBLIC_KEY|BOT_TOKEN|GIT_TOKEN)\s*=\s*[^\s\'"]+', - r"\1=***REDACTED***", - text, - ) - return text - - async def _fetch_token_for_url(self, url: str) -> str: - """Fetch appropriate token based on repository URL.""" - try: - parsed = urlparse(url) - hostname = parsed.hostname or "" - - if "gitlab" in hostname.lower(): - token = await self._fetch_gitlab_token() - if token: - logger.info(f"Using fresh GitLab token for {hostname}") - return token - else: - logger.warning(f"No GitLab credentials configured for {url}") - return "" - - # Always fetch fresh GitHub token (PAT or App) - token = await self._fetch_github_token() - if token: - logger.info(f"Using fresh GitHub token for {hostname}") - return token - - except Exception as e: - logger.warning( - f"Failed to parse URL {url}: {e}, falling back to GitHub token" - ) - return os.getenv("GITHUB_TOKEN") or await self._fetch_github_token() - - async def _populate_runtime_credentials(self) -> None: - """Fetch all credentials from backend and populate environment variables. - - This is called before each SDK run to ensure MCP servers have fresh tokens. - """ - logger.info("Fetching fresh credentials from backend API...") - - # Fetch Google credentials - google_creds = await self._fetch_google_credentials() - if google_creds.get("accessToken"): - # Write credentials to file for workspace-mcp - creds_dir = Path("/workspace/.google_workspace_mcp/credentials") - creds_dir.mkdir(parents=True, exist_ok=True) - creds_file = creds_dir / "credentials.json" - - # Get OAuth client config from env - client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID", "") - client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET", "") - - # Create credentials.json for workspace-mcp - creds_data = { - "token": google_creds.get("accessToken"), - "refresh_token": "", # Backend handles refresh - "token_uri": "https://oauth2.googleapis.com/token", - "client_id": client_id, - "client_secret": client_secret, - "scopes": google_creds.get("scopes", []), - "expiry": google_creds.get("expiresAt", ""), - } - - with open(creds_file, "w") as f: - _json.dump(creds_data, f, indent=2) - creds_file.chmod(0o644) - logger.info("✓ Updated Google credentials file for workspace-mcp") - - # Set USER_GOOGLE_EMAIL for MCP server (from backend API response) - user_email = google_creds.get("email", "") - if user_email and user_email != "user@example.com": - os.environ["USER_GOOGLE_EMAIL"] = user_email - logger.info( - f"✓ Set USER_GOOGLE_EMAIL to {user_email} for workspace-mcp" - ) - - # Fetch Jira credentials - jira_creds = await self._fetch_jira_credentials() - if jira_creds.get("apiToken"): - os.environ["JIRA_URL"] = jira_creds.get("url", "") - os.environ["JIRA_API_TOKEN"] = jira_creds.get("apiToken", "") - os.environ["JIRA_EMAIL"] = jira_creds.get("email", "") - logger.info("✓ Updated Jira credentials in environment") - - # Fetch GitLab token - gitlab_token = await self._fetch_gitlab_token() - if gitlab_token: - os.environ["GITLAB_TOKEN"] = gitlab_token - logger.info("✓ Updated GitLab token in environment") - - # Fetch GitHub token (PAT or App) - github_token = await self._fetch_github_token() - if github_token: - os.environ["GITHUB_TOKEN"] = github_token - logger.info("✓ Updated GitHub token in environment") - - logger.info("Runtime credentials populated successfully") - - async def _fetch_credential(self, credential_type: str) -> dict: - """Fetch credentials from backend API at runtime. - - Args: - credential_type: One of 'github', 'google', 'jira', 'gitlab' - - Returns: - Dictionary with credential data or empty dict if unavailable - """ - base = os.getenv("BACKEND_API_URL", "").rstrip("/") - project = os.getenv("PROJECT_NAME") or os.getenv( - "AGENTIC_SESSION_NAMESPACE", "" - ) - project = project.strip() - session_id = self.context.session_id - - if not base or not project or not session_id: - logger.warning( - f"Cannot fetch {credential_type} credentials: missing environment variables (base={base}, project={project}, session={session_id})" - ) - return {} - - url = f"{base}/projects/{project}/agentic-sessions/{session_id}/credentials/{credential_type}" - logger.info(f"Fetching fresh {credential_type} credentials from: {url}") - - req = _urllib_request.Request(url, method="GET") - bot = (os.getenv("BOT_TOKEN") or "").strip() - if bot: - req.add_header("Authorization", f"Bearer {bot}") - - loop = asyncio.get_event_loop() - - def _do_req(): - try: - with _urllib_request.urlopen(req, timeout=10) as resp: - return resp.read().decode("utf-8", errors="replace") - except Exception as e: - logger.warning(f"{credential_type} credential fetch failed: {e}") - return "" - - resp_text = await loop.run_in_executor(None, _do_req) - if not resp_text: - return {} - - try: - data = _json.loads(resp_text) - logger.info( - f"Successfully fetched {credential_type} credentials from backend" - ) - return data - except Exception as e: - logger.error(f"Failed to parse {credential_type} credential response: {e}") - return {} - - async def _fetch_github_token(self) -> str: - """Fetch GitHub token from backend API (always fresh - PAT or minted App token).""" - data = await self._fetch_credential("github") - token = data.get("token", "") - if token: - logger.info("Using fresh GitHub token from backend") - return token - - async def _fetch_google_credentials(self) -> dict: - """Fetch Google OAuth credentials from backend API.""" - data = await self._fetch_credential("google") - if data.get("accessToken"): - logger.info( - f"Using fresh Google credentials from backend (email: {data.get('email', 'unknown')})" - ) - return data - - async def _fetch_jira_credentials(self) -> dict: - """Fetch Jira credentials from backend API.""" - data = await self._fetch_credential("jira") - if data.get("apiToken"): - logger.info( - f"Using Jira credentials from backend (url: {data.get('url', 'unknown')})" - ) - return data - - async def _fetch_gitlab_token(self) -> str: - """Fetch GitLab token from backend API.""" - data = await self._fetch_credential("gitlab") - token = data.get("token", "") - if token: - logger.info( - f"Using fresh GitLab token from backend (instance: {data.get('instanceUrl', 'unknown')})" - ) - return token - - async def _fetch_github_token_legacy(self) -> str: - """Legacy method - kept for backward compatibility.""" - # Build mint URL from environment - base = os.getenv("BACKEND_API_URL", "").rstrip("/") - project = os.getenv("PROJECT_NAME") or os.getenv( - "AGENTIC_SESSION_NAMESPACE", "" - ) - project = project.strip() - session_id = self.context.session_id - - if not base or not project or not session_id: - logger.warning("Cannot fetch GitHub token: missing environment variables") - return "" - - url = f"{base}/projects/{project}/agentic-sessions/{session_id}/github/token" - logger.info(f"Fetching GitHub token from legacy endpoint: {url}") - - req = _urllib_request.Request( - url, data=b"{}", headers={"Content-Type": "application/json"}, method="POST" - ) - bot = (os.getenv("BOT_TOKEN") or "").strip() - if bot: - req.add_header("Authorization", f"Bearer {bot}") - - loop = asyncio.get_event_loop() - - def _do_req(): - try: - with _urllib_request.urlopen(req, timeout=10) as resp: - return resp.read().decode("utf-8", errors="replace") - except Exception as e: - logger.warning(f"GitHub token fetch failed: {e}") - return "" - - resp_text = await loop.run_in_executor(None, _do_req) - if not resp_text: - return "" - - try: - data = _json.loads(resp_text) - token = str(data.get("token") or "") - if token: - logger.info("Successfully fetched GitHub token from backend") - return token - except Exception as e: - logger.error(f"Failed to parse token response: {e}") - return "" - - def _parse_owner_repo(self, url: str) -> tuple[str, str, str]: - """Return (owner, name, host) from various URL formats.""" - s = (url or "").strip() - s = s.removesuffix(".git") - host = "github.com" - try: - if s.startswith("http://") or s.startswith("https://"): - p = urlparse(s) - host = p.netloc - parts = [pt for pt in p.path.split("/") if pt] - if len(parts) >= 2: - return parts[0], parts[1], host - if s.startswith("git@") or ":" in s: - s2 = s - if s2.startswith("git@"): - s2 = s2.replace(":", "/", 1) - s2 = s2.replace("git@", "ssh://git@", 1) - p = urlparse(s2) - host = p.hostname or host - parts = [pt for pt in (p.path or "").split("/") if pt] - if len(parts) >= 2: - return parts[-2], parts[-1], host - parts = [pt for pt in s.split("/") if pt] - if len(parts) == 2: - return parts[0], parts[1], host - except Exception: - return "", "", host - return "", "", host - - def _get_repos_config(self) -> list[dict]: - """Read repos mapping from REPOS_JSON env if present. - - Expected format: [{"url": "...", "branch": "main", "autoPush": true}, ...] - Returns: [{"name": "repo-name", "url": "...", "branch": "...", "autoPush": bool}, ...] - """ - try: - raw = os.getenv("REPOS_JSON", "").strip() - if not raw: - return [] - data = _json.loads(raw) - if isinstance(data, list): - out = [] - for it in data: - if not isinstance(it, dict): - continue - - # Extract simple format fields - url = str(it.get("url") or "").strip() - # Auto-generate branch from session name if not provided - branch_from_json = it.get("branch") - if branch_from_json and str(branch_from_json).strip(): - branch = str(branch_from_json).strip() - else: - # Fallback: use AGENTIC_SESSION_NAME to match backend logic - session_id = os.getenv("AGENTIC_SESSION_NAME", "").strip() - branch = f"ambient/{session_id}" if session_id else "main" - # Parse autoPush as boolean, defaulting to False for invalid types - auto_push_raw = it.get("autoPush", False) - auto_push = ( - auto_push_raw if isinstance(auto_push_raw, bool) else False - ) - - if not url: - continue - - # Derive repo name from URL if not provided - name = str(it.get("name") or "").strip() - if not name: - try: - owner, repo, _ = self._parse_owner_repo(url) - derived = repo or "" - if not derived: - p = urlparse(url) - parts = [pt for pt in (p.path or "").split("/") if pt] - if parts: - derived = parts[-1] - name = (derived or "").removesuffix(".git").strip() - except Exception: - name = "" - - if name and url: - out.append( - { - "name": name, - "url": url, - "branch": branch, - "autoPush": auto_push, - } - ) - return out - except Exception: - return [] - return [] - - def _expand_env_vars(self, value: Any) -> Any: - """Recursively expand ${VAR} and ${VAR:-default} patterns in config values.""" - if isinstance(value, str): - # Pattern: ${VAR} or ${VAR:-default} - pattern = r"\$\{([^}:]+)(?::-([^}]*))?\}" - - def replace_var(match): - var_name = match.group(1) - default_val = match.group(2) if match.group(2) is not None else "" - return os.environ.get(var_name, default_val) - - return re.sub(pattern, replace_var, value) - elif isinstance(value, dict): - return {k: self._expand_env_vars(v) for k, v in value.items()} - elif isinstance(value, list): - return [self._expand_env_vars(item) for item in value] - return value - - def _load_mcp_config(self, cwd_path: str) -> Optional[dict]: - """Load MCP server configuration from the ambient runner's .mcp.json file.""" - try: - # Allow override via MCP_CONFIG_FILE env var (useful for e2e with minimal MCPs) - mcp_config_file = self.context.get_env( - "MCP_CONFIG_FILE", "/app/claude-runner/.mcp.json" - ) - runner_mcp_file = Path(mcp_config_file) - - if runner_mcp_file.exists() and runner_mcp_file.is_file(): - logger.info(f"Loading MCP config from: {runner_mcp_file}") - with open(runner_mcp_file, "r") as f: - config = _json.load(f) - mcp_servers = config.get("mcpServers", {}) - # Expand environment variables in the config - expanded = self._expand_env_vars(mcp_servers) - logger.info( - f"Expanded MCP config env vars for {len(expanded)} servers" - ) - return expanded - else: - logger.info(f"No MCP config file found at: {runner_mcp_file}") - return None - - except _json.JSONDecodeError as e: - logger.error(f"Failed to parse MCP config: {e}") - return None - except Exception as e: - logger.error(f"Error loading MCP config: {e}") - return None - - def _load_ambient_config(self, cwd_path: str) -> dict: - """Load ambient.json configuration from workflow directory.""" - try: - config_path = Path(cwd_path) / ".ambient" / "ambient.json" - - if not config_path.exists(): - logger.info(f"No ambient.json found at {config_path}, using defaults") - return {} - - with open(config_path, "r") as f: - config = _json.load(f) - logger.info(f"Loaded ambient.json: name={config.get('name')}") - return config - - except _json.JSONDecodeError as e: - logger.error(f"Failed to parse ambient.json: {e}") - return {} - except Exception as e: - logger.error(f"Error loading ambient.json: {e}") - return {} - - def _build_workspace_context_prompt( - self, repos_cfg, workflow_name, artifacts_path, ambient_config - ): - """Generate concise system prompt describing workspace layout.""" - prompt = "# Workspace Structure\n\n" - - # Workflow directory (if active) - if workflow_name: - prompt += f"**Working Directory**: workflows/{workflow_name}/ (workflow logic - do not create files here)\n\n" - - # Artifacts - prompt += f"**Artifacts**: {artifacts_path} (create all output files here)\n\n" - - # Uploaded files - file_uploads_path = Path(self.context.workspace_path) / "file-uploads" - if file_uploads_path.exists() and file_uploads_path.is_dir(): - try: - files = sorted( - [f.name for f in file_uploads_path.iterdir() if f.is_file()] - ) - if files: - max_display = 10 - if len(files) <= max_display: - prompt += f"**Uploaded Files**: {', '.join(files)}\n\n" - else: - prompt += f"**Uploaded Files** ({len(files)} total): {', '.join(files[:max_display])}, and {len(files) - max_display} more\n\n" - except Exception: - pass - else: - prompt += "**Uploaded Files**: None\n\n" - - # Repositories - if repos_cfg: - session_id = os.getenv("AGENTIC_SESSION_NAME", "").strip() - feature_branch = f"ambient/{session_id}" if session_id else None - - repo_names = [ - repo.get("name", f"repo-{i}") for i, repo in enumerate(repos_cfg) - ] - if len(repo_names) <= 5: - prompt += f"**Repositories**: {', '.join([f'repos/{name}/' for name in repo_names])}\n" - else: - prompt += f"**Repositories** ({len(repo_names)} total): {', '.join([f'repos/{name}/' for name in repo_names[:5]])}, and {len(repo_names) - 5} more\n" - - if feature_branch: - prompt += f"**Working Branch**: `{feature_branch}` (all repos are on this feature branch)\n\n" - else: - prompt += "\n" - - # Add git push instructions for repos with autoPush enabled - auto_push_repos = [ - repo for repo in repos_cfg if repo.get("autoPush", False) - ] - if auto_push_repos: - push_branch = feature_branch or "ambient/" - - prompt += "## Git Push Instructions\n\n" - prompt += "The following repositories have auto-push enabled. When you make changes to these repositories, you MUST commit and push your changes:\n\n" - for repo in auto_push_repos: - repo_name = repo.get("name", "unknown") - prompt += f"- **repos/{repo_name}/**\n" - prompt += "\nAfter making changes to any auto-push repository:\n" - prompt += "1. Use `git add` to stage your changes\n" - prompt += '2. Use `git commit -m "description"` to commit with a descriptive message\n' - prompt += f"3. Use `git push origin {push_branch}` to push to the remote repository\n\n" - - # MCP Integration Setup Instructions - prompt += "## MCP Integrations\n" - prompt += "If you need Google Drive access: Ask user to go to Integrations page in Ambient and authenticate with Google Drive.\n" - prompt += "If you need Jira access: Ask user to go to Workspace Settings in Ambient and configure Jira credentials there.\n\n" - - # Workflow instructions (if any) - if ambient_config.get("systemPrompt"): - prompt += f"## Workflow Instructions\n{ambient_config['systemPrompt']}\n\n" - - return prompt - - # NOTE: Google credential copy functions removed - credentials now fetched at runtime via backend API - # This supersedes PR #562's volume mounting approach with just-in-time credential fetching - # See _populate_runtime_credentials() for new approach diff --git a/components/runners/claude-code-runner/auth.py b/components/runners/claude-code-runner/auth.py new file mode 100644 index 000000000..27d4cc77d --- /dev/null +++ b/components/runners/claude-code-runner/auth.py @@ -0,0 +1,351 @@ +""" +Authentication and credential management for the Claude Code runner. + +Handles Anthropic API keys, Vertex AI setup, and runtime credential +fetching from the backend API (GitHub, Google, Jira, GitLab). +""" + +import asyncio +import json as _json +import logging +import os +import re +from pathlib import Path +from urllib import request as _urllib_request +from urllib.parse import urlparse + +from context import RunnerContext + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# User context sanitization +# --------------------------------------------------------------------------- + +def sanitize_user_context(user_id: str, user_name: str) -> tuple[str, str]: + """Validate and sanitize user context fields to prevent injection attacks.""" + if user_id: + user_id = str(user_id).strip() + if len(user_id) > 255: + user_id = user_id[:255] + user_id = re.sub(r"[^a-zA-Z0-9@._-]", "", user_id) + + if user_name: + user_name = str(user_name).strip() + if len(user_name) > 255: + user_name = user_name[:255] + user_name = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", user_name) + + return user_id, user_name + + +# --------------------------------------------------------------------------- +# Model helpers +# --------------------------------------------------------------------------- + +# Anthropic API → Vertex AI model name mapping +VERTEX_MODEL_MAP: dict[str, str] = { + "claude-opus-4-5": "claude-opus-4-5@20251101", + "claude-opus-4-1": "claude-opus-4-1@20250805", + "claude-sonnet-4-5": "claude-sonnet-4-5@20250929", + "claude-haiku-4-5": "claude-haiku-4-5@20251001", +} + + +def map_to_vertex_model(model: str) -> str: + """Map Anthropic API model names to Vertex AI model names.""" + return VERTEX_MODEL_MAP.get(model, model) + + +async def setup_vertex_credentials(context: RunnerContext) -> dict: + """Set up Google Cloud Vertex AI credentials from service account. + + Returns: + Dict with credentials_path, project_id, region. + + Raises: + RuntimeError: If required environment variables are missing. + """ + service_account_path = context.get_env( + "GOOGLE_APPLICATION_CREDENTIALS", "" + ).strip() + project_id = context.get_env("ANTHROPIC_VERTEX_PROJECT_ID", "").strip() + region = context.get_env("CLOUD_ML_REGION", "").strip() + + if not service_account_path: + raise RuntimeError( + "GOOGLE_APPLICATION_CREDENTIALS must be set when CLAUDE_CODE_USE_VERTEX=1" + ) + if not project_id: + raise RuntimeError( + "ANTHROPIC_VERTEX_PROJECT_ID must be set when CLAUDE_CODE_USE_VERTEX=1" + ) + if not region: + raise RuntimeError( + "CLOUD_ML_REGION must be set when CLAUDE_CODE_USE_VERTEX=1" + ) + + if not Path(service_account_path).exists(): + raise RuntimeError( + f"Service account key file not found at {service_account_path}" + ) + + logger.info(f"Vertex AI configured: project={project_id}, region={region}") + return { + "credentials_path": service_account_path, + "project_id": project_id, + "region": region, + } + + +# --------------------------------------------------------------------------- +# Backend credential fetching +# --------------------------------------------------------------------------- + +async def _fetch_credential(context: RunnerContext, credential_type: str) -> dict: + """Fetch credentials from backend API at runtime. + + Args: + context: Runner context with session_id. + credential_type: One of 'github', 'google', 'jira', 'gitlab'. + + Returns: + Dictionary with credential data or empty dict if unavailable. + """ + base = os.getenv("BACKEND_API_URL", "").rstrip("/") + project = os.getenv("PROJECT_NAME") or os.getenv( + "AGENTIC_SESSION_NAMESPACE", "" + ) + project = project.strip() + session_id = context.session_id + + if not base or not project or not session_id: + logger.warning( + f"Cannot fetch {credential_type} credentials: missing environment " + f"variables (base={base}, project={project}, session={session_id})" + ) + return {} + + url = ( + f"{base}/projects/{project}/agentic-sessions/" + f"{session_id}/credentials/{credential_type}" + ) + logger.info(f"Fetching fresh {credential_type} credentials from: {url}") + + req = _urllib_request.Request(url, method="GET") + bot = (os.getenv("BOT_TOKEN") or "").strip() + if bot: + req.add_header("Authorization", f"Bearer {bot}") + + loop = asyncio.get_event_loop() + + def _do_req(): + try: + with _urllib_request.urlopen(req, timeout=10) as resp: + return resp.read().decode("utf-8", errors="replace") + except Exception as e: + logger.warning(f"{credential_type} credential fetch failed: {e}") + return "" + + resp_text = await loop.run_in_executor(None, _do_req) + if not resp_text: + return {} + + try: + data = _json.loads(resp_text) + logger.info( + f"Successfully fetched {credential_type} credentials from backend" + ) + return data + except Exception as e: + logger.error( + f"Failed to parse {credential_type} credential response: {e}" + ) + return {} + + +async def fetch_github_token(context: RunnerContext) -> str: + """Fetch GitHub token from backend API (always fresh — PAT or minted App token).""" + data = await _fetch_credential(context, "github") + token = data.get("token", "") + if token: + logger.info("Using fresh GitHub token from backend") + return token + + +async def fetch_google_credentials(context: RunnerContext) -> dict: + """Fetch Google OAuth credentials from backend API.""" + data = await _fetch_credential(context, "google") + if data.get("accessToken"): + logger.info( + f"Using fresh Google credentials from backend " + f"(email: {data.get('email', 'unknown')})" + ) + return data + + +async def fetch_jira_credentials(context: RunnerContext) -> dict: + """Fetch Jira credentials from backend API.""" + data = await _fetch_credential(context, "jira") + if data.get("apiToken"): + logger.info( + f"Using Jira credentials from backend " + f"(url: {data.get('url', 'unknown')})" + ) + return data + + +async def fetch_gitlab_token(context: RunnerContext) -> str: + """Fetch GitLab token from backend API.""" + data = await _fetch_credential(context, "gitlab") + token = data.get("token", "") + if token: + logger.info( + f"Using fresh GitLab token from backend " + f"(instance: {data.get('instanceUrl', 'unknown')})" + ) + return token + + +async def fetch_token_for_url(context: RunnerContext, url: str) -> str: + """Fetch appropriate token based on repository URL host.""" + try: + parsed = urlparse(url) + hostname = parsed.hostname or "" + + if "gitlab" in hostname.lower(): + token = await fetch_gitlab_token(context) + if token: + logger.info(f"Using fresh GitLab token for {hostname}") + return token + else: + logger.warning(f"No GitLab credentials configured for {url}") + return "" + + token = await fetch_github_token(context) + if token: + logger.info(f"Using fresh GitHub token for {hostname}") + return token + + except Exception as e: + logger.warning( + f"Failed to parse URL {url}: {e}, falling back to GitHub token" + ) + return os.getenv("GITHUB_TOKEN") or await fetch_github_token(context) + + +async def populate_runtime_credentials(context: RunnerContext) -> None: + """Fetch all credentials from backend and populate environment variables. + + Called before each SDK run to ensure MCP servers have fresh tokens. + """ + logger.info("Fetching fresh credentials from backend API...") + + # Google credentials + google_creds = await fetch_google_credentials(context) + if google_creds.get("accessToken"): + creds_dir = Path("/workspace/.google_workspace_mcp/credentials") + creds_dir.mkdir(parents=True, exist_ok=True) + creds_file = creds_dir / "credentials.json" + + client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID", "") + client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET", "") + + creds_data = { + "token": google_creds.get("accessToken"), + "refresh_token": "", + "token_uri": "https://oauth2.googleapis.com/token", + "client_id": client_id, + "client_secret": client_secret, + "scopes": google_creds.get("scopes", []), + "expiry": google_creds.get("expiresAt", ""), + } + + with open(creds_file, "w") as f: + _json.dump(creds_data, f, indent=2) + creds_file.chmod(0o644) + logger.info("✓ Updated Google credentials file for workspace-mcp") + + user_email = google_creds.get("email", "") + if user_email and user_email != "user@example.com": + os.environ["USER_GOOGLE_EMAIL"] = user_email + logger.info( + f"✓ Set USER_GOOGLE_EMAIL to {user_email} for workspace-mcp" + ) + + # Jira credentials + jira_creds = await fetch_jira_credentials(context) + if jira_creds.get("apiToken"): + os.environ["JIRA_URL"] = jira_creds.get("url", "") + os.environ["JIRA_API_TOKEN"] = jira_creds.get("apiToken", "") + os.environ["JIRA_EMAIL"] = jira_creds.get("email", "") + logger.info("✓ Updated Jira credentials in environment") + + # GitLab token + gitlab_token = await fetch_gitlab_token(context) + if gitlab_token: + os.environ["GITLAB_TOKEN"] = gitlab_token + logger.info("✓ Updated GitLab token in environment") + + # GitHub token + github_token = await fetch_github_token(context) + if github_token: + os.environ["GITHUB_TOKEN"] = github_token + logger.info("✓ Updated GitHub token in environment") + + logger.info("Runtime credentials populated successfully") + + +async def fetch_github_token_legacy(context: RunnerContext) -> str: + """Legacy method — kept for backward compatibility.""" + base = os.getenv("BACKEND_API_URL", "").rstrip("/") + project = os.getenv("PROJECT_NAME") or os.getenv( + "AGENTIC_SESSION_NAMESPACE", "" + ) + project = project.strip() + session_id = context.session_id + + if not base or not project or not session_id: + logger.warning("Cannot fetch GitHub token: missing environment variables") + return "" + + url = ( + f"{base}/projects/{project}/agentic-sessions/" + f"{session_id}/github/token" + ) + logger.info(f"Fetching GitHub token from legacy endpoint: {url}") + + req = _urllib_request.Request( + url, + data=b"{}", + headers={"Content-Type": "application/json"}, + method="POST", + ) + bot = (os.getenv("BOT_TOKEN") or "").strip() + if bot: + req.add_header("Authorization", f"Bearer {bot}") + + loop = asyncio.get_event_loop() + + def _do_req(): + try: + with _urllib_request.urlopen(req, timeout=10) as resp: + return resp.read().decode("utf-8", errors="replace") + except Exception as e: + logger.warning(f"GitHub token fetch failed: {e}") + return "" + + resp_text = await loop.run_in_executor(None, _do_req) + if not resp_text: + return "" + + try: + data = _json.loads(resp_text) + token = str(data.get("token") or "") + if token: + logger.info("Successfully fetched GitHub token from backend") + return token + except Exception as e: + logger.error(f"Failed to parse token response: {e}") + return "" diff --git a/components/runners/claude-code-runner/config.py b/components/runners/claude-code-runner/config.py new file mode 100644 index 000000000..18bd328fb --- /dev/null +++ b/components/runners/claude-code-runner/config.py @@ -0,0 +1,150 @@ +""" +Configuration loading for the Claude Code runner. + +Reads ambient.json, MCP server config, and repository configuration +from environment variables and the filesystem. +""" + +import json as _json +import logging +import os +from pathlib import Path +from typing import Optional + +from context import RunnerContext +from utils import expand_env_vars, parse_owner_repo + +logger = logging.getLogger(__name__) + + +def load_ambient_config(cwd_path: str) -> dict: + """Load ambient.json configuration from workflow directory. + + Returns: + Parsed config dict, or empty dict if not found / invalid. + """ + try: + config_path = Path(cwd_path) / ".ambient" / "ambient.json" + + if not config_path.exists(): + logger.info( + f"No ambient.json found at {config_path}, using defaults" + ) + return {} + + with open(config_path, "r") as f: + config = _json.load(f) + logger.info(f"Loaded ambient.json: name={config.get('name')}") + return config + + except _json.JSONDecodeError as e: + logger.error(f"Failed to parse ambient.json: {e}") + return {} + except Exception as e: + logger.error(f"Error loading ambient.json: {e}") + return {} + + +def load_mcp_config(context: RunnerContext, cwd_path: str) -> Optional[dict]: + """Load MCP server configuration from the ambient runner's .mcp.json file. + + Returns: + Dict of MCP server configs with env vars expanded, or None. + """ + try: + mcp_config_file = context.get_env( + "MCP_CONFIG_FILE", "/app/claude-runner/.mcp.json" + ) + runner_mcp_file = Path(mcp_config_file) + + if runner_mcp_file.exists() and runner_mcp_file.is_file(): + logger.info(f"Loading MCP config from: {runner_mcp_file}") + with open(runner_mcp_file, "r") as f: + config = _json.load(f) + mcp_servers = config.get("mcpServers", {}) + expanded = expand_env_vars(mcp_servers) + logger.info( + f"Expanded MCP config env vars for {len(expanded)} servers" + ) + return expanded + else: + logger.info(f"No MCP config file found at: {runner_mcp_file}") + return None + + except _json.JSONDecodeError as e: + logger.error(f"Failed to parse MCP config: {e}") + return None + except Exception as e: + logger.error(f"Error loading MCP config: {e}") + return None + + +def get_repos_config() -> list[dict]: + """Read repos mapping from REPOS_JSON env if present. + + Expected format:: + + [{"url": "...", "branch": "main", "autoPush": true}, ...] + + Returns: + List of dicts: ``[{"name": ..., "url": ..., "branch": ..., "autoPush": bool}, ...]`` + """ + try: + raw = os.getenv("REPOS_JSON", "").strip() + if not raw: + return [] + data = _json.loads(raw) + if isinstance(data, list): + out: list[dict] = [] + for it in data: + if not isinstance(it, dict): + continue + + url = str(it.get("url") or "").strip() + branch_from_json = it.get("branch") + if branch_from_json and str(branch_from_json).strip(): + branch = str(branch_from_json).strip() + else: + session_id = os.getenv("AGENTIC_SESSION_NAME", "").strip() + branch = ( + f"ambient/{session_id}" if session_id else "main" + ) + auto_push_raw = it.get("autoPush", False) + auto_push = ( + auto_push_raw if isinstance(auto_push_raw, bool) else False + ) + + if not url: + continue + + name = str(it.get("name") or "").strip() + if not name: + try: + _owner, repo, _ = parse_owner_repo(url) + derived = repo or "" + if not derived: + from urllib.parse import urlparse + + p = urlparse(url) + parts = [ + pt for pt in (p.path or "").split("/") if pt + ] + if parts: + derived = parts[-1] + name = (derived or "").removesuffix(".git").strip() + except Exception: + name = "" + + if name and url: + out.append( + { + "name": name, + "url": url, + "branch": branch, + "autoPush": auto_push, + } + ) + return out + except Exception: + return [] + return [] diff --git a/components/runners/claude-code-runner/main.py b/components/runners/claude-code-runner/main.py index ea5a607b3..d609e1550 100644 --- a/components/runners/claude-code-runner/main.py +++ b/components/runners/claude-code-runner/main.py @@ -665,28 +665,24 @@ def _check_mcp_authentication(server_name: str) -> tuple[bool | None, str | None @app.get("/mcp/status") async def get_mcp_status(): """ - Returns MCP servers configured for this session with authentication status. - Goes straight to the source - uses adapter's _load_mcp_config() method. - - For known integrations (Google, Jira), also checks if credentials are present. + Returns MCP server connection status by using the SDK's get_mcp_status() method. + Spins up a minimal ClaudeSDKClient, queries MCP status, then tears it down. """ try: global adapter - if not adapter: + if not adapter or not adapter.context: return { "servers": [], "totalCount": 0, "message": "Adapter not initialized yet", } - mcp_servers_list = [] - - # Get the working directory (same logic as adapter uses) - workspace_path = ( - adapter.context.workspace_path if adapter.context else "/workspace" - ) + from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient + import config as runner_config + # Resolve working directory (same logic as adapter) + workspace_path = adapter.context.workspace_path or "/workspace" active_workflow_url = os.getenv("ACTIVE_WORKFLOW_GIT_URL", "").strip() cwd_path = workspace_path @@ -696,40 +692,63 @@ async def get_mcp_status(): if os.path.exists(workflow_path): cwd_path = workflow_path - # Use adapter's method to load MCP config (same as it does during runs) - mcp_config = adapter._load_mcp_config(cwd_path) - logger.info(f"MCP config: {mcp_config}") - - if mcp_config: - for server_name, server_config in mcp_config.items(): - # Check authentication status for known servers (Google, Jira) - is_authenticated, auth_message = _check_mcp_authentication(server_name) - - # Platform servers are built-in (webfetch), workflow servers come from config - is_platform = server_name == "webfetch" - - server_info = { - "name": server_name, - "displayName": server_name.replace("-", " ") - .replace("_", " ") - .title(), - "status": "configured", - "command": server_config.get("command", ""), - "source": "platform" if is_platform else "workflow", - } + # Load MCP server config (same config the adapter uses for runs) + mcp_servers = runner_config.load_mcp_config(adapter.context, cwd_path) or {} - # Only include auth fields for servers we know how to check - if is_authenticated is not None: - server_info["authenticated"] = is_authenticated - server_info["authMessage"] = auth_message + # Build minimal options — just enough to initialise MCP servers + options = ClaudeAgentOptions( + cwd=cwd_path, + permission_mode="acceptEdits", + mcp_servers=mcp_servers, + ) - mcp_servers_list.append(server_info) + client = ClaudeSDKClient(options=options) + try: + logger.info("MCP Status: Connecting ephemeral SDK client...") + await client.connect() + + # Use the SDK's public get_mcp_status() method (added in v0.1.23) + sdk_status = await client.get_mcp_status() + logger.info("MCP Status: SDK returned:\n%s", json.dumps(sdk_status, indent=2, default=str)) + + # SDK returns: { mcpServers: [{ name, status, serverInfo: { name, version }, scope, tools }] } + raw_servers = [] + if isinstance(sdk_status, dict): + raw_servers = sdk_status.get("mcpServers", []) + elif isinstance(sdk_status, list): + raw_servers = sdk_status + + servers_list = [] + for srv in raw_servers: + if not isinstance(srv, dict): + continue + server_info = srv.get("serverInfo") or {} + raw_tools = srv.get("tools") or [] + tools = [ + { + "name": t.get("name", ""), + "annotations": { + k: v for k, v in (t.get("annotations") or {}).items() + }, + } + for t in raw_tools + if isinstance(t, dict) + ] + servers_list.append({ + "name": srv.get("name", ""), + "displayName": server_info.get("name", srv.get("name", "")), + "status": srv.get("status", "unknown"), + "version": server_info.get("version", ""), + "tools": tools, + }) - return { - "servers": mcp_servers_list, - "totalCount": len(mcp_servers_list), - "note": "Status shows 'configured' - check 'authenticated' field for credential status", - } + return { + "servers": servers_list, + "totalCount": len(servers_list), + } + finally: + logger.info("MCP Status: Disconnecting ephemeral SDK client...") + await client.disconnect() except Exception as e: logger.error(f"Failed to get MCP status: {e}", exc_info=True) diff --git a/components/runners/claude-code-runner/prompts.py b/components/runners/claude-code-runner/prompts.py new file mode 100644 index 000000000..4008057dd --- /dev/null +++ b/components/runners/claude-code-runner/prompts.py @@ -0,0 +1,200 @@ +""" +System prompt construction and prompt constants for the Claude Code runner. + +All hardcoded prompt strings are defined as constants here, and the main +build function assembles them into the workspace context prompt that gets +appended to the Claude Code system prompt preset. +""" + +import logging +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Prompt constants +# --------------------------------------------------------------------------- + +WORKSPACE_STRUCTURE_HEADER = "# Workspace Structure\n\n" + +MCP_INTEGRATIONS_PROMPT = ( + "## MCP Integrations\n" + "If you need Google Drive access: Ask user to go to Integrations page " + "in Ambient and authenticate with Google Drive.\n" + "If you need Jira access: Ask user to go to Workspace Settings in Ambient " + "and configure Jira credentials there.\n\n" +) + +GIT_PUSH_INSTRUCTIONS_HEADER = "## Git Push Instructions\n\n" + +GIT_PUSH_INSTRUCTIONS_BODY = ( + "The following repositories have auto-push enabled. When you make changes " + "to these repositories, you MUST commit and push your changes:\n\n" +) + +GIT_PUSH_STEPS = ( + "\nAfter making changes to any auto-push repository:\n" + "1. Use `git add` to stage your changes\n" + '2. Use `git commit -m "description"` to commit with a descriptive message\n' + "3. Use `git push origin {branch}` to push to the remote repository\n\n" +) + +RUBRIC_EVALUATION_HEADER = "## Rubric Evaluation\n\n" + +RUBRIC_EVALUATION_INTRO = ( + "This workflow includes a scoring rubric for evaluating outputs. " + "The rubric is located at `.ambient/rubric.md`.\n\n" +) + +RUBRIC_EVALUATION_PROCESS = ( + "**Process**:\n" + "1. Read `.ambient/rubric.md` using the Read tool\n" + "2. Evaluate the output against each criterion\n" + "3. Call `evaluate_rubric` (via the rubric MCP server) " + "with your scores and reasoning\n\n" + "**Important**: Always read the rubric first before scoring. " + "Provide honest, calibrated scores with clear reasoning.\n\n" +) + +RESTART_TOOL_DESCRIPTION = ( + "Restart the Claude session to recover from issues, clear state, " + "or get a fresh connection. Use this if you detect you're in a " + "broken state or need to reset." +) + + +# --------------------------------------------------------------------------- +# Prompt builder +# --------------------------------------------------------------------------- + +def build_workspace_context_prompt( + repos_cfg: list, + workflow_name: str | None, + artifacts_path: str, + ambient_config: dict, + workspace_path: str, +) -> str: + """Generate the workspace context prompt appended to the Claude Code preset. + + Args: + repos_cfg: List of repo config dicts. + workflow_name: Active workflow name (or None). + artifacts_path: Relative path for output artifacts. + ambient_config: Parsed ambient.json dict. + workspace_path: Absolute workspace root path. + + Returns: + Formatted prompt string. + """ + prompt = WORKSPACE_STRUCTURE_HEADER + + # Workflow directory + if workflow_name: + prompt += ( + f"**Working Directory**: workflows/{workflow_name}/ " + "(workflow logic - do not create files here)\n\n" + ) + + # Artifacts + prompt += f"**Artifacts**: {artifacts_path} (create all output files here)\n\n" + + # Uploaded files + file_uploads_path = Path(workspace_path) / "file-uploads" + if file_uploads_path.exists() and file_uploads_path.is_dir(): + try: + files = sorted( + [f.name for f in file_uploads_path.iterdir() if f.is_file()] + ) + if files: + max_display = 10 + if len(files) <= max_display: + prompt += f"**Uploaded Files**: {', '.join(files)}\n\n" + else: + prompt += ( + f"**Uploaded Files** ({len(files)} total): " + f"{', '.join(files[:max_display])}, " + f"and {len(files) - max_display} more\n\n" + ) + except Exception: + pass + else: + prompt += "**Uploaded Files**: None\n\n" + + # Repositories + if repos_cfg: + session_id = os.getenv("AGENTIC_SESSION_NAME", "").strip() + feature_branch = f"ambient/{session_id}" if session_id else None + + repo_names = [ + repo.get("name", f"repo-{i}") for i, repo in enumerate(repos_cfg) + ] + if len(repo_names) <= 5: + prompt += ( + f"**Repositories**: " + f"{', '.join([f'repos/{name}/' for name in repo_names])}\n" + ) + else: + prompt += ( + f"**Repositories** ({len(repo_names)} total): " + f"{', '.join([f'repos/{name}/' for name in repo_names[:5]])}, " + f"and {len(repo_names) - 5} more\n" + ) + + if feature_branch: + prompt += ( + f"**Working Branch**: `{feature_branch}` " + "(all repos are on this feature branch)\n\n" + ) + else: + prompt += "\n" + + # Git push instructions for auto-push repos + auto_push_repos = [ + repo for repo in repos_cfg if repo.get("autoPush", False) + ] + if auto_push_repos: + push_branch = feature_branch or "ambient/" + prompt += GIT_PUSH_INSTRUCTIONS_HEADER + prompt += GIT_PUSH_INSTRUCTIONS_BODY + for repo in auto_push_repos: + repo_name = repo.get("name", "unknown") + prompt += f"- **repos/{repo_name}/**\n" + prompt += GIT_PUSH_STEPS.format(branch=push_branch) + + # MCP integration setup instructions + prompt += MCP_INTEGRATIONS_PROMPT + + # Workflow instructions + if ambient_config.get("systemPrompt"): + prompt += ( + f"## Workflow Instructions\n" + f"{ambient_config['systemPrompt']}\n\n" + ) + + # Rubric evaluation instructions + prompt += _build_rubric_prompt_section(ambient_config) + + return prompt + + +def _build_rubric_prompt_section(ambient_config: dict) -> str: + """Build the rubric evaluation section for the system prompt. + + Returns empty string if no rubric config is present. + """ + rubric_config = ambient_config.get("rubric", {}) + if not rubric_config: + return "" + + section = RUBRIC_EVALUATION_HEADER + section += RUBRIC_EVALUATION_INTRO + + activation_prompt = rubric_config.get("activationPrompt", "") + if activation_prompt: + section += f"**When to evaluate**: {activation_prompt}\n\n" + + section += RUBRIC_EVALUATION_PROCESS + + return section diff --git a/components/runners/claude-code-runner/pyproject.toml b/components/runners/claude-code-runner/pyproject.toml index a36cafa92..093186cf3 100644 --- a/components/runners/claude-code-runner/pyproject.toml +++ b/components/runners/claude-code-runner/pyproject.toml @@ -43,7 +43,8 @@ dev-dependencies = [ ] [tool.setuptools] -py-modules = ["main", "adapter", "context", "observability", "security_utils"] +py-modules = ["main", "adapter", "auth", "config", "context", "observability", "prompts", "security_utils", "utils", "workspace"] +packages = ["tools"] [build-system] requires = ["setuptools>=61.0"] diff --git a/components/runners/claude-code-runner/tests/test_auto_push.py b/components/runners/claude-code-runner/tests/test_auto_push.py index 1fd600ad6..4351eadff 100644 --- a/components/runners/claude-code-runner/tests/test_auto_push.py +++ b/components/runners/claude-code-runner/tests/test_auto_push.py @@ -1,27 +1,28 @@ -"""Unit tests for autoPush functionality in adapter.py.""" +"""Unit tests for autoPush functionality.""" import json import os import sys +from pathlib import Path from unittest.mock import MagicMock, Mock, patch import pytest -# Mock ag_ui module before importing adapter +# Add parent directory to path +runner_dir = Path(__file__).parent.parent +if str(runner_dir) not in sys.path: + sys.path.insert(0, str(runner_dir)) + +# Mock ag_ui module before importing modules sys.modules["ag_ui"] = Mock() sys.modules["ag_ui.core"] = Mock() -sys.modules["context"] = Mock() - -class TestGetReposConfig: - """Tests for _get_repos_config method.""" +from config import get_repos_config # type: ignore[import] +from prompts import build_workspace_context_prompt # type: ignore[import] - def setup_method(self): - """Set up test fixtures.""" - # Import here after mocking dependencies - from adapter import ClaudeCodeAdapter - self.adapter = ClaudeCodeAdapter() +class TestGetReposConfig: + """Tests for config.get_repos_config function.""" def test_parse_simple_repo_with_autopush_true(self): """Test parsing repo with autoPush=true.""" @@ -36,7 +37,7 @@ def test_parse_simple_repo_with_autopush_true(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 assert result[0]["url"] == "https://github.com/owner/repo.git" @@ -57,7 +58,7 @@ def test_parse_simple_repo_with_autopush_false(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 assert result[0]["autoPush"] is False @@ -69,7 +70,7 @@ def test_parse_repo_without_autopush(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 assert result[0]["autoPush"] is False @@ -97,7 +98,7 @@ def test_parse_multiple_repos_mixed_autopush(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 3 assert result[0]["autoPush"] is True @@ -118,7 +119,7 @@ def test_parse_repo_with_explicit_name(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 assert result[0]["name"] == "my-custom-repo" @@ -127,21 +128,21 @@ def test_parse_repo_with_explicit_name(self): def test_parse_empty_repos_json(self): """Test parsing empty REPOS_JSON.""" with patch.dict(os.environ, {"REPOS_JSON": ""}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert result == [] def test_parse_missing_repos_json(self): """Test parsing when REPOS_JSON not set.""" with patch.dict(os.environ, {}, clear=True): - result = self.adapter._get_repos_config() + result = get_repos_config() assert result == [] def test_parse_invalid_json(self): """Test parsing invalid JSON returns empty list.""" with patch.dict(os.environ, {"REPOS_JSON": "invalid-json{"}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert result == [] @@ -150,7 +151,7 @@ def test_parse_non_list_json(self): repos_json = json.dumps({"url": "https://github.com/owner/repo.git"}) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert result == [] @@ -159,7 +160,7 @@ def test_parse_repo_without_url(self): repos_json = json.dumps([{"branch": "main", "autoPush": True}]) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert result == [] @@ -175,7 +176,7 @@ def test_derive_repo_name_from_url(self): repos_json = json.dumps([{"url": url, "autoPush": True}]) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 assert result[0]["name"] == expected_name @@ -192,7 +193,7 @@ def test_autopush_with_invalid_string_type(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 # Invalid type should default to False @@ -210,7 +211,7 @@ def test_autopush_with_invalid_number_type(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 # Invalid type should default to False @@ -228,7 +229,7 @@ def test_autopush_with_null_value(self): ) with patch.dict(os.environ, {"REPOS_JSON": repos_json}): - result = self.adapter._get_repos_config() + result = get_repos_config() assert len(result) == 1 # null should default to False @@ -236,19 +237,7 @@ def test_autopush_with_null_value(self): class TestBuildWorkspaceContextPrompt: - """Tests for _build_workspace_context_prompt method.""" - - def setup_method(self): - """Set up test fixtures.""" - # Import here after mocking dependencies - from adapter import ClaudeCodeAdapter - - self.adapter = ClaudeCodeAdapter() - - # Create a mock context - mock_context = MagicMock() - mock_context.workspace_path = "/workspace" - self.adapter.context = mock_context + """Tests for prompts.build_workspace_context_prompt function.""" def test_prompt_includes_git_instructions_with_autopush(self): """Test that git push instructions are included when autoPush=true.""" @@ -261,11 +250,12 @@ def test_prompt_includes_git_instructions_with_autopush(self): } ] - prompt = self.adapter._build_workspace_context_prompt( + prompt = build_workspace_context_prompt( repos_cfg=repos_cfg, workflow_name=None, artifacts_path="artifacts", ambient_config={}, + workspace_path="/workspace", ) # Verify git instructions are present @@ -287,11 +277,12 @@ def test_prompt_excludes_git_instructions_without_autopush(self): } ] - prompt = self.adapter._build_workspace_context_prompt( + prompt = build_workspace_context_prompt( repos_cfg=repos_cfg, workflow_name=None, artifacts_path="artifacts", ambient_config={}, + workspace_path="/workspace", ) # Verify git instructions are NOT present @@ -323,11 +314,12 @@ def test_prompt_includes_multiple_autopush_repos(self): }, ] - prompt = self.adapter._build_workspace_context_prompt( + prompt = build_workspace_context_prompt( repos_cfg=repos_cfg, workflow_name=None, artifacts_path="artifacts", ambient_config={}, + workspace_path="/workspace", ) # Verify both autoPush repos are listed @@ -364,11 +356,12 @@ def test_prompt_with_workflow(self): } ] - prompt = self.adapter._build_workspace_context_prompt( + prompt = build_workspace_context_prompt( repos_cfg=repos_cfg, workflow_name="test-workflow", artifacts_path="artifacts", ambient_config={}, + workspace_path="/workspace", ) # Should include both workflow info and git instructions diff --git a/components/runners/claude-code-runner/tests/test_model_mapping.py b/components/runners/claude-code-runner/tests/test_model_mapping.py index 3e8decfe9..f4ac24610 100644 --- a/components/runners/claude-code-runner/tests/test_model_mapping.py +++ b/components/runners/claude-code-runner/tests/test_model_mapping.py @@ -1,5 +1,5 @@ """ -Test cases for ClaudeCodeAdapter._map_to_vertex_model() +Test cases for auth.map_to_vertex_model() This module tests the model name mapping from Anthropic API model names to Vertex AI model identifiers. @@ -10,12 +10,12 @@ import pytest -# Add parent directory to path for importing adapter module -adapter_dir = Path(__file__).parent.parent -if str(adapter_dir) not in sys.path: - sys.path.insert(0, str(adapter_dir)) +# Add parent directory to path for importing auth module +runner_dir = Path(__file__).parent.parent +if str(runner_dir) not in sys.path: + sys.path.insert(0, str(runner_dir)) -from adapter import ClaudeCodeAdapter # type: ignore[import] +from auth import map_to_vertex_model # type: ignore[import] class TestMapToVertexModel: @@ -23,77 +23,68 @@ class TestMapToVertexModel: def test_map_opus_4_5(self): """Test mapping for Claude Opus 4.5""" - adapter = ClaudeCodeAdapter() - result = adapter._map_to_vertex_model("claude-opus-4-5") + result = map_to_vertex_model("claude-opus-4-5") assert result == "claude-opus-4-5@20251101" def test_map_opus_4_1(self): """Test mapping for Claude Opus 4.1""" - adapter = ClaudeCodeAdapter() - result = adapter._map_to_vertex_model("claude-opus-4-1") + result = map_to_vertex_model("claude-opus-4-1") assert result == "claude-opus-4-1@20250805" def test_map_sonnet_4_5(self): """Test mapping for Claude Sonnet 4.5""" - adapter = ClaudeCodeAdapter() - result = adapter._map_to_vertex_model("claude-sonnet-4-5") + result = map_to_vertex_model("claude-sonnet-4-5") assert result == "claude-sonnet-4-5@20250929" def test_map_haiku_4_5(self): """Test mapping for Claude Haiku 4.5""" - adapter = ClaudeCodeAdapter() - result = adapter._map_to_vertex_model("claude-haiku-4-5") + result = map_to_vertex_model("claude-haiku-4-5") assert result == "claude-haiku-4-5@20251001" def test_unknown_model_returns_unchanged(self): """Test that unknown model names are returned unchanged""" - adapter = ClaudeCodeAdapter() unknown_model = "claude-unknown-model-99" - result = adapter._map_to_vertex_model(unknown_model) + result = map_to_vertex_model(unknown_model) assert result == unknown_model def test_empty_string_returns_unchanged(self): """Test that empty string is returned unchanged""" - adapter = ClaudeCodeAdapter() - result = adapter._map_to_vertex_model("") + result = map_to_vertex_model("") assert result == "" def test_case_sensitive_mapping(self): """Test that model mapping is case-sensitive""" - adapter = ClaudeCodeAdapter() + # Uppercase should not match - result = adapter._map_to_vertex_model("CLAUDE-OPUS-4-1") + result = map_to_vertex_model("CLAUDE-OPUS-4-1") assert result == "CLAUDE-OPUS-4-1" # Should return unchanged def test_whitespace_in_model_name(self): """Test handling of whitespace in model names""" - adapter = ClaudeCodeAdapter() + # Model name with whitespace should not match - result = adapter._map_to_vertex_model(" claude-opus-4-1 ") + result = map_to_vertex_model(" claude-opus-4-1 ") assert result == " claude-opus-4-1 " # Should return unchanged def test_partial_model_name_no_match(self): """Test that partial model names don't match""" - adapter = ClaudeCodeAdapter() - result = adapter._map_to_vertex_model("claude-opus") + result = map_to_vertex_model("claude-opus") assert result == "claude-opus" # Should return unchanged def test_vertex_model_id_passthrough(self): """Test that Vertex AI model IDs are returned unchanged""" - adapter = ClaudeCodeAdapter() vertex_id = "claude-opus-4-1@20250805" - result = adapter._map_to_vertex_model(vertex_id) + result = map_to_vertex_model(vertex_id) # If already a Vertex ID, should return unchanged assert result == vertex_id def test_all_frontend_models_have_mapping(self): """Test that all models from frontend dropdown have valid mappings""" - adapter = ClaudeCodeAdapter() + # These are the exact model values from the frontend dropdown frontend_models = [ "claude-sonnet-4-5", - "claude-opus-4-6", "claude-opus-4-5", "claude-opus-4-1", "claude-haiku-4-5", @@ -101,33 +92,31 @@ def test_all_frontend_models_have_mapping(self): expected_mappings = { "claude-sonnet-4-5": "claude-sonnet-4-5@20250929", - "claude-opus-4-6": "claude-opus-4-6", "claude-opus-4-5": "claude-opus-4-5@20251101", "claude-opus-4-1": "claude-opus-4-1@20250805", "claude-haiku-4-5": "claude-haiku-4-5@20251001", } for model in frontend_models: - result = adapter._map_to_vertex_model(model) + result = map_to_vertex_model(model) assert ( result == expected_mappings[model] ), f"Model {model} should map to {expected_mappings[model]}, got {result}" def test_mapping_includes_version_date(self): - """Test that mapped models include version dates (except Opus 4.6)""" - adapter = ClaudeCodeAdapter() + """Test that all mapped models include version dates""" + - # Opus 4.6 is the exception - uses simplified naming without @date - models_with_dates = [ + models = [ "claude-opus-4-5", "claude-opus-4-1", "claude-sonnet-4-5", "claude-haiku-4-5", ] - for model in models_with_dates: - result = adapter._map_to_vertex_model(model) - # All Vertex AI models (except Opus 4.6) should have @YYYYMMDD format + for model in models: + result = map_to_vertex_model(model) + # All Vertex AI models should have @YYYYMMDD format assert "@" in result, f"Mapped model {result} should include @ version date" assert ( len(result.split("@")) == 2 @@ -140,34 +129,27 @@ def test_mapping_includes_version_date(self): version_date.isdigit() ), f"Version date {version_date} should be all digits" - def test_opus_4_6_no_date_suffix(self): - """Test that Opus 4.6 uses simplified naming (no @date suffix)""" - adapter = ClaudeCodeAdapter() - result = adapter._map_to_vertex_model("claude-opus-4-6") - assert result == "claude-opus-4-6" - assert "@" not in result, "Opus 4.6 should NOT have @date suffix" - def test_none_input_handling(self): """Test that None input raises TypeError (invalid type per signature)""" - adapter = ClaudeCodeAdapter() + # Function signature specifies str -> str, so None should raise with pytest.raises((TypeError, AttributeError)): - adapter._map_to_vertex_model(None) # type: ignore[arg-type] + map_to_vertex_model(None) # type: ignore[arg-type] def test_numeric_input_handling(self): """Test that numeric input raises TypeError (invalid type per signature)""" - adapter = ClaudeCodeAdapter() + # Function signature specifies str -> str, so int should raise with pytest.raises((TypeError, AttributeError)): - adapter._map_to_vertex_model(123) # type: ignore[arg-type] + map_to_vertex_model(123) # type: ignore[arg-type] def test_mapping_consistency(self): """Test that mapping is consistent across multiple calls""" - adapter = ClaudeCodeAdapter() + model = "claude-sonnet-4-5" # Call multiple times - results = [adapter._map_to_vertex_model(model) for _ in range(5)] + results = [map_to_vertex_model(model) for _ in range(5)] # All results should be identical assert all(r == results[0] for r in results) @@ -179,11 +161,10 @@ class TestModelMappingIntegration: def test_mapping_matches_available_vertex_models(self): """Test that mapped model IDs match the expected Vertex AI format""" - adapter = ClaudeCodeAdapter() - # Expected Vertex AI model ID format: model-name@YYYYMMDD (except Opus 4.6) + + # Expected Vertex AI model ID format: model-name@YYYYMMDD models_to_test = [ - ("claude-opus-4-6", "claude-opus-4-6"), ("claude-opus-4-5", "claude-opus-4-5@20251101"), ("claude-opus-4-1", "claude-opus-4-1@20250805"), ("claude-sonnet-4-5", "claude-sonnet-4-5@20250929"), @@ -191,78 +172,59 @@ def test_mapping_matches_available_vertex_models(self): ] for input_model, expected_vertex_id in models_to_test: - result = adapter._map_to_vertex_model(input_model) + result = map_to_vertex_model(input_model) assert ( result == expected_vertex_id ), f"Expected {input_model} to map to {expected_vertex_id}, got {result}" def test_ui_to_vertex_round_trip(self): """Test that UI model selection properly maps to Vertex AI""" - adapter = ClaudeCodeAdapter() + # Simulate user selecting from UI dropdown ui_selections = [ "claude-sonnet-4-5", # User selects Sonnet 4.5 - "claude-opus-4-6", # User selects Opus 4.6 (newest) "claude-opus-4-5", # User selects Opus 4.5 "claude-opus-4-1", # User selects Opus 4.1 "claude-haiku-4-5", # User selects Haiku 4.5 ] for selection in ui_selections: - vertex_model = adapter._map_to_vertex_model(selection) + vertex_model = map_to_vertex_model(selection) # Verify it maps to a valid Vertex AI model ID assert vertex_model.startswith("claude-") - - # Opus 4.6 is the exception - no @date suffix - if selection == "claude-opus-4-6": - assert "@" not in vertex_model - else: - assert "@" in vertex_model + assert "@" in vertex_model # Verify the base model name is preserved - if "@" in vertex_model: - base_name = vertex_model.split("@")[0] - assert selection in vertex_model or base_name in selection - else: - assert selection == vertex_model + base_name = vertex_model.split("@")[0] + assert selection in vertex_model or base_name in selection def test_end_to_end_vertex_mapping_flow(self): """Test complete flow: UI selection → model mapping → Vertex AI call""" - adapter = ClaudeCodeAdapter() + # Simulate complete flow for each model test_scenarios = [ - { - "ui_selection": "claude-opus-4-6", - "expected_vertex_id": "claude-opus-4-6", - "description": "Newest Opus model (simplified naming)", - "has_date_suffix": False, - }, { "ui_selection": "claude-opus-4-5", "expected_vertex_id": "claude-opus-4-5@20251101", "description": "Latest Opus model", - "has_date_suffix": True, }, { "ui_selection": "claude-opus-4-1", "expected_vertex_id": "claude-opus-4-1@20250805", "description": "Previous Opus model", - "has_date_suffix": True, }, { "ui_selection": "claude-sonnet-4-5", "expected_vertex_id": "claude-sonnet-4-5@20250929", "description": "Balanced model", - "has_date_suffix": True, }, { "ui_selection": "claude-haiku-4-5", "expected_vertex_id": "claude-haiku-4-5@20251001", "description": "Fastest model", - "has_date_suffix": True, }, ] @@ -271,7 +233,7 @@ def test_end_to_end_vertex_mapping_flow(self): ui_model = scenario["ui_selection"] # Step 2: Backend maps to Vertex AI model ID - vertex_model_id = adapter._map_to_vertex_model(ui_model) + vertex_model_id = map_to_vertex_model(ui_model) # Step 3: Verify correct mapping assert ( @@ -279,27 +241,21 @@ def test_end_to_end_vertex_mapping_flow(self): ), f"{scenario['description']}: Expected {scenario['expected_vertex_id']}, got {vertex_model_id}" # Step 4: Verify Vertex AI model ID format is valid - if scenario["has_date_suffix"]: - assert "@" in vertex_model_id - parts = vertex_model_id.split("@") - assert len(parts) == 2 - model_name, version_date = parts - assert model_name.startswith("claude-") - assert len(version_date) == 8 # YYYYMMDD format - assert version_date.isdigit() - else: - # Opus 4.6 uses simplified naming - no @date suffix - assert "@" not in vertex_model_id - assert vertex_model_id.startswith("claude-") + assert "@" in vertex_model_id + parts = vertex_model_id.split("@") + assert len(parts) == 2 + model_name, version_date = parts + assert model_name.startswith("claude-") + assert len(version_date) == 8 # YYYYMMDD format + assert version_date.isdigit() def test_model_ordering_consistency(self): """Test that model ordering is consistent between frontend and backend""" - adapter = ClaudeCodeAdapter() - # Expected ordering: Sonnet → Opus 4.6 → Opus 4.5 → Opus 4.1 → Haiku (matches frontend dropdown) + + # Expected ordering: Sonnet → Opus 4.5 → Opus 4.1 → Haiku (matches frontend dropdown) expected_order = [ "claude-sonnet-4-5", - "claude-opus-4-6", "claude-opus-4-5", "claude-opus-4-1", "claude-haiku-4-5", @@ -307,20 +263,11 @@ def test_model_ordering_consistency(self): # Verify all models map successfully in order for model in expected_order: - vertex_id = adapter._map_to_vertex_model(model) - # Opus 4.6 is the exception - no @date suffix - if model == "claude-opus-4-6": - assert ( - "@" not in vertex_id - ), f"Model {model} should use simplified naming" - else: - assert ( - "@" in vertex_id - ), f"Model {model} should map to valid Vertex AI ID" + vertex_id = map_to_vertex_model(model) + assert "@" in vertex_id, f"Model {model} should map to valid Vertex AI ID" # Verify ordering matches frontend dropdown assert expected_order[0] == "claude-sonnet-4-5" # Balanced (default) - assert expected_order[1] == "claude-opus-4-6" # Newest Opus - assert expected_order[2] == "claude-opus-4-5" # Latest Opus - assert expected_order[3] == "claude-opus-4-1" # Previous Opus - assert expected_order[4] == "claude-haiku-4-5" # Fastest + assert expected_order[1] == "claude-opus-4-5" # Latest Opus + assert expected_order[2] == "claude-opus-4-1" # Previous Opus + assert expected_order[3] == "claude-haiku-4-5" # Fastest diff --git a/components/runners/claude-code-runner/tests/test_wrapper_vertex.py b/components/runners/claude-code-runner/tests/test_wrapper_vertex.py index bd14180e3..e5829f36f 100644 --- a/components/runners/claude-code-runner/tests/test_wrapper_vertex.py +++ b/components/runners/claude-code-runner/tests/test_wrapper_vertex.py @@ -1,17 +1,24 @@ """ -Test cases for wrapper._setup_vertex_credentials() +Test cases for auth.setup_vertex_credentials() This module tests all error cases and validation logic for Vertex AI credential setup. """ import asyncio import os +import sys import tempfile from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest -from claude_code_runner.wrapper import ClaudeCodeWrapper + +# Add parent directory to path +runner_dir = Path(__file__).parent.parent +if str(runner_dir) not in sys.path: + sys.path.insert(0, str(runner_dir)) + +from auth import setup_vertex_credentials # type: ignore[import] class TestSetupVertexCredentials: @@ -48,10 +55,8 @@ async def test_success_all_valid_credentials( "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) - # Execute - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(mock_context) # Verify assert result is not None @@ -71,11 +76,11 @@ async def test_error_missing_google_application_credentials(self, mock_context): "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "GOOGLE_APPLICATION_CREDENTIALS" in str(exc_info.value) assert "not set" in str(exc_info.value) @@ -90,11 +95,11 @@ async def test_error_empty_google_application_credentials(self, mock_context): "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "GOOGLE_APPLICATION_CREDENTIALS" in str(exc_info.value) @@ -109,11 +114,11 @@ async def test_error_missing_anthropic_vertex_project_id( "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "ANTHROPIC_VERTEX_PROJECT_ID" in str(exc_info.value) assert "not set" in str(exc_info.value) @@ -130,11 +135,11 @@ async def test_error_empty_anthropic_vertex_project_id( "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "ANTHROPIC_VERTEX_PROJECT_ID" in str(exc_info.value) @@ -149,11 +154,11 @@ async def test_error_missing_cloud_ml_region( "ANTHROPIC_VERTEX_PROJECT_ID": "test-project-123", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "CLOUD_ML_REGION" in str(exc_info.value) assert "not set" in str(exc_info.value) @@ -170,11 +175,11 @@ async def test_error_empty_cloud_ml_region( "CLOUD_ML_REGION": "", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "CLOUD_ML_REGION" in str(exc_info.value) @@ -189,11 +194,11 @@ async def test_error_credentials_file_does_not_exist(self, mock_context): "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "Service account file" in str(exc_info.value) assert "does not exist" in str(exc_info.value) @@ -205,11 +210,11 @@ async def test_error_all_env_vars_missing(self, mock_context): # Setup - all vars missing mock_context.get_env.side_effect = lambda key: None - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify - should fail on first check with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "GOOGLE_APPLICATION_CREDENTIALS" in str(exc_info.value) @@ -222,11 +227,11 @@ async def test_validation_order_checks_credentials_path_first(self, mock_context "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Should fail on GOOGLE_APPLICATION_CREDENTIALS first with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "GOOGLE_APPLICATION_CREDENTIALS" in str(exc_info.value) @@ -241,11 +246,11 @@ async def test_validation_order_checks_project_id_second( "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Should fail on ANTHROPIC_VERTEX_PROJECT_ID second with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "ANTHROPIC_VERTEX_PROJECT_ID" in str(exc_info.value) @@ -260,11 +265,11 @@ async def test_validation_order_checks_region_third( "ANTHROPIC_VERTEX_PROJECT_ID": "test-project-123", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Should fail on CLOUD_ML_REGION third with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "CLOUD_ML_REGION" in str(exc_info.value) @@ -279,11 +284,11 @@ async def test_validation_checks_file_existence_last(self, mock_context): "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Should fail on file existence check last with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "Service account file" in str(exc_info.value) assert "does not exist" in str(exc_info.value) @@ -300,10 +305,10 @@ async def test_logging_output_includes_config_details( "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) # Verify logging was called with details assert mock_context.send_log.called @@ -330,11 +335,11 @@ async def test_whitespace_in_env_vars_is_not_trimmed( "CLOUD_ML_REGION": " us-central1 ", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute - depending on implementation, this might succeed or fail # If the code doesn't strip whitespace, the values should work - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(mock_context) # Verify that whitespace is preserved (not stripped) assert result["project_id"] == " test-project-123 " @@ -350,11 +355,11 @@ async def test_none_value_from_get_env(self, mock_context, temp_credentials_file key ) # Returns None for other keys - wrapper = ClaudeCodeWrapper(mock_context) + # Should fail when checking for None values with pytest.raises(RuntimeError) as exc_info: - await wrapper._setup_vertex_credentials() + await setup_vertex_credentials(mock_context) assert "not set" in str(exc_info.value) @@ -371,12 +376,12 @@ async def test_directory_instead_of_file(self, mock_context, tmp_path): "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute and verify # Path.exists() returns True for directories, so this might not fail # depending on implementation - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(mock_context) # If implementation only checks exists(), this will pass # If it checks is_file(), this should fail @@ -399,10 +404,10 @@ async def test_relative_path_credentials_file(self, mock_context): "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute - should work if file exists in current directory - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(mock_context) assert result is not None assert result["credentials_path"] == relative_path @@ -424,10 +429,8 @@ async def test_special_characters_in_project_id( "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) - # Execute - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(mock_context) # Should accept special characters assert result["project_id"] == special_project_id @@ -452,10 +455,10 @@ async def test_international_region_codes( "CLOUD_ML_REGION": region, }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) + # Execute - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(mock_context) # Should accept all valid region codes assert result["region"] == region @@ -470,10 +473,8 @@ async def test_return_value_structure(self, mock_context, temp_credentials_file) "CLOUD_ML_REGION": "us-central1", }.get(key) - wrapper = ClaudeCodeWrapper(mock_context) - # Execute - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(mock_context) # Verify structure assert isinstance(result, dict) @@ -506,10 +507,8 @@ async def test_integration_with_real_file_creation(self): ) context.send_log = AsyncMock() - wrapper = ClaudeCodeWrapper(context) - # Execute - result = await wrapper._setup_vertex_credentials() + result = await setup_vertex_credentials(context) # Verify assert Path(temp_path).exists() @@ -544,9 +543,8 @@ async def test_concurrent_calls_to_setup_vertex_credentials(self, tmp_path): contexts.append(context) # Execute concurrently - wrappers = [ClaudeCodeWrapper(ctx) for ctx in contexts] results = await asyncio.gather( - *[wrapper._setup_vertex_credentials() for wrapper in wrappers] + *[setup_vertex_credentials(ctx) for ctx in contexts] ) # Verify all succeeded diff --git a/components/runners/claude-code-runner/tools/__init__.py b/components/runners/claude-code-runner/tools/__init__.py new file mode 100644 index 000000000..0c732c8f2 --- /dev/null +++ b/components/runners/claude-code-runner/tools/__init__.py @@ -0,0 +1,15 @@ +""" +MCP tool definitions for the Claude Code runner. + +Tools are created dynamically per-run and registered as in-process +MCP servers alongside the Claude Agent SDK. +""" + +from tools.rubric import create_rubric_mcp_tool, load_rubric_content +from tools.session import create_restart_session_tool + +__all__ = [ + "create_restart_session_tool", + "load_rubric_content", + "create_rubric_mcp_tool", +] diff --git a/components/runners/claude-code-runner/tools/rubric.py b/components/runners/claude-code-runner/tools/rubric.py new file mode 100644 index 000000000..1a4d08df7 --- /dev/null +++ b/components/runners/claude-code-runner/tools/rubric.py @@ -0,0 +1,218 @@ +""" +Rubric evaluation MCP tool — logs a single score to Langfuse. + +Scans the workflow's .ambient/ folder for a rubric.md file, then creates +an evaluate_rubric tool that accepts a score, comment, and metadata. +The tool makes one ``langfuse.create_score()`` call with the trace ID +from the current observability context. +""" + +import json as _json +import logging +import os +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def load_rubric_content(cwd_path: str) -> tuple: + """Load rubric content from the workflow's .ambient/ folder. + + Looks for ``.ambient/rubric.md`` — a single markdown file containing + the evaluation criteria. + + Returns: + Tuple of ``(rubric_content, rubric_config)`` where rubric_content + is the markdown string and rubric_config is the ``rubric`` key + from ambient.json. Returns ``(None, {})`` if no rubric found. + """ + ambient_dir = Path(cwd_path) / ".ambient" + rubric_content = None + + single_rubric = ambient_dir / "rubric.md" + if single_rubric.exists() and single_rubric.is_file(): + try: + rubric_content = single_rubric.read_text(encoding="utf-8") + logger.info(f"Loaded rubric from {single_rubric}") + except Exception as e: + logger.error(f"Failed to read rubric.md: {e}") + + rubric_config: dict = {} + try: + config_path = ambient_dir / "ambient.json" + if config_path.exists(): + with open(config_path, "r") as f: + config = _json.load(f) + rubric_config = config.get("rubric", {}) + except Exception as e: + logger.error(f"Failed to load rubric config from ambient.json: {e}") + + return rubric_content, rubric_config + + +def create_rubric_mcp_tool( + rubric_content: str, + rubric_config: dict, + obs: Any, + session_id: str, + sdk_tool_decorator, +): + """Create a dynamic MCP tool for rubric-based evaluation. + + The tool accepts a score, comment, and optional metadata, then makes + a single ``langfuse.create_score()`` call. The ``rubric.schema`` from + ambient.json is passed through as the ``metadata`` field's JSON Schema + in the tool's input_schema. + + Args: + rubric_content: Markdown rubric instructions (for reference only). + rubric_config: Config dict with ``activationPrompt`` and ``schema``. + obs: ObservabilityManager instance for trace ID. + session_id: Current session ID. + sdk_tool_decorator: The ``tool`` decorator from ``claude_agent_sdk``. + + Returns: + Decorated async tool function. + """ + # JSON Schema format per Claude Agent SDK docs: + # https://platform.claude.com/docs/en/agent-sdk/python#tool + user_schema = rubric_config.get("schema", {}) + + properties: dict = { + "score": {"type": "number", "description": "Overall evaluation score."}, + "comment": {"type": "string", "description": "Evaluation reasoning and commentary."}, + } + if user_schema: + properties["metadata"] = user_schema + + required = ["score", "comment"] + if user_schema: + required.append("metadata") + + input_schema: dict = { + "type": "object", + "properties": properties, + "required": required, + } + + tool_description = ( + "Log a rubric evaluation score to Langfuse. " + "Read .ambient/rubric.md FIRST, evaluate the output " + "against the criteria, then call this tool with your " + "score, comment, and metadata." + ) + + # Capture references for closure + _obs = obs + _session_id = session_id + + @sdk_tool_decorator( + "evaluate_rubric", + tool_description, + input_schema, + ) + async def evaluate_rubric_tool(args: dict) -> dict: + """Log a single rubric evaluation score to Langfuse.""" + score = args.get("score") + comment = args.get("comment", "") + metadata = args.get("metadata") + + # Log to Langfuse + success, error = _log_to_langfuse( + score=score, + comment=comment, + metadata=metadata, + obs=_obs, + session_id=_session_id, + ) + + if success: + return { + "content": [ + {"type": "text", "text": f"Score {score} logged to Langfuse."} + ] + } + else: + return { + "content": [ + {"type": "text", "text": f"Failed to log score: {error}"} + ], + "isError": True, + } + + return evaluate_rubric_tool + + +def _log_to_langfuse( + score: float | None, + comment: str, + metadata: Any, + obs: Any, + session_id: str, +) -> tuple[bool, str | None]: + """Make a single langfuse.create_score() call. + + Uses the existing Langfuse client from the ObservabilityManager + if available, otherwise creates a new one. + + Returns: + (True, None) on success, (False, error_message) on failure. + """ + try: + # Try to reuse the session's Langfuse client + langfuse_client = getattr(obs, "langfuse_client", None) if obs else None + + if not langfuse_client: + # Fall back to creating our own + langfuse_enabled = os.getenv( + "LANGFUSE_ENABLED", "" + ).strip().lower() in ("1", "true", "yes") + if not langfuse_enabled: + return False, "Langfuse not enabled." + + from langfuse import Langfuse + + public_key = os.getenv("LANGFUSE_PUBLIC_KEY", "").strip() + secret_key = os.getenv("LANGFUSE_SECRET_KEY", "").strip() + host = os.getenv("LANGFUSE_HOST", "").strip() + + if not (public_key and secret_key and host): + return False, "Langfuse credentials missing." + + langfuse_client = Langfuse( + public_key=public_key, + secret_key=secret_key, + host=host, + ) + + trace_id = obs.get_current_trace_id() if obs else None + + if score is None: + return False, "Score value is required (got None)." + + kwargs: dict = { + "name": "rubric-evaluation", + "value": score, + "data_type": "NUMERIC", + "comment": comment[:500] if comment else None, + "metadata": metadata, + } + if trace_id: + kwargs["trace_id"] = trace_id + + langfuse_client.create_score(**kwargs) + langfuse_client.flush() + + logger.info( + f"Rubric score logged to Langfuse: " + f"value={score}, trace_id={trace_id}" + ) + return True, None + + except ImportError: + return False, "Langfuse package not installed." + except Exception as e: + msg = str(e) + logger.error(f"Failed to log rubric score to Langfuse: {msg}") + return False, msg diff --git a/components/runners/claude-code-runner/tools/session.py b/components/runners/claude-code-runner/tools/session.py new file mode 100644 index 000000000..0c61c94ec --- /dev/null +++ b/components/runners/claude-code-runner/tools/session.py @@ -0,0 +1,46 @@ +""" +Session control MCP tool — allows Claude to request a session restart. +""" + +import logging + +from prompts import RESTART_TOOL_DESCRIPTION + +logger = logging.getLogger(__name__) + + +def create_restart_session_tool(adapter_ref, sdk_tool_decorator): + """Create the restart_session MCP tool. + + Args: + adapter_ref: Reference to the ClaudeCodeAdapter instance + (used to set _restart_requested flag). + sdk_tool_decorator: The ``tool`` decorator from ``claude_agent_sdk``. + + Returns: + Decorated async tool function. + """ + + @sdk_tool_decorator( + "restart_session", + RESTART_TOOL_DESCRIPTION, + {}, + ) + async def restart_session_tool(args: dict) -> dict: + """Tool that allows Claude to request a session restart.""" + adapter_ref._restart_requested = True + logger.info("🔄 Session restart requested by Claude via MCP tool") + return { + "content": [ + { + "type": "text", + "text": ( + "Session restart has been requested. The current run " + "will complete and a fresh session will be established. " + "Your conversation context will be preserved on disk." + ), + } + ] + } + + return restart_session_tool diff --git a/components/runners/claude-code-runner/utils.py b/components/runners/claude-code-runner/utils.py new file mode 100644 index 000000000..55b4c63d2 --- /dev/null +++ b/components/runners/claude-code-runner/utils.py @@ -0,0 +1,174 @@ +""" +General utility functions for the Claude Code runner. + +Pure functions with no business-logic dependencies — URL parsing, +secret redaction, subprocess helpers, environment variable expansion. +""" + +import asyncio +import logging +import os +import re +from datetime import datetime, timezone +from typing import Any +from urllib.parse import urlparse, urlunparse + +logger = logging.getLogger(__name__) + + +def timestamp() -> str: + """Return current UTC timestamp in ISO format.""" + return datetime.now(timezone.utc).isoformat() + + +def redact_secrets(text: str) -> str: + """Redact tokens and secrets from text for safe logging.""" + if not text: + return text + + text = re.sub(r"gh[pousr]_[a-zA-Z0-9]{36,255}", "gh*_***REDACTED***", text) + text = re.sub(r"sk-ant-[a-zA-Z0-9\-_]{30,200}", "sk-ant-***REDACTED***", text) + text = re.sub(r"pk-lf-[a-zA-Z0-9\-_]{10,100}", "pk-lf-***REDACTED***", text) + text = re.sub(r"sk-lf-[a-zA-Z0-9\-_]{10,100}", "sk-lf-***REDACTED***", text) + text = re.sub( + r"x-access-token:[^@\s]+@", "x-access-token:***REDACTED***@", text + ) + text = re.sub(r"oauth2:[^@\s]+@", "oauth2:***REDACTED***@", text) + text = re.sub(r"://[^:@\s]+:[^@\s]+@", "://***REDACTED***@", text) + text = re.sub( + r'(ANTHROPIC_API_KEY|LANGFUSE_SECRET_KEY|LANGFUSE_PUBLIC_KEY|BOT_TOKEN|GIT_TOKEN)\s*=\s*[^\s\'"]+', + r"\1=***REDACTED***", + text, + ) + return text + + +def url_with_token(url: str, token: str) -> str: + """Add authentication token to a git URL. + + Uses x-access-token for GitHub, oauth2 for GitLab. + """ + if not token or not url.lower().startswith("http"): + return url + try: + parsed = urlparse(url) + netloc = parsed.netloc + if "@" in netloc: + netloc = netloc.split("@", 1)[1] + + hostname = parsed.hostname or "" + if "gitlab" in hostname.lower(): + auth = f"oauth2:{token}@" + else: + auth = f"x-access-token:{token}@" + + new_netloc = auth + netloc + return urlunparse( + ( + parsed.scheme, + new_netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + except Exception: + return url + + +def parse_owner_repo(url: str) -> tuple[str, str, str]: + """Return (owner, name, host) from various git URL formats. + + Supports HTTPS, SSH, and shorthand owner/repo formats. + """ + s = (url or "").strip() + s = s.removesuffix(".git") + host = "github.com" + try: + if s.startswith("http://") or s.startswith("https://"): + p = urlparse(s) + host = p.netloc + parts = [pt for pt in p.path.split("/") if pt] + if len(parts) >= 2: + return parts[0], parts[1], host + if s.startswith("git@") or ":" in s: + s2 = s + if s2.startswith("git@"): + s2 = s2.replace(":", "/", 1) + s2 = s2.replace("git@", "ssh://git@", 1) + p = urlparse(s2) + host = p.hostname or host + parts = [pt for pt in (p.path or "").split("/") if pt] + if len(parts) >= 2: + return parts[-2], parts[-1], host + parts = [pt for pt in s.split("/") if pt] + if len(parts) == 2: + return parts[0], parts[1], host + except Exception: + return "", "", host + return "", "", host + + +def expand_env_vars(value: Any) -> Any: + """Recursively expand ${VAR} and ${VAR:-default} patterns in config values.""" + if isinstance(value, str): + pattern = r"\$\{([^}:]+)(?::-([^}]*))?\}" + + def replace_var(match): + var_name = match.group(1) + default_val = match.group(2) if match.group(2) is not None else "" + return os.environ.get(var_name, default_val) + + return re.sub(pattern, replace_var, value) + elif isinstance(value, dict): + return {k: expand_env_vars(v) for k, v in value.items()} + elif isinstance(value, list): + return [expand_env_vars(item) for item in value] + return value + + +async def run_cmd( + cmd: list, + cwd: str | None = None, + capture_stdout: bool = False, + ignore_errors: bool = False, +) -> str: + """Run a subprocess command asynchronously. + + Args: + cmd: Command and arguments list. + cwd: Working directory (defaults to current directory). + capture_stdout: If True, return stdout text. + ignore_errors: If True, don't raise on non-zero exit. + + Returns: + stdout text if capture_stdout is True, else empty string. + + Raises: + RuntimeError: If command fails and ignore_errors is False. + """ + cmd_safe = [redact_secrets(str(arg)) for arg in cmd] + logger.info(f"Running command: {' '.join(cmd_safe)}") + + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + ) + stdout_data, stderr_data = await proc.communicate() + stdout_text = stdout_data.decode("utf-8", errors="replace") + stderr_text = stderr_data.decode("utf-8", errors="replace") + + if stdout_text.strip(): + logger.info(f"Command stdout: {redact_secrets(stdout_text.strip())}") + if stderr_text.strip(): + logger.info(f"Command stderr: {redact_secrets(stderr_text.strip())}") + + if proc.returncode != 0 and not ignore_errors: + raise RuntimeError(stderr_text or f"Command failed: {' '.join(cmd_safe)}") + + if capture_stdout: + return stdout_text + return "" diff --git a/components/runners/claude-code-runner/workspace.py b/components/runners/claude-code-runner/workspace.py new file mode 100644 index 000000000..0b01e84aa --- /dev/null +++ b/components/runners/claude-code-runner/workspace.py @@ -0,0 +1,251 @@ +""" +Workspace and path management for the Claude Code runner. + +Handles workflow/repo directory setup, workspace validation, +and prerequisite checking for phase-based commands. +""" + +import logging +import os +from pathlib import Path +from typing import AsyncIterator +from urllib.parse import urlparse + +from context import RunnerContext +from utils import parse_owner_repo + +logger = logging.getLogger(__name__) + + +class PrerequisiteError(RuntimeError): + """Raised when slash-command prerequisites are missing.""" + + pass + + +def setup_workflow_paths( + context: RunnerContext, active_workflow_url: str, repos_cfg: list +) -> tuple[str, list, str]: + """Setup CWD and additional directories for workflow mode. + + Returns: + (cwd_path, additional_dirs, derived_workflow_name) + """ + add_dirs: list[str] = [] + derived_name = None + cwd_path = context.workspace_path + + try: + _owner, repo, _ = parse_owner_repo(active_workflow_url) + derived_name = repo or "" + if not derived_name: + p = urlparse(active_workflow_url) + parts = [pt for pt in (p.path or "").split("/") if pt] + if parts: + derived_name = parts[-1] + derived_name = (derived_name or "").removesuffix(".git").strip() + + if derived_name: + workflow_path = str( + Path(context.workspace_path) / "workflows" / derived_name + ) + if Path(workflow_path).exists(): + cwd_path = workflow_path + logger.info(f"Using workflow as CWD: {derived_name}") + else: + logger.warning( + f"Workflow directory not found: {workflow_path}, using default" + ) + cwd_path = str( + Path(context.workspace_path) / "workflows" / "default" + ) + else: + cwd_path = str( + Path(context.workspace_path) / "workflows" / "default" + ) + except Exception as e: + logger.warning(f"Failed to derive workflow name: {e}, using default") + cwd_path = str( + Path(context.workspace_path) / "workflows" / "default" + ) + + # Add all repos as additional directories + repos_base = Path(context.workspace_path) / "repos" + for r in repos_cfg: + name = (r.get("name") or "").strip() + if name: + repo_path = str(repos_base / name) + if repo_path not in add_dirs: + add_dirs.append(repo_path) + + # Add artifacts and file-uploads directories + artifacts_path = str(Path(context.workspace_path) / "artifacts") + if artifacts_path not in add_dirs: + add_dirs.append(artifacts_path) + + file_uploads_path = str(Path(context.workspace_path) / "file-uploads") + if file_uploads_path not in add_dirs: + add_dirs.append(file_uploads_path) + + return cwd_path, add_dirs, derived_name + + +def setup_multi_repo_paths( + context: RunnerContext, repos_cfg: list +) -> tuple[str, list]: + """Setup CWD and additional directories for multi-repo mode. + + Repos are cloned to /workspace/repos/{name} by both + hydrate.sh (init container) and clone_repo_at_runtime(). + + Returns: + (cwd_path, additional_dirs) + """ + add_dirs: list[str] = [] + repos_base = Path(context.workspace_path) / "repos" + + main_name = (os.getenv("MAIN_REPO_NAME") or "").strip() + if not main_name: + idx_raw = (os.getenv("MAIN_REPO_INDEX") or "").strip() + try: + idx_val = int(idx_raw) if idx_raw else 0 + except Exception: + idx_val = 0 + if idx_val < 0 or idx_val >= len(repos_cfg): + idx_val = 0 + main_name = (repos_cfg[idx_val].get("name") or "").strip() + + cwd_path = ( + str(repos_base / main_name) if main_name else context.workspace_path + ) + + for r in repos_cfg: + name = (r.get("name") or "").strip() + if not name: + continue + p = str(repos_base / name) + if p != cwd_path: + add_dirs.append(p) + + # Add artifacts and file-uploads directories + artifacts_path = str(Path(context.workspace_path) / "artifacts") + if artifacts_path not in add_dirs: + add_dirs.append(artifacts_path) + + file_uploads_path = str(Path(context.workspace_path) / "file-uploads") + if file_uploads_path not in add_dirs: + add_dirs.append(file_uploads_path) + + return cwd_path, add_dirs + + +async def prepare_workspace(context: RunnerContext) -> None: + """Validate workspace prepared by init container. + + The init-hydrate container handles downloading state from S3, + cloning repos, and cloning workflows. This just validates and logs. + """ + workspace = Path(context.workspace_path) + logger.info(f"Validating workspace at {workspace}") + + hydrated_paths = [] + for path_name in [".claude", "artifacts", "file-uploads"]: + path_dir = workspace / path_name + if path_dir.exists(): + file_count = len([f for f in path_dir.rglob("*") if f.is_file()]) + if file_count > 0: + hydrated_paths.append(f"{path_name} ({file_count} files)") + + if hydrated_paths: + logger.info(f"Hydrated from S3: {', '.join(hydrated_paths)}") + else: + logger.info("No state hydrated (fresh session)") + + +async def validate_prerequisites(context: RunnerContext) -> None: + """Validate prerequisite files exist for phase-based slash commands. + + Raises: + PrerequisiteError: If a required file is missing. + """ + prompt = context.get_env("INITIAL_PROMPT", "") + if not prompt: + return + + prompt_lower = prompt.strip().lower() + + prerequisites = { + "/speckit.plan": ( + "spec.md", + "Specification file (spec.md) not found. Please run /speckit.specify first.", + ), + "/speckit.tasks": ( + "plan.md", + "Planning file (plan.md) not found. Please run /speckit.plan first.", + ), + "/speckit.implement": ( + "tasks.md", + "Tasks file (tasks.md) not found. Please run /speckit.tasks first.", + ), + } + + for cmd, (required_file, error_msg) in prerequisites.items(): + if prompt_lower.startswith(cmd): + workspace = Path(context.workspace_path) + found = False + + if (workspace / required_file).exists(): + found = True + break + + for subdir in workspace.rglob("specs/*/"): + if (subdir / required_file).exists(): + found = True + break + + if not found: + raise PrerequisiteError(error_msg) + break + + +async def initialize_workflow_if_set(context: RunnerContext) -> None: + """Validate workflow was cloned by init container.""" + active_workflow_url = (os.getenv("ACTIVE_WORKFLOW_GIT_URL") or "").strip() + if not active_workflow_url: + return + + try: + _owner, repo, _ = parse_owner_repo(active_workflow_url) + derived_name = repo or "" + if not derived_name: + p = urlparse(active_workflow_url) + parts = [pt for pt in (p.path or "").split("/") if pt] + if parts: + derived_name = parts[-1] + derived_name = (derived_name or "").removesuffix(".git").strip() + + if not derived_name: + logger.warning("Could not derive workflow name from URL") + return + + workspace = Path(context.workspace_path) + workflow_temp_dir = workspace / "workflows" / f"{derived_name}-clone-temp" + workflow_dir = workspace / "workflows" / derived_name + + if workflow_temp_dir.exists(): + logger.info( + f"Workflow {derived_name} cloned by init container " + f"at {workflow_temp_dir.name}" + ) + elif workflow_dir.exists(): + logger.info( + f"Workflow {derived_name} available at {workflow_dir.name}" + ) + else: + logger.warning( + f"Workflow {derived_name} not found " + "(init container may have failed to clone)" + ) + + except Exception as e: + logger.error(f"Failed to validate workflow: {e}")