diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index f219c0c28f..1b3d4f2a26 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -7,7 +7,7 @@ import logging import types from collections.abc import Awaitable, Callable -from typing import Any, TypeVar, overload +from typing import Any, TypeVar, get_type_hints, overload from ..observability import create_processing_span from ._events import ( @@ -508,7 +508,7 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: """Hook called when the workflow is restored from a checkpoint. Override this method in subclasses to implement custom logic that should - run when the workflow is restored from a checkpoint. + run when the workflow is restored from the checkpoint. Args: state: The state dictionary that was saved during checkpointing. @@ -722,21 +722,55 @@ def _validate_handler_signature( if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty: raise ValueError(f"Handler {func.__name__} must have a type annotation for the message parameter") - # Validate ctx parameter is WorkflowContext and extract type args + # Build locals for forward-ref resolution. + # - For methods defined in a class body, local forward refs live in the class namespace. + # - For nested definitions, use the function's globals; those locals aren't reliably available. + # - Always include the function object itself to support refs like "cls" patterns. + localns: dict[str, Any] = {func.__name__: func} + qualname = getattr(func, "__qualname__", "") + if "." in qualname: + cls_name = qualname.split(".", 1)[0] + globalns = getattr(func, "__globals__", {}) + cls_obj = globalns.get(cls_name) + if isinstance(cls_obj, type): + localns.update(vars(cls_obj)) + + # When possible, resolve annotations using typing.get_type_hints to: + # - correctly handle non-__future__ (already evaluated) typing objects + # - resolve string forward refs using *both* globalns and localns + # - avoid custom eval logic for introspection path + hints: dict[str, Any] = {} + try: + hints = get_type_hints( + func, + globalns=getattr(func, "__globals__", None), + localns=localns, + include_extras=True, + ) + except (NameError, TypeError) as e: + raise ValueError( + f"Handler {func.__name__} type annotations could not be resolved. " + "Make sure all referenced types are defined/imported at runtime (not only under TYPE_CHECKING)." + ) from e + + # Validate ctx parameter is WorkflowContext and extract type args. + message_type = hints.get(message_param.name, message_param.annotation) + if message_type == inspect.Parameter.empty: + message_type = None + ctx_param = params[2] - if skip_message_annotation and ctx_param.annotation == inspect.Parameter.empty: + ctx_annotation: Any = hints.get(ctx_param.name, ctx_param.annotation) + + if skip_message_annotation and ctx_annotation == inspect.Parameter.empty: # When explicit types are provided via @handler(input=..., output=...), # the ctx parameter doesn't need a type annotation - types come from the decorator. output_types: list[type[Any] | types.UnionType] = [] workflow_output_types: list[type[Any] | types.UnionType] = [] else: output_types, workflow_output_types = validate_workflow_context_annotation( - ctx_param.annotation, f"parameter '{ctx_param.name}'", "Handler" + ctx_annotation, f"parameter '{ctx_param.name}'", "Handler" ) - message_type = message_param.annotation if message_param.annotation != inspect.Parameter.empty else None - ctx_annotation = ctx_param.annotation - return message_type, ctx_annotation, output_types, workflow_output_types diff --git a/python/packages/core/tests/workflow/test_executor_future.py b/python/packages/core/tests/workflow/test_executor_future.py new file mode 100644 index 0000000000..7061733c5a --- /dev/null +++ b/python/packages/core/tests/workflow/test_executor_future.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft. All rights reserved. +from __future__ import annotations + +import pytest + +from agent_framework import Executor, WorkflowContext, handler + + +class TypeA: + pass + + +class TypeB: + pass + + +class TestExecutorFutureAnnotations: + """Test suite for Executor/@handler with from __future__ import annotations.""" + + def test_handler_decorator_future_annotations(self) -> None: + class MyExecutor(Executor): + @handler + async def example(self, input: str, ctx: WorkflowContext[TypeA, TypeB]) -> None: + pass + + e = MyExecutor(id="test") + + # Ensure handler was registered correctly + assert str in e._handlers + + spec = e._handler_specs[0] + assert spec["message_type"] is str + # OutT should be TypeA; W_OutT should be TypeB + assert spec["output_types"] == [TypeA] + assert spec["workflow_output_types"] == [TypeB] + + def test_handler_decorator_future_annotations_unresolvable_forward_ref_raises_clear_error(self) -> None: + with pytest.raises(ValueError, match=r"type annotations could not be resolved"): + + class BadExecutor(Executor): + @handler + async def example( + self, input: str, ctx: "WorkflowContext['DoesNotExist']" # type: ignore[name-defined] # noqa: F821 + ) -> None: + pass + + +def test_handler_decorator_non_future_annotations_preserve_typing_objects() -> None: + """Regression test: non-__future__ typing objects must not be stringified/mis-propagated.""" + + class MyExecutor(Executor): + @handler + async def example(self, input: str, ctx: WorkflowContext[TypeA, TypeB]) -> None: + pass + + e = MyExecutor(id="test") + spec = e._handler_specs[0] + + assert spec["message_type"] is str + assert spec["ctx_annotation"].__origin__ is WorkflowContext # type: ignore[attr-defined] + assert spec["output_types"] == [TypeA] + assert spec["workflow_output_types"] == [TypeB] + + +def test_handler_explicit_types_allows_missing_ctx_annotation() -> None: + class MyExecutor(Executor): + @handler(input=str, output=int) + async def example(self, input, ctx) -> None: # type: ignore[no-untyped-def] + pass + + e = MyExecutor(id="test") + spec = e._handler_specs[0] + + assert spec["message_type"] is str + assert spec["output_types"] == [int] + assert spec["workflow_output_types"] == [] + + +def test_handler_future_annotations_forward_ref_requires_local_scope_resolves() -> None: + # Forward refs to *function-local* names cannot be resolved by get_type_hints, + # but module-level names are resolvable at decoration time. + from agent_framework import WorkflowMessage + + e = Executor(id="test", defer_discovery=True) + + class MyExecutor(Executor): + @handler + async def example(self, input: "TypeA", ctx: WorkflowContext) -> None: + pass + + e = MyExecutor(id="test") + assert e.can_handle(WorkflowMessage(data=TypeA(), source_id="mock")) + + +def test_handler_future_annotations_missing_name_resolution_failure_is_clear() -> None: + with pytest.raises(ValueError, match=r"type annotations could not be resolved"): + + class BadExecutor(Executor): + @handler + async def example( + self, input: "MissingType", ctx: WorkflowContext # type: ignore[name-defined] # noqa: F821 + ) -> None: + pass + + +def test_handler_future_annotations_message_param_forward_ref_failure_is_clear() -> None: + with pytest.raises(ValueError, match=r"type annotations could not be resolved"): + + class BadExecutor(Executor): + @handler + async def example( + self, input: "MissingMsg", ctx: WorkflowContext # type: ignore[name-defined] # noqa: F821 + ) -> None: + pass diff --git a/python/packages/core/tests/workflow/test_full_conversation.py b/python/packages/core/tests/workflow/test_full_conversation.py index 23861ecc69..20d9abd8c0 100644 --- a/python/packages/core/tests/workflow/test_full_conversation.py +++ b/python/packages/core/tests/workflow/test_full_conversation.py @@ -362,9 +362,7 @@ async def test_run_request_with_full_history_clears_service_session_id() -> None """Replaying a full conversation (including function calls) via AgentExecutorRequest must clear service_session_id so the API does not receive both previous_response_id and the same function-call items in input — which would cause a 'Duplicate item' API error.""" - tool_agent = _ToolHistoryAgent( - id="tool_agent", name="ToolAgent", summary_text="Done." - ) + tool_agent = _ToolHistoryAgent(id="tool_agent", name="ToolAgent", summary_text="Done.") tool_exec = AgentExecutor(tool_agent, id="tool_agent") spy_agent = _SessionIdCapturingAgent(id="spy_agent", name="SpyAgent") @@ -393,9 +391,7 @@ async def test_from_response_preserves_service_session_id() -> None: """from_response hands off a prior agent's full conversation to the next executor. The receiving executor's service_session_id is preserved so the API can continue the conversation using previous_response_id.""" - tool_agent = _ToolHistoryAgent( - id="tool_agent2", name="ToolAgent", summary_text="Done." - ) + tool_agent = _ToolHistoryAgent(id="tool_agent2", name="ToolAgent", summary_text="Done.") tool_exec = AgentExecutor(tool_agent, id="tool_agent2") spy_agent = _SessionIdCapturingAgent(id="spy_agent2", name="SpyAgent") @@ -403,11 +399,7 @@ async def test_from_response_preserves_service_session_id() -> None: # Simulate a prior run on the spy executor. spy_exec._session.service_session_id = "resp_PREVIOUS_RUN" # pyright: ignore[reportPrivateUsage] - wf = ( - WorkflowBuilder(start_executor=tool_exec, output_executors=[spy_exec]) - .add_edge(tool_exec, spy_exec) - .build() - ) + wf = WorkflowBuilder(start_executor=tool_exec, output_executors=[spy_exec]).add_edge(tool_exec, spy_exec).build() result = await wf.run("start") assert result.get_outputs() is not None