diff --git a/README.md b/README.md index 94f74a21..eac4a92c 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,52 @@ session.register_tool(name='Adder', tool=addition_tool) **Note:** `@flotool` comes with inherent error handling capabilities to retry if an exception is thrown. Use `unsafe=True` to disable error handling +## 📊 Tool Logging and Data Collection + +FloAI provides built-in capabilities for logging tool calls and collecting data through the `FloExecutionLogger` and `DataCollector` classes, facilitating the creation of valuable training data. +You can customize `DataCollector` implementation according to your database. A sample implementation where logs are stored locally as JSON files is implemented in `JSONLFileCollector`. + +### Quick Setup + +```python +from flo_ai.callbacks import FloExecutionLogger +from flo_ai.storage.data_collector import JSONLFileCollector + +# Initialize the file collector with a path for the JSONL log file +file_collector = JSONLFileCollector("./path/to/my_llm_logs.jsonl") + +# Create a tool logger with the collector +local_tracker = FloExecutionLogger(file_collector) + +# Register the logger with your session +session.register_callback(local_tracker) +``` + +### Features + +- 📝 Logs all tool calls, chain executions, and agent actions +- 🕒 Includes timestamps for start and end of operations +- 🔍 Tracks inputs, outputs, and errors +- 💾 Stores data in JSONL format for easy analysis +- 📚 Facilitates the creation of training data from logged interactions + +### Log Data Structure + +The logger captures detailed information including: +- Tool name and inputs +- Execution timestamps +- Operation status (completed/error) +- Chain and agent activities +- Parent-child relationship between operations + +### Training Data Generation + +The structured logs provide valuable training data that can be used to: +- **Fine-tune LLMs** on your specific use cases +- **Train new models** to replicate successful tool usage patterns +- **Create supervised datasets** for tool selection and chain optimization + + ## 📖 Documentation Visit our [comprehensive documentation](https://flo-ai.rootflo.ai) for: diff --git a/examples/python/tool_data_logging_example.py b/examples/python/tool_data_logging_example.py new file mode 100644 index 00000000..ca4e59d2 --- /dev/null +++ b/examples/python/tool_data_logging_example.py @@ -0,0 +1,79 @@ +from flo_ai.callbacks import FloExecutionLogger +from flo_ai.storage.data_collector import JSONLFileCollector +from langchain_openai import AzureChatOpenAI +import os +from dotenv import load_dotenv +from flo_ai import Flo +from flo_ai import FloSession +from typing import List +from flo_ai.tools import flotool + +load_dotenv() + +llm = AzureChatOpenAI( + temperature=0, + deployment_name='gpt-4', + model_name='gpt-4', + azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'), + api_key=os.getenv('AZURE_OPENAI_API_KEY'), + api_version='2024-08-01-preview', +) + +session = FloSession( + llm, + log_level='ERROR', +) + + +@flotool(name='AdditionTool', description='Tool to add numbers') +def addition_tool(numbers: List[int]) -> str: + result = sum(numbers) + return f'The sum is {result}' + + +@flotool( + name='MultiplicationTool', + description='Tool to multiply numbers to get product of numbers', +) +def mul_tool(numbers: List[int]) -> str: + result = 1 + for num in numbers: + result *= num + return f'The product is {result}' + + +session.register_tool(name='Adder', tool=addition_tool).register_tool( + name='Multiplier', tool=mul_tool +) + +simple_calculator_agent = """ +apiVersion: flo/alpha-v1 +kind: FloAgent +name: calculating-assistant +agent: + name: SummationHelper + kind: agentic + job: > + You are a calculation assistant that MUST ONLY use the provided tools for calculations. + You MUST ONLY return the exact outputs from the tools without modification. + You MUST NOT perform any calculations yourself. + If you need both sum and product, you MUST use both tools and combine their exact outputs. + tools: + - name: Adder + - name: Multiplier +""" + + +current_dir = os.path.dirname(os.path.abspath(__file__)) +log_file_path = os.path.join(current_dir, 'my_llm_logs.jsonl') + +file_collector = JSONLFileCollector(log_file_path) +local_tracker = FloExecutionLogger(file_collector) + +session.register_callback(local_tracker) + +flo = Flo.build(session, simple_calculator_agent, log_level='ERROR') + +result = flo.invoke( + 'find the sum of first three numbers and last three numbers and multilply the result. Numbers are 1, 3, 4, 2, 0, 1', +) diff --git a/flo_ai/callbacks/__init__.py b/flo_ai/callbacks/__init__.py index 69fba216..03eb1a64 100644 --- a/flo_ai/callbacks/__init__.py +++ b/flo_ai/callbacks/__init__.py @@ -4,10 +4,12 @@ flo_tool_callback, flo_call_back, ) +from flo_ai.callbacks.flo_execution_logger import FloExecutionLogger __all__ = [ 'flo_agent_callback', 'flo_router_callback', 'flo_tool_callback', 'flo_call_back', + 'FloExecutionLogger', ] diff --git a/flo_ai/callbacks/flo_execution_logger.py b/flo_ai/callbacks/flo_execution_logger.py new file mode 100644 index 00000000..f0cb04bd --- /dev/null +++ b/flo_ai/callbacks/flo_execution_logger.py @@ -0,0 +1,199 @@ +import json +from typing import Any, Dict, Optional +from datetime import datetime +from uuid import UUID +from langchain_core.callbacks import BaseCallbackHandler +from langchain.schema.agent import AgentAction, AgentFinish +from langchain.schema import HumanMessage, AIMessage, BaseMessage +from langchain_core.prompts.chat import ChatPromptValue +from flo_ai.storage.data_collector import DataCollector +from flo_ai.common.flo_logger import get_logger + + +class EnhancedJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (HumanMessage, AIMessage, BaseMessage)): + return { + 'type': obj.__class__.__name__, + 'content': obj.content, + 'additional_kwargs': obj.additional_kwargs, + } + elif isinstance(obj, AgentAction): + return { + 'type': 'AgentAction', + 'tool': obj.tool, + 'tool_input': obj.tool_input, + 'log': obj.log, + } + elif isinstance(obj, AgentFinish): + return { + 'type': 'AgentFinish', + 'return_values': obj.return_values, + 'log': obj.log, + } + elif isinstance(obj, ChatPromptValue): + return { + 'type': 'ChatPromptValue', + 'messages': [self.default(msg) for msg in obj.messages], + } + elif isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, UUID): + return str(obj) + elif hasattr(obj, 'to_dict'): + return obj.to_dict() + return super().default(obj) + + +class FloExecutionLogger(BaseCallbackHandler): + def __init__(self, data_collector: DataCollector): + self.data_collector = data_collector + self.runs = {} + self.encoder = EnhancedJSONEncoder() + + def _encode_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: + return json.loads(self.encoder.encode(entry)) + + def _store_entry(self, entry: Dict[str, Any]) -> None: + try: + encoded_entry = self._encode_entry(entry) + self.data_collector.store_entry(encoded_entry) + except Exception as e: + get_logger().error(f'Error storing entry in FloExecutionLogger: {e}') + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + chain_name = ( + serialized.get('name', 'unnamed_chain') if serialized else 'unnamed_chain' + ) + self.runs[str(run_id)] = { + 'type': 'chain', + 'start_time': datetime.utcnow(), + 'inputs': inputs, + 'name': chain_name, + 'parent_run_id': str(parent_run_id) if parent_run_id else None, + } + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if str(run_id) in self.runs: + run_info = self.runs[str(run_id)] + run_info['end_time'] = datetime.utcnow() + run_info['outputs'] = outputs + run_info['status'] = 'completed' + self._store_entry(run_info) + del self.runs[str(run_id)] + + def on_chain_error( + self, + error: Exception, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if str(run_id) in self.runs: + run_info = self.runs[str(run_id)] + run_info['end_time'] = datetime.utcnow() + run_info['error'] = str(error) + run_info['status'] = 'error' + self._store_entry(run_info) + del self.runs[str(run_id)] + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + self.runs[str(run_id)] = { + 'type': 'tool', + 'start_time': datetime.utcnow(), + 'tool_name': serialized.get('name', 'unnamed_tool'), + 'input': input_str, + 'parent_run_id': str(parent_run_id) if parent_run_id else None, + } + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if str(run_id) in self.runs: + run_info = self.runs[str(run_id)] + run_info['end_time'] = datetime.utcnow() + run_info['output'] = output + run_info['status'] = 'completed' + self._store_entry(run_info) + del self.runs[str(run_id)] + + def on_tool_error( + self, + error: Exception, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if str(run_id) in self.runs: + run_info = self.runs[str(run_id)] + run_info['end_time'] = datetime.utcnow() + run_info['error'] = str(error) + run_info['status'] = 'error' + self._store_entry(run_info) + del self.runs[str(run_id)] + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + agent_info = { + 'type': 'agent_action', + 'start_time': datetime.utcnow(), + 'tool': action.tool, + 'tool_input': action.tool_input, + 'log': action.log, + 'parent_run_id': str(parent_run_id) if parent_run_id else None, + } + self.runs[str(run_id)] = agent_info + self._store_entry(agent_info) + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + log_entry = { + 'type': 'agent_finish', + 'time': datetime.utcnow(), + 'output': finish.return_values, + 'log': finish.log, + 'parent_run_id': str(parent_run_id) if parent_run_id else None, + } + self._store_entry(log_entry) diff --git a/flo_ai/core.py b/flo_ai/core.py index 79b8f9bb..0a76d452 100644 --- a/flo_ai/core.py +++ b/flo_ai/core.py @@ -40,13 +40,23 @@ def async_stream(self, query, config=None) -> Iterator[Union[dict[str, Any], Any return self.runnable.astream(query, config) def invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: - config = {'callbacks': [self.session.langchain_logger]} + config = config or {} + config['callbacks'] = ( + config.get('callbacks', []) + + [self.session.langchain_logger] + + self.session.callbacks + ) self.validate_invoke(self.session) get_logger().info(f"Invoking query: '{query}'", self.session) return self.runnable.invoke(query, config) def async_invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]: - config = {'callbacks': [self.session.langchain_logger]} + config = config or {} + config['callbacks'] = ( + config.get('callbacks', []) + + [self.session.langchain_logger] + + self.session.callbacks + ) get_logger().info(f"Invoking async query: '{query}'", self.session) return self.runnable.ainvoke(query, config) diff --git a/flo_ai/state/flo_session.py b/flo_ai/state/flo_session.py index 82457fb0..a4a9f9a8 100644 --- a/flo_ai/state/flo_session.py +++ b/flo_ai/state/flo_session.py @@ -87,6 +87,9 @@ def register_callback( filter(lambda x: isinstance(x, FloToolCallback), self.callbacks) ) self.langchain_logger = FloLangchainLogger(self.session_id, tool_callbacks) + + if self.llm is not None: + self.llm = self.llm.bind(callbacks=[callback]) return self def append(self, node: str) -> int: diff --git a/flo_ai/storage/data_collector.py b/flo_ai/storage/data_collector.py new file mode 100644 index 00000000..79a6f053 --- /dev/null +++ b/flo_ai/storage/data_collector.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict +from pathlib import Path +import json + + +class DataCollector(ABC): + @abstractmethod + def store_entry(self, entry: Dict[str, Any]) -> None: + pass + + @abstractmethod + def close(self) -> None: + pass + + +class JSONLFileCollector(DataCollector): + def __init__(self, file_path: str): + self.file_path = Path(file_path) + self.file_path.parent.mkdir(parents=True, exist_ok=True) + + def store_entry(self, entry: Dict[str, Any]) -> None: + with open(self.file_path, 'a') as f: + json.dump(entry, f) + f.write('\n') + + def close(self) -> None: + pass