diff --git a/.github/workflows/build-project.yml b/.github/workflows/build-project.yml index a5c9622a..89103371 100644 --- a/.github/workflows/build-project.yml +++ b/.github/workflows/build-project.yml @@ -29,8 +29,12 @@ jobs: - name: Install dependencies run: poetry lock && poetry install + - name: Run Ruff + run: poetry run pre-commit run --all-files + - name: Run build run: poetry build - name: Run tests - run: poetry run pytest \ No newline at end of file + run: poetry run pytest + \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..44cfaecd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.7.3 + hooks: + # Run the linter. + - id: ruff + files: \.py$ + # Run the formatter. + - id: ruff-format + files: \.py$ + args: [--config, format.quote-style = 'single'] diff --git a/examples/delegator_example.py b/examples/delegator_example.py index d1a3b4cf..33cc592d 100644 --- a/examples/delegator_example.py +++ b/examples/delegator_example.py @@ -3,6 +3,7 @@ from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv + load_dotenv() yaml_data = """ @@ -41,11 +42,10 @@ llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') session = FloSession(llm).register_tool( - name="TavilySearchResults", - tool=TavilySearchResults() + name='TavilySearchResults', tool=TavilySearchResults() ) flo: Flo = Flo.build(session, yaml=yaml_data) -flo.draw_to_file("delegate.png", xray=True) +flo.draw_to_file('delegate.png', xray=True) # data = flo.invoke(input_prompt) -# print((data['messages'][-1]).content) \ No newline at end of file +# print((data['messages'][-1]).content) diff --git a/examples/hierarchical_blogging_team.py b/examples/hierarchical_blogging_team.py index 5b4c0eaa..07252624 100644 --- a/examples/hierarchical_blogging_team.py +++ b/examples/hierarchical_blogging_team.py @@ -1,9 +1,9 @@ -from flo_ai import Flo from flo_ai.core import Flo from langchain_openai import ChatOpenAI from flo_ai import FloSession from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv + load_dotenv() yaml_data = """ @@ -47,6 +47,6 @@ llm = ChatOpenAI(temperature=0, model_name='gpt-4o') session = FloSession(llm).register_tool( - name="TavilySearchResults", tool=TavilySearchResults() + name='TavilySearchResults', tool=TavilySearchResults() ) -flo: Flo = Flo.build(session, yaml=yaml_data) \ No newline at end of file +flo: Flo = Flo.build(session, yaml=yaml_data) diff --git a/examples/linear_router_team.py b/examples/linear_router_team.py index 3ba9f30e..ba26eed2 100644 --- a/examples/linear_router_team.py +++ b/examples/linear_router_team.py @@ -1,36 +1,44 @@ -from flo_ai import Flo from flo_ai.core import Flo from langchain_openai import ChatOpenAI from flo_ai import FloSession from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv -load_dotenv() + from typing import Optional, Type from pydantic import BaseModel, Field -from langchain.tools import BaseTool, StructuredTool, tool +from langchain.tools import BaseTool from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) +load_dotenv() + + class FetchTrxInput(BaseModel): - reference_number: str = Field(description="The transaction reference number") + reference_number: str = Field(description='The transaction reference number') + class FetchTransactionTool(BaseTool): - name = "fetch_transactions" - description = "useful for when you want to fetch the transaction details given reference number" + name = 'fetch_transactions' + description = 'useful for when you want to fetch the transaction details given reference number' args_schema: Type[BaseModel] = FetchTrxInput def _run( - self, reference_number: str, run_manager: Optional[CallbackManagerForToolRun] = None + self, + reference_number: str, + run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: - return "The transaction happened on 23/07/2024 IST and it failed because there was not enough balance in the account" + return 'The transaction happened on 23/07/2024 IST and it failed because there was not enough balance in the account' async def _arun( - self, reference_number: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None + self, + reference_number: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: - return "The transaction happened on 23/07/2024 IST and it failed because there was not enough balance in the account" + return 'The transaction happened on 23/07/2024 IST and it failed because there was not enough balance in the account' + yaml_data = """ apiVersion: flo/alpha-v1 @@ -59,11 +67,11 @@ async def _arun( llm = ChatOpenAI(temperature=0, model_name='gpt-4o') session = FloSession(llm).register_tool( - name="TavilySearchResults", tool=TavilySearchResults() + name='TavilySearchResults', tool=TavilySearchResults() ) flo: Flo = Flo.build(session, yaml=yaml_data) for event in flo.stream(input_prompt): for k, v in event.items(): - if k != "__end__": + if k != '__end__': print(v) diff --git a/examples/llm_extensibility.py b/examples/llm_extensibility.py index da32c72b..6176d97f 100644 --- a/examples/llm_extensibility.py +++ b/examples/llm_extensibility.py @@ -1,14 +1,15 @@ -import os from flo_ai import Flo from flo_ai import FloSession from pydantic import BaseModel, Field -from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_openai import ChatOpenAI from flo_ai.tools.flo_tool import flotool from dotenv import load_dotenv +import warnings + load_dotenv() -import warnings + warnings.simplefilter('default', DeprecationWarning) gpt35 = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo') @@ -16,19 +17,28 @@ gpt_4o = ChatOpenAI(temperature=0, model_name='gpt-4o') session = FloSession(gpt35) -session.register_model("bronze", gpt35) -session.register_model("silver", gpt_4o_mini) -session.register_model("gold", gpt_4o) +session.register_model('bronze', gpt35) +session.register_model('silver', gpt_4o_mini) +session.register_model('gold', gpt_4o) + class SendEmailInput(BaseModel): - to: str = Field(description="Comma seperared list of users emails to which email needs to be sent") - message: str = Field(description="The email text to be sent") + to: str = Field( + description='Comma seperared list of users emails to which email needs to be sent' + ) + message: str = Field(description='The email text to be sent') + -@flotool("email_triage", "useful for when you need to send an email to someone", argument_contract=SendEmailInput) +@flotool( + 'email_triage', + 'useful for when you need to send an email to someone', + argument_contract=SendEmailInput, +) def email_tool(to: str, message: str): - return f"Email sent successfully to: {to}" + return f'Email sent successfully to: {to}' -session.register_tool("SendEmailTool", email_tool) + +session.register_tool('SendEmailTool', email_tool) agent_yaml = """ apiVersion: flo/alpha-v1 @@ -70,10 +80,6 @@ def email_tool(to: str, message: str): flo: Flo = Flo.build(session, yaml=agent_yaml) for s in flo.stream(input_prompt): - if "__end__" not in s: + if '__end__' not in s: print(s) - print("----") - - - - + print('----') diff --git a/examples/rag_tool.py b/examples/rag_tool.py index f2ea7a8a..6df5ccff 100644 --- a/examples/rag_tool.py +++ b/examples/rag_tool.py @@ -1,7 +1,6 @@ from flo_ai import Flo from flo_ai import FloSession from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_chroma import Chroma from langchain_community.document_loaders import TextLoader from langchain_community.embeddings.sentence_transformer import ( @@ -10,19 +9,19 @@ from langchain_text_splitters import CharacterTextSplitter from dotenv import load_dotenv -load_dotenv() +from flo_ai.retrievers.flo_retriever import FloRagBuilder +from flo_ai.retrievers.flo_compression_pipeline import FloCompressionPipeline + +load_dotenv() llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') -session = FloSession( - llm, - log_level="ERROR" -) +session = FloSession(llm, log_level='ERROR') # load the document and split it into chunks -loader = TextLoader("./examples/rag_document.txt") +loader = TextLoader('./examples/rag_document.txt') documents = loader.load() # split it into chunks @@ -30,26 +29,27 @@ docs = text_splitter.split_documents(documents) # create the open-source embedding function -embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") +embedding_function = SentenceTransformerEmbeddings(model_name='all-MiniLM-L6-v2') # load it into Chroma db = Chroma.from_documents(docs, embedding_function) -from flo_ai.retrievers.flo_retriever import FloRagBuilder -from flo_ai.retrievers.flo_compression_pipeline import FloCompressionPipeline llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') session = FloSession(llm) builder = FloRagBuilder(session, db.as_retriever()) -compression_pipeline = FloCompressionPipeline(OpenAIEmbeddings(model="text-embedding-3-small")) +compression_pipeline = FloCompressionPipeline( + OpenAIEmbeddings(model='text-embedding-3-small') +) compression_pipeline.add_embedding_reduntant_filter() compression_pipeline.add_embedding_relevant_filter() # Reranking -retriever_tool = builder.with_compression(compression_pipeline).build_rag_tool(name="HousingLoanRetreiver", - description="Tool to fetch data around housing loans") -session.register_tool(name="HousingLoanTool", tool=retriever_tool) - +retriever_tool = builder.with_compression(compression_pipeline).build_rag_tool( + name='HousingLoanRetreiver', description='Tool to fetch data around housing loans' +) +session.register_tool(name='HousingLoanTool', tool=retriever_tool) + simple_tool_agent = """ apiVersion: flo/alpha-v1 kind: FloAgent @@ -64,4 +64,4 @@ flo = Flo.build(session, simple_tool_agent) -print(flo.invoke("Whats interest rate on loan")) \ No newline at end of file +print(flo.invoke('Whats interest rate on loan')) diff --git a/examples/rag_with_reranking.py b/examples/rag_with_reranking.py index 91999dd1..f3bc9eb3 100644 --- a/examples/rag_with_reranking.py +++ b/examples/rag_with_reranking.py @@ -4,26 +4,31 @@ from langchain_openai import OpenAIEmbeddings from langchain_openai import ChatOpenAI from dotenv import load_dotenv -load_dotenv() + from flo_ai import FloSession from flo_ai.retrievers.flo_retriever import FloRagBuilder from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from flo_ai.retrievers.flo_compression_pipeline import FloCompressionPipeline +import logging -db_url = os.getenv("MONGO_DB_URL") + +load_dotenv() +db_url = os.getenv('MONGO_DB_URL') connection_timeout = 60000 -mongo_client = MongoClient(db_url, connectTimeoutMS=connection_timeout, socketTimeoutMS=connection_timeout) -mongo_embedding_collection = (mongo_client - .get_database("dohabank") - .get_collection("products")) +mongo_client = MongoClient( + db_url, connectTimeoutMS=connection_timeout, socketTimeoutMS=connection_timeout +) +mongo_embedding_collection = mongo_client.get_database('dohabank').get_collection( + 'products' +) store = MongoDBAtlasVectorSearch( collection=mongo_embedding_collection, - embedding_key="embedding", - embedding=OpenAIEmbeddings(model="text-embedding-3-small"), - index_name="bank-products-index", + embedding_key='embedding', + embedding=OpenAIEmbeddings(model='text-embedding-3-small'), + index_name='bank-products-index', ) @@ -31,31 +36,35 @@ session = FloSession(llm) rag_builder = FloRagBuilder(session, store.as_retriever()) -import logging logging.basicConfig() -logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO) +logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO) custom_prompt = ChatPromptTemplate.from_messages( - [ - ("system", """You are an assistant for banking employees, of Doha Bank. + [ + ( + 'system', + """You are an assistant for banking employees, of Doha Bank. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. - Try to answer questions as bullet points that are easy to read"""), - MessagesPlaceholder(variable_name="chat_history"), - ("human", "{question}"), - ] - ) - -from langchain.schema import BaseMessage -compression_pipeline = FloCompressionPipeline(OpenAIEmbeddings(model="text-embedding-3-small")) + Try to answer questions as bullet points that are easy to read""", + ), + MessagesPlaceholder(variable_name='chat_history'), + ('human', '{question}'), + ] +) + + +compression_pipeline = FloCompressionPipeline( + OpenAIEmbeddings(model='text-embedding-3-small') +) compression_pipeline.add_embedding_reduntant_filter() compression_pipeline.add_embedding_relevant_filter() -rag = rag_builder.with_prompt( - custom_prompt -).with_multi_query().with_compression( - compression_pipeline - ).build_rag() -print(rag.invoke({ "question": "What are the documents applying for housing loan" })) - +rag = ( + rag_builder.with_prompt(custom_prompt) + .with_multi_query() + .with_compression(compression_pipeline) + .build_rag() +) +print(rag.invoke({'question': 'What are the documents applying for housing loan'})) diff --git a/examples/reflection_example.py b/examples/reflection_example.py index ba893af8..7a9ad7b9 100644 --- a/examples/reflection_example.py +++ b/examples/reflection_example.py @@ -3,6 +3,7 @@ from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv + load_dotenv() yaml_data = """ @@ -40,11 +41,10 @@ llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') session = FloSession(llm).register_tool( - name="TavilySearchResults", - tool=TavilySearchResults() + name='TavilySearchResults', tool=TavilySearchResults() ) flo: Flo = Flo.build(session, yaml=yaml_data) -flo.draw_to_file("event.png", xray=True) +flo.draw_to_file('event.png', xray=True) data = flo.invoke(input_prompt) -print((data['messages'][-1]).content) \ No newline at end of file +print((data['messages'][-1]).content) diff --git a/examples/simple_blogging_team.py b/examples/simple_blogging_team.py index 006ffc41..9692f287 100644 --- a/examples/simple_blogging_team.py +++ b/examples/simple_blogging_team.py @@ -2,6 +2,7 @@ from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv + load_dotenv() yaml_data = """ @@ -31,15 +32,16 @@ """ llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') -session = FloSession(llm, log_level="INFO").register_tool( - name="TavilySearchResults", - tool=TavilySearchResults() -).register_tool( - name="DummyTool", - tool=TavilySearchResults(description="Tool is a dummy tool, dont use this") +session = ( + FloSession(llm, log_level='INFO') + .register_tool(name='TavilySearchResults', tool=TavilySearchResults()) + .register_tool( + name='DummyTool', + tool=TavilySearchResults(description='Tool is a dummy tool, dont use this'), + ) ) -Flo.set_log_level("INFO") +Flo.set_log_level('INFO') flo: Flo = Flo.build(session, yaml=yaml_data) # data = flo.invoke(input_prompt) -# print((data['messages'][-1]).content) \ No newline at end of file +# print((data['messages'][-1]).content) diff --git a/examples/tool_agent.py b/examples/tool_agent.py index a97fa722..5ae43c31 100644 --- a/examples/tool_agent.py +++ b/examples/tool_agent.py @@ -1,37 +1,28 @@ from flo_ai import Flo from flo_ai import FloSession -from flo_ai.common.flo_langchain_logger import FloLangchainLogger -from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_openai import ChatOpenAI from dotenv import load_dotenv +from langchain.tools import BaseTool + load_dotenv() -from langchain_community.tools.tavily_search.tool import TavilySearchResults -from flo_ai.common.flo_langchain_logger import FloLangchainLogger llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') -session = FloSession( - llm, - log_level="ERROR" -) +session = FloSession(llm, log_level='ERROR') -from langchain.tools import BaseTool class PrintStateTool(BaseTool): - name = "printStateTool" - description = "Just print the state" - - def _run( - self, **kwargs - ) -> str: - return "Print tool call success" - -session.register_tool( - name="printStateTool", - tool=PrintStateTool() -) - + name = 'printStateTool' + description = 'Just print the state' + + def _run(self, **kwargs) -> str: + return 'Print tool call success' + + +session.register_tool(name='printStateTool', tool=PrintStateTool()) + simple_tool_agent = """ apiVersion: flo/alpha-v1 kind: FloRoutedTeam @@ -48,6 +39,6 @@ def _run( - name: printStateTool """ -flo = Flo.build(session, simple_tool_agent, log_level="ERROR") +flo = Flo.build(session, simple_tool_agent, log_level='ERROR') -print(flo.invoke("Testing ....")) \ No newline at end of file +print(flo.invoke('Testing ....')) diff --git a/examples/tool_error_handling.py b/examples/tool_error_handling.py index d54dded4..5686df5a 100644 --- a/examples/tool_error_handling.py +++ b/examples/tool_error_handling.py @@ -5,39 +5,39 @@ from typing import List from dotenv import load_dotenv from flo_ai.tools.flo_tool import flotool +import asyncio load_dotenv() llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') -session = FloSession( - llm, - log_level="ERROR" -) +session = FloSession(llm, log_level='ERROR') + class AdditionToolInput(BaseModel): - numbers: List[int] = Field(..., description="List of numbers to add") + numbers: List[int] = Field(..., description='List of numbers to add') + -import asyncio # Use flotool to define the tool function -@flotool(name="AdditionTool", description="Tool to add numbers") +@flotool(name='AdditionTool', description='Tool to add numbers') async def addition_tool(numbers: List[int]) -> str: result = sum(numbers) - await asyncio.sleep(1) - return f"The sum is {result}" + await asyncio.sleep(1) + return f'The sum is {result}' + -@flotool(name="MultiplicationTool", description="Tool to multiply numbers to get product of numbers") +@flotool( + name='MultiplicationTool', + description='Tool to multiply numbers to get product of numbers', +) def mul_tool(numbers: List[int]) -> str: result = sum(numbers) - # await asyncio.sleep(1) - return f"The product is {result}" - -session.register_tool( - name="Adder", - tool=addition_tool -).register_tool( - name="Multiplier", - tool=mul_tool + # await asyncio.sleep(1) + return f'The product is {result}' + + +session.register_tool(name='Adder', tool=addition_tool).register_tool( + name='Multiplier', tool=mul_tool ) simple_weather_checking_agent = """ @@ -54,16 +54,18 @@ def mul_tool(numbers: List[int]) -> str: - name: Multiplier """ -from IPython.display import Image, display -flo = Flo.build(session, simple_weather_checking_agent, log_level="ERROR") -import asyncio +flo = Flo.build(session, simple_weather_checking_agent, log_level='ERROR') + # Assuming flo.ainvoke is your async method for invoking the tool or chain async def invoke_main(): - result = await flo.async_invoke("Whats the sum of 1, 3, 4, 5 and 6, and their product") + result = await flo.async_invoke( + 'Whats the sum of 1, 3, 4, 5 and 6, and their product' + ) print(result) + asyncio.run(invoke_main()) @@ -77,6 +79,3 @@ async def invoke_main(): # print("----") # asyncio.run(stream_main()) - - - diff --git a/flo_ai/__init__.py b/flo_ai/__init__.py index 9cd58afa..d7f3fe39 100644 --- a/flo_ai/__init__.py +++ b/flo_ai/__init__.py @@ -1,9 +1,9 @@ -from flo_ai.core import Flo -from flo_ai.models.flo_agent import FloAgent -from flo_ai.router.flo_supervisor import FloSupervisor -from flo_ai.models.flo_team import FloTeam -from flo_ai.router.flo_linear import FloLinear -from flo_ai.state.flo_session import FloSession -from flo_ai.retrievers.flo_retriever import FloRagBuilder -from flo_ai.common.flo_logger import get_logger -from flo_ai.common.flo_langchain_logger import FloLangchainLogger +from flo_ai.core import Flo as Flo +from flo_ai.models.flo_agent import FloAgent as FloAgent +from flo_ai.router.flo_supervisor import FloSupervisor as FloSupervisor +from flo_ai.models.flo_team import FloTeam as FloTeam +from flo_ai.router.flo_linear import FloLinear as FloLinear +from flo_ai.state.flo_session import FloSession as FloSession +from flo_ai.retrievers.flo_retriever import FloRagBuilder as FloRagBuilder +from flo_ai.common.flo_logger import get_logger as get_logger +from flo_ai.common.flo_langchain_logger import FloLangchainLogger as FloLangchainLogger diff --git a/flo_ai/builders/yaml_builder.py b/flo_ai/builders/yaml_builder.py index 71cda50f..81cee6c8 100644 --- a/flo_ai/builders/yaml_builder.py +++ b/flo_ai/builders/yaml_builder.py @@ -1,5 +1,10 @@ from flo_ai.models.flo_team import FloTeam -from flo_ai.yaml.config import (FloRoutedTeamConfig, TeamConfig, AgentConfig, FloAgentConfig) +from flo_ai.yaml.config import ( + FloRoutedTeamConfig, + TeamConfig, + AgentConfig, + FloAgentConfig, +) from flo_ai.models.flo_executable import ExecutableFlo from flo_ai.state.flo_session import FloSession from flo_ai.router.flo_router_factory import FloRouterFactory @@ -8,6 +13,7 @@ from flo_ai.yaml.validators import raise_for_name_error from flo_ai.common.flo_logger import get_logger + def build_supervised_team(session: FloSession) -> ExecutableFlo: name_set = set() flo_config = session.config @@ -20,11 +26,15 @@ def build_supervised_team(session: FloSession) -> ExecutableFlo: validate_names(name_set, agent_config.name, session) agent = AgentFactory.create(session, agent_config) return agent - + + def validate_team(name_set: set, team_config: TeamConfig, session: FloSession): validate_names(name_set, team_config.name, session) -def parse_and_build_subteams(session: FloSession, team_config: TeamConfig, name_set = set()) -> ExecutableFlo: + +def parse_and_build_subteams( + session: FloSession, team_config: TeamConfig, name_set=set() +) -> ExecutableFlo: flo_team = None validate_team(name_set, team_config, session) if team_config.agents: @@ -42,9 +52,12 @@ def parse_and_build_subteams(session: FloSession, team_config: TeamConfig, name_ flo_routed_team = router.build_routed_team() return flo_routed_team + def validate_names(name_set: set, name, session: FloSession): raise_for_name_error(name) if name in name_set: get_logger().error(f"Duplicate name found: '{name}'", session) - raise FloException(f"The name '{name}' is duplicate in the config. Make sure all teams and agents have unique names") - name_set.add(name) \ No newline at end of file + raise FloException( + f"The name '{name}' is duplicate in the config. Make sure all teams and agents have unique names" + ) + name_set.add(name) diff --git a/flo_ai/common/flo_langchain_logger.py b/flo_ai/common/flo_langchain_logger.py index d45b7558..3ea56f8a 100644 --- a/flo_ai/common/flo_langchain_logger.py +++ b/flo_ai/common/flo_langchain_logger.py @@ -3,47 +3,58 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult from flo_ai.common.flo_logger import get_logger -class FloLangchainLogger(BaseCallbackHandler): - def __init__(self, - session_id: str): +class FloLangchainLogger(BaseCallbackHandler): + def __init__(self, session_id: str): self.session_id = session_id - def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: - get_logger().debug(f"onLLMStart: {prompts}", self) + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + get_logger().debug(f'onLLMStart: {prompts}', self) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - get_logger().debug(f"onNewToken: {token}", self) + get_logger().debug(f'onNewToken: {token}', self) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - get_logger().debug(f"onLLMEnd: {response.generations}", self) + get_logger().debug(f'onLLMEnd: {response.generations}', self) - def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: - get_logger().debug(f"onLLMEnd: {error}", self) + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + get_logger().debug(f'onLLMEnd: {error}', self) - def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None: - get_logger().debug(f"onChainStart: {inputs}", self) + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + get_logger().debug(f'onChainStart: {inputs}', self) def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - get_logger().debug(f"onChainEnd: {outputs}", self) + get_logger().debug(f'onChainEnd: {outputs}', self) - def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: - get_logger().debug(f"onChainError: {error}", self) + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + get_logger().debug(f'onChainError: {error}', self) - def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> None: - get_logger().debug(f"onToolStart: {input_str}", self) + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + get_logger().debug(f'onToolStart: {input_str}', self) def on_tool_end(self, output: str, **kwargs: Any) -> None: - get_logger().debug(f"onToolEnd: {output}", self) + get_logger().debug(f'onToolEnd: {output}', self) - def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: - get_logger().debug(f"onToolError: {error}", self) + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + get_logger().debug(f'onToolError: {error}', self) def on_text(self, text: str, **kwargs: Any) -> None: - get_logger().debug(f"onText: {text}", self) + get_logger().debug(f'onText: {text}', self) def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - get_logger().debug(f"onAgentAction: {action.tool} - {action.tool_input}", self) + get_logger().debug(f'onAgentAction: {action.tool} - {action.tool_input}', self) def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - get_logger().debug(f"onAgentFinish: {finish.return_values}", self) \ No newline at end of file + get_logger().debug(f'onAgentFinish: {finish.return_values}', self) diff --git a/flo_ai/common/flo_logger.py b/flo_ai/common/flo_logger.py index b59e5903..eadc2ebe 100644 --- a/flo_ai/common/flo_logger.py +++ b/flo_ai/common/flo_logger.py @@ -4,17 +4,18 @@ from logging.handlers import RotatingFileHandler from dataclasses import dataclass -DEFAULT_LOGGER_NAME = "FloAI" -DEFAULT_LOG_LEVEL = "ERROR" +DEFAULT_LOGGER_NAME = 'FloAI' +DEFAULT_LOG_LEVEL = 'ERROR' LEVEL_MAP = { 'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING': logging.WARNING, 'ERROR': logging.ERROR, - 'CRITICAL': logging.CRITICAL + 'CRITICAL': logging.CRITICAL, } + @dataclass class FloLogConfig: name: str @@ -28,8 +29,8 @@ def get_level(self) -> int: return LEVEL_MAP.get(self.level.upper(), logging.ERROR) return self.level -class FloLoggerUtil(logging.Logger): +class FloLoggerUtil(logging.Logger): def __init__(self, config: FloLogConfig): level = config.get_level() super().__init__(config.name, level) @@ -39,15 +40,19 @@ def __init__(self, config: FloLogConfig): self.setConfig(config) def setConfig(self, config: FloLogConfig): - formatter = logging.Formatter('%(asctime)s - %(name)s - %(session)s - %(levelname)s - %(message)s') - + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(session)s - %(levelname)s - %(message)s' + ) + console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) console_handler.setLevel(self.level) self.addHandler(console_handler) if config.file_path: - file_handler = RotatingFileHandler(config.file_path, maxBytes=config.max_bytes) + file_handler = RotatingFileHandler( + config.file_path, maxBytes=config.max_bytes + ) file_handler.setFormatter(formatter) file_handler.setLevel(self.level) self.addHandler(file_handler) @@ -57,15 +62,17 @@ def setLevel(self, level: Union[str, int]) -> None: level = LEVEL_MAP.get(level.upper(), logging.ERROR) super().setLevel(level) for handler in self.handlers: - print("Setting levels in handler: " + str(level)) + print('Setting levels in handler: ' + str(level)) handler.setLevel(level) - def _log(self, level: int, msg: str, session: Optional[str] = None, *args, **kwargs): + def _log( + self, level: int, msg: str, session: Optional[str] = None, *args, **kwargs + ): if not self.isEnabledFor(level): return if kwargs.get('extra') is None: kwargs['extra'] = {} - kwargs['extra']['session'] = f"[{session.session_id}]" if session else "[-]" + kwargs['extra']['session'] = f'[{session.session_id}]' if session else '[-]' super()._log(level, msg, args, **kwargs) def debug(self, msg: str, session: Optional[Any] = None, *args, **kwargs): @@ -83,22 +90,33 @@ def error(self, msg: str, session: Optional[Any] = None, *args, **kwargs): def critical(self, msg: str, session: Optional[Any] = None, *args, **kwargs): self._log(logging.CRITICAL, msg, session, *args, **kwargs) -logging_cache: Dict[str, FloLoggerUtil] = dict({ - DEFAULT_LOGGER_NAME: FloLoggerUtil(FloLogConfig(DEFAULT_LOGGER_NAME, DEFAULT_LOG_LEVEL)) -}) -def get_logger(config: FloLogConfig = FloLogConfig(DEFAULT_LOGGER_NAME)) -> FloLoggerUtil: +logging_cache: Dict[str, FloLoggerUtil] = dict( + { + DEFAULT_LOGGER_NAME: FloLoggerUtil( + FloLogConfig(DEFAULT_LOGGER_NAME, DEFAULT_LOG_LEVEL) + ) + } +) + + +def get_logger( + config: FloLogConfig = FloLogConfig(DEFAULT_LOGGER_NAME), +) -> FloLoggerUtil: if config.name not in logging_cache: logging_cache[config.name] = FloLoggerUtil(config) return logging_cache[config.name] + def set_log_level_internal(level: Union[str, int]) -> None: updated_logger = FloLoggerUtil(FloLogConfig(DEFAULT_LOGGER_NAME, level)) logging_cache[DEFAULT_LOGGER_NAME] = updated_logger + def set_log_config_internal(config: FloLogConfig): updated_logger = FloLoggerUtil(config) logging_cache[DEFAULT_LOGGER_NAME] = updated_logger + def set_logger_internal(logger: logging.Logger): - logging_cache[DEFAULT_LOGGER_NAME] = logger \ No newline at end of file + logging_cache[DEFAULT_LOGGER_NAME] = logger diff --git a/flo_ai/constants/common_constants.py b/flo_ai/constants/common_constants.py index 3b0bfc3b..f0b809a2 100644 --- a/flo_ai/constants/common_constants.py +++ b/flo_ai/constants/common_constants.py @@ -1,4 +1,4 @@ -DOCUMENTATION_WEBSITE = "https://flo-ai.rootflo.ai" -DOCUMENTATION_ROUTER_ANCHOR = f"{DOCUMENTATION_WEBSITE}/basics/routers" -DOCUMENTATION_AGENT_ANCHOR = f"{DOCUMENTATION_WEBSITE}/basics/agents" -DOCUMENTATION_AGENT_TOOLS = f"{DOCUMENTATION_WEBSITE}/basics/tools" \ No newline at end of file +DOCUMENTATION_WEBSITE = 'https://flo-ai.rootflo.ai' +DOCUMENTATION_ROUTER_ANCHOR = f'{DOCUMENTATION_WEBSITE}/basics/routers' +DOCUMENTATION_AGENT_ANCHOR = f'{DOCUMENTATION_WEBSITE}/basics/agents' +DOCUMENTATION_AGENT_TOOLS = f'{DOCUMENTATION_WEBSITE}/basics/tools' diff --git a/flo_ai/constants/flo_node_contants.py b/flo_ai/constants/flo_node_contants.py index 6106d5c3..e9d17bcc 100644 --- a/flo_ai/constants/flo_node_contants.py +++ b/flo_ai/constants/flo_node_contants.py @@ -1,2 +1,2 @@ -INTERNAL_NODE_REFLECTION_MANAGER = "f/ReflectionManager" -INTERNAL_NODE_DELEGATION_MANAGER = "f/DelegationManager" \ No newline at end of file +INTERNAL_NODE_REFLECTION_MANAGER = 'f/ReflectionManager' +INTERNAL_NODE_DELEGATION_MANAGER = 'f/DelegationManager' diff --git a/flo_ai/constants/prompt_constants.py b/flo_ai/constants/prompt_constants.py index dfb12ac4..c8ca9388 100644 --- a/flo_ai/constants/prompt_constants.py +++ b/flo_ai/constants/prompt_constants.py @@ -1 +1 @@ -FLO_FINISH = "FINISH" \ No newline at end of file +FLO_FINISH = 'FINISH' diff --git a/flo_ai/core.py b/flo_ai/core.py index 624d9a7e..d583230e 100644 --- a/flo_ai/core.py +++ b/flo_ai/core.py @@ -9,42 +9,42 @@ from flo_ai.models.flo_executable import ExecutableFlo from flo_ai.error.flo_exception import FloException from flo_ai.constants.common_constants import DOCUMENTATION_WEBSITE -from flo_ai.common.flo_logger import get_logger, set_log_level_internal, set_log_config_internal, set_logger_internal, FloLogConfig +from flo_ai.common.flo_logger import ( + get_logger, + set_log_level_internal, + set_log_config_internal, + set_logger_internal, + FloLogConfig, +) -class Flo: - def __init__(self, - session: FloSession, - config: FloRoutedTeamConfig) -> None: +class Flo: + def __init__(self, session: FloSession, config: FloRoutedTeamConfig) -> None: self.session = session self.config = config session.config = config self.runnable: ExecutableFlo = build_supervised_team(session) - + self.langchain_logger = session.langchain_logger - get_logger().info(f"Flo instance created ...", session) + get_logger().info('Flo instance created ...', session) - def stream(self, query, config = None) -> Iterator[Union[dict[str, Any], Any]]: + def stream(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: self.validate_invoke(self.session) get_logger().info(f"streaming query requested: '{query}'", self.session) return self.runnable.stream(query, config) - - def async_stream(self, query, config = None) -> Iterator[Union[dict[str, Any], Any]]: + + def async_stream(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: get_logger().info(f"Streaming async query requested: '{query}'", self.session) return self.runnable.astream(query, config) - - def invoke(self, query, config = None) -> Iterator[Union[dict[str, Any], Any]]: - config = { - 'callbacks' : [self.session.langchain_logger] - } + + def invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: + config = {'callbacks': [self.session.langchain_logger]} self.validate_invoke(self.session) get_logger().info(f"Invoking query: '{query}'", self.session) return self.runnable.invoke(query, config) - - def async_invoke(self, query, config = None) -> Iterator[Union[dict[str, Any], Any]]: - config = { - 'callbacks' : [self.session.langchain_logger] - } + + def async_invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: + config = {'callbacks': [self.session.langchain_logger]} get_logger().info(f"Invoking async query: '{query}'", self.session) return self.runnable.ainvoke(query, config) @@ -52,15 +52,15 @@ def async_invoke(self, query, config = None) -> Iterator[Union[dict[str, Any], A def build(session: FloSession, yaml: str, log_level: Optional[str] = None): if log_level: warnings.warn( - "`log_level` is deprecated and will be removed in a future version. " - "Please use `Flo.set_log_level()` instead.", + '`log_level` is deprecated and will be removed in a future version. ' + 'Please use `Flo.set_log_level()` instead.', DeprecationWarning, stacklevel=2, ) Flo.set_log_level(log_level) - get_logger().info("Building Flo instance from YAML ...", session) + get_logger().info('Building Flo instance from YAML ...', session) return Flo(session, to_supervised_team(yaml)) - + @staticmethod def set_log_level(log_level: str): set_log_level_internal(log_level) @@ -75,17 +75,26 @@ def set_logger(logging_config: logging.Logger): def draw(self, xray=True): from IPython.display import Image, display + return display(Image(self.runnable.draw(xray))) - + def draw_to_file(self, filename: str, xray=True): from PIL import Image as PILImage import io + byte_image = self.runnable.draw(xray) with io.BytesIO(byte_image) as image_io: image = PILImage.open(image_io) image.save(filename) def validate_invoke(self, session: FloSession): - async_coroutines = filter(lambda x: (hasattr(x, "coroutine") and asyncio.iscoroutinefunction(x.coroutine)), session.tools.values()) + async_coroutines = filter( + lambda x: ( + hasattr(x, 'coroutine') and asyncio.iscoroutinefunction(x.coroutine) + ), + session.tools.values(), + ) if len(list(async_coroutines)) > 0: - raise FloException(f"""You seem to have atleast one async tool registered in this session. Please use flo.async_invoke or flo.async_stream. Checkout {DOCUMENTATION_WEBSITE}""") \ No newline at end of file + raise FloException( + f"""You seem to have atleast one async tool registered in this session. Please use flo.async_invoke or flo.async_stream. Checkout {DOCUMENTATION_WEBSITE}""" + ) diff --git a/flo_ai/error/flo_exception.py b/flo_ai/error/flo_exception.py index f37d0e69..2418b521 100644 --- a/flo_ai/error/flo_exception.py +++ b/flo_ai/error/flo_exception.py @@ -2,15 +2,15 @@ class FloException(Exception): def __init__(self, message: str, error_code: int = -1): """ Initialize the FloException with a message and optional error code. - + :param message: Error message to be displayed. :param error_code: Optional error code to be associated with the exception. """ self.message = message self.error_code = error_code super().__init__(message) - + def __str__(self): if self.error_code is not None: - return f"[Error {self.error_code}] {self.message}" - return self.message \ No newline at end of file + return f'[Error {self.error_code}] {self.message}' + return self.message diff --git a/flo_ai/factory/agent_factory.py b/flo_ai/factory/agent_factory.py index 831669e7..0e400bf2 100644 --- a/flo_ai/factory/agent_factory.py +++ b/flo_ai/factory/agent_factory.py @@ -1,6 +1,6 @@ from typing import Optional from flo_ai.state.flo_session import FloSession -from flo_ai.yaml.config import (AgentConfig) +from flo_ai.yaml.config import AgentConfig from flo_ai.models.flo_agent import FloAgent from flo_ai.models.flo_llm_agent import FloLLMAgent from flo_ai.models.flo_reflection_agent import FloReflectionAgent @@ -10,16 +10,17 @@ from flo_ai.constants.common_constants import DOCUMENTATION_AGENT_ANCHOR from enum import Enum + class AgentKinds(Enum): - agentic = "agentic" - llm = "llm" - tool = "tool" - function = "function" - reflection = "reflection" - delegator = "delegator" + agentic = 'agentic' + llm = 'llm' + tool = 'tool' + function = 'function' + reflection = 'reflection' + delegator = 'delegator' -class AgentFactory(): +class AgentFactory: @staticmethod def create(session: FloSession, agent: AgentConfig): kind = agent.kind @@ -29,7 +30,7 @@ def create(session: FloSession, agent: AgentConfig): if agent_kind is None: raise FloException(f"""Unknown agent kind: `{kind}`. The supported types are llm, tool, reflection, delegator or agentic. Check the documentation @ {DOCUMENTATION_AGENT_ANCHOR}""") - match(agent_kind): + match agent_kind: case AgentKinds.llm: return AgentFactory.__create_llm_agent(session, agent) case AgentKinds.tool: @@ -39,7 +40,7 @@ def create(session: FloSession, agent: AgentConfig): case AgentKinds.delegator: return AgentFactory.__create_delegator_agent(session, agent) return AgentFactory.__create_agentic_agent(session, agent, tool_map) - + @staticmethod def __resolve_model(session: FloSession, model_name: Optional[str] = None): if model_name is None: @@ -48,19 +49,18 @@ def __resolve_model(session: FloSession, model_name: Optional[str] = None): raise FloException( f"""Model not found: {model_name}. The model you would like to use should be registered to the session using session.register_model api, - and the same model name should be used here instead of `{model_name}`""") + and the same model name should be used here instead of `{model_name}`""" + ) return session.models[model_name] @staticmethod - def __create_agentic_agent(session: FloSession, agent: AgentConfig, tool_map) -> FloAgent: + def __create_agentic_agent( + session: FloSession, agent: AgentConfig, tool_map + ) -> FloAgent: agent_model = AgentFactory.__resolve_model(session, agent.model) tools = [tool_map[tool.name] for tool in agent.tools] flo_agent: FloAgent = FloAgent.Builder( - session, - agent, - tools, - llm=agent_model, - on_error=session.on_agent_error + session, agent, tools, llm=agent_model, on_error=session.on_agent_error ).build() return flo_agent @@ -70,18 +70,22 @@ def __create_llm_agent(session: FloSession, agent: AgentConfig) -> FloLLMAgent: builder = FloLLMAgent.Builder(session, agent, llm=agent_model) llm_agent: FloLLMAgent = builder.build() return llm_agent - + @staticmethod def __create_runnable_agent(session: FloSession, agent: AgentConfig) -> FloLLMAgent: runnable = session.tools[agent.tools[0].name] return FloToolAgent.Builder(session, agent, runnable).build() - + @staticmethod - def __create_reflection_agent(session: FloSession, agent: AgentConfig) -> FloReflectionAgent: + def __create_reflection_agent( + session: FloSession, agent: AgentConfig + ) -> FloReflectionAgent: agent_model = AgentFactory.__resolve_model(session, agent.model) return FloReflectionAgent.Builder(session, agent, llm=agent_model).build() - + @staticmethod - def __create_delegator_agent(session: FloSession, agent: AgentConfig) -> FloReflectionAgent: + def __create_delegator_agent( + session: FloSession, agent: AgentConfig + ) -> FloReflectionAgent: agent_model = AgentFactory.__resolve_model(session, agent.model) - return FloDelegatorAgent.Builder(session, agent, llm=agent_model).build() \ No newline at end of file + return FloDelegatorAgent.Builder(session, agent, llm=agent_model).build() diff --git a/flo_ai/helpers/utils.py b/flo_ai/helpers/utils.py index c457fa0c..b1db025b 100644 --- a/flo_ai/helpers/utils.py +++ b/flo_ai/helpers/utils.py @@ -1,7 +1,8 @@ import random import string + def random_str(length: int = 5): letters = string.ascii_letters + string.digits result_str = ''.join(random.choice(letters) for i in range(length)) - return result_str \ No newline at end of file + return result_str diff --git a/flo_ai/models/exception.py b/flo_ai/models/exception.py index b335a053..9477254a 100644 --- a/flo_ai/models/exception.py +++ b/flo_ai/models/exception.py @@ -2,6 +2,7 @@ class FloValidationException(Exception): def __init__(self, *args: object) -> None: super().__init__(*args) + class FloIllegalStateException(Exception): def __init__(self, *args: object) -> None: - super().__init__(*args) \ No newline at end of file + super().__init__(*args) diff --git a/flo_ai/models/flo_agent.py b/flo_ai/models/flo_agent.py index 98881051..0ba124f3 100644 --- a/flo_ai/models/flo_agent.py +++ b/flo_ai/models/flo_agent.py @@ -10,46 +10,54 @@ from flo_ai.yaml.config import AgentConfig from flo_ai.models.flo_executable import ExecutableType + class FloAgent(ExecutableFlo): - def __init__(self, - agent: Runnable, - executor: AgentExecutor, - config: AgentConfig) -> None: + def __init__( + self, agent: Runnable, executor: AgentExecutor, config: AgentConfig + ) -> None: super().__init__(config.name, executor, ExecutableType.agentic) - self.agent: Runnable = agent, + self.agent: Runnable = (agent,) self.executor: AgentExecutor = executor self.config: AgentConfig = config class Builder: - def __init__(self, - session: FloSession, - config: AgentConfig, - tools: list[BaseTool], - verbose: bool = True, - role: Optional[str] = None, - llm: Union[BaseLanguageModel, None] = None, - on_error: Union[str, Callable] = True) -> None: - + def __init__( + self, + session: FloSession, + config: AgentConfig, + tools: list[BaseTool], + verbose: bool = True, + role: Optional[str] = None, + llm: Union[BaseLanguageModel, None] = None, + on_error: Union[str, Callable] = True, + ) -> None: prompt: Union[ChatPromptTemplate, str] = config.job self.name: str = config.name self.llm = llm if llm is not None else session.llm self.config = config - system_prompts = [("system", "You are a {}".format(role)), ("system", prompt)] if role is not None else [("system", prompt)] - system_prompts.append(MessagesPlaceholder(variable_name="messages")) - system_prompts.append(MessagesPlaceholder(variable_name="agent_scratchpad")) - self.prompt: ChatPromptTemplate = ChatPromptTemplate.from_messages( - system_prompts - ) if isinstance(prompt, str) else prompt + system_prompts = ( + [('system', 'You are a {}'.format(role)), ('system', prompt)] + if role is not None + else [('system', prompt)] + ) + system_prompts.append(MessagesPlaceholder(variable_name='messages')) + system_prompts.append(MessagesPlaceholder(variable_name='agent_scratchpad')) + self.prompt: ChatPromptTemplate = ( + ChatPromptTemplate.from_messages(system_prompts) + if isinstance(prompt, str) + else prompt + ) self.tools: list[BaseTool] = tools self.verbose = verbose self.on_error = on_error - def build(self) -> AgentExecutor: agent = create_tool_calling_agent(self.llm, self.tools, self.prompt) - executor = AgentExecutor(agent=agent, - tools=self.tools, - verbose=self.verbose, - return_intermediate_steps=True, - handle_parsing_errors=self.on_error) + executor = AgentExecutor( + agent=agent, + tools=self.tools, + verbose=self.verbose, + return_intermediate_steps=True, + handle_parsing_errors=self.on_error, + ) return FloAgent(agent, executor, self.config) diff --git a/flo_ai/models/flo_delegation_agent.py b/flo_ai/models/flo_delegation_agent.py index 5cc71804..a04cb571 100644 --- a/flo_ai/models/flo_delegation_agent.py +++ b/flo_ai/models/flo_delegation_agent.py @@ -10,72 +10,69 @@ class FloDelegatorAgent(ExecutableFlo): - - def __init__(self, - executor: Runnable, - config: AgentConfig) -> None: + def __init__(self, executor: Runnable, config: AgentConfig) -> None: super().__init__(config.name, executor, ExecutableType.delegator) self.executor: Runnable = executor self.config: AgentConfig = config - - class Builder(): - def __init__(self, - session: FloSession, - agentConfig: AgentConfig, - llm: Optional[BaseLanguageModel] = None) -> None: + class Builder: + def __init__( + self, + session: FloSession, + agentConfig: AgentConfig, + llm: Optional[BaseLanguageModel] = None, + ) -> None: self.config = agentConfig delegator_base_system_message = ( - "You are a delegator tasked with routing a conversation between the" - " following {member_type}: {members}. Given the following rules," - " respond with the worker to act next " + 'You are a delegator tasked with routing a conversation between the' + ' following {member_type}: {members}. Given the following rules,' + ' respond with the worker to act next ' ) self.llm = session.llm if llm is None else llm self.options = [x.name for x in agentConfig.to] self.llm_router_prompt = ChatPromptTemplate.from_messages( [ - ("system", delegator_base_system_message), - MessagesPlaceholder(variable_name="messages"), - ("system", "Rules: {delegator_rules}"), + ('system', delegator_base_system_message), + MessagesPlaceholder(variable_name='messages'), + ('system', 'Rules: {delegator_rules}'), ( - "system", - "Given the conversation above, who should act next?" - "Select one of: {options}", + 'system', + 'Given the conversation above, who should act next?' + 'Select one of: {options}', ), ] ).partial( - options=str(self.options), - members=", ".join(self.options), - member_type="agents", - delegator_rules=agentConfig.job + options=str(self.options), + members=', '.join(self.options), + member_type='agents', + delegator_rules=agentConfig.job, ) def build(self): function_def = { - "name": "route", - "description": "Select the next role.", - "parameters": { - "title": "routeSchema", - "type": "object", - "properties": { - "next": { - "title": "Next", - "anyOf": [ - {"enum": self.options}, + 'name': 'route', + 'description': 'Select the next role.', + 'parameters': { + 'title': 'routeSchema', + 'type': 'object', + 'properties': { + 'next': { + 'title': 'Next', + 'anyOf': [ + {'enum': self.options}, ], } }, - "required": ["next"], - } + 'required': ['next'], + }, } - + chain = ( self.llm_router_prompt - | self.llm.bind_functions(functions=[function_def], function_call="route") + | self.llm.bind_functions( + functions=[function_def], function_call='route' + ) | JsonOutputFunctionsParser() ) - return FloDelegatorAgent(executor = chain, - config=self.config) - - \ No newline at end of file + return FloDelegatorAgent(executor=chain, config=self.config) diff --git a/flo_ai/models/flo_executable.py b/flo_ai/models/flo_executable.py index 1dd73060..28043676 100644 --- a/flo_ai/models/flo_executable.py +++ b/flo_ai/models/flo_executable.py @@ -4,18 +4,19 @@ from enum import Enum from flo_ai.state.flo_state import STATE_NAME_MESSAGES + class ExecutableType(Enum): - agentic = "agentic" - llm = "llm" - tool = "tool" - reflection = "reflection" - delegator = "delegator" - team = "team" - router = "router" + agentic = 'agentic' + llm = 'llm' + tool = 'tool' + reflection = 'reflection' + delegator = 'delegator' + team = 'team' + router = 'router' @staticmethod def isAgent(type: 'ExecutableType'): - match(type): + match type: case ExecutableType.agentic: return True case ExecutableType.llm: @@ -24,42 +25,39 @@ def isAgent(type: 'ExecutableType'): return True return False + class ExecutableFlo(FloMember): - def __init__(self, - name: str, - runnable: Runnable, - type: str = ExecutableType.team) -> None: + def __init__( + self, name: str, runnable: Runnable, type: str = ExecutableType.team + ) -> None: super().__init__(name, type) self.runnable = runnable - def stream(self, work, config = None): - return self.runnable.stream({ - STATE_NAME_MESSAGES: [ - HumanMessage(content=work) - ] - }, config) - - def astream(self, work, config = None): - return self.runnable.astream({ - STATE_NAME_MESSAGES: [ - HumanMessage(content=work) - ] - }, config) - - def invoke(self, work, config = None): - return self.runnable.invoke({ - STATE_NAME_MESSAGES: [ - HumanMessage(content=work) - ], - }, config) - - def ainvoke(self, work, config = None): - return self.runnable.ainvoke({ - STATE_NAME_MESSAGES: [ - HumanMessage(content=work) - ], - }, config) - + def stream(self, work, config=None): + return self.runnable.stream( + {STATE_NAME_MESSAGES: [HumanMessage(content=work)]}, config + ) + + def astream(self, work, config=None): + return self.runnable.astream( + {STATE_NAME_MESSAGES: [HumanMessage(content=work)]}, config + ) + + def invoke(self, work, config=None): + return self.runnable.invoke( + { + STATE_NAME_MESSAGES: [HumanMessage(content=work)], + }, + config, + ) + + def ainvoke(self, work, config=None): + return self.runnable.ainvoke( + { + STATE_NAME_MESSAGES: [HumanMessage(content=work)], + }, + config, + ) def draw(self, xray=True): - return self.runnable.get_graph().draw_mermaid_png() \ No newline at end of file + return self.runnable.get_graph().draw_mermaid_png() diff --git a/flo_ai/models/flo_llm_agent.py b/flo_ai/models/flo_llm_agent.py index 814db9fc..fce815b6 100644 --- a/flo_ai/models/flo_llm_agent.py +++ b/flo_ai/models/flo_llm_agent.py @@ -8,33 +8,38 @@ from flo_ai.yaml.config import AgentConfig from flo_ai.models.flo_executable import ExecutableType -class FloLLMAgent(ExecutableFlo): - def __init__(self, - executor: Runnable, - config: AgentConfig) -> None: +class FloLLMAgent(ExecutableFlo): + def __init__(self, executor: Runnable, config: AgentConfig) -> None: super().__init__(config.name, executor, ExecutableType.llm) self.executor: Runnable = executor self.config: AgentConfig = config class Builder: - def __init__(self, - session: FloSession, - config: AgentConfig, - llm: Union[BaseLanguageModel, None] = None) -> None: + def __init__( + self, + session: FloSession, + config: AgentConfig, + llm: Union[BaseLanguageModel, None] = None, + ) -> None: prompt: Union[ChatPromptTemplate, str] = config.job - + self.name: str = config.name self.llm = llm if llm is not None else session.llm # TODO improve to add more context of what other agents are available - system_prompts = [("system", "You are a {}".format(config.role)), ("system", prompt)] if config.role is not None else [("system", prompt)] - system_prompts.append(MessagesPlaceholder(variable_name="messages")) - self.prompt: ChatPromptTemplate = ChatPromptTemplate.from_messages( - system_prompts - ) if isinstance(prompt, str) else prompt + system_prompts = ( + [('system', 'You are a {}'.format(config.role)), ('system', prompt)] + if config.role is not None + else [('system', prompt)] + ) + system_prompts.append(MessagesPlaceholder(variable_name='messages')) + self.prompt: ChatPromptTemplate = ( + ChatPromptTemplate.from_messages(system_prompts) + if isinstance(prompt, str) + else prompt + ) self.config = config - def build(self) -> Runnable: executor = self.prompt | self.llm | StrOutputParser() return FloLLMAgent(executor, self.config) diff --git a/flo_ai/models/flo_member.py b/flo_ai/models/flo_member.py index c3966004..cfd39ad4 100644 --- a/flo_ai/models/flo_member.py +++ b/flo_ai/models/flo_member.py @@ -1,4 +1,4 @@ -class FloMember(): +class FloMember: def __init__(self, name: str, type: str) -> None: self.name = name - self.type = type \ No newline at end of file + self.type = type diff --git a/flo_ai/models/flo_node.py b/flo_ai/models/flo_node.py index 5c1dfd84..a8272075 100644 --- a/flo_ai/models/flo_node.py +++ b/flo_ai/models/flo_node.py @@ -8,58 +8,93 @@ from flo_ai.models.flo_executable import ExecutableType from typing import Union -class FloNode(): - def __init__(self, - func: functools.partial, - name: str, - kind: ExecutableType, - config: Union[AgentConfig | TeamConfig]) -> None: +class FloNode: + def __init__( + self, + func: functools.partial, + name: str, + kind: ExecutableType, + config: Union[AgentConfig | TeamConfig], + ) -> None: self.name = name self.func = func self.kind: ExecutableType = kind self.config: Union[AgentConfig | TeamConfig] = config - class Builder(): - + class Builder: def build_from_agent(self, flo_agent: FloAgent) -> 'FloNode': - agent_func = functools.partial(FloNode.Builder.__teamflo_agent_node, agent=flo_agent.runnable, name=flo_agent.name, agent_config=flo_agent.config) + agent_func = functools.partial( + FloNode.Builder.__teamflo_agent_node, + agent=flo_agent.runnable, + name=flo_agent.name, + agent_config=flo_agent.config, + ) return FloNode(agent_func, flo_agent.name, flo_agent.type, flo_agent.config) - + def build_from_team(self, flo_team: FloRoutedTeam) -> 'FloNode': - team_chain = (functools.partial(FloNode.Builder.__teamflo_team_node, members=flo_team.runnable.nodes) | flo_team.runnable) - return FloNode(( - FloNode.Builder.__get_last_message | team_chain | FloNode.Builder.__join_graph - ), flo_team.name, flo_team.type, flo_team.config) - + team_chain = ( + functools.partial( + FloNode.Builder.__teamflo_team_node, members=flo_team.runnable.nodes + ) + | flo_team.runnable + ) + return FloNode( + ( + FloNode.Builder.__get_last_message + | team_chain + | FloNode.Builder.__join_graph + ), + flo_team.name, + flo_team.type, + flo_team.config, + ) + def build_from_router(self, flo_router) -> 'FloNode': - router_func = functools.partial(FloNode.Builder.__teamflo_router_node, agent=flo_router.executor, name=flo_router.router_name, agent_config=flo_router.config) - return FloNode(router_func, flo_router.router_name, flo_router.type, flo_router.config) + router_func = functools.partial( + FloNode.Builder.__teamflo_router_node, + agent=flo_router.executor, + name=flo_router.router_name, + agent_config=flo_router.config, + ) + return FloNode( + router_func, flo_router.router_name, flo_router.type, flo_router.config + ) @staticmethod - def __teamflo_agent_node(state: TeamFloAgentState, agent: AgentExecutor, name: str, agent_config: AgentConfig): + def __teamflo_agent_node( + state: TeamFloAgentState, + agent: AgentExecutor, + name: str, + agent_config: AgentConfig, + ): result = agent.invoke(state) - output = result if isinstance(result, str) else result["output"] - return { STATE_NAME_MESSAGES: [HumanMessage(content=output, name=name)] } - + output = result if isinstance(result, str) else result['output'] + return {STATE_NAME_MESSAGES: [HumanMessage(content=output, name=name)]} + @staticmethod - def __teamflo_router_node(state: TeamFloAgentState, agent: AgentExecutor, name: str, agent_config: AgentConfig): + def __teamflo_router_node( + state: TeamFloAgentState, + agent: AgentExecutor, + name: str, + agent_config: AgentConfig, + ): result = agent.invoke(state) - nextNode = result if isinstance(result, str) else result["next"] - return { "next": nextNode } + nextNode = result if isinstance(result, str) else result['next'] + return {'next': nextNode} @staticmethod def __get_last_message(state: TeamFloAgentState) -> str: return state[STATE_NAME_MESSAGES][-1].content - + @staticmethod def __join_graph(response: dict): - return { STATE_NAME_MESSAGES: [ response[STATE_NAME_MESSAGES][-1] ] } - + return {STATE_NAME_MESSAGES: [response[STATE_NAME_MESSAGES][-1]]} + @staticmethod def __teamflo_team_node(message: str, members: list[str]): results = { STATE_NAME_MESSAGES: [HumanMessage(content=message)], - "team_members": ", ".join(members), + 'team_members': ', '.join(members), } return results diff --git a/flo_ai/models/flo_rag.py b/flo_ai/models/flo_rag.py index a66576d7..3c3468fb 100644 --- a/flo_ai/models/flo_rag.py +++ b/flo_ai/models/flo_rag.py @@ -7,40 +7,41 @@ from flo_ai.state.flo_state import TeamFloAgentState from langgraph.prebuilt import tools_condition from flo_ai.models.flo_executable import ExecutableFlo -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain import hub +from langchain.prompts import ChatPromptTemplate + class FloRag(ExecutableFlo): - def __init__(self, - name: str, - graph: CompiledGraph) -> None: - super().__init__(name, "team") + def __init__(self, name: str, graph: CompiledGraph) -> None: + super().__init__(name, 'team') self.graph = graph def draw(self, xray=True): return self.graph.get_graph(xray=xray).draw_mermaid_png() -class FloRagBuilder: - def __init__(self, - name: str, - tools: list[Tool], - llm: BaseLanguageModel, - prompt: Optional[ChatPromptTemplate] = None) -> None: +class FloRagBuilder: + def __init__( + self, + name: str, + tools: list[Tool], + llm: BaseLanguageModel, + prompt: Optional[ChatPromptTemplate] = None, + ) -> None: self.name = name self.llm = llm self.tools = tools - self.prompt = hub.pull("rlm/rag-prompt") if prompt is None else prompt - + self.prompt = hub.pull('rlm/rag-prompt') if prompt is None else prompt + def retriever_agent(self, state: TeamFloAgentState): - messages = state["messages"] + messages = state['messages'] model = self.llm.bind_tools(self.tools) response = model.invoke(messages) # We return a list, because this will get added to the existing list - return { "messages": [response] } - + return {'messages': [response]} + def generate(self, state: TeamFloAgentState): - messages = state["messages"] + messages = state['messages'] question = messages[0].content last_message = messages[-1] @@ -48,28 +49,28 @@ def generate(self, state: TeamFloAgentState): docs = last_message.content rag_chain = self.prompt | self.llm - response = rag_chain.invoke({"context": docs, "question": question}) - return { "messages": [response] } - + response = rag_chain.invoke({'context': docs, 'question': question}) + return {'messages': [response]} + def build(self) -> FloRag: retrieve = ToolNode(self.tools) workflow = StateGraph(TeamFloAgentState) - workflow.add_node("agent", self.retriever_agent) - workflow.add_node("retrieve", retrieve) # retrieval - workflow.add_node("generate", self.generate) + workflow.add_node('agent', self.retriever_agent) + workflow.add_node('retrieve', retrieve) # retrieval + workflow.add_node('generate', self.generate) workflow.add_conditional_edges( - "agent", + 'agent', tools_condition, { - "tools": "retrieve", + 'tools': 'retrieve', END: END, }, ) - workflow.add_edge("retrieve", "generate") - workflow.add_edge("generate", END) + workflow.add_edge('retrieve', 'generate') + workflow.add_edge('generate', END) - workflow.set_entry_point("agent") + workflow.set_entry_point('agent') graph = workflow.compile() return FloRag(self.name, graph=graph) diff --git a/flo_ai/models/flo_reflection_agent.py b/flo_ai/models/flo_reflection_agent.py index 97b5a287..72903c04 100644 --- a/flo_ai/models/flo_reflection_agent.py +++ b/flo_ai/models/flo_reflection_agent.py @@ -10,28 +10,37 @@ class FloReflectionAgent(ExecutableFlo): - def __init__(self, executor: Runnable, config: AgentConfig) -> None: super().__init__(config.name, executor, ExecutableType.reflection) self.config = config - class Builder(): - def __init__(self, - session: FloSession, - config: AgentConfig, - llm: Union[BaseLanguageModel, None] = None) -> None: - + class Builder: + def __init__( + self, + session: FloSession, + config: AgentConfig, + llm: Union[BaseLanguageModel, None] = None, + ) -> None: prompt_message: Union[ChatPromptTemplate, str] = config.job self.name: str = config.name self.llm = llm if llm is not None else session.llm self.config = config - system_prompts = [("system", "You are a {}".format(config.role)), ("system", prompt_message)] if config.role is not None else [("system", prompt_message)] - system_prompts.append(MessagesPlaceholder(variable_name="messages")) - self.prompt: ChatPromptTemplate = ChatPromptTemplate.from_messages( - system_prompts - ) if isinstance(prompt_message, str) else prompt_message + system_prompts = ( + [ + ('system', 'You are a {}'.format(config.role)), + ('system', prompt_message), + ] + if config.role is not None + else [('system', prompt_message)] + ) + system_prompts.append(MessagesPlaceholder(variable_name='messages')) + self.prompt: ChatPromptTemplate = ( + ChatPromptTemplate.from_messages(system_prompts) + if isinstance(prompt_message, str) + else prompt_message + ) def build(self): executor = self.prompt | self.llm | StrOutputParser() - return FloReflectionAgent(executor, self.config) \ No newline at end of file + return FloReflectionAgent(executor, self.config) diff --git a/flo_ai/models/flo_routed_team.py b/flo_ai/models/flo_routed_team.py index 552258b7..48fa0ff2 100644 --- a/flo_ai/models/flo_routed_team.py +++ b/flo_ai/models/flo_routed_team.py @@ -2,12 +2,12 @@ from langgraph.graph.graph import CompiledGraph from flo_ai.yaml.config import TeamConfig -class FloRoutedTeam(ExecutableFlo): +class FloRoutedTeam(ExecutableFlo): def __init__(self, name: str, graph: CompiledGraph, config: TeamConfig) -> None: super().__init__(name, graph) self.config = config # Overridden for xray use, doesnt work in base class def draw(self, xray=True): - return self.runnable.get_graph(xray=xray).draw_mermaid_png() \ No newline at end of file + return self.runnable.get_graph(xray=xray).draw_mermaid_png() diff --git a/flo_ai/models/flo_team.py b/flo_ai/models/flo_team.py index feb56442..7bc7db54 100644 --- a/flo_ai/models/flo_team.py +++ b/flo_ai/models/flo_team.py @@ -1,7 +1,8 @@ from flo_ai.models.flo_member import FloMember from flo_ai.yaml.config import TeamConfig -class FloTeam(): + +class FloTeam: def __init__(self, team_config: TeamConfig, members: list[FloMember]) -> None: self.name = team_config.name self.config = team_config @@ -11,10 +12,7 @@ class Builder: def __init__(self, team_config: TeamConfig, members: list[FloMember]) -> None: self.team_config = team_config self.members = members - self.member_names= list(map(lambda x: x.name, self.members)) - + self.member_names = list(map(lambda x: x.name, self.members)) + def build(self): - return FloTeam( - team_config=self.team_config, - members=self.members - ) \ No newline at end of file + return FloTeam(team_config=self.team_config, members=self.members) diff --git a/flo_ai/models/flo_tool_agent.py b/flo_ai/models/flo_tool_agent.py index 0783bd39..37869650 100644 --- a/flo_ai/models/flo_tool_agent.py +++ b/flo_ai/models/flo_tool_agent.py @@ -4,20 +4,17 @@ from flo_ai.yaml.config import AgentConfig from flo_ai.models.flo_executable import ExecutableType -class FloToolAgent(ExecutableFlo): - def __init__(self, - executor: Runnable, - config: AgentConfig) -> None: +class FloToolAgent(ExecutableFlo): + def __init__(self, executor: Runnable, config: AgentConfig) -> None: super().__init__(config.name, executor, ExecutableType.tool) self.executor: Runnable = executor self.config: AgentConfig = config class Builder: - def __init__(self, - session: FloSession, - config: AgentConfig, - tool_runnable: Runnable) -> None: + def __init__( + self, session: FloSession, config: AgentConfig, tool_runnable: Runnable + ) -> None: self.name: str = config.name self.runnable = tool_runnable self.config = config diff --git a/flo_ai/models/flo_transformer.py b/flo_ai/models/flo_transformer.py new file mode 100644 index 00000000..35d6e261 --- /dev/null +++ b/flo_ai/models/flo_transformer.py @@ -0,0 +1,6 @@ +from flo_ai.models.flo_node import FloNode + + +class FloTransformer(FloNode): + def __init__(self, func: object, name: str) -> None: + super().__init__(func, name) diff --git a/flo_ai/retrievers/flo_compression_pipeline.py b/flo_ai/retrievers/flo_compression_pipeline.py index 82da9216..81338d44 100644 --- a/flo_ai/retrievers/flo_compression_pipeline.py +++ b/flo_ai/retrievers/flo_compression_pipeline.py @@ -3,14 +3,16 @@ from langchain_community.document_transformers import EmbeddingsRedundantFilter from langchain.retrievers.document_compressors import EmbeddingsFilter -class FloCompressionPipeline(): +class FloCompressionPipeline: def __init__(self, embeddings: Embeddings) -> None: self.__embeddings = embeddings self.__pipeline = [] - def add_chuncking(self, chunk_size = 300, chunk_overlap = 0): - splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=". ") + def add_chuncking(self, chunk_size=300, chunk_overlap=0): + splitter = CharacterTextSplitter( + chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator='. ' + ) self.__pipeline.append(splitter) def add_embedding_reduntant_filter(self): @@ -18,16 +20,22 @@ def add_embedding_reduntant_filter(self): self.__pipeline.append(redundant_filter) def add_embedding_relevant_filter(self, threshold: float = 0.50): - relevant_filter = EmbeddingsFilter(embeddings=self.__embeddings, similarity_threshold=threshold) - self.__pipeline.append(relevant_filter) + relevant_filter = EmbeddingsFilter( + embeddings=self.__embeddings, similarity_threshold=threshold + ) + self.__pipeline.append(relevant_filter) + + def add_flashrank_reranking(self, model_name='ms-marco-MultiBERT-L-12'): + from langchain.retrievers.document_compressors.flashrank_rerank import ( + FlashrankRerank, + ) - def add_flashrank_reranking(self, model_name="ms-marco-MultiBERT-L-12"): - from langchain.retrievers.document_compressors.flashrank_rerank import FlashrankRerank compressor = FlashrankRerank(model=model_name) self.__pipeline.append(compressor) - def add_cohere_reranking(self, model_name="rerank-english-v3.0"): + def add_cohere_reranking(self, model_name='rerank-english-v3.0'): from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank + compressor = CohereRerank(model=model_name) self.__pipeline.append(compressor) diff --git a/flo_ai/retrievers/flo_multi_query.py b/flo_ai/retrievers/flo_multi_query.py index 89881ad2..ac56dd70 100644 --- a/flo_ai/retrievers/flo_multi_query.py +++ b/flo_ai/retrievers/flo_multi_query.py @@ -1,6 +1,4 @@ from typing import List, Union - -from langchain.chains.llm import LLMChain from langchain_core.vectorstores import VectorStoreRetriever from langchain.output_parsers import PydanticOutputParser from langchain_core.prompts import PromptTemplate @@ -8,8 +6,9 @@ from flo_ai.state.flo_session import FloSession from langchain.retrievers.multi_query import MultiQueryRetriever + class LineList(BaseModel): - lines: List[str] = Field(description="Lines of text") + lines: List[str] = Field(description='Lines of text') class LineListOutputParser(PydanticOutputParser): @@ -17,38 +16,40 @@ def __init__(self) -> None: super().__init__(pydantic_object=LineList) def parse(self, text: str) -> LineList: - lines = text.strip().split("\n") + lines = text.strip().split('\n') return LineList(lines=lines) - -class FloMultiQueryRetriever(): + + +class FloMultiQueryRetriever: def __init__(self, retriever) -> None: self.retriever = retriever -class FloMultiQueryRetriverBuilder(): - def __init__(self, - session: FloSession, - retriver: VectorStoreRetriever, - query_prompt: Union[str, None] = None) -> None: +class FloMultiQueryRetriverBuilder: + def __init__( + self, + session: FloSession, + retriver: VectorStoreRetriever, + query_prompt: Union[str, None] = None, + ) -> None: self.session = session self.retriver = retriver self.output_parser = LineListOutputParser() self.prompt = PromptTemplate( - input_variables=["question"], + input_variables=['question'], template="""You are an AI language model assistant. Your task is to generate three different versions of the given user question to retrieve relevant documents from a vector database. By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of the distance-based similarity search. Provide these alternative questions separated by newlines. - Original question: {question}""" if query_prompt is None else query_prompt, + Original question: {question}""" + if query_prompt is None + else query_prompt, ) def build(self): multi_query_retriever = MultiQueryRetriever.from_llm( - retriever=self.retriver, - llm=self.session.llm, - prompt=self.prompt + retriever=self.retriver, llm=self.session.llm, prompt=self.prompt ) return FloMultiQueryRetriever(multi_query_retriever) - diff --git a/flo_ai/retrievers/flo_retriever.py b/flo_ai/retrievers/flo_retriever.py index 9e05940a..b095a354 100644 --- a/flo_ai/retrievers/flo_retriever.py +++ b/flo_ai/retrievers/flo_retriever.py @@ -3,7 +3,6 @@ from flo_ai.state.flo_session import FloSession from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnablePassthrough -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, BasePromptTemplate from flo_ai.retrievers.flo_multi_query import FloMultiQueryRetriverBuilder from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import DocumentCompressorPipeline @@ -18,14 +17,21 @@ PromptTemplate, aformat_document, format_document, + ChatPromptTemplate, + MessagesPlaceholder, ) from typing import List + class FloRagBaseMessage(BaseModel): content: str + class FloRagToolInput(BaseModel): - messages: List[FloRagBaseMessage] = Field(description="query to look up in the vector store") + messages: List[FloRagBaseMessage] = Field( + description='query to look up in the vector store' + ) + def _get_relevant_documents( messages: List[FloRagBaseMessage], @@ -34,32 +40,36 @@ def _get_relevant_documents( document_separator: str, callbacks: Callbacks = None, ) -> str: - docs = retriever.invoke(messages[-1].content, config={"callbacks": callbacks}) + docs = retriever.invoke(messages[-1].content, config={'callbacks': callbacks}) return document_separator.join( format_document(doc, document_prompt) for doc in docs ) + async def _aget_relevant_documents( messages: List[str], retriever: VectorStoreRetriever, document_prompt: BasePromptTemplate, document_separator: str, - callbacks: Callbacks = None + callbacks: Callbacks = None, ) -> str: - docs = await retriever.ainvoke(messages[-1].content, config={"callbacks": callbacks}) + docs = await retriever.ainvoke( + messages[-1].content, config={'callbacks': callbacks} + ) return document_separator.join( [await aformat_document(doc, document_prompt) for doc in docs] ) -class FloRagBuilder(): - def __init__(self, - session: FloSession, - retriever: VectorStoreRetriever) -> None: + +class FloRagBuilder: + def __init__(self, session: FloSession, retriever: VectorStoreRetriever) -> None: self.session = session self.retriever = retriever self.default_prompt = ChatPromptTemplate.from_messages( [ - ("system", """You are an assistant for question-answering tasks. + ( + 'system', + """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. @@ -67,9 +77,10 @@ def __init__(self, Here is the context: {context} - """), - MessagesPlaceholder(variable_name="chat_history"), - ("human", "{question}"), + """, + ), + MessagesPlaceholder(variable_name='chat_history'), + ('human', '{question}'), ] ) self.history_aware_retriever = self.__create_history_aware_retriever() @@ -78,18 +89,16 @@ def with_prompt(self, prompt: ChatPromptTemplate): self.default_prompt = prompt return self - def with_multi_query(self, prompt = None): - builder = FloMultiQueryRetriverBuilder(session=self.session, - retriver=self.retriever, - query_prompt=prompt) + def with_multi_query(self, prompt=None): + builder = FloMultiQueryRetriverBuilder( + session=self.session, retriver=self.retriever, query_prompt=prompt + ) multi_query_retriever = builder.build() self.retriever = multi_query_retriever.retriever return self - + def with_compression(self, pipeline: FloCompressionPipeline): - pipeline_compressor = DocumentCompressorPipeline( - transformers=pipeline.get() - ) + pipeline_compressor = DocumentCompressorPipeline(transformers=pipeline.get()) compression_retriever = ContextualCompressionRetriever( base_compressor=pipeline_compressor, base_retriever=self.retriever ) @@ -104,66 +113,71 @@ def __create_history_aware_retriever(self): contextualize_q_prompt = ChatPromptTemplate.from_messages( [ - ("system", contextualize_q_system_prompt), - MessagesPlaceholder("chat_history"), - ("human", "{question}"), + ('system', contextualize_q_system_prompt), + MessagesPlaceholder('chat_history'), + ('human', '{question}'), ] ) - self.history_aware_retriever = contextualize_q_prompt | self.session.llm | StrOutputParser() + self.history_aware_retriever = ( + contextualize_q_prompt | self.session.llm | StrOutputParser() + ) return self.history_aware_retriever def __get_retriever(self): def __precontext_retriver(input_prompt: dict): - if input_prompt.get("chat_history"): + if input_prompt.get('chat_history'): return self.history_aware_retriever else: - return input_prompt["question"] + return input_prompt['question'] + return __precontext_retriver | self.retriever - + def __format_docs(self, docs): - return "\n\n".join(doc.page_content for doc in docs) - + return '\n\n'.join(doc.page_content for doc in docs) + def __get_optional_chat_history(self, x): - return x["chat_history"] if "chat_history" in x else [] - + return x['chat_history'] if 'chat_history' in x else [] + def __build_history_aware_rag(self): rag_chain = ( RunnablePassthrough.assign( - context=(lambda x: x["context"]), + context=(lambda x: x['context']), ) | self.default_prompt - | self.session.llm + | self.session.llm ) rag_chain_with_source = RunnableParallel( { - "context": self.__get_retriever() | self.__format_docs, - "question": RunnablePassthrough(), - "chat_history": lambda x: self.__get_optional_chat_history(x) + 'context': self.__get_retriever() | self.__format_docs, + 'question': RunnablePassthrough(), + 'chat_history': lambda x: self.__get_optional_chat_history(x), } ).assign(answer=rag_chain) return rag_chain_with_source def build_rag(self): return self.__build_history_aware_rag() - + def build_retriever_tool(self, name, description): return self.__create_retriever_tool(self.retriever, name, description) - + @staticmethod def __get_rag_answer(messages: List[FloRagBaseMessage], runnable: Runnable): question = messages[-1].content chat_history = messages[:-1] - result = runnable.invoke({ "question": question, "chat_history": chat_history }) - return result["answer"].content + result = runnable.invoke({'question': question, 'chat_history': chat_history}) + return result['answer'].content @staticmethod async def __aget_rag_answer(messages: List[FloRagBaseMessage], runnable: Runnable): question = messages[-1].content chat_history = messages[:-1] - result = await runnable.ainvoke({ "question": question, "chat_history": chat_history }) - return result["answer"].content - + result = await runnable.ainvoke( + {'question': question, 'chat_history': chat_history} + ) + return result['answer'].content + def __create_retriever_tool( self, retriever: VectorStoreRetriever, @@ -171,9 +185,11 @@ def __create_retriever_tool( description: str, *, document_prompt: Optional[BasePromptTemplate] = None, - document_separator: str = "\n", + document_separator: str = '\n', ) -> Tool: - document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") + document_prompt = document_prompt or PromptTemplate.from_template( + '{page_content}' + ) func = partial( _get_relevant_documents, retriever=retriever, @@ -195,21 +211,11 @@ def __create_retriever_tool( ) @staticmethod - def __create_rag_tool( - runnable_rag: Runnable, - name: str, - description: str - ) -> Tool: - func = partial( - FloRagBuilder.__get_rag_answer, - runnable=runnable_rag - ) + def __create_rag_tool(runnable_rag: Runnable, name: str, description: str) -> Tool: + func = partial(FloRagBuilder.__get_rag_answer, runnable=runnable_rag) + + afunc = partial(FloRagBuilder.__aget_rag_answer, runnable=runnable_rag) - afunc = partial( - FloRagBuilder.__aget_rag_answer, - runnable=runnable_rag - ) - return Tool( name=name, description=description, @@ -217,7 +223,7 @@ def __create_rag_tool( coroutine=afunc, args_schema=FloRagToolInput, ) - + def build_rag_tool(self, name, description) -> Tool: rag = self.__build_history_aware_rag() - return FloRagBuilder.__create_rag_tool(rag, name, description) \ No newline at end of file + return FloRagBuilder.__create_rag_tool(rag, name, description) diff --git a/flo_ai/router/flo_linear.py b/flo_ai/router/flo_linear.py index 53f8f36d..0f0e57b4 100644 --- a/flo_ai/router/flo_linear.py +++ b/flo_ai/router/flo_linear.py @@ -7,44 +7,56 @@ from flo_ai.state.flo_session import FloSession from flo_ai.models.flo_executable import ExecutableType -class FloLinear(FloRouter): +class FloLinear(FloRouter): def __init__(self, session: FloSession, config: TeamConfig, flo_team: FloTeam): - super().__init__(session=session, name=config.name, - flo_team=flo_team, executor=None, config=config) + super().__init__( + session=session, + name=config.name, + flo_team=flo_team, + executor=None, + config=config, + ) self.router_config = config.router - + def build_graph(self): flo_agent_nodes = [self.build_node(member) for member in self.members] - + workflow = StateGraph(TeamFloAgentState) - + for flo_node in flo_agent_nodes: agent_name = flo_node.name workflow.add_node(agent_name, flo_node.func) - + if self.router_config.edges is None: start_node = flo_agent_nodes[0] end_node = flo_agent_nodes[-1] workflow.add_edge(START, start_node.name) for i in range(len(flo_agent_nodes) - 1): parent_node = flo_agent_nodes[i] - child_node = flo_agent_nodes[i+1] - next_node = flo_agent_nodes[i+2] if (i+2) < len(flo_agent_nodes) else END - if (parent_node.kind == ExecutableType.reflection): + child_node = flo_agent_nodes[i + 1] + next_node = ( + flo_agent_nodes[i + 2] if (i + 2) < len(flo_agent_nodes) else END + ) + if parent_node.kind == ExecutableType.reflection: self.add_reflection_edge(workflow, parent_node, child_node) continue - if (child_node.kind == ExecutableType.delegator): - self.add_delegation_edge(workflow, parent_node, child_node, next_node) + if child_node.kind == ExecutableType.delegator: + self.add_delegation_edge( + workflow, parent_node, child_node, next_node + ) continue - if (child_node.kind != ExecutableType.reflection and parent_node.kind != ExecutableType.delegator): + if ( + child_node.kind != ExecutableType.reflection + and parent_node.kind != ExecutableType.delegator + ): workflow.add_edge(parent_node.name, child_node.name) - - if (end_node.kind == ExecutableType.reflection): + + if end_node.kind == ExecutableType.reflection: self.add_reflection_edge(workflow, end_node, END) - elif (end_node.kind != ExecutableType.delegator): - workflow.add_edge(end_node.name, END) + elif end_node.kind != ExecutableType.delegator: + workflow.add_edge(end_node.name, END) else: workflow.add_edge(START, self.router_config.start_node) for edge in self.router_config.edges: @@ -52,16 +64,16 @@ def build_graph(self): workflow.add_edge(self.router_config.end_node, END) workflow_graph = workflow.compile() - - return FloRoutedTeam(self.flo_team.name, workflow_graph, self.flo_team.config) - class Builder(): + return FloRoutedTeam(self.flo_team.name, workflow_graph, self.flo_team.config) - def __init__(self, session: FloSession, config: TeamConfig, flo_team: FloTeam) -> None: + class Builder: + def __init__( + self, session: FloSession, config: TeamConfig, flo_team: FloTeam + ) -> None: self.config = config self.session = session self.team = flo_team def build(self): return FloLinear(self.session, self.config, self.team) - diff --git a/flo_ai/router/flo_llm_router.py b/flo_ai/router/flo_llm_router.py index 054f3cb2..bca970d3 100644 --- a/flo_ai/router/flo_llm_router.py +++ b/flo_ai/router/flo_llm_router.py @@ -13,23 +13,19 @@ from langchain_core.output_parsers import JsonOutputParser from pydantic import BaseModel, Field + class NextAgent(BaseModel): - next: str = Field(description="Name of the next member to be called") + next: str = Field(description='Name of the next member to be called') + class FloLLMRouter(FloRouter): - - def __init__(self, - session: FloSession, - executor: Runnable, - flo_team: FloTeam, - name: str) -> None: + def __init__( + self, session: FloSession, executor: Runnable, flo_team: FloTeam, name: str + ) -> None: super().__init__( - session = session, - name = name, - flo_team = flo_team, - executor = executor + session=session, name=name, flo_team=flo_team, executor=executor ) - + def build_graph(self): flo_agent_nodes = [self.build_node(flo_agent) for flo_agent in self.members] workflow = StateGraph(TeamFloAgentState) @@ -42,15 +38,16 @@ def build_graph(self): workflow.set_entry_point(self.router_name) workflow_graph = workflow.compile() return FloRoutedTeam(self.flo_team.name, workflow_graph, self.flo_team.config) - + class Builder: - def __init__(self, - session: FloSession, - team_config: TeamConfig, - flo_team: FloTeam, - router_prompt: ChatPromptTemplate = None, - llm: Union[BaseLanguageModel, None] = None) -> None: - + def __init__( + self, + session: FloSession, + team_config: TeamConfig, + flo_team: FloTeam, + router_prompt: ChatPromptTemplate = None, + llm: Union[BaseLanguageModel, None] = None, + ) -> None: self.name = team_config.router.name self.session = session self.llm = llm if llm is not None else session.llm @@ -58,42 +55,42 @@ def __init__(self, self.agents = flo_team.members self.members = [agent.name for agent in flo_team.members] self.options = self.members + [FLO_FINISH] - member_type = "workers" if flo_team.members[0].type == "agent" else "team members" + member_type = ( + 'workers' if flo_team.members[0].type == 'agent' else 'team members' + ) router_base_system_message = ( - "You are a supervisor tasked with managing a conversation between the" - " following {member_type}: {members}. Given the following rules," - " respond with the worker to act next " + 'You are a supervisor tasked with managing a conversation between the' + ' following {member_type}: {members}. Given the following rules,' + ' respond with the worker to act next ' ) self.parser = JsonOutputParser(pydantic_object=NextAgent) self.llm_router_prompt = ChatPromptTemplate.from_messages( [ - ("system", router_base_system_message), - MessagesPlaceholder(variable_name="messages"), - ("system", "Rules: {router_prompt}"), + ('system', router_base_system_message), + MessagesPlaceholder(variable_name='messages'), + ('system', 'Rules: {router_prompt}'), ( - "system", - "Given the conversation above, who should act next?" - " Or should we FINISH if the task is already answered ? Select one of: {options} \n {format_instructions}", + 'system', + 'Given the conversation above, who should act next?' + ' Or should we FINISH if the task is already answered ? Select one of: {options} \n {format_instructions}', ), ] ).partial( - options=str(self.options), - members=", ".join(self.members), - member_type=member_type, + options=str(self.options), + members=', '.join(self.members), + member_type=member_type, router_prompt=router_prompt, - format_instructions=self.parser.get_format_instructions() - ) - - def build(self): - chain = ( - self.llm_router_prompt - | self.llm - | self.parser + format_instructions=self.parser.get_format_instructions(), ) - return FloLLMRouter(executor=chain, - flo_team=self.flo_team, - name=self.name, - session=self.session) \ No newline at end of file + def build(self): + chain = self.llm_router_prompt | self.llm | self.parser + + return FloLLMRouter( + executor=chain, + flo_team=self.flo_team, + name=self.name, + session=self.session, + ) diff --git a/flo_ai/router/flo_router.py b/flo_ai/router/flo_router.py index 96debcce..17a55589 100644 --- a/flo_ai/router/flo_router.py +++ b/flo_ai/router/flo_router.py @@ -1,4 +1,3 @@ - import functools from typing import Union from abc import ABC, abstractmethod @@ -11,13 +10,26 @@ from flo_ai.models.flo_routed_team import FloRoutedTeam from flo_ai.constants.prompt_constants import FLO_FINISH from flo_ai.models.flo_executable import ExecutableType -from flo_ai.state.flo_state import TeamFloAgentState, STATE_NAME_LOOP_CONTROLLER, STATE_NAME_NEXT -from flo_ai.constants.flo_node_contants import (INTERNAL_NODE_REFLECTION_MANAGER, INTERNAL_NODE_DELEGATION_MANAGER) +from flo_ai.state.flo_state import ( + TeamFloAgentState, + STATE_NAME_LOOP_CONTROLLER, + STATE_NAME_NEXT, +) +from flo_ai.constants.flo_node_contants import ( + INTERNAL_NODE_REFLECTION_MANAGER, + INTERNAL_NODE_DELEGATION_MANAGER, +) class FloRouter(ABC): - - def __init__(self, session: FloSession, name: str, flo_team: FloTeam, executor, config: TeamConfig = None): + def __init__( + self, + session: FloSession, + name: str, + flo_team: FloTeam, + executor, + config: TeamConfig = None, + ): self.router_name = name self.session: FloSession = session self.flo_team: FloTeam = flo_team @@ -26,7 +38,7 @@ def __init__(self, session: FloSession, name: str, flo_team: FloTeam, executor, self.type: ExecutableType = ExecutableType.router self.executor = executor self.config = config - + def build_routed_team(self) -> FloRoutedTeam: return self.build_graph() @@ -38,58 +50,71 @@ def build_node(self, flo_member: FloMember) -> FloNode: node_builder = FloNode.Builder() if flo_member.type == ExecutableType.router: return node_builder.build_from_router(flo_member) - if (flo_member.type == ExecutableType.team): + if flo_member.type == ExecutableType.team: return node_builder.build_from_team(flo_member) - if (flo_member.type == ExecutableType.delegator): - return FloNode(flo_member.executor, flo_member.name, flo_member.type, flo_member.config) + if flo_member.type == ExecutableType.delegator: + return FloNode( + flo_member.executor, flo_member.name, flo_member.type, flo_member.config + ) node_builder = FloNode.Builder() return node_builder.build_from_agent(flo_member) - + def router_fn(self, state: TeamFloAgentState): - next = state["next"] + next = state['next'] conditional_map = {k: k for k in self.member_names} conditional_map[FLO_FINISH] = END self.session.append(node=next) if self.session.is_looping(node=next): return conditional_map[FLO_FINISH] return conditional_map[next] - - def update_reflection_state(self, state: TeamFloAgentState, reflection_agent_name: str): + + def update_reflection_state( + self, state: TeamFloAgentState, reflection_agent_name: str + ): tracker = state.get(STATE_NAME_LOOP_CONTROLLER) or {} tracker[reflection_agent_name] = tracker.get(reflection_agent_name, 0) + 1 return {STATE_NAME_LOOP_CONTROLLER: tracker} - - def add_delegation_edge(self, workflow: StateGraph, parent: FloNode, delegation_node: FloNode, nextNode: Union[FloNode|str]): + def add_delegation_edge( + self, + workflow: StateGraph, + parent: FloNode, + delegation_node: FloNode, + nextNode: Union[FloNode | str], + ): to_agent_names = [x.name for x in delegation_node.config.to] delegation_node_name = delegation_node.name next_node_name = nextNode if isinstance(nextNode, str) else nextNode.name retry = delegation_node.config.retry or 1 - + conditional_map = {} for agent_name in to_agent_names: conditional_map[agent_name] = agent_name conditional_map[next_node_name] = next_node_name workflow.add_node( - INTERNAL_NODE_DELEGATION_MANAGER, + INTERNAL_NODE_DELEGATION_MANAGER, functools.partial( - self.update_reflection_state, - reflection_agent_name=delegation_node_name - ) + self.update_reflection_state, reflection_agent_name=delegation_node_name + ), ) workflow.add_edge(parent.name, INTERNAL_NODE_DELEGATION_MANAGER) workflow.add_conditional_edges( - INTERNAL_NODE_DELEGATION_MANAGER, - self.__get_refelection_routing_fn(retry, delegation_node_name, next_node_name), - { delegation_node_name: delegation_node_name, next_node_name: next_node_name} + INTERNAL_NODE_DELEGATION_MANAGER, + self.__get_refelection_routing_fn( + retry, delegation_node_name, next_node_name + ), + { + delegation_node_name: delegation_node_name, + next_node_name: next_node_name, + }, ) workflow.add_conditional_edges( - delegation_node_name, + delegation_node_name, FloRouter.__get_delegation_router_fn(next_node_name), - conditional_map + conditional_map, ) @staticmethod @@ -98,35 +123,48 @@ def delegation_router(state: TeamFloAgentState): if STATE_NAME_NEXT not in state: return nextNode return state[STATE_NAME_NEXT] + return delegation_router - - def add_reflection_edge(self, workflow: StateGraph, reflection_node: FloNode, nextNode: Union[FloNode | str]): + + def add_reflection_edge( + self, + workflow: StateGraph, + reflection_node: FloNode, + nextNode: Union[FloNode | str], + ): to_agent_name = reflection_node.config.to[0].name retry = reflection_node.config.retry or 1 reflection_agent_name = reflection_node.name next = nextNode if isinstance(nextNode, str) else nextNode.name - - workflow.add_node(INTERNAL_NODE_REFLECTION_MANAGER, functools.partial(self.update_reflection_state, reflection_agent_name=reflection_agent_name)) - + + workflow.add_node( + INTERNAL_NODE_REFLECTION_MANAGER, + functools.partial( + self.update_reflection_state, + reflection_agent_name=reflection_agent_name, + ), + ) + workflow.add_edge(to_agent_name, INTERNAL_NODE_REFLECTION_MANAGER) workflow.add_conditional_edges( - INTERNAL_NODE_REFLECTION_MANAGER, - self.__get_refelection_routing_fn(retry, reflection_agent_name, next), - { reflection_agent_name: reflection_agent_name, next: next } + INTERNAL_NODE_REFLECTION_MANAGER, + self.__get_refelection_routing_fn(retry, reflection_agent_name, next), + {reflection_agent_name: reflection_agent_name, next: next}, ) workflow.add_edge(reflection_agent_name, to_agent_name) @staticmethod - def __get_refelection_routing_fn(retries: int, reflection_agent_name, next_node_name): + def __get_refelection_routing_fn( + retries: int, reflection_agent_name, next_node_name + ): def reflection_routing_fn(state: TeamFloAgentState): tracker = state[STATE_NAME_LOOP_CONTROLLER] - if tracker is not None and reflection_agent_name in tracker and tracker[reflection_agent_name] > retries: + if ( + tracker is not None + and reflection_agent_name in tracker + and tracker[reflection_agent_name] > retries + ): return next_node_name return reflection_agent_name return reflection_routing_fn - - - - - \ No newline at end of file diff --git a/flo_ai/router/flo_router_factory.py b/flo_ai/router/flo_router_factory.py index 86e93da6..82460271 100644 --- a/flo_ai/router/flo_router_factory.py +++ b/flo_ai/router/flo_router_factory.py @@ -9,23 +9,31 @@ from flo_ai.error.flo_exception import FloException from flo_ai.constants.common_constants import DOCUMENTATION_ROUTER_ANCHOR -class FloRouterFactory: +class FloRouterFactory: @staticmethod - def create(session: FloSession, team_config: TeamConfig, flo_team: FloTeam) -> FloRouter: + def create( + session: FloSession, team_config: TeamConfig, flo_team: FloTeam + ) -> FloRouter: router_kind = team_config.router.kind - router_model = FloRouterFactory.__resolve_model(session, team_config.router.model) + router_model = FloRouterFactory.__resolve_model( + session, team_config.router.model + ) if router_kind == 'supervisor': - return FloSupervisor.Builder(session, team_config, flo_team, llm=router_model).build() + return FloSupervisor.Builder( + session, team_config, flo_team, llm=router_model + ).build() elif router_kind == 'linear': return FloLinear.Builder(session, team_config, flo_team).build() elif router_kind == 'llm': - return FloLLMRouter.Builder(session, team_config, flo_team, llm=router_model).build() + return FloLLMRouter.Builder( + session, team_config, flo_team, llm=router_model + ).build() else: raise Exception(f"""Unknown router type: {router_kind}. The supported types are supervisor, linear and llm. Check the documentation @ {DOCUMENTATION_ROUTER_ANCHOR}""") - + @staticmethod def __resolve_model(session: FloSession, model_name: Optional[str] = None): if model_name is None: @@ -34,5 +42,6 @@ def __resolve_model(session: FloSession, model_name: Optional[str] = None): raise FloException( f"""Model not found: {model_name}. The model you would like to use should be registered to the session using session.register_model api, - and the same model name should be used here instead of `{model_name}`""") + and the same model name should be used here instead of `{model_name}`""" + ) return session.models[model_name] diff --git a/flo_ai/router/flo_supervisor.py b/flo_ai/router/flo_supervisor.py index a16ba2b3..a537d88d 100644 --- a/flo_ai/router/flo_supervisor.py +++ b/flo_ai/router/flo_supervisor.py @@ -11,33 +11,30 @@ from flo_ai.router.flo_llm_router import NextAgent supervisor_system_message = ( - "You are a supervisor tasked with managing a conversation between the" - " following {member_type}: {members}. Given the following user request," - " respond with the worker to act next. Each worker will perform a" - " task and respond with their results and status. When the users question is answered or the assigned task is finished," - " respond with FINISH. " + 'You are a supervisor tasked with managing a conversation between the' + ' following {member_type}: {members}. Given the following user request,' + ' respond with the worker to act next. Each worker will perform a' + ' task and respond with their results and status. When the users question is answered or the assigned task is finished,' + ' respond with FINISH. ' ) + class FloSupervisor(FloLLMRouter): - - def __init__(self, - session: FloSession, - executor: Runnable, - flo_team: FloTeam, - name: str) -> None: + def __init__( + self, session: FloSession, executor: Runnable, flo_team: FloTeam, name: str + ) -> None: super().__init__( - session = session, - name = name, - flo_team = flo_team, - executor = executor + session=session, name=name, flo_team=flo_team, executor=executor ) class Builder: - def __init__(self, - session: FloSession, - team_config: TeamConfig, - flo_team: FloTeam, - llm: Union[BaseLanguageModel, None] = None) -> None: + def __init__( + self, + session: FloSession, + team_config: TeamConfig, + flo_team: FloTeam, + llm: Union[BaseLanguageModel, None] = None, + ) -> None: self.name = team_config.router.name self.session = session self.llm = llm if llm is not None else session.llm @@ -45,33 +42,33 @@ def __init__(self, self.agents = flo_team.members self.members = [agent.name for agent in flo_team.members] self.options = self.members + [FLO_FINISH] - member_type = "workers" if flo_team.members[0].type == "agent" else "team members" + member_type = ( + 'workers' if flo_team.members[0].type == 'agent' else 'team members' + ) self.parser = JsonOutputParser(pydantic_object=NextAgent) self.supervisor_prompt = ChatPromptTemplate.from_messages( [ - ("system", supervisor_system_message), - MessagesPlaceholder(variable_name="messages"), + ('system', supervisor_system_message), + MessagesPlaceholder(variable_name='messages'), ( - "system", - "Given the conversation above, who should act next?" - " Or should we FINISH if the task is already answered, Select one of: {options} \n {format_instructions}", + 'system', + 'Given the conversation above, who should act next?' + ' Or should we FINISH if the task is already answered, Select one of: {options} \n {format_instructions}', ), ] ).partial( - options=str(self.options), - members=", ".join(self.members), + options=str(self.options), + members=', '.join(self.members), member_type=member_type, - format_instructions=self.parser.get_format_instructions() + format_instructions=self.parser.get_format_instructions(), ) - + def build(self): - chain = ( - self.supervisor_prompt - | self.llm - | self.parser - ) + chain = self.supervisor_prompt | self.llm | self.parser - return FloSupervisor(executor=chain, - flo_team=self.flo_team, - name=self.name, - session=self.session) \ No newline at end of file + return FloSupervisor( + executor=chain, + flo_team=self.flo_team, + name=self.name, + session=self.session, + ) diff --git a/flo_ai/state/flo_session.py b/flo_ai/state/flo_session.py index 7bb6557f..38cbd1bc 100644 --- a/flo_ai/state/flo_session.py +++ b/flo_ai/state/flo_session.py @@ -9,32 +9,34 @@ from typing import Optional + def _handle_agent_error(error) -> str: - error_message = str(error)[:50] - return f""" + error_message = str(error)[:50] + return f""" Following error happened while agent execution, please retry with the fix for the same: {error_message} """ -class FloSession: - def __init__(self, - default_llm: BaseLanguageModel = None, - loop_size: int = 2, - max_loop: int = 3, - llm: BaseLanguageModel = None, - log_level: Optional[str] = None, - callbacks: Optional[FloLangchainLogger] = None, - on_agent_error=_handle_agent_error) -> None: - +class FloSession: + def __init__( + self, + default_llm: BaseLanguageModel = None, + loop_size: int = 2, + max_loop: int = 3, + llm: BaseLanguageModel = None, + log_level: Optional[str] = None, + callbacks: Optional[FloLangchainLogger] = None, + on_agent_error=_handle_agent_error, + ) -> None: if log_level: warnings.warn( - "`log_level` is deprecated and will be removed in a future version. " - "Please use `Flo.set_log_level()` instead.", + '`log_level` is deprecated and will be removed in a future version. ' + 'Please use `Flo.set_log_level()` instead.', DeprecationWarning, stacklevel=2, ) - + self.session_id = str(random_str(16)) self.llm = self.resolve_llm(default_llm, llm) self.tools = dict() @@ -46,18 +48,20 @@ def __init__(self, self.loop_size: int = loop_size self.max_loop: int = max_loop self.on_agent_error = on_agent_error - + self.config: Union[FloRoutedTeamConfig, FloAgentConfig] = None - get_logger().info(f"New session created ...", self) - self.langchain_logger = FloLangchainLogger(self.session_id) + get_logger().info('New session created ...', self) + self.langchain_logger = FloLangchainLogger(self.session_id) - def resolve_llm(self, default_llm: BaseLanguageModel = None, llm: BaseLanguageModel = None): + def resolve_llm( + self, default_llm: BaseLanguageModel = None, llm: BaseLanguageModel = None + ): if default_llm is not None: return default_llm if llm: warnings.warn( - "`llm` is deprecated and will be removed in a future version. " - "Please use `default_llm` instead.", + '`llm` is deprecated and will be removed in a future version. ' + 'Please use `default_llm` instead.', DeprecationWarning, stacklevel=2, ) @@ -67,20 +71,22 @@ def register_tool(self, name: str, tool: BaseTool): self.tools[name] = tool get_logger().info(f"Tool '{name}' registered for session {self.session_id}") return self - + def register_model(self, name: str, model: BaseLanguageModel): self.models[name] = model get_logger().info(f"Model '{name}' registered for session {self.session_id}") return self def append(self, node: str) -> int: - get_logger().debug(f"Appending node: {node}") + get_logger().debug(f'Appending node: {node}') self.counter[node] = self.counter.get(node, 0) + 1 if node in self.navigation: - last_known_index = len(self.navigation) - 1 - self.navigation[::-1].index(node) - pattern_array = self.navigation[last_known_index: len(self.navigation)] + last_known_index = ( + len(self.navigation) - 1 - self.navigation[::-1].index(node) + ) + pattern_array = self.navigation[last_known_index : len(self.navigation)] if len(pattern_array) + 1 >= self.loop_size: - pattern = "|".join(pattern_array) + "|" + node + pattern = '|'.join(pattern_array) + '|' + node if node in self.pattern_series: self.pattern_series[node].append(pattern) else: @@ -88,11 +94,11 @@ def append(self, node: str) -> int: self.navigation.append(node) def is_looping(self, node) -> bool: - get_logger().debug(f"Checking if node {node} is looping") + get_logger().debug(f'Checking if node {node} is looping') patterns = self.pattern_series[node] if node in self.pattern_series else [] if len(patterns) < self.max_loop: return False - return patterns[-(self.max_loop):] == [patterns[-1]] * self.max_loop + return patterns[-(self.max_loop) :] == [patterns[-1]] * self.max_loop def stringify(self): - return str(self.counter) \ No newline at end of file + return str(self.counter) diff --git a/flo_ai/state/flo_state.py b/flo_ai/state/flo_state.py index 9e35b7e3..485e1aee 100644 --- a/flo_ai/state/flo_state.py +++ b/flo_ai/state/flo_state.py @@ -1,12 +1,13 @@ -from typing import Annotated, List, Sequence, TypedDict +from typing import Annotated, List, Sequence, TypedDict, Tuple from langchain_core.messages import BaseMessage -from typing import List, Tuple, Annotated, TypedDict + import operator -STATE_NAME_LOOP_CONTROLLER = "loop_tracker" -STATE_NAME_NEXT = "next" -STATE_NAME_MESSAGES = "messages" +STATE_NAME_LOOP_CONTROLLER = 'loop_tracker' +STATE_NAME_NEXT = 'next' +STATE_NAME_MESSAGES = 'messages' + # The agent state is the input to each node in the graph class TeamFloAgentState(TypedDict): @@ -18,8 +19,9 @@ class TeamFloAgentState(TypedDict): # used for reflection agents loop_tracker: dict + class TeamFloAgentStateWithPlan(TypedDict): input: str plan: List[str] past_steps: Annotated[List[Tuple], operator.add] - response: str \ No newline at end of file + response: str diff --git a/flo_ai/tools/flo_tool.py b/flo_ai/tools/flo_tool.py index 92dcb94d..da6bbd7a 100644 --- a/flo_ai/tools/flo_tool.py +++ b/flo_ai/tools/flo_tool.py @@ -2,18 +2,17 @@ from langchain.tools import tool from functools import wraps from typing import Optional -from flo_ai.constants.common_constants import DOCUMENTATION_AGENT_TOOLS -from flo_ai.error.flo_exception import FloException + def flotool( - name: str, - description: Optional[str] = None, - argument_contract: Optional[type] = None, - unsafe: bool = False): - + name: str, + description: Optional[str] = None, + argument_contract: Optional[type] = None, + unsafe: bool = False, +): def decorator(func): func.__doc__ = func.__doc__ or description - + @tool(name, args_schema=argument_contract) @wraps(func) def wrapper(*args, **kwargs): @@ -22,9 +21,8 @@ def wrapper(*args, **kwargs): except Exception as e: if unsafe: raise e - return f"An error occurred while executing the tool: {str(e)}, please retry with the corresponding fix" - - + return f'An error occurred while executing the tool: {str(e)}, please retry with the corresponding fix' + @tool(name, args_schema=argument_contract) @wraps(func) async def async_wrapper(*args, **kwargs): @@ -33,8 +31,8 @@ async def async_wrapper(*args, **kwargs): except Exception as e: if unsafe: raise e - return f"An error occurred while executing the tool: {str(e)}, please retry with the corresponding fix" - + return f'An error occurred while executing the tool: {str(e)}, please retry with the corresponding fix' + return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper - + return decorator diff --git a/flo_ai/yaml/config.py b/flo_ai/yaml/config.py index b0daa979..66361f8e 100644 --- a/flo_ai/yaml/config.py +++ b/flo_ai/yaml/config.py @@ -6,38 +6,41 @@ from flo_ai.models.exception import FloValidationException -KIND_SUPERVISED_TEAM = "FloRoutedTeam" -KIND_FLO_AGENT = "FloAgent" +KIND_SUPERVISED_TEAM = 'FloRoutedTeam' +KIND_FLO_AGENT = 'FloAgent' + +yaml_kinds = [KIND_SUPERVISED_TEAM, KIND_FLO_AGENT] -yaml_kinds = [ - KIND_SUPERVISED_TEAM, - KIND_FLO_AGENT -] class KeyValueArgs(BaseModel): name: str value: str + class FilterArgs(BaseModel): name: str description: str type: str + class ToolConfig(BaseModel): name: str args: Optional[List[KeyValueArgs]] = None properties: Optional[List[KeyValueArgs]] = None filters: Optional[List[FilterArgs]] = None + class PromptStrategy(BaseModel): kind: str agent_name: str retries: int next: str | None = None + class MemberKey(BaseModel): name: str + class AgentConfig(BaseModel): name: str role: Optional[str] = None @@ -48,11 +51,13 @@ class AgentConfig(BaseModel): retry: Optional[int] = 1 model: Optional[str] = None + class EdgeConfig(BaseModel): edge: List[str] type: Optional[str] = None rule: Optional[str] = None + class RouterConfig(BaseModel): name: str kind: str @@ -62,9 +67,11 @@ class RouterConfig(BaseModel): end_node: Union[Optional[str], List[str]] = None edges: Optional[List[EdgeConfig]] = None + class PlannerConfig(BaseModel): name: str + class TeamConfig(BaseModel): name: str kind: Optional[str] = None @@ -73,21 +80,24 @@ class TeamConfig(BaseModel): router: Optional[RouterConfig] = None planner: Optional[PlannerConfig] = None + class FloRoutedTeamConfig(BaseModel): apiVersion: str kind: str name: str team: TeamConfig + class FloAgentConfig(BaseModel): apiVersion: str kind: Optional[str] = None name: str agent: AgentConfig + def to_supervised_team(yaml_str: str) -> FloRoutedTeamConfig: parsed_data = yaml.safe_load(yaml_str) - kind = parsed_data["kind"] + kind = parsed_data['kind'] if kind == KIND_SUPERVISED_TEAM: flo_supervised_team = FloRoutedTeamConfig(**parsed_data) validate_sup_team_config(flo_supervised_team) @@ -97,14 +107,18 @@ def to_supervised_team(yaml_str: str) -> FloRoutedTeamConfig: validate_sup_team_config(flo_agent) return flo_agent else: - raise FloValidationException("Unknown kind: {}".format(kind)) + raise FloValidationException('Unknown kind: {}'.format(kind)) + def validate_sup_team_config(flo: FloRoutedTeamConfig): if flo.kind == KIND_FLO_AGENT: return if flo.name is None or not is_valid_name(flo.name): - raise FloValidationException("Invalid agent name while creating the flow, expected: [^[a-z][a-z0-9_-]*$]") - + raise FloValidationException( + 'Invalid agent name while creating the flow, expected: [^[a-z][a-z0-9_-]*$]' + ) + + def is_valid_name(s: str) -> bool: pattern = r'^[a-z][a-z0-9_-]*$' - return bool(re.match(pattern, s)) \ No newline at end of file + return bool(re.match(pattern, s)) diff --git a/flo_ai/yaml/validators.py b/flo_ai/yaml/validators.py index a0d318ab..288a87b2 100644 --- a/flo_ai/yaml/validators.py +++ b/flo_ai/yaml/validators.py @@ -3,6 +3,9 @@ name_regex = r'^[a-zA-Z0-9-_]+$' + def raise_for_name_error(string): if not re.match(name_regex, string): - raise FloException("Name must contain only alphanumeric characters and hyphens.") \ No newline at end of file + raise FloException( + 'Name must contain only alphanumeric characters and hyphens.' + ) diff --git a/poetry.lock b/poetry.lock index b8504c37..9dbadde3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -478,6 +478,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.4.0" @@ -848,6 +859,17 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "distlib" +version = "0.3.9" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87"}, + {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, +] + [[package]] name = "distro" version = "1.9.0" @@ -1468,6 +1490,20 @@ files = [ [package.dependencies] pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} +[[package]] +name = "identify" +version = "2.6.2" +description = "File identification library for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "identify-2.6.2-py2.py3-none-any.whl", hash = "sha256:c097384259f49e372f4ea00a19719d95ae27dd5ff0fd77ad630aa891306b82f3"}, + {file = "identify-2.6.2.tar.gz", hash = "sha256:fab5c716c24d7a789775228823797296a2994b075fb6080ac83a102772a98cbd"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.10" @@ -2456,6 +2492,17 @@ files = [ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "numpy" version = "1.26.4" @@ -3022,6 +3069,24 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] sentry = ["django", "sentry-sdk"] test = ["coverage", "django", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] +[[package]] +name = "pre-commit" +version = "4.0.1" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878"}, + {file = "pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "prompt-toolkit" version = "3.0.48" @@ -4617,6 +4682,26 @@ dev = ["Cython (>=3.0,<4.0)", "setuptools (>=60)"] docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] +[[package]] +name = "virtualenv" +version = "20.27.1" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.8" +files = [ + {file = "virtualenv-20.27.1-py3-none-any.whl", hash = "sha256:f11f1b8a29525562925f745563bfd48b189450f61fb34c4f9cc79dd5aa32a1f4"}, + {file = "virtualenv-20.27.1.tar.gz", hash = "sha256:142c6be10212543b32c6c45d3d3893dff89112cc588b7d0879ae5a1ec03a47ba"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "watchfiles" version = "0.24.0" @@ -4804,6 +4889,17 @@ files = [ {file = "websockets-14.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:3f1a697262e28682222f18fae70eb0800dfa50c6eb96b0561c6beb83d6cf78ca"}, {file = "websockets-14.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e0e543e0e81c55e68552bd3c081282721c710a6379a2a78e1ec793853479b25"}, {file = "websockets-14.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2786c74cbcb0263fd541e4a075aa8c932bdcaa91e5bbb8649c65304799acdd64"}, + {file = "websockets-14.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:176b39547950ff3520728bd1eadd0fa02c68492a1fabca636bab7883dd390905"}, + {file = "websockets-14.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86626d560ceb9d846d128b9c7bd2d0f247dbb62fb49c386762d109583140bf48"}, + {file = "websockets-14.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ca447967131023e98fcb4867f05cf8584adb424b9108180b2414745a6ff41c31"}, + {file = "websockets-14.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c4eb304743ab285f8f057344d115259fbe31e42151b9aae7610db83d2a7379b1"}, + {file = "websockets-14.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:cc7dbe53276429b2ca511a04a3979ce27aa2088fdd28c119c6913dccdfd0e909"}, + {file = "websockets-14.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6dd785f7a521189b1233d3c86c0b66fb73d4769a1d253ce5b31081c5946f05f"}, + {file = "websockets-14.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:77697c303b874daf1c76d4e167cd5d6871c26964bc189e4bdb40427067d53a86"}, + {file = "websockets-14.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20979614e4d7266f15018c154255d35dfb9fc828fdf6b4924166b6728fed359f"}, + {file = "websockets-14.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3fb3d9e3940ea15b30404200e768e6111c3ee2956c60ceb001cae057961ab058"}, + {file = "websockets-14.0-py3-none-any.whl", hash = "sha256:1a3bca8cfb66614e23a65aa5d6b87190876ec6f3247094939f9db877db55319c"}, + {file = "websockets-14.0.tar.gz", hash = "sha256:be90aa6dab180fed523c0c10a6729ad16c9ba79067402d01a4d8aa7ce48d4084"}, ] [[package]] @@ -5017,4 +5113,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "c74b260a1a71aa4ce734a0de3d6ee1cd77f11f3b27937c59c8fc64e15a24ee34" +content-hash = "e79cf016126b346a5c982216d64408333a26914df79755ac59453deede7b2438" diff --git a/pyproject.toml b/pyproject.toml index 0cbdb9ae..b4fb3835 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ python-dotenv = "^1.0.1" wikipedia = "^1.4.0" pytest = "^8.3.3" pytest-asyncio = "^0.24.0" +pre-commit = "^4.0.1" [build-system] requires = ["poetry-core"] diff --git a/setup.py b/setup.py index 7a6b8c69..d9cbfa82 100644 --- a/setup.py +++ b/setup.py @@ -1,23 +1,23 @@ import setuptools -with open("README.md", "r") as fh: +with open('README.md', 'r') as fh: long_description = fh.read() setuptools.setup( - name="flo-ai", - version="0.0.3", - author="Rootflo", - description="Create composable AI agents", - long_description=long_description, - long_description_content_type="text/markdown", - packages=setuptools.find_packages(), + name='flo-ai', + version='0.0.3', + author='Rootflo', + description='Create composable AI agents', + long_description=long_description, + long_description_content_type='text/markdown', + packages=setuptools.find_packages(), classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - python_requires='>=3.10', - py_modules=["flo"], - package_dir={'':'flo'}, - install_requires=[] -) \ No newline at end of file + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: MIT License', + 'Operating System :: OS Independent', + ], + python_requires='>=3.10', + py_modules=['flo'], + package_dir={'': 'flo'}, + install_requires=[], +) diff --git a/tests/test_comonent_name.py b/tests/test_comonent_name.py index 86188350..a1deff4c 100644 --- a/tests/test_comonent_name.py +++ b/tests/test_comonent_name.py @@ -1,18 +1,21 @@ import pytest from flo_ai.yaml.validators import raise_for_name_error -@pytest.mark.parametrize('flo_name, validity', [ - ('CorrectName', True), - ('Wrong Name', False), - ('correct_name', True), - ('correct-name', True), - ('wrong/name', False), -]) + +@pytest.mark.parametrize( + 'flo_name, validity', + [ + ('CorrectName', True), + ('Wrong Name', False), + ('correct_name', True), + ('correct-name', True), + ('wrong/name', False), + ], +) def test_flo_component_names(flo_name, validity): isException = False try: raise_for_name_error(flo_name) - except Exception as e: + except Exception: isException = True assert isException != validity - \ No newline at end of file diff --git a/tests/test_flotool.py b/tests/test_flotool.py index 2933d447..e814ad9f 100644 --- a/tests/test_flotool.py +++ b/tests/test_flotool.py @@ -8,39 +8,49 @@ from flo_ai.state.flo_session import FloSession from flo_ai.constants.common_constants import DOCUMENTATION_WEBSITE -@flotool(name="AdditionTool", description="Tool to add numbers") + +@flotool(name='AdditionTool', description='Tool to add numbers') async def addition_tool(numbers: List[int]) -> str: result = sum(numbers) - await asyncio.sleep(1) - return f"The sum is {result}" + await asyncio.sleep(1) + return f'The sum is {result}' + -@flotool(name="MultiplicationTool", description="Tool to multiply numbers to get product of numbers") +@flotool( + name='MultiplicationTool', + description='Tool to multiply numbers to get product of numbers', +) def mul_tool(numbers: List[int]) -> str: result = sum(numbers) - return f"The product is {result}" + return f'The product is {result}' + def test_flotool_async(): built_tool = asyncio.iscoroutinefunction(addition_tool.coroutine) assert built_tool + def test_flotool_sync(): built_tool = asyncio.iscoroutinefunction(mul_tool.coroutine) assert not built_tool + @pytest.mark.asyncio async def test_flotool_invoke_with_async_tool(): - result = await addition_tool.ainvoke({"numbers": [1, 32, 2]}) - assert f"The sum is 35" == result + result = await addition_tool.ainvoke({'numbers': [1, 32, 2]}) + assert 'The sum is 35' == result + def test_flotool_invoke_with_sync_tool(): - result = mul_tool.invoke({"numbers": [1, 32, 2]}) - assert f"The product is 35" == result + result = mul_tool.invoke({'numbers': [1, 32, 2]}) + assert 'The product is 35' == result + def test_session_registration_and_invoke(): - llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini', api_key="TEST_KEY") + llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini', api_key='TEST_KEY') session = FloSession(llm) - session.register_tool("adder", addition_tool) - session.register_tool("muller", mul_tool) + session.register_tool('adder', addition_tool) + session.register_tool('muller', mul_tool) mock_agent_yaml = mock_agent_yaml = """ apiVersion: flo/alpha-v1 @@ -57,6 +67,9 @@ def test_session_registration_and_invoke(): flo = Flo.build(session, mock_agent_yaml) try: - flo.invoke("What the whether in berlin") + flo.invoke('What the whether in berlin') except FloException as e: - assert str(e) == f"[Error -1] You seem to have atleast one async tool registered in this session. Please use flo.async_invoke or flo.async_stream. Checkout {DOCUMENTATION_WEBSITE}" \ No newline at end of file + assert ( + str(e) + == f'[Error -1] You seem to have atleast one async tool registered in this session. Please use flo.async_invoke or flo.async_stream. Checkout {DOCUMENTATION_WEBSITE}' + )