diff --git a/.gitignore b/.gitignore index 0787c186..74199cf1 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,6 @@ bin *.sql *.log *.yaml +examples/local/* +.logs scratch_pad.py -examples/local/* \ No newline at end of file diff --git a/README.md b/README.md index 5fef3657..feffe7d3 100644 --- a/README.md +++ b/README.md @@ -405,8 +405,8 @@ You can customize `DataCollector` implementation according to your database. A s from flo_ai.callbacks import FloExecutionLogger from flo_ai.storage.data_collector import JSONLFileCollector -# Initialize the file collector with a path for the JSONL log file -file_collector = JSONLFileCollector("./path/to/my_llm_logs.jsonl") +# Initialize the file collector with a path for the JSONL log file to be stored +file_collector = JSONLFileCollector("'.logs'") # Create a tool logger with the collector local_tracker = FloExecutionLogger(file_collector) @@ -439,6 +439,17 @@ The structured logs provide valuable training data that can be used to: - **Train new models** to replicate successful tool usage patterns - **Create supervised datasets** for tool selection and chain optimization +We have created a script to convert your logs to training data: + +```python +python generate_training_data.py --logger-path PATH --tool-path PATH [--output PATH] +``` + +Arguments: +- *logger-path*: Path to the logger file containing tool and chain entries, eg: .logs/logs/log.jsonl +- *tool-path*: Path to the tool descriptions file eg: eg: .logs/tools/tools.jsonl +- *output*: path to save the output eg: training-data.jsonl + ## 📖 Documentation diff --git a/examples/python/json_training_data_generation.py b/examples/python/json_training_data_generation.py new file mode 100644 index 00000000..ae26640a --- /dev/null +++ b/examples/python/json_training_data_generation.py @@ -0,0 +1,46 @@ +import os +from langchain_openai import ChatOpenAI +from langchain_core.prompts import PromptTemplate +from flo_ai.callbacks import FloExecutionLogger +from flo_ai.storage.data_collector import JSONLFileCollector +from flo_ai import Flo, FloSession +from flo_ai.models.flo_agent import FloAgent +from langchain_community.tools.tavily_search.tool import TavilySearchResults +from dotenv import load_dotenv + + +load_dotenv() + +file_collector = JSONLFileCollector('.logs') + +# Create a tool logger with the collector +local_tracker = FloExecutionLogger(file_collector) +# Create the LLM object +llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') + + +prompt = PromptTemplate.from_template('1 + {number} = ') + +chain = prompt | llm +print(chain.invoke({'number': 2})) + + +session = FloSession(llm) +session.register_callback(local_tracker) + +os.environ['TAVILY_API_KEY'] = os.getenv('TAVILY_API_KEY') +tavily_tool = TavilySearchResults() + +session.register_tool('thappal', tavily_tool) + +weather_agent = FloAgent.create( + session=session, + name='Blogger', + job='You can research the internet and create a blog about the topic given by the user', + tools=[tavily_tool], +) + + +agent_flo: Flo = Flo.create(session, weather_agent) + +print(agent_flo.invoke('Whats the whether in New Delhi, India ?')) diff --git a/examples/python/output_parser.py b/examples/python/output_parser.py index a3d920ed..4aeee0ae 100644 --- a/examples/python/output_parser.py +++ b/examples/python/output_parser.py @@ -1,25 +1,26 @@ -import os -from flo_ai import FloAgent, FloSession, Flo +from flo_ai import FloLLMAgent, FloSession, Flo from langchain_community.tools.tavily_search.tool import TavilySearchResults from dotenv import load_dotenv -from langchain_openai import AzureChatOpenAI +from langchain_openai import ChatOpenAI from flo_ai.parsers import FloJsonParser from flo_ai.state import FloJsonOutputCollector +from flo_ai.callbacks import FloExecutionLogger +from flo_ai.storage.data_collector import JSONLFileCollector load_dotenv() -llm = AzureChatOpenAI( - azure_endpoint=os.getenv('AZURE_GPT4_ENDPOINT'), - model_name='gpt-4o', - temperature=0.2, - max_tokens=4096, - api_version='2024-08-01-preview', - api_key=os.getenv('AZURE_OPEN_AI_API_KEY'), -) +llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini') session = FloSession(llm).register_tool( name='TavilySearchResults', tool=TavilySearchResults() ) + +file_collector = JSONLFileCollector('.logs') + +local_tracker = FloExecutionLogger(file_collector) + +session.register_callback(local_tracker) + format = { 'name': 'NameFormat', 'fields': [ @@ -48,12 +49,11 @@ dc = FloJsonOutputCollector() -researcher = FloAgent.create( +researcher = FloLLMAgent.create( session, - name='Researcher', - role='Internet Researcher', + name='Formatter', + role='Output formatter', job='What is the first name, last name and middle name of the the person user asks about', - tools=[TavilySearchResults()], parser=FloJsonParser.create(json_dict=format), data_collector=dc, ) diff --git a/flo_ai/callbacks/flo_execution_logger.py b/flo_ai/callbacks/flo_execution_logger.py index f0cb04bd..29d174ac 100644 --- a/flo_ai/callbacks/flo_execution_logger.py +++ b/flo_ai/callbacks/flo_execution_logger.py @@ -8,6 +8,13 @@ from langchain_core.prompts.chat import ChatPromptValue from flo_ai.storage.data_collector import DataCollector from flo_ai.common.flo_logger import get_logger +from abc import ABC, abstractmethod + + +class ToolLogger(ABC): + @abstractmethod + def log_all_tools(session_tools): + pass class EnhancedJSONEncoder(json.JSONEncoder): @@ -45,11 +52,14 @@ def default(self, obj): return super().default(obj) -class FloExecutionLogger(BaseCallbackHandler): +class FloExecutionLogger(BaseCallbackHandler, ToolLogger): def __init__(self, data_collector: DataCollector): self.data_collector = data_collector self.runs = {} self.encoder = EnhancedJSONEncoder() + self.query = None + self.added_tools = set() + self.prompt = {} def _encode_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: return json.loads(self.encoder.encode(entry)) @@ -57,10 +67,23 @@ def _encode_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: def _store_entry(self, entry: Dict[str, Any]) -> None: try: encoded_entry = self._encode_entry(entry) - self.data_collector.store_entry(encoded_entry) + self.data_collector.store_log(encoded_entry) except Exception as e: get_logger().error(f'Error storing entry in FloExecutionLogger: {e}') + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + self.prompt[str(run_id)] = prompts + def on_chain_start( self, serialized: Dict[str, Any], @@ -73,11 +96,27 @@ def on_chain_start( chain_name = ( serialized.get('name', 'unnamed_chain') if serialized else 'unnamed_chain' ) + + if parent_run_id and chain_name != 'agent_chain': + return + if isinstance(inputs, dict): + user_input = inputs.get('messages', {}) + else: + user_input = {} + if ( + user_input + and len(user_input) > 0 + and isinstance(user_input[0], HumanMessage) + ): + if isinstance(user_input[0], HumanMessage): + self.query = user_input[0].content + self.runs[str(run_id)] = { 'type': 'chain', 'start_time': datetime.utcnow(), 'inputs': inputs, 'name': chain_name, + 'run_id': str(run_id), 'parent_run_id': str(parent_run_id) if parent_run_id else None, } @@ -91,11 +130,29 @@ def on_chain_end( ) -> None: if str(run_id) in self.runs: run_info = self.runs[str(run_id)] + if run_info['type'] != 'chain' and run_info['type'] != 'llm': + return run_info['end_time'] = datetime.utcnow() run_info['outputs'] = outputs run_info['status'] = 'completed' + run_info['parent_run_id'] = str(parent_run_id) if parent_run_id else None + run_info['prompt'] = ( + self.prompt[str(run_id)] if str(run_id) in self.prompt else [] + ) self._store_entry(run_info) del self.runs[str(run_id)] + else: + if isinstance(outputs, ChatPromptValue) or isinstance(outputs, AgentFinish): + run_info = {} + run_info['type'] = 'llm' + run_info['end_time'] = datetime.utcnow() + run_info['inputs'] = outputs + run_info['status'] = 'completed' + run_info['run_id'] = str(run_id) + run_info['parent_run_id'] = ( + str(parent_run_id) if parent_run_id else None + ) + self.runs[str(parent_run_id)] = run_info def on_chain_error( self, @@ -105,13 +162,7 @@ def on_chain_error( parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: - if str(run_id) in self.runs: - run_info = self.runs[str(run_id)] - run_info['end_time'] = datetime.utcnow() - run_info['error'] = str(error) - run_info['status'] = 'error' - self._store_entry(run_info) - del self.runs[str(run_id)] + pass def on_tool_start( self, @@ -120,10 +171,14 @@ def on_tool_start( *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: self.runs[str(run_id)] = { 'type': 'tool', + 'query': self.query, 'start_time': datetime.utcnow(), 'tool_name': serialized.get('name', 'unnamed_tool'), 'input': input_str, @@ -136,6 +191,7 @@ def on_tool_end( *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: if str(run_id) in self.runs: @@ -197,3 +253,25 @@ def on_agent_finish( 'parent_run_id': str(parent_run_id) if parent_run_id else None, } self._store_entry(log_entry) + + def log_all_tools(self, session_tools): + try: + tools = [] + + for val in session_tools: + tool_name = session_tools[val].name + if tool_name not in self.added_tools: + tools.append( + { + 'tool_name': tool_name, + 'description': session_tools.get(val).description, + 'args': session_tools.get(val).args, + } + ) + self.added_tools.add(tool_name) + + encoded_entry = self._encode_entry(tools) + if encoded_entry: + self.data_collector.store_tool_log(encoded_entry) + except Exception as e: + get_logger().error(f'Error storing tool in FloExecutionLogger: {e}') diff --git a/flo_ai/core.py b/flo_ai/core.py index a8ed2682..5f03dac6 100644 --- a/flo_ai/core.py +++ b/flo_ai/core.py @@ -22,6 +22,7 @@ from flo_ai.models.flo_node import FloNode from flo_ai.models.flo_agent import FloAgent from langchain.tools import StructuredTool +from flo_ai.callbacks.flo_execution_logger import ToolLogger class Flo: @@ -42,23 +43,17 @@ def async_stream(self, query, config=None) -> Iterator[Union[dict[str, Any], Any return self.runnable.astream(query, config) def invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: - config = config or {} - config['callbacks'] = ( - config.get('callbacks', []) - + [self.session.langchain_logger] - + self.session.callbacks - ) + config = self.session.prepare_config(config) + + for callback in self.session.callbacks: + if isinstance(callback, ToolLogger): + callback.log_all_tools(self.session.tools) + self.validate_invoke(self.session) get_logger().info(f"Invoking query: '{query}'", self.session) return self.runnable.invoke(query, config) def async_invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: - config = config or {} - config['callbacks'] = ( - config.get('callbacks', []) - + [self.session.langchain_logger] - + self.session.callbacks - ) get_logger().info(f"Invoking async query: '{query}'", self.session) return self.runnable.ainvoke(query, config) diff --git a/flo_ai/models/flo_node.py b/flo_ai/models/flo_node.py index 22d362bd..8f46f41d 100644 --- a/flo_ai/models/flo_node.py +++ b/flo_ai/models/flo_node.py @@ -36,14 +36,19 @@ def __init__( self.agent_executable = agent_executable def invoke(self, query, config): - return self.func({STATE_NAME_MESSAGES: [HumanMessage(content=query)]}) + return self.func( + {STATE_NAME_MESSAGES: [HumanMessage(content=query)]}, config=config + ) async def ainvoke(self, query, config): return await self.async_func( - {STATE_NAME_MESSAGES: [HumanMessage(content=query)]} + {STATE_NAME_MESSAGES: [HumanMessage(content=query)]}, config=config ) - def draw(self, xray=True): + def draw( + self, + xray=True, + ): return ( self.agent_executable.get_graph().draw_mermaid_png() if self.agent_executable is not None @@ -151,6 +156,7 @@ def __teamflo_agent_node( name: str, session: FloSession, model_name: str, + config=None, data_collector: Optional[FloOutputCollector] = None, ): agent_cbs: List[FloAgentCallback] = FloNode.Builder.__filter_callbacks( @@ -168,7 +174,7 @@ def __teamflo_agent_node( for callback in flo_cbs ] try: - result = agent.invoke(state) + result = agent.invoke(state, config=config) output = result if isinstance(result, str) else result['output'] if data_collector is not None: get_logger().info( @@ -202,6 +208,7 @@ async def __async_teamflo_agent_node( name: str, session: FloSession, model_name: str, + config: dict = None, data_collector: Optional[FloOutputCollector] = None, ): agent_cbs: List[FloAgentCallback] = FloNode.Builder.__filter_callbacks( @@ -219,7 +226,7 @@ async def __async_teamflo_agent_node( for callback in flo_cbs ] try: - result = await agent.ainvoke(state) + result = await agent.ainvoke(state, config=config) output = result if isinstance(result, str) else result['output'] if data_collector is not None: get_logger().info( @@ -258,6 +265,7 @@ def __teamflo_router_node( name: str, session: FloSession, model_name: str, + config: dict = None, ): agent_cbs: List[FloRouterCallback] = FloNode.Builder.__filter_callbacks( session, FloRouterCallback @@ -274,7 +282,7 @@ def __teamflo_router_node( for callback in flo_cbs ] try: - result = agent.invoke(state) + result = agent.invoke(state, config=config) nextNode = result if isinstance(result, str) else result['next'] except Exception as e: [ diff --git a/flo_ai/state/flo_session.py b/flo_ai/state/flo_session.py index 1efeb408..72d40971 100644 --- a/flo_ai/state/flo_session.py +++ b/flo_ai/state/flo_session.py @@ -108,11 +108,16 @@ def register_callback( filter(lambda x: isinstance(x, FloToolCallback), self.callbacks) ) self.langchain_logger = FloLangchainLogger(self.session_id, tool_callbacks) - - if self.llm is not None: - self.llm = self.llm.bind(callbacks=[callback]) + self.callbacks.append(self.langchain_logger) return self + def prepare_config(self, config=None): + get_logger().info(f'Binding all callbacks ... {len(self.callbacks)}') + config = {} if config is None else config + existing_cbs = config['callbacks'] if 'callbacks' in config else [] + config['callbacks'] = self.callbacks + existing_cbs + return config + def append(self, node: str) -> int: get_logger().debug(f'Appending node: {node}') self.counter[node] = self.counter.get(node, 0) + 1 diff --git a/flo_ai/storage/data_collector.py b/flo_ai/storage/data_collector.py index 79a6f053..14d72163 100644 --- a/flo_ai/storage/data_collector.py +++ b/flo_ai/storage/data_collector.py @@ -6,7 +6,11 @@ class DataCollector(ABC): @abstractmethod - def store_entry(self, entry: Dict[str, Any]) -> None: + def store_log(self, entry: Dict[str, Any]) -> None: + pass + + @abstractmethod + def store_tool_log(self, entry: Dict[str, Any]) -> None: pass @abstractmethod @@ -15,12 +19,20 @@ def close(self) -> None: class JSONLFileCollector(DataCollector): - def __init__(self, file_path: str): - self.file_path = Path(file_path) - self.file_path.parent.mkdir(parents=True, exist_ok=True) + def __init__(self, folder_path: str): + self.log_file_path = Path(f'{folder_path}/logs/logs.jsonl') + self.tool_file_path = Path(f'{folder_path}/tools/tools.jsonl') + + self.log_file_path.parent.mkdir(parents=True, exist_ok=True) + self.tool_file_path.parent.mkdir(parents=True, exist_ok=True) + + def store_log(self, entry: Dict[str, Any]) -> None: + with open(self.log_file_path, 'a') as f: + json.dump(entry, f) + f.write('\n') - def store_entry(self, entry: Dict[str, Any]) -> None: - with open(self.file_path, 'a') as f: + def store_tool_log(self, entry: Dict[str, Any]) -> None: + with open(self.tool_file_path, 'a') as f: json.dump(entry, f) f.write('\n') diff --git a/generate_training_data.py b/generate_training_data.py new file mode 100644 index 00000000..8d4760b0 --- /dev/null +++ b/generate_training_data.py @@ -0,0 +1,157 @@ +""" +Generate Training Data Script + +This script processes log files and tool descriptions to generate training datasets. +It handles both tool-based and chain-based data, transforming them into a proper format. + +Usage: + python generate_training_data.py --logger-path PATH --tool-path PATH [--output PATH] + +Arguments: + --logger-path: Path to the logger file containing tool and chain entries + --tool-path: Path to the tool descriptions file + --output: path to save the output +""" + +import json +import argparse + + +def parse_arguments(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description='Generate training data from logs and tool descriptions' + ) + parser.add_argument( + '--logger-path', + required=True, + help='Path to the logger file containing tool and chain entries', + ) + parser.add_argument( + '--tool-path', required=True, help='Path to the tool descriptions file' + ) + parser.add_argument('--output-path', required=True, help='path to save the output') + return parser.parse_args() + + +def read_file(file_path): + """ + Read and parse JSON lines from a file. + Returns: + List of parsed JSON objects + """ + try: + datas = [] + with open(file_path, 'r') as file: + for line in file: + datas.append(json.loads(line)) + return datas + except FileNotFoundError: + raise FileNotFoundError(f'Could not find file: {file_path}') + except Exception as e: + raise Exception(e) + + +def extracting_tool_details(tools, toolbox): + dataset = [] + for tool in tools: + query = tool.get('query') + tool_name = tool.get('tool_name') + for tool_d in toolbox[tool_name]: + if tool_d.get('tool_name') == tool_name: + description = tool_d['description'] + args = tool_d['args'] + tool_input = tool['input'] + dataset.append( + { + 'query': query, + 'tool_name': tool_name, + 'description': description, + 'args': args, + 'tool_input': tool_input, + } + ) + return dataset + + +def tool_transformation(input_data): + transform_data = {} + + for idx, data in enumerate(input_data): + tool_name = data['tool_name'] + description = (data['description'],) + args = data['args'] + tool_input = json.loads(data['tool_input'].replace("'", '"')) + for _, value in args.items(): + if 'title' in value: + del value['title'] + transformed_query = { + 'id': idx, + 'answers': json.dumps([{'name': tool_name, 'arguments': tool_input}]), + 'tools': json.dumps( + [{'name': tool_name, 'description': description, 'parameters': args}] + ), + } + transform_data[tool_name] = transformed_query + return transform_data + + +def chain_transformation(chains, start_idx): + chain_dataset = [] + + for i, chain in enumerate(chains, start_idx): + if chain.get('inputs'): + chain_dataset.append( + { + 'query': chain['prompt'][0], + 'id': i, + 'answer': chain['outputs']['output'], + } + ) + return chain_dataset + + +def llm_transformation(llm_logs, tool_set): + dataset = [] + for i, llm_log in enumerate(llm_logs): + if llm_log['inputs'] and 'messages' in llm_log['inputs']: + tools = None + answer = llm_log['outputs'] + print() + if ( + 'type' in llm_log['outputs'][0] + and llm_log['outputs'][0]['type'] == 'AgentAction' + ): + tools = tool_set[llm_log['outputs'][0]['tool']] + answer = tools['answers'] + dataset.append( + { + 'query': llm_log['inputs']['messages'], + 'id': i, + 'answers': answer, + 'tools': tools['tools'] if tools is not None else None, + } + ) + return dataset + + +if __name__ == '__main__': + args = parse_arguments() + + logger_data = read_file(args.logger_path) + tool_descriptions = read_file(args.tool_path) + + toolbox = {} + for td in tool_descriptions: + toolbox[td[0]['tool_name']] = td + + tools = [entry for entry in logger_data if entry['type'] == 'tool'] + llms = [entry for entry in logger_data if entry['type'] == 'llm'] + + tool_extraction = extracting_tool_details(tools, toolbox) + tool_transformed = tool_transformation(tool_extraction) + training_data = llm_transformation(llms, tool_transformed) + + with open(args.output_path, 'w') as f: + for data in training_data: + f.write(json.dumps(data) + '\n') diff --git a/poetry.lock b/poetry.lock index 6761944f..df4f21d2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiohappyeyeballs"