diff --git a/flo_ai/models/flo_node.py b/flo_ai/models/flo_node.py index 67303dbe..4dbb86f5 100644 --- a/flo_ai/models/flo_node.py +++ b/flo_ai/models/flo_node.py @@ -25,15 +25,22 @@ def __init__( name: str, kind: ExecutableType, delegate: Optional[Delegate] = None, + async_func: functools.partial = None, ) -> None: self.name = name self.func = func self.kind: ExecutableType = kind self.delegate = delegate + self.async_func = async_func def invoke(self, query, config): return self.func({STATE_NAME_MESSAGES: [HumanMessage(content=query)]}) + async def ainvoke(self, query, config): + return await self.async_func( + {STATE_NAME_MESSAGES: [HumanMessage(content=query)]} + ) + class Builder: def __init__(self, session: FloSession) -> None: self.session = session @@ -47,7 +54,17 @@ def build_from_agent(self, flo_agent: FloAgent) -> 'FloNode': model_name=flo_agent.model_name, data_collector=flo_agent.data_collector, ) - return FloNode(agent_func, flo_agent.name, flo_agent.type) + agent_func_async = functools.partial( + FloNode.Builder.__async_teamflo_agent_node, + agent=flo_agent.runnable, + name=flo_agent.name, + session=self.session, + model_name=flo_agent.model_name, + data_collector=flo_agent.data_collector, + ) + return FloNode( + agent_func, flo_agent.name, flo_agent.type, async_func=agent_func_async + ) def build_from_reflection(self, flo_agent: FloReflectionAgent) -> 'FloNode': agent_func = functools.partial( @@ -154,6 +171,57 @@ def __teamflo_agent_node( ] return {STATE_NAME_MESSAGES: [AIMessage(content=output, name=name)]} + @staticmethod + async def __async_teamflo_agent_node( + state: TeamFloAgentState, + agent: AgentExecutor, + name: str, + session: FloSession, + model_name: str, + data_collector: Optional[FloDataCollector] = None, + ): + agent_cbs: List[FloAgentCallback] = FloNode.Builder.__filter_callbacks( + session, FloAgentCallback + ) + flo_cbs: List[FloCallback] = FloNode.Builder.__filter_callbacks( + session, FloCallback + ) + [ + callback.on_agent_start(name, model_name, state['messages'], **{}) + for callback in agent_cbs + ] + [ + callback.on_agent_start(name, model_name, state['messages'], **{}) + for callback in flo_cbs + ] + try: + result = await agent.ainvoke(state) + output = result if isinstance(result, str) else result['output'] + if data_collector is not None: + get_logger().info( + 'appending output to data collector', session=session + ) + data_collector.append(output) + except Exception as e: + [ + callback.on_agent_error(name, model_name, e, **{}) + for callback in agent_cbs + ] + [ + callback.on_agent_error(name, model_name, e, **{}) + for callback in flo_cbs + ] + raise e + [ + callback.on_agent_end(name, model_name, output, **{}) + for callback in agent_cbs + ] + [ + callback.on_agent_start(name, model_name, output, **{}) + for callback in flo_cbs + ] + return {STATE_NAME_MESSAGES: [AIMessage(content=output, name=name)]} + @staticmethod def __filter_callbacks(session: FloSession, type: Type): cbs = session.callbacks diff --git a/pyproject.toml b/pyproject.toml index 2c70774e..98b2b6e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "flo-ai" -version = "0.0.5-dev4" +version = "0.0.5-dev5" description = "A easy way to create structured AI agents" authors = ["vizsatiz "] license = "MIT" diff --git a/setup.py b/setup.py index 30d2caba..56ab90c1 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='flo-ai', - version='0.0.5-dev4', + version='0.0.5-dev5', author='Rootflo', description='Create composable AI agents', long_description=long_description,