Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolMessage
from langchain_core.messages.utils import convert_to_openai_messages
from langchain_core.runnables import Runnable
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
Expand All @@ -35,6 +37,30 @@

logger = logging.getLogger(__name__)


def _chunk_to_message(chunk: AIMessageChunk) -> AIMessage:
"""Convert an accumulated AIMessageChunk into an AIMessage, preserving tool_calls.

When streaming chunks are accumulated via ``+``, the result has ``tool_calls``
but ``additional_kwargs["tool_calls"]`` (the OpenAI wire format) is left empty.
LLM providers read the wire format when the message is sent back in conversation
history, so we reconstruct it here using ``convert_to_openai_messages``.
"""
additional_kwargs = dict(chunk.additional_kwargs)
if chunk.tool_calls and not additional_kwargs.get("tool_calls"):
openai_msg = convert_to_openai_messages([chunk])[0]
if "tool_calls" in openai_msg:
additional_kwargs["tool_calls"] = openai_msg["tool_calls"]

return AIMessage(
content=chunk.content,
additional_kwargs=additional_kwargs,
response_metadata=chunk.response_metadata,
id=chunk.id,
usage_metadata=chunk.usage_metadata,
)


TOOL_NOT_FOUND_ERROR_MESSAGE = "There is no tool named {tool_name}. Tool must be one of {tools}."
INPUT_SCHEMA_MESSAGE = ". Arguments must be provided as a valid JSON object following this format: {schema}"
NO_INPUT_ERROR_MESSAGE = "No human input received to the agent, Please ask a valid question."
Expand Down Expand Up @@ -88,6 +114,11 @@ async def _stream_llm(self, runnable: Any, inputs: dict[str, Any]) -> AIMessage:
"""
Stream from LLM runnable. Retry logic is handled automatically by the underlying LLM client.

Accumulates streamed chunks using LangChain's ``+`` operator which preserves
``tool_calls`` and ``tool_call_chunks``, then converts the result to an
``AIMessage`` via ``_chunk_to_message``. This ensures that native tool calling
(``use_native_tool_calling=True``) works correctly with the ReAct agent.

Parameters
----------
runnable : Any
Expand All @@ -98,23 +129,21 @@ async def _stream_llm(self, runnable: Any, inputs: dict[str, Any]) -> AIMessage:
Returns
-------
AIMessage
The LLM response
The LLM response, including any tool_calls from native tool calling.
"""
content_parts = []
reasoning_parts = []
async for event in runnable.astream(inputs, config=self._runnable_config):
content_parts.append(event.content)
extra = getattr(event, 'additional_kwargs', None)
if isinstance(extra, dict):
reasoning = extra.get('reasoning_content', '')
if reasoning:
reasoning_parts.append(reasoning)

additional_kwargs: dict[str, Any] = {}
if reasoning_parts:
additional_kwargs['reasoning_content'] = "".join(reasoning_parts)

return AIMessage(content="".join(content_parts), additional_kwargs=additional_kwargs)
chunks: list[AIMessageChunk] = []
async for chunk in runnable.astream(inputs, config=self._runnable_config):
chunks.append(chunk)

if not chunks:
return AIMessage(content="")

# Accumulate using LangChain's + operator (preserves tool_call_chunks)
accumulated = chunks[0]
for c in chunks[1:]:
accumulated = accumulated + c

return _chunk_to_message(accumulated)

async def _call_llm(self, llm: Runnable, inputs: dict[str, Any]) -> AIMessage:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from langchain_core.messages import ToolMessage
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.base import BaseMessage
from langchain_core.messages.utils import convert_to_openai_messages
from langchain_core.runnables import RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import StateGraph
Expand All @@ -37,6 +36,7 @@
from nat.plugins.langchain.agent.base import AGENT_CALL_LOG_MESSAGE
from nat.plugins.langchain.agent.base import AGENT_LOG_PREFIX
from nat.plugins.langchain.agent.base import AgentDecision
from nat.plugins.langchain.agent.base import _chunk_to_message
from nat.plugins.langchain.agent.dual_node import DualNodeAgent

if typing.TYPE_CHECKING:
Expand All @@ -45,29 +45,6 @@
logger = logging.getLogger(__name__)


def _chunk_to_message(chunk: "AIMessageChunk") -> AIMessage:
"""Convert an accumulated AIMessageChunk into an AIMessage.

When streaming chunks are accumulated via ``+``, the result has ``tool_calls``
but ``additional_kwargs["tool_calls"]`` (the OpenAI wire format) is left empty.
LLM providers read the wire format when the message is sent back in conversation
history, so we reconstruct it here using ``convert_to_openai_messages``.
"""
additional_kwargs = dict(chunk.additional_kwargs)
if chunk.tool_calls and not additional_kwargs.get("tool_calls"):
openai_msg = convert_to_openai_messages([chunk])[0]
if "tool_calls" in openai_msg:
additional_kwargs["tool_calls"] = openai_msg["tool_calls"]

return AIMessage(
content=chunk.content,
additional_kwargs=additional_kwargs,
response_metadata=chunk.response_metadata,
id=chunk.id,
usage_metadata=chunk.usage_metadata,
)


class ToolCallAgentGraphState(BaseModel):
"""State schema for the Tool Calling Agent Graph"""
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the Agent
Expand Down
55 changes: 46 additions & 9 deletions packages/nvidia_nat_langchain/tests/agent/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pytest
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -65,14 +66,10 @@ class TestStreamLLM:
async def test_successful_streaming(self, base_agent):
"""Test successful streaming without retries."""
mock_runnable = Mock()
mock_event1 = Mock()
mock_event1.content = "Hello "
mock_event2 = Mock()
mock_event2.content = "world!"

async def mock_astream(inputs, **kwargs):
for event in [mock_event1, mock_event2]:
yield event
yield AIMessageChunk(content="Hello ")
yield AIMessageChunk(content="world!")

mock_runnable.astream = mock_astream

Expand Down Expand Up @@ -101,11 +98,9 @@ async def mock_astream(inputs, **kwargs):
async def test_streaming_empty_content(self, base_agent):
"""Test streaming with empty content."""
mock_runnable = Mock()
mock_event = Mock()
mock_event.content = ""

async def mock_astream(inputs, **kwargs):
yield mock_event
yield AIMessageChunk(content="")

mock_runnable.astream = mock_astream

Expand All @@ -116,6 +111,48 @@ async def mock_astream(inputs, **kwargs):
assert isinstance(result, AIMessage)
assert result.content == ""

async def test_streaming_preserves_tool_calls(self, base_agent):
"""Test that tool_calls from native tool calling are preserved."""
mock_runnable = Mock()

async def mock_astream(inputs, **kwargs):
yield AIMessageChunk(
content="I'll check the time.",
tool_call_chunks=[{
"name": "get_time",
"args": '{"tz": "UTC"}',
"id": "call_123",
"index": 0,
"type": "tool_call_chunk",
}],
)

mock_runnable.astream = mock_astream

inputs = {"messages": [HumanMessage(content="test")]}
result = await base_agent._stream_llm(mock_runnable, inputs)

assert isinstance(result, AIMessage)
assert result.content == "I'll check the time."
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "get_time"

Comment on lines +114 to +139
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Assert OpenAI wire-format tool_calls in this regression test.

This test confirms result.tool_calls, but it does not verify result.additional_kwargs["tool_calls"], which is the critical reconstructed wire format for downstream provider compatibility.

Suggested test addition
         assert isinstance(result, AIMessage)
         assert result.content == "I'll check the time."
         assert len(result.tool_calls) == 1
         assert result.tool_calls[0]["name"] == "get_time"
+        assert result.additional_kwargs.get("tool_calls"), "Expected OpenAI wire-format tool_calls to be preserved"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@packages/nvidia_nat_langchain/tests/agent/test_base.py` around lines 114 -
139, Update the test_streaming_preserves_tool_calls regression test to also
assert that the reconstructed OpenAI wire-format is present in the returned
message: after calling base_agent._stream_llm(mock_runnable, inputs) and
validating result.tool_calls, add an assertion that
result.additional_kwargs["tool_calls"] exists, is a list, and contains an entry
whose "name" == "get_time" (or otherwise matches the expected wire-format
payload produced by mock_astream). This change ensures _stream_llm preserves
both the parsed tool_calls and the serialized wire-format under
additional_kwargs["tool_calls"] for downstream compatibility.

async def test_streaming_no_chunks_returns_empty(self, base_agent):
"""Test that empty stream returns empty AIMessage."""
mock_runnable = Mock()

async def mock_astream(inputs, **kwargs):
return
yield # makes this an async generator

mock_runnable.astream = mock_astream

inputs = {"messages": [HumanMessage(content="test")]}
result = await base_agent._stream_llm(mock_runnable, inputs)

assert isinstance(result, AIMessage)
assert result.content == ""


class TestCallLLM:
"""Test the _call_llm method."""
Expand Down