diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 8bec66737a..4599cd3526 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -17,8 +17,10 @@ BaseAgent, Content, ContextProvider, + HistoryProvider, Message, ResponseStream, + SessionContext, normalize_messages, ) from agent_framework._settings import load_settings @@ -352,13 +354,25 @@ def run( AgentException: If the request fails. """ if stream: + ctx_holder: dict[str, Any] = {} + + async def _after_run_hook(response: AgentResponse) -> None: + session_context = ctx_holder.get("session_context") + sess = ctx_holder.get("session") + if session_context is not None and sess is not None: + session_context._response = response + try: + await self._run_after_providers(session=sess, context=session_context) + except Exception: + logger.exception("Error running after_run providers in streaming result hook") def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) return ResponseStream( - self._stream_updates(messages=messages, session=session, options=options), + self._stream_updates(messages=messages, session=session, options=options, _ctx_holder=ctx_holder), finalizer=_finalize, + result_hooks=[_after_run_hook], ) return self._run_impl(messages=messages, session=session, options=options) @@ -377,11 +391,22 @@ async def _run_impl( session = self.create_session() opts: dict[str, Any] = dict(options) if options else {} - timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS + timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS - copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) input_messages = normalize_messages(messages) - prompt = "\n".join([message.text for message in input_messages]) + + session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) + + # NOTE: session is created after providers run so that future provider-contributed + # tools/config could be folded into runtime_options before session creation. + copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) + + # Build the prompt from the full set of messages in the session context, + # so that any context/history provider-injected messages are included. + context_messages = session_context.get_messages(include_input=True) + prompt = "\n".join([message.text for message in context_messages]) + if session_context.instructions: + prompt = "\n".join(session_context.instructions) + "\n" + prompt message_options = cast(MessageOptions, {"prompt": prompt}) try: @@ -408,7 +433,10 @@ async def _run_impl( ) response_id = message_id - return AgentResponse(messages=response_messages, response_id=response_id) + response = AgentResponse(messages=response_messages, response_id=response_id) + session_context._response = response # type: ignore[assignment] + await self._run_after_providers(session=session, context=session_context) + return response async def _stream_updates( self, @@ -416,6 +444,7 @@ async def _stream_updates( *, session: AgentSession | None = None, options: OptionsT | None = None, + _ctx_holder: dict[str, Any] | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Internal method to stream updates from GitHub Copilot. @@ -425,6 +454,9 @@ async def _stream_updates( Keyword Args: session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). + _ctx_holder: Internal dict populated with session_context and session + so that the caller (via a ResponseStream result_hook) can run + after_run providers without duplicating the updates buffer. Yields: AgentResponseUpdate items. @@ -440,9 +472,23 @@ async def _stream_updates( opts: dict[str, Any] = dict(options) if options else {} - copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) input_messages = normalize_messages(messages) - prompt = "\n".join([message.text for message in input_messages]) + + session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) + + # NOTE: session is created after providers run so that future provider-contributed + # tools/config could be folded into runtime_options before session creation. + copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) + + if _ctx_holder is not None: + _ctx_holder["session_context"] = session_context + _ctx_holder["session"] = session + + # Build the prompt from the full session context so provider-injected messages are included. + context_messages = session_context.get_messages(include_input=True) + prompt = "\n".join([message.text for message in context_messages]) + if session_context.instructions: + prompt = "\n".join(session_context.instructions) + "\n" + prompt message_options = cast(MessageOptions, {"prompt": prompt}) queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue() @@ -513,6 +559,46 @@ def event_handler(event: SessionEvent) -> None: finally: unsubscribe() + async def _run_before_providers( + self, + *, + session: AgentSession, + input_messages: list[Message], + options: dict[str, Any], + ) -> SessionContext: + """Run before_run on all context providers and return the session context. + + Creates a SessionContext and invokes ``before_run`` on each provider in + forward order. ``HistoryProvider`` instances with + ``load_messages=False`` are skipped. + + Keyword Args: + session: The conversation session. + input_messages: The normalized input messages. + options: Runtime options dict. + + Returns: + The SessionContext with provider context populated. + """ + session_context = SessionContext( + session_id=session.session_id, + service_session_id=session.service_session_id, + input_messages=input_messages, + options=options, + ) + + for provider in self.context_providers: + if isinstance(provider, HistoryProvider) and not provider.load_messages: + continue + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=session, + context=session_context, + state=session.state.setdefault(provider.source_id, {}), + ) + + return session_context + @staticmethod def _prepare_system_message( instructions: str | None, diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index d3aae50bba..e91a725765 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -17,6 +17,8 @@ AgentResponseUpdate, AgentSession, Content, + ContextProvider, + HistoryProvider, Message, ) from agent_framework.exceptions import AgentException @@ -1367,3 +1369,532 @@ async def test_session_config_excludes_permission_handler_when_not_set( call_args = mock_client.create_session.call_args config = call_args[0][0] assert "on_permission_request" not in config + + +class SpyContextProvider(ContextProvider): + """A context provider that records whether its hooks are called.""" + + def __init__(self) -> None: + super().__init__(source_id="spy-provider") + self.before_run_called = False + self.after_run_called = False + self.before_run_context: Any = None + self.after_run_context: Any = None + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + self.before_run_called = True + self.before_run_context = context + context.instructions.append("Injected by spy provider") + + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + self.after_run_called = True + self.after_run_context = context + + +class TestGitHubCopilotAgentContextProviders: + """Test cases for context provider integration.""" + + async def test_before_run_called_on_run( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that before_run is called on context providers during run().""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert spy.before_run_called + + async def test_after_run_called_on_run( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that after_run is called on context providers after run().""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert spy.after_run_called + + async def test_provider_instructions_included_in_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that instructions added by context providers are included in the prompt.""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"] + assert "Injected by spy provider" in sent_prompt + + async def test_after_run_receives_response( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that after_run context contains the agent response.""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert spy.after_run_context is not None + assert spy.after_run_context.response is not None + + async def test_before_run_called_on_streaming( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that before_run is called on context providers during streaming.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.before_run_called + + async def test_after_run_called_on_streaming( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that after_run is called on context providers after streaming.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.after_run_called + + async def test_provider_instructions_included_in_streaming_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that instructions from context providers are included in the streaming prompt.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + sent_prompt = mock_session.send.call_args[0][0]["prompt"] + assert "Injected by spy provider" in sent_prompt + + async def test_context_preserved_across_runs( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that provider state is preserved across multiple runs with the same session.""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + + await agent.run("Hello", session=session) + assert spy.before_run_called + + spy.before_run_called = False + await agent.run("Hello again", session=session) + assert spy.before_run_called + + async def test_context_messages_included_in_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that context messages added by providers via extend_messages are included in the prompt.""" + mock_session.send_and_wait.return_value = assistant_message_event + + class MessageInjectingProvider(ContextProvider): + def __init__(self) -> None: + super().__init__(source_id="msg-injector") + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])]) + + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + pass + + provider = MessageInjectingProvider() + agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider]) + session = agent.create_session() + await agent.run("Hello", session=session) + + sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"] + assert "History message" in sent_prompt + assert "Hello" in sent_prompt + + async def test_context_messages_included_in_streaming_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that context messages added by providers are included in the streaming prompt.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + + class MessageInjectingProvider(ContextProvider): + def __init__(self) -> None: + super().__init__(source_id="msg-injector") + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])]) + + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + pass + + provider = MessageInjectingProvider() + agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + sent_prompt = mock_session.send.call_args[0][0]["prompt"] + assert "History message" in sent_prompt + assert "Hello" in sent_prompt + + async def test_after_run_not_called_on_error( + self, + mock_client: MagicMock, + mock_session: MagicMock, + ) -> None: + """Test that after_run is NOT called when send_and_wait raises.""" + mock_session.send_and_wait.side_effect = Exception("Request failed") + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + with pytest.raises(AgentException): + await agent.run("Hello", session=session) + + assert spy.before_run_called + assert not spy.after_run_called + + async def test_after_run_not_called_on_streaming_error( + self, + mock_client: MagicMock, + mock_session: MagicMock, + session_error_event: SessionEvent, + ) -> None: + """Test that after_run is NOT called when streaming encounters an error.""" + events = [session_error_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + with pytest.raises(AgentException): + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.before_run_called + assert not spy.after_run_called + + async def test_multiple_providers_ordering( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that before_run is called in forward order and after_run in reverse order.""" + mock_session.send_and_wait.return_value = assistant_message_event + call_order: list[str] = [] + + class OrderedProvider(ContextProvider): + def __init__(self, name: str) -> None: + super().__init__(source_id=name) + self.name = name + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + call_order.append(f"before:{self.name}") + + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + call_order.append(f"after:{self.name}") + + providers = [OrderedProvider("A"), OrderedProvider("B"), OrderedProvider("C")] + agent = GitHubCopilotAgent(client=mock_client, context_providers=providers) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert call_order == ["before:A", "before:B", "before:C", "after:C", "after:B", "after:A"] + + async def test_history_provider_skip_when_load_messages_false( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that HistoryProvider with load_messages=False is skipped in before_run.""" + mock_session.send_and_wait.return_value = assistant_message_event + + class StubHistoryProvider(HistoryProvider): + def __init__(self, *, load_messages: bool = True) -> None: + super().__init__(source_id="stub-history", load_messages=load_messages) + self.before_run_called = False + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + self.before_run_called = True + + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + self.after_run_called = True + + async def get_messages(self, *, session_id: str, **kwargs: Any) -> list[Message]: + return [] + + async def save_messages(self, *, session_id: str, messages: list[Message], **kwargs: Any) -> None: + pass + + skipped_provider = StubHistoryProvider(load_messages=False) + active_provider = StubHistoryProvider(load_messages=True) + # Use unique source_ids + skipped_provider._source_id = "skipped-history" + active_provider._source_id = "active-history" + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[skipped_provider, active_provider]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert not skipped_provider.before_run_called + assert active_provider.before_run_called + # after_run should still be called even when load_messages=False + assert skipped_provider.after_run_called + assert active_provider.after_run_called + + async def test_streaming_after_run_response_has_updates( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that streaming after_run context.response contains the aggregated updates.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.after_run_context is not None + assert spy.after_run_context.response is not None + assert len(spy.after_run_context.response.messages) > 0 + assert spy.after_run_context.response.messages[0].text == "Hello" + + async def test_streaming_after_run_sets_empty_response_on_no_updates( + self, + mock_client: MagicMock, + mock_session: MagicMock, + session_idle_event: SessionEvent, + ) -> None: + """Test that streaming after_run sets an empty response when no updates are yielded.""" + events = [session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.after_run_called + assert spy.after_run_context.response is not None + assert len(spy.after_run_context.response.messages) == 0 + + async def test_timeout_preserved_in_session_context_options( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that timeout is preserved in session context options for providers.""" + mock_session.send_and_wait.return_value = assistant_message_event + observed_options: dict[str, Any] = {} + + class OptionsObserverProvider(ContextProvider): + def __init__(self) -> None: + super().__init__(source_id="options-observer") + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + observed_options.update(context.options) + + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + pass + + provider = OptionsObserverProvider() + agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider]) + session = agent.create_session() + await agent.run("Hello", session=session, options={"timeout": 120}) + + assert observed_options.get("timeout") == 120