Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
BaseAgent,
Content,
ContextProvider,
HistoryProvider,
Message,
ResponseStream,
SessionContext,
normalize_messages,
)
from agent_framework._settings import load_settings
Expand Down Expand Up @@ -352,13 +354,25 @@ def run(
AgentException: If the request fails.
"""
if stream:
ctx_holder: dict[str, Any] = {}

async def _after_run_hook(response: AgentResponse) -> None:
session_context = ctx_holder.get("session_context")
sess = ctx_holder.get("session")
if session_context is not None and sess is not None:
session_context._response = response
try:
await self._run_after_providers(session=sess, context=session_context)
except Exception:
logger.exception("Error running after_run providers in streaming result hook")

def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
return AgentResponse.from_updates(updates)

return ResponseStream(
self._stream_updates(messages=messages, session=session, options=options),
self._stream_updates(messages=messages, session=session, options=options, _ctx_holder=ctx_holder),
finalizer=_finalize,
result_hooks=[_after_run_hook],
)
return self._run_impl(messages=messages, session=session, options=options)

Expand All @@ -377,11 +391,22 @@ async def _run_impl(
session = self.create_session()

opts: dict[str, Any] = dict(options) if options else {}
timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS

Comment thread
giles17 marked this conversation as resolved.
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
input_messages = normalize_messages(messages)
prompt = "\n".join([message.text for message in input_messages])

session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)

# NOTE: session is created after providers run so that future provider-contributed
# tools/config could be folded into runtime_options before session creation.
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)

# Build the prompt from the full set of messages in the session context,
# so that any context/history provider-injected messages are included.
context_messages = session_context.get_messages(include_input=True)
prompt = "\n".join([message.text for message in context_messages])
if session_context.instructions:
prompt = "\n".join(session_context.instructions) + "\n" + prompt
message_options = cast(MessageOptions, {"prompt": prompt})

try:
Expand All @@ -408,14 +433,18 @@ async def _run_impl(
)
response_id = message_id

return AgentResponse(messages=response_messages, response_id=response_id)
response = AgentResponse(messages=response_messages, response_id=response_id)
session_context._response = response # type: ignore[assignment]
await self._run_after_providers(session=session, context=session_context)
return response

async def _stream_updates(
self,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
_ctx_holder: dict[str, Any] | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Internal method to stream updates from GitHub Copilot.

Expand All @@ -425,6 +454,9 @@ async def _stream_updates(
Keyword Args:
session: The conversation session associated with the message(s).
options: Runtime options (model, timeout, etc.).
_ctx_holder: Internal dict populated with session_context and session
so that the caller (via a ResponseStream result_hook) can run
after_run providers without duplicating the updates buffer.

Yields:
AgentResponseUpdate items.
Expand All @@ -440,9 +472,23 @@ async def _stream_updates(

opts: dict[str, Any] = dict(options) if options else {}

copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
input_messages = normalize_messages(messages)
prompt = "\n".join([message.text for message in input_messages])

session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)

# NOTE: session is created after providers run so that future provider-contributed
# tools/config could be folded into runtime_options before session creation.
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)

if _ctx_holder is not None:
_ctx_holder["session_context"] = session_context
_ctx_holder["session"] = session

# Build the prompt from the full session context so provider-injected messages are included.
context_messages = session_context.get_messages(include_input=True)
prompt = "\n".join([message.text for message in context_messages])
if session_context.instructions:
prompt = "\n".join(session_context.instructions) + "\n" + prompt
message_options = cast(MessageOptions, {"prompt": prompt})

queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue()
Expand Down Expand Up @@ -513,6 +559,46 @@ def event_handler(event: SessionEvent) -> None:
finally:
Comment thread
giles17 marked this conversation as resolved.
unsubscribe()

async def _run_before_providers(
self,
*,
session: AgentSession,
input_messages: list[Message],
options: dict[str, Any],
) -> SessionContext:
"""Run before_run on all context providers and return the session context.

Creates a SessionContext and invokes ``before_run`` on each provider in
forward order. ``HistoryProvider`` instances with
``load_messages=False`` are skipped.

Keyword Args:
session: The conversation session.
input_messages: The normalized input messages.
options: Runtime options dict.

Returns:
The SessionContext with provider context populated.
"""
session_context = SessionContext(
session_id=session.session_id,
service_session_id=session.service_session_id,
input_messages=input_messages,
options=options,
)

for provider in self.context_providers:
if isinstance(provider, HistoryProvider) and not provider.load_messages:
continue
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=session,
context=session_context,
state=session.state.setdefault(provider.source_id, {}),
)

return session_context

@staticmethod
def _prepare_system_message(
instructions: str | None,
Expand Down
Loading
Loading