Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions python/packages/core/agent_framework/_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def split_before_memory(conversation):
# Fallback: split at last user message
return EvalItem._split_last_turn_static(conversation)


item.split_messages(split=split_before_memory)
"""

Expand Down Expand Up @@ -468,10 +469,7 @@ def raise_for_status(self, msg: str | None = None) -> None:
"""
if not self.all_passed:
errored = (self.result_counts or {}).get("errored", 0)
detail = msg or (
f"Eval run {self.run_id} {self.status}: "
f"{self.passed} passed, {self.failed} failed."
)
detail = msg or (f"Eval run {self.run_id} {self.status}: {self.passed} passed, {self.failed} failed.")
if errored:
detail += f" {errored} errored."
if self.report_url:
Expand Down Expand Up @@ -1188,8 +1186,7 @@ def _coerce_result(value: Any, check_name: str) -> CheckResult:
score = float(d["score"])
except (TypeError, ValueError) as exc:
raise TypeError(
f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value:"
f" {d['score']!r}"
f"Function evaluator '{check_name}' returned dict with non-numeric 'score' value: {d['score']!r}"
) from exc
# Honour an explicit 'passed' override; otherwise threshold-based.
passed = bool(d["passed"]) if "passed" in d else score >= float(d.get("threshold", 0.5))
Expand Down
99 changes: 74 additions & 25 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import sys
import uuid
from collections.abc import AsyncIterable, Awaitable, Sequence
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
Expand Down Expand Up @@ -152,7 +152,8 @@ def run(
session: AgentSession | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...

@overload
Expand All @@ -164,7 +165,8 @@ async def run(
session: AgentSession | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AgentResponse: ...

def run(
Expand All @@ -175,7 +177,8 @@ def run(
session: AgentSession | None = None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]:
"""Get a response from the workflow agent.

Expand All @@ -192,8 +195,12 @@ def run(
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
used to load and restore the checkpoint. When provided without checkpoint_id,
enables checkpointing for this run.
**kwargs: Additional keyword arguments passed through to underlying workflow
and tool functions.
function_invocation_kwargs: Keyword arguments forwarded to tool invocations in
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
mapping of kwargs for all tool invocations.
client_kwargs: Keyword arguments forwarded to chat client calls in
subagents. Either a mapping of agent name/executor id to kwargs, or a flat
mapping of kwargs for all chat client calls.

Returns:
When stream=True: An AsyncIterable[AgentResponseUpdate] for streaming updates.
Expand All @@ -208,10 +215,26 @@ def run(
response_id = str(uuid.uuid4())
if stream:
return ResponseStream(
self._run_stream_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs),
self._run_stream_impl(
messages,
response_id,
session,
checkpoint_id,
checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
),
finalizer=AgentResponse.from_updates,
)
return self._run_impl(messages, response_id, session, checkpoint_id, checkpoint_storage, **kwargs)
return self._run_impl(
messages,
response_id,
session,
checkpoint_id,
checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)

async def _run_impl(
self,
Expand All @@ -220,7 +243,8 @@ async def _run_impl(
session: AgentSession | None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AgentResponse:
"""Internal implementation of non-streaming execution.

Expand All @@ -230,8 +254,8 @@ async def _run_impl(
session: The agent session for conversation context.
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and tool functions.
function_invocation_kwargs: Optional kwargs for tool invocations.
client_kwargs: Optional kwargs for chat client calls.

Returns:
An AgentResponse representing the workflow execution results.
Expand Down Expand Up @@ -264,7 +288,12 @@ async def _run_impl(

output_events: list[WorkflowEvent[Any]] = []
async for event in self._run_core(
session_messages, checkpoint_id, checkpoint_storage, streaming=False, **kwargs
session_messages,
checkpoint_id,
checkpoint_storage,
streaming=False,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
if event.type == "output" or event.type == "request_info":
output_events.append(event)
Expand All @@ -285,7 +314,8 @@ async def _run_stream_impl(
session: AgentSession | None,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Internal implementation of streaming execution.

Expand All @@ -295,8 +325,8 @@ async def _run_stream_impl(
session: The agent session for conversation context.
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and tool functions.
function_invocation_kwargs: Optional kwargs for tool invocations.
client_kwargs: Optional kwargs for chat client calls.

Yields:
AgentResponseUpdate objects representing the workflow execution progress.
Expand Down Expand Up @@ -329,7 +359,12 @@ async def _run_stream_impl(
session_messages: list[Message] = session_context.get_messages(include_input=True)
all_updates: list[AgentResponseUpdate] = []
async for event in self._run_core(
session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs
session_messages,
checkpoint_id,
checkpoint_storage,
streaming=True,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
for update in updates:
Expand All @@ -349,7 +384,8 @@ async def _run_core(
checkpoint_id: str | None,
checkpoint_storage: CheckpointStorage | None,
streaming: bool,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Core implementation that yields workflow events for both streaming and non-streaming modes.

Expand All @@ -358,8 +394,8 @@ async def _run_core(
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
streaming: Whether to use streaming workflow methods.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and tool functions.
function_invocation_kwargs: Optional kwargs for tool invocations.
client_kwargs: Optional kwargs for chat client calls.

Yields:
WorkflowEvent objects from the workflow execution.
Expand All @@ -371,10 +407,19 @@ async def _run_core(
if bool(self.pending_requests):
function_responses = self._process_pending_requests(input_messages)
if streaming:
async for event in self.workflow.run(responses=function_responses, stream=True, **kwargs):
async for event in self.workflow.run(
responses=function_responses,
stream=True,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(responses=function_responses, **kwargs):
for event in await self.workflow.run(
responses=function_responses,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event

elif checkpoint_id is not None:
Expand All @@ -383,14 +428,16 @@ async def _run_core(
stream=True,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event

Expand All @@ -400,14 +447,16 @@ async def _run_core(
message=input_messages,
stream=True,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
message=input_messages,
checkpoint_storage=checkpoint_storage,
**kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event

Expand Down
Loading
Loading