Skip to content

Commit 0ef1ebc

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: GenAI Client(evals) - Apply async function for agent run
PiperOrigin-RevId: 825073570
1 parent c81f912 commit 0ef1ebc

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,9 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
10701070
)
10711071

10721072
mock_agent_engine = mock.Mock()
1073-
mock_agent_engine.create_session.return_value = {"id": "session1"}
1073+
mock_agent_engine.async_create_session = mock.AsyncMock(
1074+
return_value={"id": "session1"}
1075+
)
10741076
stream_query_return_value = [
10751077
{
10761078
"id": "1",
@@ -1086,7 +1088,13 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
10861088
},
10871089
]
10881090

1089-
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
1091+
async def _async_iterator(iterable):
1092+
for item in iterable:
1093+
yield item
1094+
1095+
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1096+
stream_query_return_value
1097+
)
10901098
mock_vertexai_client.return_value.agent_engines.get.return_value = (
10911099
mock_agent_engine
10921100
)
@@ -1100,10 +1108,10 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
11001108
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
11011109
name="projects/test-project/locations/us-central1/reasoningEngines/123"
11021110
)
1103-
mock_agent_engine.create_session.assert_called_once_with(
1111+
mock_agent_engine.async_create_session.assert_called_once_with(
11041112
user_id="123", state={"a": "1"}
11051113
)
1106-
mock_agent_engine.stream_query.assert_called_once_with(
1114+
mock_agent_engine.async_stream_query.assert_called_once_with(
11071115
user_id="123", session_id="session1", message="agent prompt"
11081116
)
11091117

@@ -1154,7 +1162,9 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
11541162
)
11551163

11561164
mock_agent_engine = mock.Mock()
1157-
mock_agent_engine.create_session.return_value = {"id": "session1"}
1165+
mock_agent_engine.async_create_session = mock.AsyncMock(
1166+
return_value={"id": "session1"}
1167+
)
11581168
stream_query_return_value = [
11591169
{
11601170
"id": "1",
@@ -1170,7 +1180,13 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
11701180
},
11711181
]
11721182

1173-
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
1183+
async def _async_iterator(iterable):
1184+
for item in iterable:
1185+
yield item
1186+
1187+
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1188+
stream_query_return_value
1189+
)
11741190
mock_vertexai_client.return_value.agent_engines.get.return_value = (
11751191
mock_agent_engine
11761192
)
@@ -1184,10 +1200,10 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
11841200
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
11851201
name="projects/test-project/locations/us-central1/reasoningEngines/123"
11861202
)
1187-
mock_agent_engine.create_session.assert_called_once_with(
1203+
mock_agent_engine.async_create_session.assert_called_once_with(
11881204
user_id="123", state={"a": "1"}
11891205
)
1190-
mock_agent_engine.stream_query.assert_called_once_with(
1206+
mock_agent_engine.async_stream_query.assert_called_once_with(
11911207
user_id="123", session_id="session1", message="agent prompt"
11921208
)
11931209

vertexai/_genai/_evals_common.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
from . import evals
4343
from . import types
44+
from . import agent_engines
4445

4546
try:
4647
import litellm
@@ -62,12 +63,9 @@ def _get_agent_engine_instance(
6263
if not hasattr(_thread_local_data, "agent_engine_instances"):
6364
_thread_local_data.agent_engine_instances = {}
6465
if agent_name not in _thread_local_data.agent_engine_instances:
65-
client = vertexai.Client(
66-
project=api_client.project,
67-
location=api_client.location,
68-
)
66+
agent_engines_module = agent_engines.AgentEngines(api_client_=api_client)
6967
_thread_local_data.agent_engine_instances[agent_name] = (
70-
client.agent_engines.get(name=agent_name)
68+
agent_engines_module.get(name=agent_name)
7169
)
7270
return _thread_local_data.agent_engine_instances[agent_name]
7371

@@ -278,10 +276,12 @@ def agent_run_wrapper(
278276
and type(agent_engine).__name__ == "AgentEngine"
279277
):
280278
agent_engine_instance = agent_engine
281-
return inference_fn_arg(
282-
row=row_arg,
283-
contents=contents_arg,
284-
agent_engine=agent_engine_instance,
279+
return asyncio.run(
280+
inference_fn_arg(
281+
row=row_arg,
282+
contents=contents_arg,
283+
agent_engine=agent_engine_instance,
284+
)
285285
)
286286

287287
future = executor.submit(
@@ -1262,7 +1262,7 @@ def _run_agent(
12621262
)
12631263

12641264

1265-
def _execute_agent_run_with_retry(
1265+
async def _execute_agent_run_with_retry(
12661266
row: pd.Series,
12671267
contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict],
12681268
agent_engine: types.AgentEngine,
@@ -1284,7 +1284,7 @@ def _execute_agent_run_with_retry(
12841284
)
12851285
user_id = session_inputs.user_id
12861286
session_state = session_inputs.state
1287-
session = agent_engine.create_session(
1287+
session = await agent_engine.async_create_session(
12881288
user_id=user_id,
12891289
state=session_state,
12901290
)
@@ -1295,7 +1295,7 @@ def _execute_agent_run_with_retry(
12951295
for attempt in range(max_retries):
12961296
try:
12971297
responses = []
1298-
for event in agent_engine.stream_query(
1298+
async for event in agent_engine.async_stream_query(
12991299
user_id=user_id,
13001300
session_id=session["id"],
13011301
message=contents,
@@ -1314,7 +1314,7 @@ def _execute_agent_run_with_retry(
13141314
)
13151315
if attempt == max_retries - 1:
13161316
return {"error": f"Resource exhausted after retries: {e}"}
1317-
time.sleep(2**attempt)
1317+
await asyncio.sleep(2**attempt)
13181318
except Exception as e: # pylint: disable=broad-exception-caught
13191319
logger.error(
13201320
"Unexpected error during generate_content on attempt %d/%d: %s",
@@ -1325,7 +1325,7 @@ def _execute_agent_run_with_retry(
13251325

13261326
if attempt == max_retries - 1:
13271327
return {"error": f"Failed after retries: {e}"}
1328-
time.sleep(1)
1328+
await asyncio.sleep(1)
13291329
return {"error": f"Failed to get agent run results after {max_retries} retries"}
13301330

13311331

0 commit comments

Comments
 (0)