Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ bin
*.sql
*.log
*.yaml
examples/local/*
.logs
scratch_pad.py
examples/local/*
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ You can customize `DataCollector` implementation according to your database. A s
from flo_ai.callbacks import FloExecutionLogger
from flo_ai.storage.data_collector import JSONLFileCollector

# Initialize the file collector with a path for the JSONL log file
file_collector = JSONLFileCollector("./path/to/my_llm_logs.jsonl")
# Initialize the file collector with a path for the JSONL log file to be stored
file_collector = JSONLFileCollector("'.logs'")

# Create a tool logger with the collector
local_tracker = FloExecutionLogger(file_collector)
Expand Down Expand Up @@ -439,6 +439,17 @@ The structured logs provide valuable training data that can be used to:
- **Train new models** to replicate successful tool usage patterns
- **Create supervised datasets** for tool selection and chain optimization

We have created a script to convert your logs to training data:

```python
python generate_training_data.py --logger-path PATH --tool-path PATH [--output PATH]
```

Arguments:
- *logger-path*: Path to the logger file containing tool and chain entries, eg: .logs/logs/log.jsonl
- *tool-path*: Path to the tool descriptions file eg: eg: .logs/tools/tools.jsonl
- *output*: path to save the output eg: training-data.jsonl


## 📖 Documentation

Expand Down
46 changes: 46 additions & 0 deletions examples/python/json_training_data_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from flo_ai.callbacks import FloExecutionLogger
from flo_ai.storage.data_collector import JSONLFileCollector
from flo_ai import Flo, FloSession
from flo_ai.models.flo_agent import FloAgent
from langchain_community.tools.tavily_search.tool import TavilySearchResults
from dotenv import load_dotenv


load_dotenv()

file_collector = JSONLFileCollector('.logs')

# Create a tool logger with the collector
local_tracker = FloExecutionLogger(file_collector)
# Create the LLM object
llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini')


prompt = PromptTemplate.from_template('1 + {number} = ')

chain = prompt | llm
print(chain.invoke({'number': 2}))


session = FloSession(llm)
session.register_callback(local_tracker)

os.environ['TAVILY_API_KEY'] = os.getenv('TAVILY_API_KEY')
tavily_tool = TavilySearchResults()

session.register_tool('thappal', tavily_tool)

weather_agent = FloAgent.create(
session=session,
name='Blogger',
job='You can research the internet and create a blog about the topic given by the user',
tools=[tavily_tool],
)


agent_flo: Flo = Flo.create(session, weather_agent)

print(agent_flo.invoke('Whats the whether in New Delhi, India ?'))
30 changes: 15 additions & 15 deletions examples/python/output_parser.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import os
from flo_ai import FloAgent, FloSession, Flo
from flo_ai import FloLLMAgent, FloSession, Flo
from langchain_community.tools.tavily_search.tool import TavilySearchResults
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain_openai import ChatOpenAI
from flo_ai.parsers import FloJsonParser
from flo_ai.state import FloJsonOutputCollector
from flo_ai.callbacks import FloExecutionLogger
from flo_ai.storage.data_collector import JSONLFileCollector

load_dotenv()
llm = AzureChatOpenAI(
azure_endpoint=os.getenv('AZURE_GPT4_ENDPOINT'),
model_name='gpt-4o',
temperature=0.2,
max_tokens=4096,
api_version='2024-08-01-preview',
api_key=os.getenv('AZURE_OPEN_AI_API_KEY'),
)
llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini')

session = FloSession(llm).register_tool(
name='TavilySearchResults', tool=TavilySearchResults()
)


file_collector = JSONLFileCollector('.logs')

local_tracker = FloExecutionLogger(file_collector)

session.register_callback(local_tracker)

format = {
'name': 'NameFormat',
'fields': [
Expand Down Expand Up @@ -48,12 +49,11 @@

dc = FloJsonOutputCollector()

researcher = FloAgent.create(
researcher = FloLLMAgent.create(
session,
name='Researcher',
role='Internet Researcher',
name='Formatter',
role='Output formatter',
job='What is the first name, last name and middle name of the the person user asks about',
tools=[TavilySearchResults()],
parser=FloJsonParser.create(json_dict=format),
data_collector=dc,
)
Expand Down
96 changes: 87 additions & 9 deletions flo_ai/callbacks/flo_execution_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
from langchain_core.prompts.chat import ChatPromptValue
from flo_ai.storage.data_collector import DataCollector
from flo_ai.common.flo_logger import get_logger
from abc import ABC, abstractmethod


class ToolLogger(ABC):
@abstractmethod
def log_all_tools(session_tools):
pass


class EnhancedJSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -45,22 +52,38 @@ def default(self, obj):
return super().default(obj)


class FloExecutionLogger(BaseCallbackHandler):
class FloExecutionLogger(BaseCallbackHandler, ToolLogger):
def __init__(self, data_collector: DataCollector):
self.data_collector = data_collector
self.runs = {}
self.encoder = EnhancedJSONEncoder()
self.query = None
self.added_tools = set()
self.prompt = {}

def _encode_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
return json.loads(self.encoder.encode(entry))

def _store_entry(self, entry: Dict[str, Any]) -> None:
try:
encoded_entry = self._encode_entry(entry)
self.data_collector.store_entry(encoded_entry)
self.data_collector.store_log(encoded_entry)
except Exception as e:
get_logger().error(f'Error storing entry in FloExecutionLogger: {e}')

def on_llm_start(
self,
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.prompt[str(run_id)] = prompts

def on_chain_start(
self,
serialized: Dict[str, Any],
Expand All @@ -73,11 +96,27 @@ def on_chain_start(
chain_name = (
serialized.get('name', 'unnamed_chain') if serialized else 'unnamed_chain'
)

if parent_run_id and chain_name != 'agent_chain':
return
if isinstance(inputs, dict):
user_input = inputs.get('messages', {})
else:
user_input = {}
if (
user_input
and len(user_input) > 0
Comment thread
vizsatiz marked this conversation as resolved.
and isinstance(user_input[0], HumanMessage)
):
if isinstance(user_input[0], HumanMessage):
self.query = user_input[0].content

self.runs[str(run_id)] = {
'type': 'chain',
'start_time': datetime.utcnow(),
'inputs': inputs,
'name': chain_name,
'run_id': str(run_id),
'parent_run_id': str(parent_run_id) if parent_run_id else None,
}

Expand All @@ -91,11 +130,29 @@ def on_chain_end(
) -> None:
if str(run_id) in self.runs:
run_info = self.runs[str(run_id)]
if run_info['type'] != 'chain' and run_info['type'] != 'llm':
return
run_info['end_time'] = datetime.utcnow()
run_info['outputs'] = outputs
run_info['status'] = 'completed'
run_info['parent_run_id'] = str(parent_run_id) if parent_run_id else None
run_info['prompt'] = (
self.prompt[str(run_id)] if str(run_id) in self.prompt else []
)
self._store_entry(run_info)
del self.runs[str(run_id)]
else:
if isinstance(outputs, ChatPromptValue) or isinstance(outputs, AgentFinish):
run_info = {}
run_info['type'] = 'llm'
run_info['end_time'] = datetime.utcnow()
run_info['inputs'] = outputs
run_info['status'] = 'completed'
run_info['run_id'] = str(run_id)
run_info['parent_run_id'] = (
str(parent_run_id) if parent_run_id else None
)
self.runs[str(parent_run_id)] = run_info

def on_chain_error(
self,
Expand All @@ -105,13 +162,7 @@ def on_chain_error(
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
if str(run_id) in self.runs:
run_info = self.runs[str(run_id)]
run_info['end_time'] = datetime.utcnow()
run_info['error'] = str(error)
run_info['status'] = 'error'
self._store_entry(run_info)
del self.runs[str(run_id)]
pass

def on_tool_start(
self,
Expand All @@ -120,10 +171,14 @@ def on_tool_start(
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.runs[str(run_id)] = {
'type': 'tool',
'query': self.query,
'start_time': datetime.utcnow(),
'tool_name': serialized.get('name', 'unnamed_tool'),
'input': input_str,
Expand All @@ -136,6 +191,7 @@ def on_tool_end(
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
if str(run_id) in self.runs:
Expand Down Expand Up @@ -197,3 +253,25 @@ def on_agent_finish(
'parent_run_id': str(parent_run_id) if parent_run_id else None,
}
self._store_entry(log_entry)

def log_all_tools(self, session_tools):
try:
tools = []

for val in session_tools:
tool_name = session_tools[val].name
if tool_name not in self.added_tools:
tools.append(
{
'tool_name': tool_name,
'description': session_tools.get(val).description,
'args': session_tools.get(val).args,
}
)
self.added_tools.add(tool_name)

encoded_entry = self._encode_entry(tools)
if encoded_entry:
self.data_collector.store_tool_log(encoded_entry)
except Exception as e:
get_logger().error(f'Error storing tool in FloExecutionLogger: {e}')
19 changes: 7 additions & 12 deletions flo_ai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flo_ai.models.flo_node import FloNode
from flo_ai.models.flo_agent import FloAgent
from langchain.tools import StructuredTool
from flo_ai.callbacks.flo_execution_logger import ToolLogger


class Flo:
Expand All @@ -42,23 +43,17 @@ def async_stream(self, query, config=None) -> Iterator[Union[dict[str, Any], Any
return self.runnable.astream(query, config)

def invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]:
config = config or {}
config['callbacks'] = (
config.get('callbacks', [])
+ [self.session.langchain_logger]
+ self.session.callbacks
)
config = self.session.prepare_config(config)

for callback in self.session.callbacks:
if isinstance(callback, ToolLogger):
callback.log_all_tools(self.session.tools)

self.validate_invoke(self.session)
get_logger().info(f"Invoking query: '{query}'", self.session)
return self.runnable.invoke(query, config)

def async_invoke(self, query, config=None) -> Iterator[Union[dict[str, Any], Any]]:
config = config or {}
config['callbacks'] = (
config.get('callbacks', [])
+ [self.session.langchain_logger]
+ self.session.callbacks
)
get_logger().info(f"Invoking async query: '{query}'", self.session)
return self.runnable.ainvoke(query, config)

Expand Down
Loading