diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 405cd458..9097ccd2 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -29,6 +29,11 @@ CONTINUE_PROMPT = "Continue the task." HINT_RE = re.compile(r"\$([A-Za-z0-9_.-]+)") +_CONTEXT_LENGTH_PATTERNS = re.compile( + r"context.{0,20}length|maximum.{0,20}context|token.{0,10}limit|prompt.{0,10}too long|tokens? > \d+ maximum", + re.IGNORECASE, +) +MAX_AUTO_HANDOFF_RETRIES = 1 class Agent: @@ -121,6 +126,7 @@ async def _agent_loop( ) -> str: next_prompt: str | list[dict] = prompt display_model = model or self.settings.model + auto_handoff_remaining = MAX_AUTO_HANDOFF_RETRIES await self.tapes.append_event( tape.name, "loop.start", @@ -188,6 +194,35 @@ async def _agent_loop( }, ) continue + + # Check if this is a context-length error that can be recovered via auto-handoff + if auto_handoff_remaining > 0 and _is_context_length_error(outcome.error): + auto_handoff_remaining -= 1 + logger.warning( + "auto_handoff: context length exceeded, performing automatic handoff. tape={} step={}", + tape.name, + step, + ) + await self.tapes.handoff( + tape.name, + name="auto_handoff/context_overflow", + state={"reason": "context_length_exceeded", "error": outcome.error}, + ) + await self.tapes.append_event( + tape.name, + "loop.step", + { + "step": step, + "elapsed_ms": elapsed_ms, + "status": "auto_handoff", + "error": outcome.error, + "date": datetime.now(UTC).isoformat(), + }, + ) + # Retry with original prompt — the handoff anchor will truncate history + next_prompt = prompt + continue + await self.tapes.append_event( tape.name, "loop.step", @@ -318,6 +353,11 @@ def _parse_args(args_tokens: list[str]) -> Args: return Args(positional=positional, kwargs=kwargs) +def _is_context_length_error(error_msg: str) -> bool: + """Check whether an error message indicates a context-length / prompt-too-long failure.""" + return bool(_CONTEXT_LENGTH_PATTERNS.search(error_msg)) + + def _extract_text_from_parts(parts: list[dict]) -> str: """Extract text content from multimodal content parts.""" return "\n".join(p.get("text", "") for p in parts if p.get("type") == "text")