From fa16364aee4e6cc8aac0594773dbfe0b6dbf6a7d Mon Sep 17 00:00:00 2001 From: Copilot Date: Tue, 31 Mar 2026 22:33:06 +0000 Subject: [PATCH 1/9] Fix GitHubCopilotAgent not calling context provider hooks (#3984) GitHubCopilotAgent accepted context_providers in its constructor but never called before_run()/after_run() on them in _run_impl() or _stream_updates(), silently ignoring all context providers. Add _run_before_providers() helper to create SessionContext and invoke before_run on each provider. Both _run_impl() and _stream_updates() now run the full provider lifecycle: before_run before sending the prompt (with provider instructions prepended) and after_run after receiving the response. This follows the same pattern used by A2AAgent. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_github_copilot/_agent.py | 63 +++++- .../tests/test_github_copilot_agent.py | 201 ++++++++++++++++++ 2 files changed, 263 insertions(+), 1 deletion(-) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 69f3bf20d4..a66703d813 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -16,9 +16,11 @@ AgentSession, BaseAgent, BaseContextProvider, + BaseHistoryProvider, Content, Message, ResponseStream, + SessionContext, normalize_messages, ) from agent_framework._settings import load_settings @@ -381,7 +383,14 @@ async def _run_impl( copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) input_messages = normalize_messages(messages) + + session_context = await self._run_before_providers( + session=session, input_messages=input_messages, options=opts + ) + prompt = "\n".join([message.text for message in input_messages]) + if session_context.instructions: + prompt = "\n".join(session_context.instructions) + "\n" + prompt message_options = cast(MessageOptions, {"prompt": prompt}) try: @@ -408,7 +417,10 @@ async def _run_impl( ) response_id = message_id - return AgentResponse(messages=response_messages, response_id=response_id) + response = AgentResponse(messages=response_messages, response_id=response_id) + session_context._response = response # type: ignore[assignment] + await self._run_after_providers(session=session, context=session_context) + return response async def _stream_updates( self, @@ -442,7 +454,14 @@ async def _stream_updates( copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) input_messages = normalize_messages(messages) + + session_context = await self._run_before_providers( + session=session, input_messages=input_messages, options=opts + ) + prompt = "\n".join([message.text for message in input_messages]) + if session_context.instructions: + prompt = "\n".join(session_context.instructions) + "\n" + prompt message_options = cast(MessageOptions, {"prompt": prompt}) queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue() @@ -502,6 +521,7 @@ def event_handler(event: SessionEvent) -> None: queue.put_nowait(AgentException(f"GitHub Copilot session error: {error_msg}")) unsubscribe = copilot_session.on(event_handler) + all_updates: list[AgentResponseUpdate] = [] try: await copilot_session.send(message_options) @@ -509,10 +529,51 @@ def event_handler(event: SessionEvent) -> None: while (item := await queue.get()) is not None: if isinstance(item, Exception): raise item + all_updates.append(item) yield item finally: unsubscribe() + if all_updates: + session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment] + await self._run_after_providers(session=session, context=session_context) + + async def _run_before_providers( + self, + *, + session: AgentSession, + input_messages: list[Message], + options: dict[str, Any], + ) -> SessionContext: + """Run before_run on all context providers and return the session context. + + Keyword Args: + session: The conversation session. + input_messages: The normalized input messages. + options: Runtime options dict. + + Returns: + The SessionContext with provider context populated. + """ + session_context = SessionContext( + session_id=session.session_id, + service_session_id=session.service_session_id, + input_messages=input_messages, + options=options, + ) + + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=session, + context=session_context, + state=session.state.setdefault(provider.source_id, {}), + ) + + return session_context + @staticmethod def _prepare_system_message( instructions: str | None, diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index d3aae50bba..f085916b37 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -16,6 +16,7 @@ AgentResponse, AgentResponseUpdate, AgentSession, + BaseContextProvider, Content, Message, ) @@ -1367,3 +1368,203 @@ async def test_session_config_excludes_permission_handler_when_not_set( call_args = mock_client.create_session.call_args config = call_args[0][0] assert "on_permission_request" not in config + + +class SpyContextProvider(BaseContextProvider): + """A context provider that records whether its hooks are called.""" + + def __init__(self) -> None: + super().__init__(source_id="spy-provider") + self.before_run_called = False + self.after_run_called = False + self.before_run_context: Any = None + self.after_run_context: Any = None + + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + self.before_run_called = True + self.before_run_context = context + context.instructions.append("Injected by spy provider") + + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: + self.after_run_called = True + self.after_run_context = context + + +class TestGitHubCopilotAgentContextProviders: + """Test cases for context provider integration.""" + + async def test_before_run_called_on_run( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that before_run is called on context providers during run().""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert spy.before_run_called + + async def test_after_run_called_on_run( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that after_run is called on context providers after run().""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert spy.after_run_called + + async def test_provider_instructions_included_in_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that instructions added by context providers are included in the prompt.""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"] + assert "Injected by spy provider" in sent_prompt + + async def test_after_run_receives_response( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that after_run context contains the agent response.""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert spy.after_run_context is not None + assert spy.after_run_context.response is not None + + async def test_before_run_called_on_streaming( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that before_run is called on context providers during streaming.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.before_run_called + + async def test_after_run_called_on_streaming( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that after_run is called on context providers after streaming.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.after_run_called + + async def test_provider_instructions_included_in_streaming_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that instructions from context providers are included in the streaming prompt.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + sent_prompt = mock_session.send.call_args[0][0]["prompt"] + assert "Injected by spy provider" in sent_prompt + + async def test_context_preserved_across_runs( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that provider state is preserved across multiple runs with the same session.""" + mock_session.send_and_wait.return_value = assistant_message_event + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + + await agent.run("Hello", session=session) + assert spy.before_run_called + + spy.before_run_called = False + await agent.run("Hello again", session=session) + assert spy.before_run_called From 9c331acc9359a6758054748e992f8b2d19f3a016 Mon Sep 17 00:00:00 2001 From: Copilot Date: Tue, 31 Mar 2026 22:37:33 +0000 Subject: [PATCH 2/9] Python: Fix GitHubCopilotAgent to invoke context provider before_run/after_run hooks Fixes #3984 --- .../agent_framework_github_copilot/_agent.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index a66703d813..7e368425eb 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -384,9 +384,7 @@ async def _run_impl( copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) input_messages = normalize_messages(messages) - session_context = await self._run_before_providers( - session=session, input_messages=input_messages, options=opts - ) + session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) prompt = "\n".join([message.text for message in input_messages]) if session_context.instructions: @@ -455,9 +453,7 @@ async def _stream_updates( copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) input_messages = normalize_messages(messages) - session_context = await self._run_before_providers( - session=session, input_messages=input_messages, options=opts - ) + session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) prompt = "\n".join([message.text for message in input_messages]) if session_context.instructions: From 1e6d1be93277ad67e5c3b4a7c016b02fe4dc0773 Mon Sep 17 00:00:00 2001 From: Copilot Date: Tue, 31 Mar 2026 22:49:04 +0000 Subject: [PATCH 3/9] fix(#3984): address review feedback for context provider integration - Build prompt from session_context.get_messages(include_input=True) so provider-injected context_messages are included in both non-streaming and streaming paths (review comments #1, #2) - Preserve timeout in opts (use get instead of pop) so providers can observe it via context.options (review comment #3) - Eliminate streaming double-buffer: move after_run invocation to a ResponseStream result_hook (matching Agent class pattern) instead of maintaining a separate updates list in the generator (review comment #4) - Improve _run_before_providers docstring Add tests for: - Context messages included in prompt (non-streaming + streaming) - Error path: after_run NOT called when send_and_wait/streaming raises - Multiple providers: forward before_run, reverse after_run ordering - BaseHistoryProvider with load_messages=False is skipped - Streaming after_run response contains aggregated updates - Streaming with no updates still sets empty response - Timeout preserved in session context options for providers Note: _run_before_providers remains on GitHubCopilotAgent for now. A follow-up PR should extract it to BaseAgent so subclasses can reuse it without duplicating the provider iteration logic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_github_copilot/_agent.py | 39 ++- .../tests/test_github_copilot_agent.py | 259 ++++++++++++++++++ 2 files changed, 289 insertions(+), 9 deletions(-) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 7e368425eb..284bc6be52 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -354,13 +354,22 @@ def run( AgentException: If the request fails. """ if stream: + ctx_holder: dict[str, Any] = {} + + async def _after_run_hook(response: AgentResponse) -> None: + session_context = ctx_holder.get("session_context") + sess = ctx_holder.get("session") + if session_context is not None and sess is not None: + session_context._response = response + await self._run_after_providers(session=sess, context=session_context) def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) return ResponseStream( - self._stream_updates(messages=messages, session=session, options=options), + self._stream_updates(messages=messages, session=session, options=options, _ctx_holder=ctx_holder), finalizer=_finalize, + result_hooks=[_after_run_hook], ) return self._run_impl(messages=messages, session=session, options=options) @@ -379,14 +388,17 @@ async def _run_impl( session = self.create_session() opts: dict[str, Any] = dict(options) if options else {} - timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS + timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) input_messages = normalize_messages(messages) session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) - prompt = "\n".join([message.text for message in input_messages]) + # Build the prompt from the full set of messages in the session context, + # so that any context/history provider-injected messages are included. + context_messages = session_context.get_messages(include_input=True) + prompt = "\n".join([message.text for message in context_messages]) if session_context.instructions: prompt = "\n".join(session_context.instructions) + "\n" + prompt message_options = cast(MessageOptions, {"prompt": prompt}) @@ -426,6 +438,7 @@ async def _stream_updates( *, session: AgentSession | None = None, options: OptionsT | None = None, + _ctx_holder: dict[str, Any] | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Internal method to stream updates from GitHub Copilot. @@ -435,6 +448,9 @@ async def _stream_updates( Keyword Args: session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). + _ctx_holder: Internal dict populated with session_context and session + so that the caller (via a ResponseStream result_hook) can run + after_run providers without duplicating the updates buffer. Yields: AgentResponseUpdate items. @@ -455,7 +471,13 @@ async def _stream_updates( session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) - prompt = "\n".join([message.text for message in input_messages]) + if _ctx_holder is not None: + _ctx_holder["session_context"] = session_context + _ctx_holder["session"] = session + + # Build the prompt from the full session context so provider-injected messages are included. + context_messages = session_context.get_messages(include_input=True) + prompt = "\n".join([message.text for message in context_messages]) if session_context.instructions: prompt = "\n".join(session_context.instructions) + "\n" + prompt message_options = cast(MessageOptions, {"prompt": prompt}) @@ -517,7 +539,6 @@ def event_handler(event: SessionEvent) -> None: queue.put_nowait(AgentException(f"GitHub Copilot session error: {error_msg}")) unsubscribe = copilot_session.on(event_handler) - all_updates: list[AgentResponseUpdate] = [] try: await copilot_session.send(message_options) @@ -525,14 +546,10 @@ def event_handler(event: SessionEvent) -> None: while (item := await queue.get()) is not None: if isinstance(item, Exception): raise item - all_updates.append(item) yield item finally: unsubscribe() - if all_updates: - session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment] - await self._run_after_providers(session=session, context=session_context) async def _run_before_providers( self, @@ -543,6 +560,10 @@ async def _run_before_providers( ) -> SessionContext: """Run before_run on all context providers and return the session context. + Creates a SessionContext and invokes ``before_run`` on each provider in + forward order. ``BaseHistoryProvider`` instances with + ``load_messages=False`` are skipped. + Keyword Args: session: The conversation session. input_messages: The normalized input messages. diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index f085916b37..b7d00d26a5 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -17,8 +17,10 @@ AgentResponseUpdate, AgentSession, BaseContextProvider, + BaseHistoryProvider, Content, Message, + SessionContext, ) from agent_framework.exceptions import AgentException from copilot.generated.session_events import Data, ErrorClass, Result, SessionEvent, SessionEventType @@ -1568,3 +1570,260 @@ async def test_context_preserved_across_runs( spy.before_run_called = False await agent.run("Hello again", session=session) assert spy.before_run_called + + async def test_context_messages_included_in_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that context messages added by providers via extend_messages are included in the prompt.""" + mock_session.send_and_wait.return_value = assistant_message_event + + class MessageInjectingProvider(BaseContextProvider): + def __init__(self) -> None: + super().__init__(source_id="msg-injector") + + async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])]) + + async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + pass + + provider = MessageInjectingProvider() + agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider]) + session = agent.create_session() + await agent.run("Hello", session=session) + + sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"] + assert "History message" in sent_prompt + assert "Hello" in sent_prompt + + async def test_context_messages_included_in_streaming_prompt( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that context messages added by providers are included in the streaming prompt.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + + class MessageInjectingProvider(BaseContextProvider): + def __init__(self) -> None: + super().__init__(source_id="msg-injector") + + async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])]) + + async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + pass + + provider = MessageInjectingProvider() + agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + sent_prompt = mock_session.send.call_args[0][0]["prompt"] + assert "History message" in sent_prompt + assert "Hello" in sent_prompt + + async def test_after_run_not_called_on_error( + self, + mock_client: MagicMock, + mock_session: MagicMock, + ) -> None: + """Test that after_run is NOT called when send_and_wait raises.""" + mock_session.send_and_wait.side_effect = Exception("Request failed") + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + with pytest.raises(AgentException): + await agent.run("Hello", session=session) + + assert spy.before_run_called + assert not spy.after_run_called + + async def test_after_run_not_called_on_streaming_error( + self, + mock_client: MagicMock, + mock_session: MagicMock, + session_error_event: SessionEvent, + ) -> None: + """Test that after_run is NOT called when streaming encounters an error.""" + events = [session_error_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + with pytest.raises(AgentException): + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.before_run_called + assert not spy.after_run_called + + async def test_multiple_providers_ordering( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that before_run is called in forward order and after_run in reverse order.""" + mock_session.send_and_wait.return_value = assistant_message_event + call_order: list[str] = [] + + class OrderedProvider(BaseContextProvider): + def __init__(self, name: str) -> None: + super().__init__(source_id=name) + self.name = name + + async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + call_order.append(f"before:{self.name}") + + async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + call_order.append(f"after:{self.name}") + + providers = [OrderedProvider("A"), OrderedProvider("B"), OrderedProvider("C")] + agent = GitHubCopilotAgent(client=mock_client, context_providers=providers) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert call_order == ["before:A", "before:B", "before:C", "after:C", "after:B", "after:A"] + + async def test_history_provider_skip_when_load_messages_false( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that BaseHistoryProvider with load_messages=False is skipped in before_run.""" + mock_session.send_and_wait.return_value = assistant_message_event + + class StubHistoryProvider(BaseHistoryProvider): + def __init__(self, *, load_messages: bool = True) -> None: + super().__init__(source_id="stub-history", load_messages=load_messages) + self.before_run_called = False + + async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + self.before_run_called = True + + async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + pass + + async def get_messages(self, *, session_id: str, **kwargs: Any) -> list[Message]: + return [] + + async def save_messages(self, *, session_id: str, messages: list[Message], **kwargs: Any) -> None: + pass + + skipped_provider = StubHistoryProvider(load_messages=False) + active_provider = StubHistoryProvider(load_messages=True) + # Use unique source_ids + skipped_provider._source_id = "skipped-history" + active_provider._source_id = "active-history" + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[skipped_provider, active_provider]) + session = agent.create_session() + await agent.run("Hello", session=session) + + assert not skipped_provider.before_run_called + assert active_provider.before_run_called + + async def test_streaming_after_run_response_has_updates( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that streaming after_run context.response contains the aggregated updates.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.after_run_context is not None + assert spy.after_run_context.response is not None + assert len(spy.after_run_context.response.messages) > 0 + assert spy.after_run_context.response.messages[0].text == "Hello" + + async def test_streaming_after_run_sets_empty_response_on_no_updates( + self, + mock_client: MagicMock, + mock_session: MagicMock, + session_idle_event: SessionEvent, + ) -> None: + """Test that streaming after_run sets an empty response when no updates are yielded.""" + events = [session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + spy = SpyContextProvider() + + agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy]) + session = agent.create_session() + async for _ in agent.run("Hello", stream=True, session=session): + pass + + assert spy.after_run_called + assert spy.after_run_context.response is not None + assert len(spy.after_run_context.response.messages) == 0 + + async def test_timeout_preserved_in_session_context_options( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that timeout is preserved in session context options for providers.""" + mock_session.send_and_wait.return_value = assistant_message_event + observed_options: dict[str, Any] = {} + + class OptionsObserverProvider(BaseContextProvider): + def __init__(self) -> None: + super().__init__(source_id="options-observer") + + async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + observed_options.update(context.options) + + async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + pass + + provider = OptionsObserverProvider() + agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider]) + session = agent.create_session() + await agent.run("Hello", session=session, options={"timeout": 120}) + + assert observed_options.get("timeout") == 120 + From acc85d125ba25b1c3b819e9273ec73f5259e17d7 Mon Sep 17 00:00:00 2001 From: Copilot Date: Tue, 31 Mar 2026 22:57:23 +0000 Subject: [PATCH 4/9] Address review feedback for #3984: Python: [Bug]: GitHubCopilotAgent Memory Example --- .../agent_framework_github_copilot/_agent.py | 1 - .../tests/test_github_copilot_agent.py | 92 ++++++++++++++++--- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 284bc6be52..c7e2825b4a 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -550,7 +550,6 @@ def event_handler(event: SessionEvent) -> None: finally: unsubscribe() - async def _run_before_providers( self, *, diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index b7d00d26a5..dd2d7c82db 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -20,7 +20,6 @@ BaseHistoryProvider, Content, Message, - SessionContext, ) from agent_framework.exceptions import AgentException from copilot.generated.session_events import Data, ErrorClass, Result, SessionEvent, SessionEventType @@ -1584,10 +1583,24 @@ class MessageInjectingProvider(BaseContextProvider): def __init__(self) -> None: super().__init__(source_id="msg-injector") - async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])]) - async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: pass provider = MessageInjectingProvider() @@ -1620,10 +1633,24 @@ class MessageInjectingProvider(BaseContextProvider): def __init__(self) -> None: super().__init__(source_id="msg-injector") - async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])]) - async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: pass provider = MessageInjectingProvider() @@ -1694,10 +1721,24 @@ def __init__(self, name: str) -> None: super().__init__(source_id=name) self.name = name - async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: call_order.append(f"before:{self.name}") - async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: call_order.append(f"after:{self.name}") providers = [OrderedProvider("A"), OrderedProvider("B"), OrderedProvider("C")] @@ -1721,10 +1762,24 @@ def __init__(self, *, load_messages: bool = True) -> None: super().__init__(source_id="stub-history", load_messages=load_messages) self.before_run_called = False - async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: self.before_run_called = True - async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: pass async def get_messages(self, *, session_id: str, **kwargs: Any) -> list[Message]: @@ -1814,10 +1869,24 @@ class OptionsObserverProvider(BaseContextProvider): def __init__(self) -> None: super().__init__(source_id="options-observer") - async def before_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def before_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: observed_options.update(context.options) - async def after_run(self, *, agent: Any, session: AgentSession, context: Any, state: dict[str, Any]) -> None: + async def after_run( + self, + *, + agent: Any, + session: AgentSession, + context: Any, + state: dict[str, Any], + ) -> None: pass provider = OptionsObserverProvider() @@ -1826,4 +1895,3 @@ async def after_run(self, *, agent: Any, session: AgentSession, context: Any, st await agent.run("Hello", session=session, options={"timeout": 120}) assert observed_options.get("timeout") == 120 - From 8042220865490742053b8c5390848740b7e21f57 Mon Sep 17 00:00:00 2001 From: Copilot Date: Wed, 1 Apr 2026 00:10:45 +0000 Subject: [PATCH 5/9] refactor(#3984): promote _run_before_providers to BaseAgent Move _run_before_providers from GitHubCopilotAgent into BaseAgent, mirroring the existing _run_after_providers helper. Agent's _prepare_session_and_messages now delegates to the shared base method, eliminating the near-duplicate provider iteration logic that could drift as the provider contract evolves. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../packages/core/agent_framework/_agents.py | 70 +++++++++++------ .../packages/core/tests/core/test_agents.py | 76 +++++++++++++++++++ .../agent_framework_github_copilot/_agent.py | 42 ---------- 3 files changed, 124 insertions(+), 64 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 1868742111..76236911a1 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -439,6 +439,52 @@ def get_session(self, service_session_id: str, *, session_id: str | None = None) """ return AgentSession(session_id=session_id, service_session_id=service_session_id) + async def _run_before_providers( + self, + *, + session: AgentSession | None, + input_messages: list[Message] | None = None, + options: dict[str, Any] | None = None, + ) -> SessionContext: + """Run before_run on all context providers and return the session context. + + Creates a SessionContext and invokes ``before_run`` on each provider in + forward order. ``BaseHistoryProvider`` instances with + ``load_messages=False`` are skipped. + + Keyword Args: + session: The conversation session (None for stateless invocation). + input_messages: The normalized input messages. + options: Runtime options dict. + + Returns: + The SessionContext with provider context populated. + """ + provider_session = session + if provider_session is None and self.context_providers: + provider_session = AgentSession() + + session_context = SessionContext( + session_id=provider_session.session_id if provider_session else None, + service_session_id=provider_session.service_session_id if provider_session else None, + input_messages=input_messages or [], + options=options or {}, + ) + + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + if provider_session is None: + raise RuntimeError("Provider session must be available when context providers are configured.") + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=provider_session, + context=session_context, + state=provider_session.state.setdefault(provider.source_id, {}), + ) + + return session_context + async def _run_after_providers( self, *, @@ -1273,30 +1319,10 @@ async def _prepare_session_and_messages( else: chat_options = {} - provider_session = session - if provider_session is None and self.context_providers: - provider_session = AgentSession() - - session_context = SessionContext( - session_id=provider_session.session_id if provider_session else None, - service_session_id=provider_session.service_session_id if provider_session else None, - input_messages=input_messages or [], - options=options or {}, + session_context = await self._run_before_providers( + session=session, input_messages=input_messages, options=options, ) - # Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False) - for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: - continue - if provider_session is None: - raise RuntimeError("Provider session must be available when context providers are configured.") - await provider.before_run( - agent=self, # type: ignore[arg-type] - session=provider_session, - context=session_context, - state=provider_session.state.setdefault(provider.source_id, {}), - ) - # Merge provider-contributed tools into chat_options if session_context.tools: if chat_options.get("tools") is not None: diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 94253b3c34..85e27c81f9 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -2085,3 +2085,79 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo assert len(exc_info.value.contents) == 1 assert exc_info.value.contents[0].type == "oauth_consent_request" assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent" + + +async def test_base_agent_run_before_providers_creates_session_context( + client: SupportsChatGetResponse, +) -> None: + """Test that BaseAgent._run_before_providers creates a SessionContext and calls providers.""" + mock_provider = MockContextProvider(messages=[Message(role="system", text="Injected context")]) + agent = Agent(client=client, context_providers=[mock_provider]) + session = agent.create_session() + + session_context = await agent._run_before_providers( # type: ignore[reportPrivateUsage] + session=session, + input_messages=[Message(role="user", text="Hello")], + options={"temperature": 0.5}, + ) + + assert mock_provider.before_run_called + assert session_context.session_id == session.session_id + messages = session_context.get_messages(include_input=True) + assert len(messages) == 2 + assert messages[0].text == "Injected context" + assert messages[1].text == "Hello" + assert session_context.options.get("temperature") == 0.5 + + +async def test_base_agent_run_before_providers_creates_session_when_none( + client: SupportsChatGetResponse, +) -> None: + """Test that _run_before_providers creates a session when None is passed with providers.""" + mock_provider = MockContextProvider() + agent = Agent(client=client, context_providers=[mock_provider]) + + session_context = await agent._run_before_providers( # type: ignore[reportPrivateUsage] + session=None, + input_messages=[Message(role="user", text="Hello")], + ) + + assert mock_provider.before_run_called + assert session_context.session_id is not None + + +async def test_base_agent_run_before_providers_skips_history_provider_load_false( + client: SupportsChatGetResponse, +) -> None: + """Test that _run_before_providers skips BaseHistoryProvider with load_messages=False.""" + from agent_framework import BaseHistoryProvider + + class StubHistoryProvider(BaseHistoryProvider): + def __init__(self, *, load_messages: bool = True) -> None: + super().__init__(source_id=f"stub-{load_messages}", load_messages=load_messages) + self.before_run_called = False + + async def before_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None: + self.before_run_called = True + + async def after_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None: + pass + + async def get_messages(self, session_id: Any, **kwargs: Any) -> list[Message]: + return [] + + async def save_messages(self, session_id: Any, messages: Any, **kwargs: Any) -> None: + pass + + skipped = StubHistoryProvider(load_messages=False) + active = StubHistoryProvider(load_messages=True) + agent = Agent(client=client, context_providers=[skipped, active]) + session = agent.create_session() + + await agent._run_before_providers( # type: ignore[reportPrivateUsage] + session=session, + input_messages=[Message(role="user", text="Hello")], + ) + + assert not skipped.before_run_called + assert active.before_run_called diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index c7e2825b4a..352d446f34 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -16,11 +16,9 @@ AgentSession, BaseAgent, BaseContextProvider, - BaseHistoryProvider, Content, Message, ResponseStream, - SessionContext, normalize_messages, ) from agent_framework._settings import load_settings @@ -550,46 +548,6 @@ def event_handler(event: SessionEvent) -> None: finally: unsubscribe() - async def _run_before_providers( - self, - *, - session: AgentSession, - input_messages: list[Message], - options: dict[str, Any], - ) -> SessionContext: - """Run before_run on all context providers and return the session context. - - Creates a SessionContext and invokes ``before_run`` on each provider in - forward order. ``BaseHistoryProvider`` instances with - ``load_messages=False`` are skipped. - - Keyword Args: - session: The conversation session. - input_messages: The normalized input messages. - options: Runtime options dict. - - Returns: - The SessionContext with provider context populated. - """ - session_context = SessionContext( - session_id=session.session_id, - service_session_id=session.service_session_id, - input_messages=input_messages, - options=options, - ) - - for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: - continue - await provider.before_run( - agent=self, # type: ignore[arg-type] - session=session, - context=session_context, - state=session.state.setdefault(provider.source_id, {}), - ) - - return session_context - @staticmethod def _prepare_system_message( instructions: str | None, From c8487ae52a982b4e05420e7949ec867016428e08 Mon Sep 17 00:00:00 2001 From: Copilot Date: Wed, 1 Apr 2026 00:12:51 +0000 Subject: [PATCH 6/9] Address review feedback for #3984: Python: [Bug]: GitHubCopilotAgent Memory Example --- python/packages/core/agent_framework/_agents.py | 4 +++- python/packages/core/agent_framework/_evaluation.py | 9 +++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 76236911a1..4dd934fded 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -1320,7 +1320,9 @@ async def _prepare_session_and_messages( chat_options = {} session_context = await self._run_before_providers( - session=session, input_messages=input_messages, options=options, + session=session, + input_messages=input_messages, + options=options, ) # Merge provider-contributed tools into chat_options diff --git a/python/packages/core/agent_framework/_evaluation.py b/python/packages/core/agent_framework/_evaluation.py index 92a694cc36..80bc1ecffb 100644 --- a/python/packages/core/agent_framework/_evaluation.py +++ b/python/packages/core/agent_framework/_evaluation.py @@ -96,6 +96,7 @@ def split_before_memory(conversation): # Fallback: split at last user message return EvalItem._split_last_turn_static(conversation) + item.split_messages(split=split_before_memory) """ @@ -468,10 +469,7 @@ def raise_for_status(self, msg: str | None = None) -> None: """ if not self.all_passed: errored = (self.result_counts or {}).get("errored", 0) - detail = msg or ( - f"Eval run {self.run_id} {self.status}: " - f"{self.passed} passed, {self.failed} failed." - ) + detail = msg or (f"Eval run {self.run_id} {self.status}: {self.passed} passed, {self.failed} failed.") if errored: detail += f" {errored} errored." if self.report_url: @@ -1188,8 +1186,7 @@ def _coerce_result(value: Any, check_name: str) -> CheckResult: score = float(d["score"]) except (TypeError, ValueError) as exc: raise TypeError( - f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value:" - f" {d['score']!r}" + f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value: {d['score']!r}" ) from exc # Honour an explicit 'passed' override; otherwise threshold-based. passed = bool(d["passed"]) if "passed" in d else score >= float(d.get("threshold", 0.5)) From d7b772a1f7759fe7d86819a99e6627c092b56d96 Mon Sep 17 00:00:00 2001 From: Copilot Date: Wed, 1 Apr 2026 21:17:49 +0000 Subject: [PATCH 7/9] revert: keep _run_before_providers in GitHubCopilotAgent only Undo the promotion of _run_before_providers to BaseAgent. The method stays in GitHubCopilotAgent where it is needed, and _agents.py retains its original inline provider iteration in RawAgent. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../packages/core/agent_framework/_agents.py | 72 ++++++------------ .../packages/core/tests/core/test_agents.py | 76 ------------------- .../agent_framework_github_copilot/_agent.py | 42 ++++++++++ 3 files changed, 64 insertions(+), 126 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 4dd934fded..1868742111 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -439,52 +439,6 @@ def get_session(self, service_session_id: str, *, session_id: str | None = None) """ return AgentSession(session_id=session_id, service_session_id=service_session_id) - async def _run_before_providers( - self, - *, - session: AgentSession | None, - input_messages: list[Message] | None = None, - options: dict[str, Any] | None = None, - ) -> SessionContext: - """Run before_run on all context providers and return the session context. - - Creates a SessionContext and invokes ``before_run`` on each provider in - forward order. ``BaseHistoryProvider`` instances with - ``load_messages=False`` are skipped. - - Keyword Args: - session: The conversation session (None for stateless invocation). - input_messages: The normalized input messages. - options: Runtime options dict. - - Returns: - The SessionContext with provider context populated. - """ - provider_session = session - if provider_session is None and self.context_providers: - provider_session = AgentSession() - - session_context = SessionContext( - session_id=provider_session.session_id if provider_session else None, - service_session_id=provider_session.service_session_id if provider_session else None, - input_messages=input_messages or [], - options=options or {}, - ) - - for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: - continue - if provider_session is None: - raise RuntimeError("Provider session must be available when context providers are configured.") - await provider.before_run( - agent=self, # type: ignore[arg-type] - session=provider_session, - context=session_context, - state=provider_session.state.setdefault(provider.source_id, {}), - ) - - return session_context - async def _run_after_providers( self, *, @@ -1319,12 +1273,30 @@ async def _prepare_session_and_messages( else: chat_options = {} - session_context = await self._run_before_providers( - session=session, - input_messages=input_messages, - options=options, + provider_session = session + if provider_session is None and self.context_providers: + provider_session = AgentSession() + + session_context = SessionContext( + session_id=provider_session.session_id if provider_session else None, + service_session_id=provider_session.service_session_id if provider_session else None, + input_messages=input_messages or [], + options=options or {}, ) + # Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False) + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + if provider_session is None: + raise RuntimeError("Provider session must be available when context providers are configured.") + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=provider_session, + context=session_context, + state=provider_session.state.setdefault(provider.source_id, {}), + ) + # Merge provider-contributed tools into chat_options if session_context.tools: if chat_options.get("tools") is not None: diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 85e27c81f9..94253b3c34 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -2085,79 +2085,3 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo assert len(exc_info.value.contents) == 1 assert exc_info.value.contents[0].type == "oauth_consent_request" assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent" - - -async def test_base_agent_run_before_providers_creates_session_context( - client: SupportsChatGetResponse, -) -> None: - """Test that BaseAgent._run_before_providers creates a SessionContext and calls providers.""" - mock_provider = MockContextProvider(messages=[Message(role="system", text="Injected context")]) - agent = Agent(client=client, context_providers=[mock_provider]) - session = agent.create_session() - - session_context = await agent._run_before_providers( # type: ignore[reportPrivateUsage] - session=session, - input_messages=[Message(role="user", text="Hello")], - options={"temperature": 0.5}, - ) - - assert mock_provider.before_run_called - assert session_context.session_id == session.session_id - messages = session_context.get_messages(include_input=True) - assert len(messages) == 2 - assert messages[0].text == "Injected context" - assert messages[1].text == "Hello" - assert session_context.options.get("temperature") == 0.5 - - -async def test_base_agent_run_before_providers_creates_session_when_none( - client: SupportsChatGetResponse, -) -> None: - """Test that _run_before_providers creates a session when None is passed with providers.""" - mock_provider = MockContextProvider() - agent = Agent(client=client, context_providers=[mock_provider]) - - session_context = await agent._run_before_providers( # type: ignore[reportPrivateUsage] - session=None, - input_messages=[Message(role="user", text="Hello")], - ) - - assert mock_provider.before_run_called - assert session_context.session_id is not None - - -async def test_base_agent_run_before_providers_skips_history_provider_load_false( - client: SupportsChatGetResponse, -) -> None: - """Test that _run_before_providers skips BaseHistoryProvider with load_messages=False.""" - from agent_framework import BaseHistoryProvider - - class StubHistoryProvider(BaseHistoryProvider): - def __init__(self, *, load_messages: bool = True) -> None: - super().__init__(source_id=f"stub-{load_messages}", load_messages=load_messages) - self.before_run_called = False - - async def before_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None: - self.before_run_called = True - - async def after_run(self, *, agent: Any, session: Any, context: Any, state: Any) -> None: - pass - - async def get_messages(self, session_id: Any, **kwargs: Any) -> list[Message]: - return [] - - async def save_messages(self, session_id: Any, messages: Any, **kwargs: Any) -> None: - pass - - skipped = StubHistoryProvider(load_messages=False) - active = StubHistoryProvider(load_messages=True) - agent = Agent(client=client, context_providers=[skipped, active]) - session = agent.create_session() - - await agent._run_before_providers( # type: ignore[reportPrivateUsage] - session=session, - input_messages=[Message(role="user", text="Hello")], - ) - - assert not skipped.before_run_called - assert active.before_run_called diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 352d446f34..c7e2825b4a 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -16,9 +16,11 @@ AgentSession, BaseAgent, BaseContextProvider, + BaseHistoryProvider, Content, Message, ResponseStream, + SessionContext, normalize_messages, ) from agent_framework._settings import load_settings @@ -548,6 +550,46 @@ def event_handler(event: SessionEvent) -> None: finally: unsubscribe() + async def _run_before_providers( + self, + *, + session: AgentSession, + input_messages: list[Message], + options: dict[str, Any], + ) -> SessionContext: + """Run before_run on all context providers and return the session context. + + Creates a SessionContext and invokes ``before_run`` on each provider in + forward order. ``BaseHistoryProvider`` instances with + ``load_messages=False`` are skipped. + + Keyword Args: + session: The conversation session. + input_messages: The normalized input messages. + options: Runtime options dict. + + Returns: + The SessionContext with provider context populated. + """ + session_context = SessionContext( + session_id=session.session_id, + service_session_id=session.service_session_id, + input_messages=input_messages, + options=options, + ) + + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + await provider.before_run( + agent=self, # type: ignore[arg-type] + session=session, + context=session_context, + state=session.state.setdefault(provider.source_id, {}), + ) + + return session_context + @staticmethod def _prepare_system_message( instructions: str | None, From fea3324f1fb9aff044d4037449d9ea56fd9183e6 Mon Sep 17 00:00:00 2001 From: Copilot Date: Wed, 1 Apr 2026 21:42:09 +0000 Subject: [PATCH 8/9] fix: replace deprecated BaseContextProvider/BaseHistoryProvider with ContextProvider/HistoryProvider Update imports and usages in GitHubCopilotAgent and its tests to use the new non-deprecated class names from the core package. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_github_copilot/_agent.py | 6 +++--- .../tests/test_github_copilot_agent.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 81d12dfe6b..faf57d565d 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -15,9 +15,9 @@ AgentResponseUpdate, AgentSession, BaseAgent, - BaseHistoryProvider, Content, ContextProvider, + HistoryProvider, Message, ResponseStream, SessionContext, @@ -560,7 +560,7 @@ async def _run_before_providers( """Run before_run on all context providers and return the session context. Creates a SessionContext and invokes ``before_run`` on each provider in - forward order. ``BaseHistoryProvider`` instances with + forward order. ``HistoryProvider`` instances with ``load_messages=False`` are skipped. Keyword Args: @@ -579,7 +579,7 @@ async def _run_before_providers( ) for provider in self.context_providers: - if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + if isinstance(provider, HistoryProvider) and not provider.load_messages: continue await provider.before_run( agent=self, # type: ignore[arg-type] diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index dd2d7c82db..6527fcdef1 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -16,9 +16,9 @@ AgentResponse, AgentResponseUpdate, AgentSession, - BaseContextProvider, - BaseHistoryProvider, Content, + ContextProvider, + HistoryProvider, Message, ) from agent_framework.exceptions import AgentException @@ -1371,7 +1371,7 @@ async def test_session_config_excludes_permission_handler_when_not_set( assert "on_permission_request" not in config -class SpyContextProvider(BaseContextProvider): +class SpyContextProvider(ContextProvider): """A context provider that records whether its hooks are called.""" def __init__(self) -> None: @@ -1579,7 +1579,7 @@ async def test_context_messages_included_in_prompt( """Test that context messages added by providers via extend_messages are included in the prompt.""" mock_session.send_and_wait.return_value = assistant_message_event - class MessageInjectingProvider(BaseContextProvider): + class MessageInjectingProvider(ContextProvider): def __init__(self) -> None: super().__init__(source_id="msg-injector") @@ -1629,7 +1629,7 @@ def mock_on(handler: Any) -> Any: mock_session.on = mock_on - class MessageInjectingProvider(BaseContextProvider): + class MessageInjectingProvider(ContextProvider): def __init__(self) -> None: super().__init__(source_id="msg-injector") @@ -1716,7 +1716,7 @@ async def test_multiple_providers_ordering( mock_session.send_and_wait.return_value = assistant_message_event call_order: list[str] = [] - class OrderedProvider(BaseContextProvider): + class OrderedProvider(ContextProvider): def __init__(self, name: str) -> None: super().__init__(source_id=name) self.name = name @@ -1754,10 +1754,10 @@ async def test_history_provider_skip_when_load_messages_false( mock_session: MagicMock, assistant_message_event: SessionEvent, ) -> None: - """Test that BaseHistoryProvider with load_messages=False is skipped in before_run.""" + """Test that HistoryProvider with load_messages=False is skipped in before_run.""" mock_session.send_and_wait.return_value = assistant_message_event - class StubHistoryProvider(BaseHistoryProvider): + class StubHistoryProvider(HistoryProvider): def __init__(self, *, load_messages: bool = True) -> None: super().__init__(source_id="stub-history", load_messages=load_messages) self.before_run_called = False @@ -1865,7 +1865,7 @@ async def test_timeout_preserved_in_session_context_options( mock_session.send_and_wait.return_value = assistant_message_event observed_options: dict[str, Any] = {} - class OptionsObserverProvider(BaseContextProvider): + class OptionsObserverProvider(ContextProvider): def __init__(self) -> None: super().__init__(source_id="options-observer") From 4ec8ccaaab94e767244a119cb8bbaae52d0c66a4 Mon Sep 17 00:00:00 2001 From: Copilot Date: Wed, 1 Apr 2026 22:29:48 +0000 Subject: [PATCH 9/9] fix: address review feedback - reorder providers before session, wrap streaming after_run in try/except, assert after_run on skipped HistoryProvider - Move _run_before_providers before _get_or_create_session so provider contributions can affect session configuration - Wrap _run_after_providers in try/except in streaming _after_run_hook to prevent provider errors from replacing successful responses - Add after_run assertion to test_history_provider_skip_when_load_messages_false Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_github_copilot/_agent.py | 15 ++++++++++++--- .../tests/test_github_copilot_agent.py | 5 ++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index faf57d565d..4599cd3526 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -361,7 +361,10 @@ async def _after_run_hook(response: AgentResponse) -> None: sess = ctx_holder.get("session") if session_context is not None and sess is not None: session_context._response = response - await self._run_after_providers(session=sess, context=session_context) + try: + await self._run_after_providers(session=sess, context=session_context) + except Exception: + logger.exception("Error running after_run providers in streaming result hook") def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) @@ -390,11 +393,14 @@ async def _run_impl( opts: dict[str, Any] = dict(options) if options else {} timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS - copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) input_messages = normalize_messages(messages) session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) + # NOTE: session is created after providers run so that future provider-contributed + # tools/config could be folded into runtime_options before session creation. + copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) + # Build the prompt from the full set of messages in the session context, # so that any context/history provider-injected messages are included. context_messages = session_context.get_messages(include_input=True) @@ -466,11 +472,14 @@ async def _stream_updates( opts: dict[str, Any] = dict(options) if options else {} - copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) input_messages = normalize_messages(messages) session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts) + # NOTE: session is created after providers run so that future provider-contributed + # tools/config could be folded into runtime_options before session creation. + copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) + if _ctx_holder is not None: _ctx_holder["session_context"] = session_context _ctx_holder["session"] = session diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index 6527fcdef1..e91a725765 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -1780,7 +1780,7 @@ async def after_run( context: Any, state: dict[str, Any], ) -> None: - pass + self.after_run_called = True async def get_messages(self, *, session_id: str, **kwargs: Any) -> list[Message]: return [] @@ -1800,6 +1800,9 @@ async def save_messages(self, *, session_id: str, messages: list[Message], **kwa assert not skipped_provider.before_run_called assert active_provider.before_run_called + # after_run should still be called even when load_messages=False + assert skipped_provider.after_run_called + assert active_provider.after_run_called async def test_streaming_after_run_response_has_updates( self,