From 62e958732146c9e6cb59fe645da6e4c8c31dc6fb Mon Sep 17 00:00:00 2001 From: jacobsanosh Date: Tue, 14 Jan 2025 17:17:20 +0530 Subject: [PATCH 1/6] setting the training data --- .codegpt/head | 1 + examples/python/json_encode.py | 63 +++++++++++++++++ examples/python/llm_chain.py | 62 +++++++++++++++++ flo_ai/callbacks/flo_execution_logger.py | 85 +++++++++++++++++++--- flo_ai/core.py | 7 ++ flo_ai/models/flo_node.py | 8 ++- flo_ai/storage/data_collector.py | 15 ++++ poetry.lock | 2 +- produce_training.py | 89 ++++++++++++++++++++++++ 9 files changed, 320 insertions(+), 12 deletions(-) create mode 100644 .codegpt/head create mode 100644 examples/python/json_encode.py create mode 100644 examples/python/llm_chain.py create mode 100644 produce_training.py diff --git a/.codegpt/head b/.codegpt/head new file mode 100644 index 00000000..15f0de48 --- /dev/null +++ b/.codegpt/head @@ -0,0 +1 @@ +ab9178a3-7e3c-481f-918d-03a98ba1d391 \ No newline at end of file diff --git a/examples/python/json_encode.py b/examples/python/json_encode.py new file mode 100644 index 00000000..7570a6d9 --- /dev/null +++ b/examples/python/json_encode.py @@ -0,0 +1,63 @@ +from langchain_openai import AzureChatOpenAI +from langchain.chains import LLMChain +from langchain_core.prompts import PromptTemplate +from flo_ai.callbacks import FloExecutionLogger +from flo_ai.storage.data_collector import JSONLFileCollector, TOOLFileCollector +import os +from flo_ai import Flo, FloSession +from flo_ai.models.flo_agent import FloAgent +from langchain_community.tools.tavily_search.tool import TavilySearchResults +from dotenv import load_dotenv + + +load_dotenv() +api_key = os.getenv("API_KEY") +end_point = os.getenv("END_POINT") +model_name = os.getenv("MODEL") +temperature = os.getenv("TEMP") +api_version = os.getenv("API_VERSION") + +file_collector = JSONLFileCollector('./logger.jsonl') + +tool_collector = TOOLFileCollector('./tools.jsonl') + +# Create a tool logger with the collector +local_tracker = FloExecutionLogger(file_collector, tool_collector) +# Create the LLM object +llm = AzureChatOpenAI( + azure_endpoint=end_point, + model_name=model_name, + temperature=temperature, + api_version=api_version, + api_key=api_key, + callbacks=[local_tracker], +) + + +prompt = PromptTemplate.from_template('1 + {number} = ') + +chain = LLMChain(llm=llm, prompt=prompt, callbacks=[local_tracker]) +print(chain.invoke({'number': 2})) + + +session = FloSession(llm) +session.register_callback(local_tracker) + +os.environ['TAVILY_API_KEY'] = os.getenv("TAVILY_API_KEY") +tavily_tool = TavilySearchResults() + +session.register_tool('thappal', tavily_tool) + + +weather_agent = FloAgent.create( + session=session, + name='Blogger', + job='You can research the internet and create a blog about the topic given by the user', + tools=[tavily_tool], +) + + +agent_flo: Flo = Flo.create(session, weather_agent) +# print("\n\n agent_flo",agent_flo.runnable) +print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) +print(agent_flo.invoke('Whats the whether in Aroor kochi, India ?')) diff --git a/examples/python/llm_chain.py b/examples/python/llm_chain.py new file mode 100644 index 00000000..067f9c84 --- /dev/null +++ b/examples/python/llm_chain.py @@ -0,0 +1,62 @@ +from langchain_core.prompts import PromptTemplate +from langchain_openai import AzureChatOpenAI +from flo_ai.storage.data_collector import JSONLFileCollector, TOOLFileCollector +from flo_ai.callbacks import FloExecutionLogger +from langchain.chains import LLMChain +import os +from flo_ai import Flo, FloSession +from flo_ai.models.flo_agent import FloAgent +from langchain_community.tools.tavily_search.tool import TavilySearchResults +from dotenv import load_dotenv +import os + +load_dotenv() +api_key = os.getenv("API_KEY") +end_point = os.getenv("END_POINT") +model_name = os.getenv("MODEL") +temperature = os.getenv("TEMP") +api_version = os.getenv("API_VERSION") + +prompt_template = 'Tell me a {adjective} joke' +prompt = PromptTemplate(input_variables=['adjective'], template=prompt_template) + +file_collector = JSONLFileCollector('./logger.jsonl') +tool_collector = TOOLFileCollector('./tools.jsonl') + +# Create a tool logger with the collector +local_tracker = FloExecutionLogger(file_collector, tool_collector) + +llm = AzureChatOpenAI( + azure_endpoint=end_point, + model_name=model_name, + temperature=temperature, + api_version=api_version, + api_key=api_key, + callbacks=[local_tracker], +) + + +chain = LLMChain(llm=llm, prompt=prompt, callbacks=[local_tracker]) + +print(chain.invoke({'adjective': 'funny'})) + +session = FloSession(llm) +session.register_callback(local_tracker) + +os.environ['TAVILY_API_KEY'] = os.getenv("TAVILY_API_KEY") +tavily_tool = TavilySearchResults() + +session.register_tool('thappal', tavily_tool) + + +weather_agent = FloAgent.create( + session=session, + name='Blogger', + job='You can research the internet and create a blog about the topic given by the user', + tools=[tavily_tool], +) + + +agent_flo: Flo = Flo.create(session, weather_agent) +# print("\n\n agent_flo",agent_flo.runnable) +print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) diff --git a/flo_ai/callbacks/flo_execution_logger.py b/flo_ai/callbacks/flo_execution_logger.py index f0cb04bd..d0d76879 100644 --- a/flo_ai/callbacks/flo_execution_logger.py +++ b/flo_ai/callbacks/flo_execution_logger.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from datetime import datetime from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler @@ -8,6 +8,14 @@ from langchain_core.prompts.chat import ChatPromptValue from flo_ai.storage.data_collector import DataCollector from flo_ai.common.flo_logger import get_logger +from abc import ABC, abstractmethod +from langchain.schema import HumanMessage + + +class ToolLogger(ABC): + @abstractmethod + def log_all_tools(): + pass class EnhancedJSONEncoder(json.JSONEncoder): @@ -45,11 +53,14 @@ def default(self, obj): return super().default(obj) -class FloExecutionLogger(BaseCallbackHandler): - def __init__(self, data_collector: DataCollector): +class FloExecutionLogger(BaseCallbackHandler, ToolLogger): + def __init__(self, data_collector: DataCollector, tool_collector: DataCollector): self.data_collector = data_collector + self.tool_collector = tool_collector self.runs = {} self.encoder = EnhancedJSONEncoder() + self.query = None + self.added_tools = set() def _encode_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: return json.loads(self.encoder.encode(entry)) @@ -61,6 +72,19 @@ def _store_entry(self, entry: Dict[str, Any]) -> None: except Exception as e: get_logger().error(f'Error storing entry in FloExecutionLogger: {e}') + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + self.prompt = prompts + def on_chain_start( self, serialized: Dict[str, Any], @@ -73,6 +97,21 @@ def on_chain_start( chain_name = ( serialized.get('name', 'unnamed_chain') if serialized else 'unnamed_chain' ) + + if parent_run_id and chain_name != 'agent_chain': + return + if isinstance(inputs, dict): + user_input = inputs.get('messages', {}) + else: + user_input = {} + if ( + user_input + and isinstance(user_input[0], HumanMessage) + and len(user_input) > 0 + ): + if isinstance(user_input[0], HumanMessage): + self.query = user_input[0].content + self.runs[str(run_id)] = { 'type': 'chain', 'start_time': datetime.utcnow(), @@ -91,9 +130,13 @@ def on_chain_end( ) -> None: if str(run_id) in self.runs: run_info = self.runs[str(run_id)] + if run_info['type'] != 'chain': + return + run_info['type'] = 'chain' run_info['end_time'] = datetime.utcnow() run_info['outputs'] = outputs run_info['status'] = 'completed' + run_info['prompt'] = self.prompt self._store_entry(run_info) del self.runs[str(run_id)] @@ -105,13 +148,7 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - if str(run_id) in self.runs: - run_info = self.runs[str(run_id)] - run_info['end_time'] = datetime.utcnow() - run_info['error'] = str(error) - run_info['status'] = 'error' - self._store_entry(run_info) - del self.runs[str(run_id)] + pass def on_tool_start( self, @@ -120,15 +157,20 @@ def on_tool_start( *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: self.runs[str(run_id)] = { 'type': 'tool', + 'query': self.query, 'start_time': datetime.utcnow(), 'tool_name': serialized.get('name', 'unnamed_tool'), 'input': input_str, 'parent_run_id': str(parent_run_id) if parent_run_id else None, } + # pass def on_tool_end( self, @@ -136,6 +178,7 @@ def on_tool_end( *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: if str(run_id) in self.runs: @@ -197,3 +240,25 @@ def on_agent_finish( 'parent_run_id': str(parent_run_id) if parent_run_id else None, } self._store_entry(log_entry) + + def log_all_tools(self, session_tools, query): + try: + tools = [] + + for val in session_tools: + tool_name = session_tools[val].name + if tool_name not in self.added_tools: + tools.append( + { + 'tool_name': tool_name, + 'description': session_tools.get(val).description, + 'args': session_tools.get(val).args, + } + ) + self.added_tools.add(tool_name) + + encoded_entry = self._encode_entry(tools) + if encoded_entry: + self.tool_collector.store_entry(encoded_entry) + except Exception as e: + get_logger().error(f'Error storing tool in FloExecutionLogger: {e}') diff --git a/flo_ai/core.py b/flo_ai/core.py index a8ed2682..3084dcd8 100644 --- a/flo_ai/core.py +++ b/flo_ai/core.py @@ -22,6 +22,8 @@ from flo_ai.models.flo_node import FloNode from flo_ai.models.flo_agent import FloAgent from langchain.tools import StructuredTool +import json +from flo_ai.callbacks.flo_execution_logger import ToolLogger class Flo: @@ -48,6 +50,11 @@ def invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: + [self.session.langchain_logger] + self.session.callbacks ) + + for callback in config['callbacks']: + if isinstance(callback, ToolLogger): + callback.log_all_tools(self.session.tools, query) + self.validate_invoke(self.session) get_logger().info(f"Invoking query: '{query}'", self.session) return self.runnable.invoke(query, config) diff --git a/flo_ai/models/flo_node.py b/flo_ai/models/flo_node.py index 22d362bd..51bb980e 100644 --- a/flo_ai/models/flo_node.py +++ b/flo_ai/models/flo_node.py @@ -168,7 +168,13 @@ def __teamflo_agent_node( for callback in flo_cbs ] try: - result = agent.invoke(state) + config = {} + config['callbacks'] = ( + config.get('callbacks', []) + + [session.langchain_logger] + + session.callbacks + ) + result = agent.invoke(state, config) output = result if isinstance(result, str) else result['output'] if data_collector is not None: get_logger().info( diff --git a/flo_ai/storage/data_collector.py b/flo_ai/storage/data_collector.py index 79a6f053..43061c22 100644 --- a/flo_ai/storage/data_collector.py +++ b/flo_ai/storage/data_collector.py @@ -26,3 +26,18 @@ def store_entry(self, entry: Dict[str, Any]) -> None: def close(self) -> None: pass + + +class TOOLFileCollector(DataCollector): + def __init__(self, file_path: str): + print('inint') + self.file_path = Path(file_path) + self.file_path.parent.mkdir(parents=True, exist_ok=True) + + def store_entry(self, entry: Dict[str, Any]) -> None: + with open(self.file_path, 'a') as f: + json.dump(entry, f) + f.write('\n') + + def close(self) -> None: + pass diff --git a/poetry.lock b/poetry.lock index 6761944f..df4f21d2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" diff --git a/produce_training.py b/produce_training.py new file mode 100644 index 00000000..01ee1e17 --- /dev/null +++ b/produce_training.py @@ -0,0 +1,89 @@ +import json + +logger_path = 'logger path' +tool_path = 'tool path' + + +def read_file(file_path): + datas = [] + with open(file_path, 'r') as file: + for line in file: + datas.append(json.loads(line)) + return datas + + +logger = read_file(logger_path) +tools = [] +chains = [] + +for entry in logger: + if entry['type'] == 'tool': + tools.append(entry) + if entry['type'] == 'chain': + chains.append(entry) + +tool_description = read_file(tool_path)[0] + + +dataset = [] +for tool in tools: + query = tool.get('query') + tool_name = tool.get('tool_name') + for tool_d in tool_description: + if tool_d.get('tool_name') == tool_name: + description = tool_d['description'] + args = tool_d['args'] + tool_input = tool['input'] + dataset.append( + { + 'query': query, + 'tool_name': tool_name, + 'description': description, + 'args': args, + 'tool_input': tool_input, + } + ) + + +def transorm_data(input_data): + transform_data = [] + + for idx, data in enumerate(input_data): + query = data['query'] + tool_name = data['tool_name'] + description = (data['description'],) + args = data['args'] + tool_input = json.loads(data['tool_input'].replace("'", '"')) + for key, value in args.items(): + if 'title' in value: + del value['title'] + transformed_query = { + 'query': query, + 'id': idx, + 'answers': json.dumps([{'name': tool_name, 'arguments': tool_input}]), + 'tools': json.dumps( + [{'name': tool_name, 'description': description, 'parameters': args}] + ), + } + transform_data.append(transformed_query) + return transform_data + + +transformed_data = transorm_data(dataset) + +chain_dataset = [] +i = len(transformed_data) +for chain in chains: + if chain.get('inputs'): + chain_dataset.append( + { + 'query': chain['prompt'][0][7:], + 'id': i, + 'answer': json.dumps([chain['outputs']]), + } + ) + i += 1 +transformed_data.extend(chain_dataset) + +for data in transformed_data: + print('\n\n', data) From be8affc83adfe046193d18c6e0bf935c8c197153 Mon Sep 17 00:00:00 2001 From: jacobsanosh Date: Tue, 14 Jan 2025 17:19:12 +0530 Subject: [PATCH 2/6] formatting the files --- examples/python/json_encode.py | 14 +++++++------- examples/python/llm_chain.py | 14 +++++++------- flo_ai/callbacks/flo_execution_logger.py | 3 +-- flo_ai/core.py | 1 - 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/examples/python/json_encode.py b/examples/python/json_encode.py index 7570a6d9..79a1ee21 100644 --- a/examples/python/json_encode.py +++ b/examples/python/json_encode.py @@ -11,11 +11,11 @@ load_dotenv() -api_key = os.getenv("API_KEY") -end_point = os.getenv("END_POINT") -model_name = os.getenv("MODEL") -temperature = os.getenv("TEMP") -api_version = os.getenv("API_VERSION") +api_key = os.getenv('API_KEY') +end_point = os.getenv('END_POINT') +model_name = os.getenv('MODEL') +temperature = os.getenv('TEMP') +api_version = os.getenv('API_VERSION') file_collector = JSONLFileCollector('./logger.jsonl') @@ -25,7 +25,7 @@ local_tracker = FloExecutionLogger(file_collector, tool_collector) # Create the LLM object llm = AzureChatOpenAI( - azure_endpoint=end_point, + azure_endpoint=end_point, model_name=model_name, temperature=temperature, api_version=api_version, @@ -43,7 +43,7 @@ session = FloSession(llm) session.register_callback(local_tracker) -os.environ['TAVILY_API_KEY'] = os.getenv("TAVILY_API_KEY") +os.environ['TAVILY_API_KEY'] = os.getenv('TAVILY_API_KEY') tavily_tool = TavilySearchResults() session.register_tool('thappal', tavily_tool) diff --git a/examples/python/llm_chain.py b/examples/python/llm_chain.py index 067f9c84..61f02ee8 100644 --- a/examples/python/llm_chain.py +++ b/examples/python/llm_chain.py @@ -8,14 +8,14 @@ from flo_ai.models.flo_agent import FloAgent from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv -import os + load_dotenv() -api_key = os.getenv("API_KEY") -end_point = os.getenv("END_POINT") -model_name = os.getenv("MODEL") -temperature = os.getenv("TEMP") -api_version = os.getenv("API_VERSION") +api_key = os.getenv('API_KEY') +end_point = os.getenv('END_POINT') +model_name = os.getenv('MODEL') +temperature = os.getenv('TEMP') +api_version = os.getenv('API_VERSION') prompt_template = 'Tell me a {adjective} joke' prompt = PromptTemplate(input_variables=['adjective'], template=prompt_template) @@ -43,7 +43,7 @@ session = FloSession(llm) session.register_callback(local_tracker) -os.environ['TAVILY_API_KEY'] = os.getenv("TAVILY_API_KEY") +os.environ['TAVILY_API_KEY'] = os.getenv('TAVILY_API_KEY') tavily_tool = TavilySearchResults() session.register_tool('thappal', tavily_tool) diff --git a/flo_ai/callbacks/flo_execution_logger.py b/flo_ai/callbacks/flo_execution_logger.py index d0d76879..6e850f89 100644 --- a/flo_ai/callbacks/flo_execution_logger.py +++ b/flo_ai/callbacks/flo_execution_logger.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional, List +from typing import Any, Dict, Optional from datetime import datetime from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler @@ -9,7 +9,6 @@ from flo_ai.storage.data_collector import DataCollector from flo_ai.common.flo_logger import get_logger from abc import ABC, abstractmethod -from langchain.schema import HumanMessage class ToolLogger(ABC): diff --git a/flo_ai/core.py b/flo_ai/core.py index 3084dcd8..1691e9b5 100644 --- a/flo_ai/core.py +++ b/flo_ai/core.py @@ -22,7 +22,6 @@ from flo_ai.models.flo_node import FloNode from flo_ai.models.flo_agent import FloAgent from langchain.tools import StructuredTool -import json from flo_ai.callbacks.flo_execution_logger import ToolLogger From 1b2827ee2d880bb681123b8b6121c5bf9b74c20b Mon Sep 17 00:00:00 2001 From: jacobsanosh Date: Tue, 14 Jan 2025 17:28:28 +0530 Subject: [PATCH 3/6] formatting the files --- examples/python/json_encode.py | 2 +- examples/python/llm_chain.py | 2 +- flo_ai/storage/data_collector.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/python/json_encode.py b/examples/python/json_encode.py index 79a1ee21..22ae6528 100644 --- a/examples/python/json_encode.py +++ b/examples/python/json_encode.py @@ -58,6 +58,6 @@ agent_flo: Flo = Flo.create(session, weather_agent) -# print("\n\n agent_flo",agent_flo.runnable) + print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) print(agent_flo.invoke('Whats the whether in Aroor kochi, India ?')) diff --git a/examples/python/llm_chain.py b/examples/python/llm_chain.py index 61f02ee8..0679937e 100644 --- a/examples/python/llm_chain.py +++ b/examples/python/llm_chain.py @@ -58,5 +58,5 @@ agent_flo: Flo = Flo.create(session, weather_agent) -# print("\n\n agent_flo",agent_flo.runnable) + print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) diff --git a/flo_ai/storage/data_collector.py b/flo_ai/storage/data_collector.py index 43061c22..d90770a3 100644 --- a/flo_ai/storage/data_collector.py +++ b/flo_ai/storage/data_collector.py @@ -30,7 +30,6 @@ def close(self) -> None: class TOOLFileCollector(DataCollector): def __init__(self, file_path: str): - print('inint') self.file_path = Path(file_path) self.file_path.parent.mkdir(parents=True, exist_ok=True) From de011b110b17a8903d7ad1353855a19809dd42dd Mon Sep 17 00:00:00 2001 From: jacobsanosh Date: Fri, 17 Jan 2025 16:39:12 +0530 Subject: [PATCH 4/6] setting up command line execution for logger --- .codegpt/head | 1 - examples/python/json_encode.py | 6 +- examples/python/llm_chain.py | 4 +- flo_ai/callbacks/flo_execution_logger.py | 6 +- flo_ai/storage/data_collector.py | 2 +- generate_training_data.py | 134 +++++++++++++++++++++++ produce_training.py | 89 --------------- 7 files changed, 142 insertions(+), 100 deletions(-) delete mode 100644 .codegpt/head create mode 100644 generate_training_data.py delete mode 100644 produce_training.py diff --git a/.codegpt/head b/.codegpt/head deleted file mode 100644 index 15f0de48..00000000 --- a/.codegpt/head +++ /dev/null @@ -1 +0,0 @@ -ab9178a3-7e3c-481f-918d-03a98ba1d391 \ No newline at end of file diff --git a/examples/python/json_encode.py b/examples/python/json_encode.py index 22ae6528..772fe026 100644 --- a/examples/python/json_encode.py +++ b/examples/python/json_encode.py @@ -2,7 +2,7 @@ from langchain.chains import LLMChain from langchain_core.prompts import PromptTemplate from flo_ai.callbacks import FloExecutionLogger -from flo_ai.storage.data_collector import JSONLFileCollector, TOOLFileCollector +from flo_ai.storage.data_collector import JSONLFileCollector, ToolLFileCollector import os from flo_ai import Flo, FloSession from flo_ai.models.flo_agent import FloAgent @@ -19,7 +19,7 @@ file_collector = JSONLFileCollector('./logger.jsonl') -tool_collector = TOOLFileCollector('./tools.jsonl') +tool_collector = ToolLFileCollector('./tools.jsonl') # Create a tool logger with the collector local_tracker = FloExecutionLogger(file_collector, tool_collector) @@ -60,4 +60,4 @@ agent_flo: Flo = Flo.create(session, weather_agent) print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) -print(agent_flo.invoke('Whats the whether in Aroor kochi, India ?')) +# print(agent_flo.invoke('Whats the whether in Aroor kochi, India ?')) diff --git a/examples/python/llm_chain.py b/examples/python/llm_chain.py index 0679937e..19e07dbf 100644 --- a/examples/python/llm_chain.py +++ b/examples/python/llm_chain.py @@ -1,6 +1,6 @@ from langchain_core.prompts import PromptTemplate from langchain_openai import AzureChatOpenAI -from flo_ai.storage.data_collector import JSONLFileCollector, TOOLFileCollector +from flo_ai.storage.data_collector import JSONLFileCollector, ToolLFileCollector from flo_ai.callbacks import FloExecutionLogger from langchain.chains import LLMChain import os @@ -21,7 +21,7 @@ prompt = PromptTemplate(input_variables=['adjective'], template=prompt_template) file_collector = JSONLFileCollector('./logger.jsonl') -tool_collector = TOOLFileCollector('./tools.jsonl') +tool_collector = ToolLFileCollector('./tools.jsonl') # Create a tool logger with the collector local_tracker = FloExecutionLogger(file_collector, tool_collector) diff --git a/flo_ai/callbacks/flo_execution_logger.py b/flo_ai/callbacks/flo_execution_logger.py index 6e850f89..dd5a5f28 100644 --- a/flo_ai/callbacks/flo_execution_logger.py +++ b/flo_ai/callbacks/flo_execution_logger.py @@ -104,10 +104,9 @@ def on_chain_start( else: user_input = {} if ( - user_input + user_input and len(user_input) > 0 and isinstance(user_input[0], HumanMessage) - and len(user_input) > 0 - ): + ): if isinstance(user_input[0], HumanMessage): self.query = user_input[0].content @@ -169,7 +168,6 @@ def on_tool_start( 'input': input_str, 'parent_run_id': str(parent_run_id) if parent_run_id else None, } - # pass def on_tool_end( self, diff --git a/flo_ai/storage/data_collector.py b/flo_ai/storage/data_collector.py index d90770a3..d7600728 100644 --- a/flo_ai/storage/data_collector.py +++ b/flo_ai/storage/data_collector.py @@ -28,7 +28,7 @@ def close(self) -> None: pass -class TOOLFileCollector(DataCollector): +class ToolLFileCollector(DataCollector): def __init__(self, file_path: str): self.file_path = Path(file_path) self.file_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/generate_training_data.py b/generate_training_data.py new file mode 100644 index 00000000..8be08de3 --- /dev/null +++ b/generate_training_data.py @@ -0,0 +1,134 @@ +""" +Generate Training Data Script + +This script processes log files and tool descriptions to generate training datasets. +It handles both tool-based and chain-based data, transforming them into a proper format. + +Usage: + python generate_training_data.py --logger-path PATH --tool-path PATH [--output PATH] + +Arguments: + --logger-path: Path to the logger file containing tool and chain entries + --tool-path: Path to the tool descriptions file + --output: Optional path to save the output (defaults to stdout) +""" + + +import json +import argparse + +def parse_arguments(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Generate training data from logs and tool descriptions") + parser.add_argument( + '--logger-path', + required=True, + help='Path to the logger file containing tool and chain entries' + ) + parser.add_argument( + '--tool-path', + required=True, + help='Path to the tool descriptions file' + ) + parser.add_argument( + '--output-path', + required=True, + help='path to save the output' + ) + return parser.parse_args() + +def read_file(file_path): + """ + Read and parse JSON lines from a file. + Returns: + List of parsed JSON objects + """ + try: + datas = [] + with open(file_path, 'r') as file: + for line in file: + datas.append(json.loads(line)) + return datas + except FileNotFoundError: + raise FileNotFoundError(f"Could not find file: {file_path}") + except Exception as e: + raise Exception(e) + +def extracting_tool_details(tools,tool_description): + dataset = [] + for tool in tools: + query = tool.get('query') + tool_name = tool.get('tool_name') + for tool_d in tool_description: + if tool_d.get('tool_name') == tool_name: + description = tool_d['description'] + args = tool_d['args'] + tool_input = tool['input'] + dataset.append( + { + 'query': query, + 'tool_name': tool_name, + 'description': description, + 'args': args, + 'tool_input': tool_input, + } + ) + return dataset + +def tool_transformation(input_data): + transform_data = [] + + for idx, data in enumerate(input_data): + query = data['query'] + tool_name = data['tool_name'] + description = (data['description'],) + args = data['args'] + tool_input = json.loads(data['tool_input'].replace("'", '"')) + for key, value in args.items(): + if 'title' in value: + del value['title'] + transformed_query = { + 'query': query, + 'id': idx, + 'answers': json.dumps([{'name': tool_name, 'arguments': tool_input}]), + 'tools': json.dumps( + [{'name': tool_name, 'description': description, 'parameters': args}] + ), + } + transform_data.append(transformed_query) + return transform_data + +def chain_transformation(chains,start_idx): + chain_dataset = [] + + for i, chain in enumerate(chains, start_idx): + if chain.get('inputs'): + chain_dataset.append({ + 'query': chain['prompt'][0][7:], + 'id': i, + 'answer': json.dumps([chain['outputs']]), + }) + return chain_dataset + +def main(): + args = parse_arguments() + + logger_data = read_file(args.logger_path) + tool_description = read_file(args.tool_path)[0] + + tools = [entry for entry in logger_data if entry['type'] == 'tool'] + chains = [entry for entry in logger_data if entry['type'] == 'chain'] + + tool_extraction = extracting_tool_details(tools,tool_description) + tool_transformed = tool_transformation(tool_extraction) + + chain_dataset = chain_transformation(chains,len(tool_transformed)) + + final_dataset = tool_transformed + chain_dataset + + with open(args.output_path, 'w') as f: + for data in final_dataset: + f.write(json.dumps(data) + '\n') + +if __name__ == "__main__": + main() diff --git a/produce_training.py b/produce_training.py deleted file mode 100644 index 01ee1e17..00000000 --- a/produce_training.py +++ /dev/null @@ -1,89 +0,0 @@ -import json - -logger_path = 'logger path' -tool_path = 'tool path' - - -def read_file(file_path): - datas = [] - with open(file_path, 'r') as file: - for line in file: - datas.append(json.loads(line)) - return datas - - -logger = read_file(logger_path) -tools = [] -chains = [] - -for entry in logger: - if entry['type'] == 'tool': - tools.append(entry) - if entry['type'] == 'chain': - chains.append(entry) - -tool_description = read_file(tool_path)[0] - - -dataset = [] -for tool in tools: - query = tool.get('query') - tool_name = tool.get('tool_name') - for tool_d in tool_description: - if tool_d.get('tool_name') == tool_name: - description = tool_d['description'] - args = tool_d['args'] - tool_input = tool['input'] - dataset.append( - { - 'query': query, - 'tool_name': tool_name, - 'description': description, - 'args': args, - 'tool_input': tool_input, - } - ) - - -def transorm_data(input_data): - transform_data = [] - - for idx, data in enumerate(input_data): - query = data['query'] - tool_name = data['tool_name'] - description = (data['description'],) - args = data['args'] - tool_input = json.loads(data['tool_input'].replace("'", '"')) - for key, value in args.items(): - if 'title' in value: - del value['title'] - transformed_query = { - 'query': query, - 'id': idx, - 'answers': json.dumps([{'name': tool_name, 'arguments': tool_input}]), - 'tools': json.dumps( - [{'name': tool_name, 'description': description, 'parameters': args}] - ), - } - transform_data.append(transformed_query) - return transform_data - - -transformed_data = transorm_data(dataset) - -chain_dataset = [] -i = len(transformed_data) -for chain in chains: - if chain.get('inputs'): - chain_dataset.append( - { - 'query': chain['prompt'][0][7:], - 'id': i, - 'answer': json.dumps([chain['outputs']]), - } - ) - i += 1 -transformed_data.extend(chain_dataset) - -for data in transformed_data: - print('\n\n', data) From 4ec266294bd0ebb52e099e10d23475c5266b56ec Mon Sep 17 00:00:00 2001 From: jacobsanosh Date: Fri, 17 Jan 2025 16:41:00 +0530 Subject: [PATCH 5/6] fixing comments --- generate_training_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generate_training_data.py b/generate_training_data.py index 8be08de3..e0f99f92 100644 --- a/generate_training_data.py +++ b/generate_training_data.py @@ -10,7 +10,7 @@ Arguments: --logger-path: Path to the logger file containing tool and chain entries --tool-path: Path to the tool descriptions file - --output: Optional path to save the output (defaults to stdout) + --output: path to save the output """ From 1bfb8296f7ec8009068ac3c7dc0dc167dbf74c56 Mon Sep 17 00:00:00 2001 From: vizsatiz Date: Sat, 15 Feb 2025 14:52:39 +0530 Subject: [PATCH 6/6] Fix for training data generation script --- .gitignore | 3 +- README.md | 15 ++- ...de.py => json_training_data_generation.py} | 31 ++--- examples/python/llm_chain.py | 62 ---------- examples/python/output_parser.py | 30 ++--- flo_ai/callbacks/flo_execution_logger.py | 40 +++++-- flo_ai/core.py | 17 +-- flo_ai/models/flo_node.py | 26 +++-- flo_ai/state/flo_session.py | 11 +- flo_ai/storage/data_collector.py | 32 +++-- generate_training_data.py | 109 +++++++++++------- 11 files changed, 171 insertions(+), 205 deletions(-) rename examples/python/{json_encode.py => json_training_data_generation.py} (54%) delete mode 100644 examples/python/llm_chain.py diff --git a/.gitignore b/.gitignore index 69cc4315..1cb57571 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ bin *.sql *.log *.yaml -examples/local/* \ No newline at end of file +examples/local/* +.logs \ No newline at end of file diff --git a/README.md b/README.md index 5fef3657..feffe7d3 100644 --- a/README.md +++ b/README.md @@ -405,8 +405,8 @@ You can customize `DataCollector` implementation according to your database. A s from flo_ai.callbacks import FloExecutionLogger from flo_ai.storage.data_collector import JSONLFileCollector -# Initialize the file collector with a path for the JSONL log file -file_collector = JSONLFileCollector("./path/to/my_llm_logs.jsonl") +# Initialize the file collector with a path for the JSONL log file to be stored +file_collector = JSONLFileCollector("'.logs'") # Create a tool logger with the collector local_tracker = FloExecutionLogger(file_collector) @@ -439,6 +439,17 @@ The structured logs provide valuable training data that can be used to: - **Train new models** to replicate successful tool usage patterns - **Create supervised datasets** for tool selection and chain optimization +We have created a script to convert your logs to training data: + +```python +python generate_training_data.py --logger-path PATH --tool-path PATH [--output PATH] +``` + +Arguments: +- *logger-path*: Path to the logger file containing tool and chain entries, eg: .logs/logs/log.jsonl +- *tool-path*: Path to the tool descriptions file eg: eg: .logs/tools/tools.jsonl +- *output*: path to save the output eg: training-data.jsonl + ## 📖 Documentation diff --git a/examples/python/json_encode.py b/examples/python/json_training_data_generation.py similarity index 54% rename from examples/python/json_encode.py rename to examples/python/json_training_data_generation.py index 772fe026..ae26640a 100644 --- a/examples/python/json_encode.py +++ b/examples/python/json_training_data_generation.py @@ -1,9 +1,8 @@ -from langchain_openai import AzureChatOpenAI -from langchain.chains import LLMChain +import os +from langchain_openai import ChatOpenAI from langchain_core.prompts import PromptTemplate from flo_ai.callbacks import FloExecutionLogger -from flo_ai.storage.data_collector import JSONLFileCollector, ToolLFileCollector -import os +from flo_ai.storage.data_collector import JSONLFileCollector from flo_ai import Flo, FloSession from flo_ai.models.flo_agent import FloAgent from langchain_community.tools.tavily_search.tool import TavilySearchResults @@ -11,32 +10,18 @@ load_dotenv() -api_key = os.getenv('API_KEY') -end_point = os.getenv('END_POINT') -model_name = os.getenv('MODEL') -temperature = os.getenv('TEMP') -api_version = os.getenv('API_VERSION') - -file_collector = JSONLFileCollector('./logger.jsonl') -tool_collector = ToolLFileCollector('./tools.jsonl') +file_collector = JSONLFileCollector('.logs') # Create a tool logger with the collector -local_tracker = FloExecutionLogger(file_collector, tool_collector) +local_tracker = FloExecutionLogger(file_collector) # Create the LLM object -llm = AzureChatOpenAI( - azure_endpoint=end_point, - model_name=model_name, - temperature=temperature, - api_version=api_version, - api_key=api_key, - callbacks=[local_tracker], -) +llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') prompt = PromptTemplate.from_template('1 + {number} = ') -chain = LLMChain(llm=llm, prompt=prompt, callbacks=[local_tracker]) +chain = prompt | llm print(chain.invoke({'number': 2})) @@ -48,7 +33,6 @@ session.register_tool('thappal', tavily_tool) - weather_agent = FloAgent.create( session=session, name='Blogger', @@ -60,4 +44,3 @@ agent_flo: Flo = Flo.create(session, weather_agent) print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) -# print(agent_flo.invoke('Whats the whether in Aroor kochi, India ?')) diff --git a/examples/python/llm_chain.py b/examples/python/llm_chain.py deleted file mode 100644 index 19e07dbf..00000000 --- a/examples/python/llm_chain.py +++ /dev/null @@ -1,62 +0,0 @@ -from langchain_core.prompts import PromptTemplate -from langchain_openai import AzureChatOpenAI -from flo_ai.storage.data_collector import JSONLFileCollector, ToolLFileCollector -from flo_ai.callbacks import FloExecutionLogger -from langchain.chains import LLMChain -import os -from flo_ai import Flo, FloSession -from flo_ai.models.flo_agent import FloAgent -from langchain_community.tools.tavily_search.tool import TavilySearchResults -from dotenv import load_dotenv - - -load_dotenv() -api_key = os.getenv('API_KEY') -end_point = os.getenv('END_POINT') -model_name = os.getenv('MODEL') -temperature = os.getenv('TEMP') -api_version = os.getenv('API_VERSION') - -prompt_template = 'Tell me a {adjective} joke' -prompt = PromptTemplate(input_variables=['adjective'], template=prompt_template) - -file_collector = JSONLFileCollector('./logger.jsonl') -tool_collector = ToolLFileCollector('./tools.jsonl') - -# Create a tool logger with the collector -local_tracker = FloExecutionLogger(file_collector, tool_collector) - -llm = AzureChatOpenAI( - azure_endpoint=end_point, - model_name=model_name, - temperature=temperature, - api_version=api_version, - api_key=api_key, - callbacks=[local_tracker], -) - - -chain = LLMChain(llm=llm, prompt=prompt, callbacks=[local_tracker]) - -print(chain.invoke({'adjective': 'funny'})) - -session = FloSession(llm) -session.register_callback(local_tracker) - -os.environ['TAVILY_API_KEY'] = os.getenv('TAVILY_API_KEY') -tavily_tool = TavilySearchResults() - -session.register_tool('thappal', tavily_tool) - - -weather_agent = FloAgent.create( - session=session, - name='Blogger', - job='You can research the internet and create a blog about the topic given by the user', - tools=[tavily_tool], -) - - -agent_flo: Flo = Flo.create(session, weather_agent) - -print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) diff --git a/examples/python/output_parser.py b/examples/python/output_parser.py index a3d920ed..4aeee0ae 100644 --- a/examples/python/output_parser.py +++ b/examples/python/output_parser.py @@ -1,25 +1,26 @@ -import os -from flo_ai import FloAgent, FloSession, Flo +from flo_ai import FloLLMAgent, FloSession, Flo from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv -from langchain_openai import AzureChatOpenAI +from langchain_openai import ChatOpenAI from flo_ai.parsers import FloJsonParser from flo_ai.state import FloJsonOutputCollector +from flo_ai.callbacks import FloExecutionLogger +from flo_ai.storage.data_collector import JSONLFileCollector load_dotenv() -llm = AzureChatOpenAI( - azure_endpoint=os.getenv('AZURE_GPT4_ENDPOINT'), - model_name='gpt-4o', - temperature=0.2, - max_tokens=4096, - api_version='2024-08-01-preview', - api_key=os.getenv('AZURE_OPEN_AI_API_KEY'), -) +llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') session = FloSession(llm).register_tool( name='TavilySearchResults', tool=TavilySearchResults() ) + +file_collector = JSONLFileCollector('.logs') + +local_tracker = FloExecutionLogger(file_collector) + +session.register_callback(local_tracker) + format = { 'name': 'NameFormat', 'fields': [ @@ -48,12 +49,11 @@ dc = FloJsonOutputCollector() -researcher = FloAgent.create( +researcher = FloLLMAgent.create( session, - name='Researcher', - role='Internet Researcher', + name='Formatter', + role='Output formatter', job='What is the first name, last name and middle name of the the person user asks about', - tools=[TavilySearchResults()], parser=FloJsonParser.create(json_dict=format), data_collector=dc, ) diff --git a/flo_ai/callbacks/flo_execution_logger.py b/flo_ai/callbacks/flo_execution_logger.py index dd5a5f28..29d174ac 100644 --- a/flo_ai/callbacks/flo_execution_logger.py +++ b/flo_ai/callbacks/flo_execution_logger.py @@ -13,7 +13,7 @@ class ToolLogger(ABC): @abstractmethod - def log_all_tools(): + def log_all_tools(session_tools): pass @@ -53,13 +53,13 @@ def default(self, obj): class FloExecutionLogger(BaseCallbackHandler, ToolLogger): - def __init__(self, data_collector: DataCollector, tool_collector: DataCollector): + def __init__(self, data_collector: DataCollector): self.data_collector = data_collector - self.tool_collector = tool_collector self.runs = {} self.encoder = EnhancedJSONEncoder() self.query = None self.added_tools = set() + self.prompt = {} def _encode_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: return json.loads(self.encoder.encode(entry)) @@ -67,7 +67,7 @@ def _encode_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: def _store_entry(self, entry: Dict[str, Any]) -> None: try: encoded_entry = self._encode_entry(entry) - self.data_collector.store_entry(encoded_entry) + self.data_collector.store_log(encoded_entry) except Exception as e: get_logger().error(f'Error storing entry in FloExecutionLogger: {e}') @@ -82,7 +82,7 @@ def on_llm_start( metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: - self.prompt = prompts + self.prompt[str(run_id)] = prompts def on_chain_start( self, @@ -104,9 +104,10 @@ def on_chain_start( else: user_input = {} if ( - user_input and len(user_input) > 0 + user_input + and len(user_input) > 0 and isinstance(user_input[0], HumanMessage) - ): + ): if isinstance(user_input[0], HumanMessage): self.query = user_input[0].content @@ -115,6 +116,7 @@ def on_chain_start( 'start_time': datetime.utcnow(), 'inputs': inputs, 'name': chain_name, + 'run_id': str(run_id), 'parent_run_id': str(parent_run_id) if parent_run_id else None, } @@ -128,15 +130,29 @@ def on_chain_end( ) -> None: if str(run_id) in self.runs: run_info = self.runs[str(run_id)] - if run_info['type'] != 'chain': + if run_info['type'] != 'chain' and run_info['type'] != 'llm': return - run_info['type'] = 'chain' run_info['end_time'] = datetime.utcnow() run_info['outputs'] = outputs run_info['status'] = 'completed' - run_info['prompt'] = self.prompt + run_info['parent_run_id'] = str(parent_run_id) if parent_run_id else None + run_info['prompt'] = ( + self.prompt[str(run_id)] if str(run_id) in self.prompt else [] + ) self._store_entry(run_info) del self.runs[str(run_id)] + else: + if isinstance(outputs, ChatPromptValue) or isinstance(outputs, AgentFinish): + run_info = {} + run_info['type'] = 'llm' + run_info['end_time'] = datetime.utcnow() + run_info['inputs'] = outputs + run_info['status'] = 'completed' + run_info['run_id'] = str(run_id) + run_info['parent_run_id'] = ( + str(parent_run_id) if parent_run_id else None + ) + self.runs[str(parent_run_id)] = run_info def on_chain_error( self, @@ -238,7 +254,7 @@ def on_agent_finish( } self._store_entry(log_entry) - def log_all_tools(self, session_tools, query): + def log_all_tools(self, session_tools): try: tools = [] @@ -256,6 +272,6 @@ def log_all_tools(self, session_tools, query): encoded_entry = self._encode_entry(tools) if encoded_entry: - self.tool_collector.store_entry(encoded_entry) + self.data_collector.store_tool_log(encoded_entry) except Exception as e: get_logger().error(f'Error storing tool in FloExecutionLogger: {e}') diff --git a/flo_ai/core.py b/flo_ai/core.py index 1691e9b5..5f03dac6 100644 --- a/flo_ai/core.py +++ b/flo_ai/core.py @@ -43,28 +43,17 @@ def async_stream(self, query, config=None) -> Iterator[Union[dict[str, Any], Any return self.runnable.astream(query, config) def invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: - config = config or {} - config['callbacks'] = ( - config.get('callbacks', []) - + [self.session.langchain_logger] - + self.session.callbacks - ) + config = self.session.prepare_config(config) - for callback in config['callbacks']: + for callback in self.session.callbacks: if isinstance(callback, ToolLogger): - callback.log_all_tools(self.session.tools, query) + callback.log_all_tools(self.session.tools) self.validate_invoke(self.session) get_logger().info(f"Invoking query: '{query}'", self.session) return self.runnable.invoke(query, config) def async_invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: - config = config or {} - config['callbacks'] = ( - config.get('callbacks', []) - + [self.session.langchain_logger] - + self.session.callbacks - ) get_logger().info(f"Invoking async query: '{query}'", self.session) return self.runnable.ainvoke(query, config) diff --git a/flo_ai/models/flo_node.py b/flo_ai/models/flo_node.py index 51bb980e..8f46f41d 100644 --- a/flo_ai/models/flo_node.py +++ b/flo_ai/models/flo_node.py @@ -36,14 +36,19 @@ def __init__( self.agent_executable = agent_executable def invoke(self, query, config): - return self.func({STATE_NAME_MESSAGES: [HumanMessage(content=query)]}) + return self.func( + {STATE_NAME_MESSAGES: [HumanMessage(content=query)]}, config=config + ) async def ainvoke(self, query, config): return await self.async_func( - {STATE_NAME_MESSAGES: [HumanMessage(content=query)]} + {STATE_NAME_MESSAGES: [HumanMessage(content=query)]}, config=config ) - def draw(self, xray=True): + def draw( + self, + xray=True, + ): return ( self.agent_executable.get_graph().draw_mermaid_png() if self.agent_executable is not None @@ -151,6 +156,7 @@ def __teamflo_agent_node( name: str, session: FloSession, model_name: str, + config=None, data_collector: Optional[FloOutputCollector] = None, ): agent_cbs: List[FloAgentCallback] = FloNode.Builder.__filter_callbacks( @@ -168,13 +174,7 @@ def __teamflo_agent_node( for callback in flo_cbs ] try: - config = {} - config['callbacks'] = ( - config.get('callbacks', []) - + [session.langchain_logger] - + session.callbacks - ) - result = agent.invoke(state, config) + result = agent.invoke(state, config=config) output = result if isinstance(result, str) else result['output'] if data_collector is not None: get_logger().info( @@ -208,6 +208,7 @@ async def __async_teamflo_agent_node( name: str, session: FloSession, model_name: str, + config: dict = None, data_collector: Optional[FloOutputCollector] = None, ): agent_cbs: List[FloAgentCallback] = FloNode.Builder.__filter_callbacks( @@ -225,7 +226,7 @@ async def __async_teamflo_agent_node( for callback in flo_cbs ] try: - result = await agent.ainvoke(state) + result = await agent.ainvoke(state, config=config) output = result if isinstance(result, str) else result['output'] if data_collector is not None: get_logger().info( @@ -264,6 +265,7 @@ def __teamflo_router_node( name: str, session: FloSession, model_name: str, + config: dict = None, ): agent_cbs: List[FloRouterCallback] = FloNode.Builder.__filter_callbacks( session, FloRouterCallback @@ -280,7 +282,7 @@ def __teamflo_router_node( for callback in flo_cbs ] try: - result = agent.invoke(state) + result = agent.invoke(state, config=config) nextNode = result if isinstance(result, str) else result['next'] except Exception as e: [ diff --git a/flo_ai/state/flo_session.py b/flo_ai/state/flo_session.py index 1efeb408..72d40971 100644 --- a/flo_ai/state/flo_session.py +++ b/flo_ai/state/flo_session.py @@ -108,11 +108,16 @@ def register_callback( filter(lambda x: isinstance(x, FloToolCallback), self.callbacks) ) self.langchain_logger = FloLangchainLogger(self.session_id, tool_callbacks) - - if self.llm is not None: - self.llm = self.llm.bind(callbacks=[callback]) + self.callbacks.append(self.langchain_logger) return self + def prepare_config(self, config=None): + get_logger().info(f'Binding all callbacks ... {len(self.callbacks)}') + config = {} if config is None else config + existing_cbs = config['callbacks'] if 'callbacks' in config else [] + config['callbacks'] = self.callbacks + existing_cbs + return config + def append(self, node: str) -> int: get_logger().debug(f'Appending node: {node}') self.counter[node] = self.counter.get(node, 0) + 1 diff --git a/flo_ai/storage/data_collector.py b/flo_ai/storage/data_collector.py index d7600728..14d72163 100644 --- a/flo_ai/storage/data_collector.py +++ b/flo_ai/storage/data_collector.py @@ -6,7 +6,11 @@ class DataCollector(ABC): @abstractmethod - def store_entry(self, entry: Dict[str, Any]) -> None: + def store_log(self, entry: Dict[str, Any]) -> None: + pass + + @abstractmethod + def store_tool_log(self, entry: Dict[str, Any]) -> None: pass @abstractmethod @@ -15,26 +19,20 @@ def close(self) -> None: class JSONLFileCollector(DataCollector): - def __init__(self, file_path: str): - self.file_path = Path(file_path) - self.file_path.parent.mkdir(parents=True, exist_ok=True) + def __init__(self, folder_path: str): + self.log_file_path = Path(f'{folder_path}/logs/logs.jsonl') + self.tool_file_path = Path(f'{folder_path}/tools/tools.jsonl') - def store_entry(self, entry: Dict[str, Any]) -> None: - with open(self.file_path, 'a') as f: + self.log_file_path.parent.mkdir(parents=True, exist_ok=True) + self.tool_file_path.parent.mkdir(parents=True, exist_ok=True) + + def store_log(self, entry: Dict[str, Any]) -> None: + with open(self.log_file_path, 'a') as f: json.dump(entry, f) f.write('\n') - def close(self) -> None: - pass - - -class ToolLFileCollector(DataCollector): - def __init__(self, file_path: str): - self.file_path = Path(file_path) - self.file_path.parent.mkdir(parents=True, exist_ok=True) - - def store_entry(self, entry: Dict[str, Any]) -> None: - with open(self.file_path, 'a') as f: + def store_tool_log(self, entry: Dict[str, Any]) -> None: + with open(self.tool_file_path, 'a') as f: json.dump(entry, f) f.write('\n') diff --git a/generate_training_data.py b/generate_training_data.py index e0f99f92..8d4760b0 100644 --- a/generate_training_data.py +++ b/generate_training_data.py @@ -10,38 +10,35 @@ Arguments: --logger-path: Path to the logger file containing tool and chain entries --tool-path: Path to the tool descriptions file - --output: path to save the output + --output: path to save the output """ - import json import argparse + def parse_arguments(): """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Generate training data from logs and tool descriptions") + parser = argparse.ArgumentParser( + description='Generate training data from logs and tool descriptions' + ) parser.add_argument( '--logger-path', required=True, - help='Path to the logger file containing tool and chain entries' + help='Path to the logger file containing tool and chain entries', ) parser.add_argument( - '--tool-path', - required=True, - help='Path to the tool descriptions file' - ) - parser.add_argument( - '--output-path', - required=True, - help='path to save the output' + '--tool-path', required=True, help='Path to the tool descriptions file' ) + parser.add_argument('--output-path', required=True, help='path to save the output') return parser.parse_args() + def read_file(file_path): """ Read and parse JSON lines from a file. Returns: - List of parsed JSON objects + List of parsed JSON objects """ try: datas = [] @@ -50,16 +47,17 @@ def read_file(file_path): datas.append(json.loads(line)) return datas except FileNotFoundError: - raise FileNotFoundError(f"Could not find file: {file_path}") + raise FileNotFoundError(f'Could not find file: {file_path}') except Exception as e: raise Exception(e) -def extracting_tool_details(tools,tool_description): + +def extracting_tool_details(tools, toolbox): dataset = [] for tool in tools: query = tool.get('query') tool_name = tool.get('tool_name') - for tool_d in tool_description: + for tool_d in toolbox[tool_name]: if tool_d.get('tool_name') == tool_name: description = tool_d['description'] args = tool_d['args'] @@ -75,60 +73,85 @@ def extracting_tool_details(tools,tool_description): ) return dataset + def tool_transformation(input_data): - transform_data = [] + transform_data = {} for idx, data in enumerate(input_data): - query = data['query'] tool_name = data['tool_name'] description = (data['description'],) args = data['args'] tool_input = json.loads(data['tool_input'].replace("'", '"')) - for key, value in args.items(): + for _, value in args.items(): if 'title' in value: del value['title'] transformed_query = { - 'query': query, 'id': idx, 'answers': json.dumps([{'name': tool_name, 'arguments': tool_input}]), 'tools': json.dumps( [{'name': tool_name, 'description': description, 'parameters': args}] ), } - transform_data.append(transformed_query) + transform_data[tool_name] = transformed_query return transform_data -def chain_transformation(chains,start_idx): + +def chain_transformation(chains, start_idx): chain_dataset = [] for i, chain in enumerate(chains, start_idx): if chain.get('inputs'): - chain_dataset.append({ - 'query': chain['prompt'][0][7:], - 'id': i, - 'answer': json.dumps([chain['outputs']]), - }) + chain_dataset.append( + { + 'query': chain['prompt'][0], + 'id': i, + 'answer': chain['outputs']['output'], + } + ) return chain_dataset -def main(): + +def llm_transformation(llm_logs, tool_set): + dataset = [] + for i, llm_log in enumerate(llm_logs): + if llm_log['inputs'] and 'messages' in llm_log['inputs']: + tools = None + answer = llm_log['outputs'] + print() + if ( + 'type' in llm_log['outputs'][0] + and llm_log['outputs'][0]['type'] == 'AgentAction' + ): + tools = tool_set[llm_log['outputs'][0]['tool']] + answer = tools['answers'] + dataset.append( + { + 'query': llm_log['inputs']['messages'], + 'id': i, + 'answers': answer, + 'tools': tools['tools'] if tools is not None else None, + } + ) + return dataset + + +if __name__ == '__main__': args = parse_arguments() - + logger_data = read_file(args.logger_path) - tool_description = read_file(args.tool_path)[0] - + tool_descriptions = read_file(args.tool_path) + + toolbox = {} + for td in tool_descriptions: + toolbox[td[0]['tool_name']] = td + tools = [entry for entry in logger_data if entry['type'] == 'tool'] - chains = [entry for entry in logger_data if entry['type'] == 'chain'] - - tool_extraction = extracting_tool_details(tools,tool_description) + llms = [entry for entry in logger_data if entry['type'] == 'llm'] + + tool_extraction = extracting_tool_details(tools, toolbox) tool_transformed = tool_transformation(tool_extraction) - - chain_dataset = chain_transformation(chains,len(tool_transformed)) - - final_dataset = tool_transformed + chain_dataset - + training_data = llm_transformation(llms, tool_transformed) + with open(args.output_path, 'w') as f: - for data in final_dataset: + for data in training_data: f.write(json.dumps(data) + '\n') - -if __name__ == "__main__": - main()