diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py index c9f3fe3a12..ce227ac0aa 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py @@ -25,8 +25,10 @@ from langchain_core.callbacks import AsyncCallbackHandler from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage from langchain_core.messages import ToolMessage +from langchain_core.messages.utils import convert_to_openai_messages from langchain_core.runnables import Runnable from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool @@ -35,6 +37,30 @@ logger = logging.getLogger(__name__) + +def _chunk_to_message(chunk: AIMessageChunk) -> AIMessage: + """Convert an accumulated AIMessageChunk into an AIMessage, preserving tool_calls. + + When streaming chunks are accumulated via ``+``, the result has ``tool_calls`` + but ``additional_kwargs["tool_calls"]`` (the OpenAI wire format) is left empty. + LLM providers read the wire format when the message is sent back in conversation + history, so we reconstruct it here using ``convert_to_openai_messages``. + """ + additional_kwargs = dict(chunk.additional_kwargs) + if chunk.tool_calls and not additional_kwargs.get("tool_calls"): + openai_msg = convert_to_openai_messages([chunk])[0] + if "tool_calls" in openai_msg: + additional_kwargs["tool_calls"] = openai_msg["tool_calls"] + + return AIMessage( + content=chunk.content, + additional_kwargs=additional_kwargs, + response_metadata=chunk.response_metadata, + id=chunk.id, + usage_metadata=chunk.usage_metadata, + ) + + TOOL_NOT_FOUND_ERROR_MESSAGE = "There is no tool named {tool_name}. Tool must be one of {tools}." INPUT_SCHEMA_MESSAGE = ". Arguments must be provided as a valid JSON object following this format: {schema}" NO_INPUT_ERROR_MESSAGE = "No human input received to the agent, Please ask a valid question." @@ -88,6 +114,11 @@ async def _stream_llm(self, runnable: Any, inputs: dict[str, Any]) -> AIMessage: """ Stream from LLM runnable. Retry logic is handled automatically by the underlying LLM client. + Accumulates streamed chunks using LangChain's ``+`` operator which preserves + ``tool_calls`` and ``tool_call_chunks``, then converts the result to an + ``AIMessage`` via ``_chunk_to_message``. This ensures that native tool calling + (``use_native_tool_calling=True``) works correctly with the ReAct agent. + Parameters ---------- runnable : Any @@ -98,23 +129,21 @@ async def _stream_llm(self, runnable: Any, inputs: dict[str, Any]) -> AIMessage: Returns ------- AIMessage - The LLM response + The LLM response, including any tool_calls from native tool calling. """ - content_parts = [] - reasoning_parts = [] - async for event in runnable.astream(inputs, config=self._runnable_config): - content_parts.append(event.content) - extra = getattr(event, 'additional_kwargs', None) - if isinstance(extra, dict): - reasoning = extra.get('reasoning_content', '') - if reasoning: - reasoning_parts.append(reasoning) - - additional_kwargs: dict[str, Any] = {} - if reasoning_parts: - additional_kwargs['reasoning_content'] = "".join(reasoning_parts) - - return AIMessage(content="".join(content_parts), additional_kwargs=additional_kwargs) + chunks: list[AIMessageChunk] = [] + async for chunk in runnable.astream(inputs, config=self._runnable_config): + chunks.append(chunk) + + if not chunks: + return AIMessage(content="") + + # Accumulate using LangChain's + operator (preserves tool_call_chunks) + accumulated = chunks[0] + for c in chunks[1:]: + accumulated = accumulated + c + + return _chunk_to_message(accumulated) async def _call_llm(self, llm: Runnable, inputs: dict[str, Any]) -> AIMessage: """ diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/tool_calling_agent/agent.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/tool_calling_agent/agent.py index b6eb73e062..ac48137b34 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/tool_calling_agent/agent.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/tool_calling_agent/agent.py @@ -24,7 +24,6 @@ from langchain_core.messages import ToolMessage from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.base import BaseMessage -from langchain_core.messages.utils import convert_to_openai_messages from langchain_core.runnables import RunnableLambda from langchain_core.tools import BaseTool from langgraph.graph import StateGraph @@ -37,6 +36,7 @@ from nat.plugins.langchain.agent.base import AGENT_CALL_LOG_MESSAGE from nat.plugins.langchain.agent.base import AGENT_LOG_PREFIX from nat.plugins.langchain.agent.base import AgentDecision +from nat.plugins.langchain.agent.base import _chunk_to_message from nat.plugins.langchain.agent.dual_node import DualNodeAgent if typing.TYPE_CHECKING: @@ -45,29 +45,6 @@ logger = logging.getLogger(__name__) -def _chunk_to_message(chunk: "AIMessageChunk") -> AIMessage: - """Convert an accumulated AIMessageChunk into an AIMessage. - - When streaming chunks are accumulated via ``+``, the result has ``tool_calls`` - but ``additional_kwargs["tool_calls"]`` (the OpenAI wire format) is left empty. - LLM providers read the wire format when the message is sent back in conversation - history, so we reconstruct it here using ``convert_to_openai_messages``. - """ - additional_kwargs = dict(chunk.additional_kwargs) - if chunk.tool_calls and not additional_kwargs.get("tool_calls"): - openai_msg = convert_to_openai_messages([chunk])[0] - if "tool_calls" in openai_msg: - additional_kwargs["tool_calls"] = openai_msg["tool_calls"] - - return AIMessage( - content=chunk.content, - additional_kwargs=additional_kwargs, - response_metadata=chunk.response_metadata, - id=chunk.id, - usage_metadata=chunk.usage_metadata, - ) - - class ToolCallAgentGraphState(BaseModel): """State schema for the Tool Calling Agent Graph""" messages: list[BaseMessage] = Field(default_factory=list) # input and output of the Agent diff --git a/packages/nvidia_nat_langchain/tests/agent/test_base.py b/packages/nvidia_nat_langchain/tests/agent/test_base.py index 3263e912ed..0ea6fb78e9 100644 --- a/packages/nvidia_nat_langchain/tests/agent/test_base.py +++ b/packages/nvidia_nat_langchain/tests/agent/test_base.py @@ -20,6 +20,7 @@ import pytest from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessageChunk from langchain_core.messages import HumanMessage from langchain_core.messages import ToolMessage from langchain_core.runnables import RunnableConfig @@ -65,14 +66,10 @@ class TestStreamLLM: async def test_successful_streaming(self, base_agent): """Test successful streaming without retries.""" mock_runnable = Mock() - mock_event1 = Mock() - mock_event1.content = "Hello " - mock_event2 = Mock() - mock_event2.content = "world!" async def mock_astream(inputs, **kwargs): - for event in [mock_event1, mock_event2]: - yield event + yield AIMessageChunk(content="Hello ") + yield AIMessageChunk(content="world!") mock_runnable.astream = mock_astream @@ -101,11 +98,9 @@ async def mock_astream(inputs, **kwargs): async def test_streaming_empty_content(self, base_agent): """Test streaming with empty content.""" mock_runnable = Mock() - mock_event = Mock() - mock_event.content = "" async def mock_astream(inputs, **kwargs): - yield mock_event + yield AIMessageChunk(content="") mock_runnable.astream = mock_astream @@ -116,6 +111,48 @@ async def mock_astream(inputs, **kwargs): assert isinstance(result, AIMessage) assert result.content == "" + async def test_streaming_preserves_tool_calls(self, base_agent): + """Test that tool_calls from native tool calling are preserved.""" + mock_runnable = Mock() + + async def mock_astream(inputs, **kwargs): + yield AIMessageChunk( + content="I'll check the time.", + tool_call_chunks=[{ + "name": "get_time", + "args": '{"tz": "UTC"}', + "id": "call_123", + "index": 0, + "type": "tool_call_chunk", + }], + ) + + mock_runnable.astream = mock_astream + + inputs = {"messages": [HumanMessage(content="test")]} + result = await base_agent._stream_llm(mock_runnable, inputs) + + assert isinstance(result, AIMessage) + assert result.content == "I'll check the time." + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_time" + + async def test_streaming_no_chunks_returns_empty(self, base_agent): + """Test that empty stream returns empty AIMessage.""" + mock_runnable = Mock() + + async def mock_astream(inputs, **kwargs): + return + yield # makes this an async generator + + mock_runnable.astream = mock_astream + + inputs = {"messages": [HumanMessage(content="test")]} + result = await base_agent._stream_llm(mock_runnable, inputs) + + assert isinstance(result, AIMessage) + assert result.content == "" + class TestCallLLM: """Test the _call_llm method."""