Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import types
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar, overload
from typing import Any, TypeVar, get_type_hints, overload

from ..observability import create_processing_span
from ._events import (
Expand Down Expand Up @@ -508,7 +508,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Hook called when the workflow is restored from a checkpoint.

Override this method in subclasses to implement custom logic that should
run when the workflow is restored from a checkpoint.
run when the workflow is restored from the checkpoint.

Args:
state: The state dictionary that was saved during checkpointing.
Expand Down Expand Up @@ -722,21 +722,55 @@ def _validate_handler_signature(
if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty:
raise ValueError(f"Handler {func.__name__} must have a type annotation for the message parameter")

# Validate ctx parameter is WorkflowContext and extract type args
# Build locals for forward-ref resolution.
# - For methods defined in a class body, local forward refs live in the class namespace.
# - For nested definitions, use the function's globals; those locals aren't reliably available.
# - Always include the function object itself to support refs like "cls" patterns.
localns: dict[str, Any] = {func.__name__: func}
qualname = getattr(func, "__qualname__", "")
if "." in qualname:
cls_name = qualname.split(".", 1)[0]
globalns = getattr(func, "__globals__", {})
cls_obj = globalns.get(cls_name)
if isinstance(cls_obj, type):
localns.update(vars(cls_obj))

# When possible, resolve annotations using typing.get_type_hints to:
# - correctly handle non-__future__ (already evaluated) typing objects
# - resolve string forward refs using *both* globalns and localns
# - avoid custom eval logic for introspection path
hints: dict[str, Any] = {}
try:
hints = get_type_hints(
func,
globalns=getattr(func, "__globals__", None),
localns=localns,
include_extras=True,
)
except (NameError, TypeError) as e:
raise ValueError(
f"Handler {func.__name__} type annotations could not be resolved. "
"Make sure all referenced types are defined/imported at runtime (not only under TYPE_CHECKING)."
) from e

# Validate ctx parameter is WorkflowContext and extract type args.
message_type = hints.get(message_param.name, message_param.annotation)
if message_type == inspect.Parameter.empty:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regression risk: when the annotation is not a str, ensure you propagate the actual annotation object (e.g., message_param.annotation) and not the inspect.Parameter itself; otherwise downstream message-type handling will break for non-__future__ code.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 66d8d3f0457e3c9d811828641c73f2c1bfa6575b.

message_type = None

ctx_param = params[2]
if skip_message_annotation and ctx_param.annotation == inspect.Parameter.empty:
ctx_annotation: Any = hints.get(ctx_param.name, ctx_param.annotation)

if skip_message_annotation and ctx_annotation == inspect.Parameter.empty:
# When explicit types are provided via @handler(input=..., output=...),
# the ctx parameter doesn't need a type annotation - types come from the decorator.
output_types: list[type[Any] | types.UnionType] = []
workflow_output_types: list[type[Any] | types.UnionType] = []
else:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking: for non-string annotations, ensure ctx_annotation is the annotation value (e.g., ctx_param.annotation) and not the inspect.Parameter object; validate_workflow_context_annotation() expects a typing annotation like WorkflowContext[...] and will fail for the common non-future case if it receives the wrong type.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 66d8d3f0457e3c9d811828641c73f2c1bfa6575b.

output_types, workflow_output_types = validate_workflow_context_annotation(
ctx_param.annotation, f"parameter '{ctx_param.name}'", "Handler"
ctx_annotation, f"parameter '{ctx_param.name}'", "Handler"
)

message_type = message_param.annotation if message_param.annotation != inspect.Parameter.empty else None
ctx_annotation = ctx_param.annotation

return message_type, ctx_annotation, output_types, workflow_output_types


Expand Down
114 changes: 114 additions & 0 deletions python/packages/core/tests/workflow/test_executor_future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations

import pytest

from agent_framework import Executor, WorkflowContext, handler


class TypeA:
pass


class TypeB:
pass


class TestExecutorFutureAnnotations:
"""Test suite for Executor/@handler with from __future__ import annotations."""

def test_handler_decorator_future_annotations(self) -> None:
class MyExecutor(Executor):
@handler
async def example(self, input: str, ctx: WorkflowContext[TypeA, TypeB]) -> None:
pass

e = MyExecutor(id="test")

# Ensure handler was registered correctly
assert str in e._handlers

spec = e._handler_specs[0]
assert spec["message_type"] is str
# OutT should be TypeA; W_OutT should be TypeB
assert spec["output_types"] == [TypeA]
assert spec["workflow_output_types"] == [TypeB]

def test_handler_decorator_future_annotations_unresolvable_forward_ref_raises_clear_error(self) -> None:
with pytest.raises(ValueError, match=r"type annotations could not be resolved"):

class BadExecutor(Executor):
@handler
async def example(
self, input: str, ctx: "WorkflowContext['DoesNotExist']" # type: ignore[name-defined] # noqa: F821
) -> None:
pass


def test_handler_decorator_non_future_annotations_preserve_typing_objects() -> None:
"""Regression test: non-__future__ typing objects must not be stringified/mis-propagated."""

class MyExecutor(Executor):
@handler
async def example(self, input: str, ctx: WorkflowContext[TypeA, TypeB]) -> None:
pass

e = MyExecutor(id="test")
spec = e._handler_specs[0]

assert spec["message_type"] is str
assert spec["ctx_annotation"].__origin__ is WorkflowContext # type: ignore[attr-defined]
assert spec["output_types"] == [TypeA]
assert spec["workflow_output_types"] == [TypeB]


def test_handler_explicit_types_allows_missing_ctx_annotation() -> None:
class MyExecutor(Executor):
@handler(input=str, output=int)
async def example(self, input, ctx) -> None: # type: ignore[no-untyped-def]
pass

e = MyExecutor(id="test")
spec = e._handler_specs[0]

assert spec["message_type"] is str
assert spec["output_types"] == [int]
assert spec["workflow_output_types"] == []


def test_handler_future_annotations_forward_ref_requires_local_scope_resolves() -> None:
# Forward refs to *function-local* names cannot be resolved by get_type_hints,
# but module-level names are resolvable at decoration time.
from agent_framework import WorkflowMessage

e = Executor(id="test", defer_discovery=True)

class MyExecutor(Executor):
@handler
async def example(self, input: "TypeA", ctx: WorkflowContext) -> None:
pass

e = MyExecutor(id="test")
assert e.can_handle(WorkflowMessage(data=TypeA(), source_id="mock"))


def test_handler_future_annotations_missing_name_resolution_failure_is_clear() -> None:
with pytest.raises(ValueError, match=r"type annotations could not be resolved"):

class BadExecutor(Executor):
@handler
async def example(
self, input: "MissingType", ctx: WorkflowContext # type: ignore[name-defined] # noqa: F821
) -> None:
pass


def test_handler_future_annotations_message_param_forward_ref_failure_is_clear() -> None:
with pytest.raises(ValueError, match=r"type annotations could not be resolved"):

class BadExecutor(Executor):
@handler
async def example(
self, input: "MissingMsg", ctx: WorkflowContext # type: ignore[name-defined] # noqa: F821
) -> None:
pass
14 changes: 3 additions & 11 deletions python/packages/core/tests/workflow/test_full_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,7 @@ async def test_run_request_with_full_history_clears_service_session_id() -> None
"""Replaying a full conversation (including function calls) via AgentExecutorRequest must
clear service_session_id so the API does not receive both previous_response_id and the
same function-call items in input — which would cause a 'Duplicate item' API error."""
tool_agent = _ToolHistoryAgent(
id="tool_agent", name="ToolAgent", summary_text="Done."
)
tool_agent = _ToolHistoryAgent(id="tool_agent", name="ToolAgent", summary_text="Done.")
tool_exec = AgentExecutor(tool_agent, id="tool_agent")

spy_agent = _SessionIdCapturingAgent(id="spy_agent", name="SpyAgent")
Expand Down Expand Up @@ -393,21 +391,15 @@ async def test_from_response_preserves_service_session_id() -> None:
"""from_response hands off a prior agent's full conversation to the next executor.
The receiving executor's service_session_id is preserved so the API can continue
the conversation using previous_response_id."""
tool_agent = _ToolHistoryAgent(
id="tool_agent2", name="ToolAgent", summary_text="Done."
)
tool_agent = _ToolHistoryAgent(id="tool_agent2", name="ToolAgent", summary_text="Done.")
tool_exec = AgentExecutor(tool_agent, id="tool_agent2")

spy_agent = _SessionIdCapturingAgent(id="spy_agent2", name="SpyAgent")
spy_exec = AgentExecutor(spy_agent, id="spy_agent2")
# Simulate a prior run on the spy executor.
spy_exec._session.service_session_id = "resp_PREVIOUS_RUN" # pyright: ignore[reportPrivateUsage]

wf = (
WorkflowBuilder(start_executor=tool_exec, output_executors=[spy_exec])
.add_edge(tool_exec, spy_exec)
.build()
)
wf = WorkflowBuilder(start_executor=tool_exec, output_executors=[spy_exec]).add_edge(tool_exec, spy_exec).build()

result = await wf.run("start")
assert result.get_outputs() is not None
Expand Down