diff --git a/README.md b/README.md index d085e876a..519bf7880 100644 --- a/README.md +++ b/README.md @@ -34,25 +34,26 @@ ART's **LangGraph integration** enables you to train sophisticated ReAct-style a ```python import art -from art.langgraph import init_chat_model +from art.langgraph import init_chat_model, wrap_rollout from langgraph.prebuilt import create_react_agent async def email_rollout(model: art.Model, scenario: str) -> art.Trajectory: - traj = art.Trajectory(reward=0.0, messages_and_choices=[]) - # Create LangGraph agent with ART's chat model chat_model = init_chat_model(model.name) agent = create_react_agent(chat_model, tools) - + await agent.ainvoke({"messages": [("user", scenario)]}) - traj.reward = 1.0 # Score based on results - return traj + return art.Trajectory(reward=1.0, messages_and_choices=[]) # Train your agent scenarios = ["Find urgent emails", "Search Q4 budget"] -groups = await art.gather_trajectory_groups( - (art.TrajectoryGroup(email_rollout(model, s) for _ in range(4)) for s in scenarios) -) + +# Using wrap_rollout (captures interactions automatically) +groups = await art.gather_trajectory_groups([ + art.TrajectoryGroup(wrap_rollout(model, email_rollout)(model, s) for _ in range(4)) + for s in scenarios +]) + await model.train(groups) ``` diff --git a/docs/integrations/langgraph-integration.mdx b/docs/integrations/langgraph-integration.mdx index 4dad672f8..7911dbb95 100644 --- a/docs/integrations/langgraph-integration.mdx +++ b/docs/integrations/langgraph-integration.mdx @@ -33,419 +33,483 @@ LangGraph provides an excellent framework for building various types of agents - - **Multi-step trajectory support**: Handles complex agent workflows with tool calls and reasoning steps - **RULER compatibility**: Use ART's general-purpose reward function to train agents without hand-crafted rewards -## Basic Usage +## Code Examples -Here's how to integrate ART with your LangGraph agent: +Here are easily readable code snippets demonstrating the LangGraph integration functionality: + +### Basic Setup and Initialization ```python import uuid +import weave from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool from langgraph.prebuilt import create_react_agent -from pydantic import BaseModel from art.langgraph import init_chat_model import art -# Custom trajectory class to store additional data -class EmailTrajectory(art.Trajectory): - final_answer: str | None = None - found_emails: list[str] = [] +# Initialize Weave tracking (optional) +if os.getenv("WANDB_API_KEY", ""): + weave.init(model.project, settings={"print_call_link": False}) +``` -# Scenario data class -class EmailScenario(BaseModel): - step: int - query: str - user_email: str +### Defining Tools for Your Agent -# Rollout function that integrates with LangGraph -async def email_rollout(model: art.Model, scenario: EmailScenario) -> EmailTrajectory: - traj = EmailTrajectory( - reward=0.0, - messages_and_choices=[], - metadata={ - "step": scenario.step, - "query": scenario.query, - }, +```python +@tool +def search_inbox_tool(keywords: list[str]) -> list[dict]: + """Search the inbox for emails matching the given keywords and return + a list of dictionaries so the LLM can easily consume them.""" + results = search_emails( + inbox=scenario.inbox_address, + keywords=keywords, + sent_before=scenario.query_date, ) + return [asdict(result) for result in results] + +@tool +def read_email_tool(message_id: str) -> dict | None: + """Read a specific email by message ID.""" + email = read_email(message_id) + if email: + return email.model_dump() + return None + +@tool +def return_final_answer_tool(answer: str, reference_message_ids: list[str]) -> dict: + """Return the final answer and the message IDs used to generate the answer.""" + nonlocal final_answer + final_answer = FinalAnswer(answer=answer, source_ids=reference_message_ids) + return final_answer.model_dump() +``` - # Define tools inside the rollout to access local variables - @tool - def search_emails(keywords: list[str]) -> list[dict]: - """Search emails by keywords.""" - # Your email search implementation here - results = [ - {"id": "email_1", "subject": f"Found email about {keywords[0]}", "from": "john@company.com"}, - {"id": "email_2", "subject": f"Another email about {keywords[0]}", "from": "sarah@company.com"}, - ] - traj.found_emails.extend([r["id"] for r in results]) - return results - - @tool - def read_email(email_id: str) -> dict: - """Read a specific email.""" - # Your email reading implementation here - return { - "id": email_id, - "content": f"Email content for {email_id}", - "timestamp": "2024-01-15" - } - - @tool - def provide_answer(answer: str) -> str: - """Provide the final answer.""" - traj.final_answer = answer - # Score the trajectory based on quality (simple example) - traj.reward = 1.0 if len(answer) > 10 else 0.0 - return f"Final answer: {answer}" - - # Create LangGraph tools and agent - tools = [search_emails, read_email, provide_answer] - chat_model = init_chat_model(model.name, temperature=0.7) - agent = create_react_agent(chat_model, tools) - - # Run the agent - try: - config = { - "configurable": {"thread_id": str(uuid.uuid4())}, - "recursion_limit": 20, - } - - await agent.ainvoke( - { - "messages": [ - SystemMessage(content=f"You are an email assistant for {scenario.user_email}. Help find and read emails."), - HumanMessage(content=scenario.query), - ] - }, - config=config, - ) - - except Exception as e: - print(f"Error in agent execution: {e}") - traj.reward = 0.0 - - return traj - -async def train_email_agent(): - from art.local import LocalBackend +### Creating and Running a LangGraph ReAct Agent - backend = LocalBackend() +```python +@weave.op +async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory: + # Initialize chat model with temperature + chat_model = init_chat_model(model.name, temperature=1.0) - # Create your trainable model - model = art.TrainableModel( - name="email-agent-langgraph", - project="email-search-agent", - base_model="Qwen/Qwen2.5-7B-Instruct", - ) + # Define available tools + tools = [search_inbox_tool, read_email_tool, return_final_answer_tool] - await model.register(backend) + # Create the LangGraph ReAct agent + react_agent = create_react_agent(chat_model, tools) - # Create training scenarios - scenarios = [ - EmailScenario(step=0, query="Find emails from John about quarterly reports", user_email="user@company.com"), - EmailScenario(step=0, query="Search for budget discussions from last week", user_email="user@company.com"), - EmailScenario(step=0, query="Find the latest email from Sarah", user_email="user@company.com"), - ] + # Configure agent execution + config = { + "configurable": {"thread_id": str(uuid.uuid4())}, + "recursion_limit": MAX_TURNS, + } - # Generate trajectory groups - trajectory_groups = await art.gather_trajectory_groups( - ( - art.TrajectoryGroup( - email_rollout(model, scenario) - for _ in range(4) # Multiple trajectories per scenario - ) - for scenario in scenarios - ), - pbar_desc="gather" + # Run the agent with system and user messages + await react_agent.ainvoke( + { + "messages": [ + SystemMessage(content=system_prompt), + HumanMessage(content=scenario.question), + ] + }, + config=config, ) - - # Train the model - await model.train(trajectory_groups, config=art.TrainConfig(learning_rate=1e-5)) - -if __name__ == "__main__": - import asyncio - asyncio.run(train_email_agent()) ``` -## How It Works - -The ART-LangGraph integration works through these key components: - -### 1. LLM Wrapper (`init_chat_model`) - -ART's `init_chat_model` automatically captures all LLM interactions for training: +### Trajectory Tracking and Scoring ```python -from art.langgraph import init_chat_model - -# Create a chat model that logs to ART -chat_model = init_chat_model(model.name, temperature=0.7) +class ProjectTrajectory(art.Trajectory): + final_answer: FinalAnswer | None = None + +# Create trajectory with metadata +traj = ProjectTrajectory( + reward=0.0, + messages_and_choices=[], + metadata={ + "scenario_id": scenario.id, + "step": email_scenario.step, + }, +) -# Use with LangGraph as normal -agent = create_react_agent(chat_model, tools) +# Score the trajectory using correctness judge +if final_answer: + traj.final_answer = final_answer + correctness_judge_response = await judge_correctness( + scenario, traj.final_answer.answer + ) + traj.metrics["correct"] = float(correctness_judge_response.accept) ``` -**What gets captured:** - -- All chat completions and their responses -- Tool calls and tool execution results -- Token usage and timing information -- Error states and exceptions - -### 2. Custom Trajectory Classes - -Extend `art.Trajectory` to store domain-specific data alongside the conversation: +### Training Loop with LangGraph Integration ```python -class MyTrajectory(art.Trajectory): - final_answer: str | None = None - tools_used: list[str] = [] - custom_metrics: dict = {} -``` +from art.langgraph import wrap_rollout + +# Training configuration +training_config = { + "groups_per_step": 2, + "num_epochs": 20, + "rollouts_per_group": 4, + "learning_rate": 1e-5, + "max_steps": 20, +} + +# Create trajectory groups for training +for batch in training_iterator: + groups = [] + for scenario in batch.items: + groups.append( + art.TrajectoryGroup( + ( + wrap_rollout(model, rollout)( + model, EmailScenario(step=batch.step, scenario=scenario) + ) + for _ in range(training_config["rollouts_per_group"]) + ) + ) + ) -This allows you to: + # Gather trajectory groups + finished_groups = await art.gather_trajectory_groups( + groups, + pbar_desc="gather", + max_exceptions=training_config["rollouts_per_group"] * len(batch.items), + ) -- Store the agent's final outputs -- Track custom metrics during execution -- Score trajectories based on your specific criteria + # Apply RULER scoring + judged_groups = [] + for group in finished_groups: + judged_group = await ruler_score_group(group, "openai/o4-mini", debug=True) + judged_groups.append(judged_group) -### 3. Rollout Functions + # Train the model + await model.train( + judged_groups, + config=art.TrainConfig(learning_rate=training_config["learning_rate"]), + _config={"logprob_calculation_chunk_size": 8}, + ) +``` -Instead of decorators, define rollout functions that create and execute your LangGraph agent: +### Correctness Evaluation ```python -async def my_rollout(model: art.Model, scenario: MyScenario) -> MyTrajectory: - # Create custom trajectory - traj = MyTrajectory(reward=0.0, messages_and_choices=[]) +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt - # Define tools with access to trajectory - @tool - def my_tool(arg: str) -> str: - traj.tools_used.append("my_tool") # Track usage - return "result" +class CorrectnessJudgeResponse(BaseModel): + reasoning: str = Field(description="Explanation of the reasoning process.") + accept: bool = Field(description="Whether the AI answer should be accepted.") - # Create and run agent - chat_model = init_chat_model(model.name) - agent = create_react_agent(chat_model, [my_tool]) +@retry(stop=stop_after_attempt(3)) +async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudgeResponse: + system_prompt = """ + You are given a question, the reference answer, and an answer generated by an AI assistant. + Your task is to decide whether the AI answer is correct and should be accepted. + """ - result = await agent.ainvoke({"messages": [...]}) + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": ( + f"Question: {scenario.question}\n" + f"Reference answer: {scenario.answer}\n" + f"AI answer: {answer}" + ), + }, + ] - # Score the trajectory based on results - traj.reward = calculate_reward(result) + response = await acompletion( + model="openai/gpt-4.1", + messages=messages, + response_format=CorrectnessJudgeResponse, + ) - return traj + return CorrectnessJudgeResponse.model_validate_json( + response.choices[0].message.content or "{}" + ) ``` -### 4. Training Loop Integration - -Use your rollout functions directly with ART's training loop: - -```python -# Generate trajectory groups -trajectory_groups = await art.gather_trajectory_groups( - ( - art.TrajectoryGroup( - my_rollout(model, scenario) - for _ in range(trajectories_per_scenario) - ) - for scenario in scenarios - ), - pbar_desc="gather" -) +### Key Components Summary -# Train the model -await model.train(trajectory_groups) -``` +1. **LangGraph ReAct Agent**: Uses `create_react_agent()` with custom tools and chat model +2. **Tool Definition**: Custom tools decorated with `@tool` for specific functionality +3. **Trajectory Tracking**: Custom trajectory class extends `art.Trajectory` +4. **Training Integration**: Uses `wrap_rollout()` and `art.gather_trajectory_groups()` +5. **Evaluation**: Automated correctness judging with retry logic +6. **Configuration**: Flexible training parameters and agent limits -## Advanced Example: Email Search Agent with Custom Judging +## Complete Email Agent Example -Here's a complete example that shows custom trajectory scoring, similar to production usage: +Here's a complete, runnable example that demonstrates training a LangGraph email search agent: ```python +import asyncio import uuid -from pydantic import BaseModel, Field +from dataclasses import asdict +from textwrap import dedent +from typing import List + +import art +import weave from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool from langgraph.prebuilt import create_react_agent from litellm import acompletion +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt -from art.langgraph import init_chat_model -import art -# Custom response format for judging -class CorrectnessJudge(BaseModel): - reasoning: str = Field(description="Explanation of the reasoning process.") - score: float = Field(description="Score from 0.0 to 1.0 for answer quality.") +from art.langgraph import init_chat_model, wrap_rollout +from art.utils import iterate_dataset + +# Initialize model and backend +model = art.Model(name="Qwen/Qwen2.5-7B-Instruct") +backend = art.backends.SkyPilotBackend() + +# Data models +class EmailResult(BaseModel): + message_id: str + subject: str + from_address: str + date: str + snippet: str + +class FinalAnswer(BaseModel): + answer: str + source_ids: List[str] -# Enhanced trajectory with custom fields -class EmailTrajectory(art.Trajectory): - final_answer: str | None = None - emails_searched: list[str] = [] - emails_read: list[str] = [] +class Scenario(BaseModel): + id: str + question: str + answer: str + inbox_address: str + query_date: str -# Scenario with ground truth for comparison class EmailScenario(BaseModel): step: int - query: str - user_email: str - expected_answer: str # Ground truth for judging + scenario: Scenario -# Custom judge function -@retry(stop=stop_after_attempt(3)) -async def judge_email_answer(scenario: EmailScenario, answer: str) -> CorrectnessJudge: - """Judge the quality of an email search answer.""" - system_prompt = """ - You are evaluating an AI assistant's answer to an email search query. - Consider accuracy, completeness, and helpfulness. - Score from 0.0 (completely wrong) to 1.0 (perfect answer). - """ +class ProjectTrajectory(art.Trajectory): + final_answer: FinalAnswer | None = None +class CorrectnessJudgeResponse(BaseModel): + reasoning: str = Field(description="Explanation of the reasoning process.") + accept: bool = Field(description="Whether the AI answer should be accepted.") + +# Mock email functions (replace with real implementation) +def search_emails(inbox: str, keywords: List[str], sent_before: str) -> List[EmailResult]: + """Mock email search function - replace with real implementation""" + return [ + EmailResult( + message_id="msg_123", + subject=f"Subject matching {keywords[0]}", + from_address="sender@example.com", + date="2024-01-15", + snippet=f"Email snippet containing {keywords[0]}" + ) + ] + +def read_email(message_id: str) -> EmailResult | None: + """Mock email read function - replace with real implementation""" + return EmailResult( + message_id=message_id, + subject="Full email subject", + from_address="sender@example.com", + date="2024-01-15", + snippet="Full email content here..." + ) + +# Correctness evaluation +@retry(stop=stop_after_attempt(3)) +async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudgeResponse: + system_prompt = dedent(""" + You are given a question, the reference answer, and an answer generated by an AI assistant. + Your task is to decide whether the AI answer is correct and should be accepted. + """) + messages = [ {"role": "system", "content": system_prompt}, { "role": "user", - "content": f"Query: {scenario.query}\nExpected: {scenario.expected_answer}\nAI Answer: {answer}" + "content": ( + f"Question: {scenario.question}\n" + f"Reference answer: {scenario.answer}\n" + f"AI answer: {answer}" + ), }, ] - + response = await acompletion( model="openai/gpt-4o-mini", messages=messages, - response_format=CorrectnessJudge, + response_format=CorrectnessJudgeResponse, + ) + + return CorrectnessJudgeResponse.model_validate_json( + response.choices[0].message.content or "{}" ) - - raw_content = response.choices[0].message.content or "{}" - try: - return CorrectnessJudge.model_validate_json(raw_content) - except Exception as e: - return CorrectnessJudge(reasoning=f"Parse error: {e}", score=0.0) # Main rollout function -async def email_rollout(model: art.Model, scenario: EmailScenario) -> EmailTrajectory: - traj = EmailTrajectory( +@weave.op +async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory: + scenario = email_scenario.scenario + MAX_TURNS = 10 + + traj = ProjectTrajectory( reward=0.0, messages_and_choices=[], metadata={ - "step": scenario.step, - "query": scenario.query, + "scenario_id": scenario.id, + "step": email_scenario.step, }, ) - - # Tools with trajectory tracking + + system_prompt = dedent(f""" + You are an email search agent. Use the tools to search emails and find answers. + User's email address: {scenario.inbox_address} + Today's date: {scenario.query_date} + + When you find the answer, use return_final_answer_tool with the answer and source message IDs. + """) + + final_answer = None + @tool - def search_emails(keywords: list[str]) -> list[dict]: - """Search emails by keywords.""" - # Simulate real email search - results = [ - {"id": f"email_{i}", "subject": f"Email about {kw}", "from": "colleague@company.com"} - for i, kw in enumerate(keywords[:3]) - ] - traj.emails_searched.extend([r["id"] for r in results]) - return results - + def search_inbox_tool(keywords: List[str]) -> List[dict]: + """Search inbox for emails matching keywords""" + results = search_emails(scenario.inbox_address, keywords, scenario.query_date) + return [asdict(result) for result in results] + @tool - def read_email(email_id: str) -> dict: - """Read a specific email.""" - traj.emails_read.append(email_id) - return { - "id": email_id, - "content": f"Content of {email_id}: Important project information...", - "timestamp": "2024-01-15T10:30:00Z" - } - + def read_email_tool(message_id: str) -> dict | None: + """Read a specific email by message ID""" + email = read_email(message_id) + return email.model_dump() if email else None + @tool - def provide_final_answer(answer: str) -> str: - """Provide the final answer to the user's query.""" - traj.final_answer = answer - return f"Final answer: {answer}" - - # Create and run agent - tools = [search_emails, read_email, provide_final_answer] - chat_model = init_chat_model(model.name, temperature=0.7) - agent = create_react_agent(chat_model, tools) - + def return_final_answer_tool(answer: str, reference_message_ids: List[str]) -> dict: + """Return final answer with source message IDs""" + nonlocal final_answer + final_answer = FinalAnswer(answer=answer, source_ids=reference_message_ids) + return final_answer.model_dump() + + tools = [search_inbox_tool, read_email_tool, return_final_answer_tool] + chat_model = init_chat_model(model.name, temperature=1.0) + react_agent = create_react_agent(chat_model, tools) + try: config = { "configurable": {"thread_id": str(uuid.uuid4())}, - "recursion_limit": 20, + "recursion_limit": MAX_TURNS, } - - await agent.ainvoke( - { - "messages": [ - SystemMessage(content=f"You are an email assistant for {scenario.user_email}. Search and read emails to answer the query."), - HumanMessage(content=scenario.query), - ] - }, - config=config, - ) - - # Score the trajectory if we got a final answer - if traj.final_answer: - judge_result = await judge_email_answer(scenario, traj.final_answer) - traj.reward = judge_result.score - traj.metrics["judge_reasoning"] = judge_result.reasoning - else: - traj.reward = 0.0 # No answer provided - + + await react_agent.ainvoke({ + "messages": [ + SystemMessage(content=system_prompt), + HumanMessage(content=scenario.question), + ] + }, config=config) + + if final_answer: + traj.final_answer = final_answer + correctness_judge_response = await judge_correctness(scenario, final_answer.answer) + traj.metrics["correct"] = float(correctness_judge_response.accept) + except Exception as e: - print(f"Error in email agent: {e}") - traj.reward = 0.0 - + print(f"Error running agent: {e}") + traj.messages_and_choices.append({"role": "assistant", "content": f"Error: {str(e)}"}) + return traj -async def train_advanced_email_agent(): - from art.local import LocalBackend - - backend = LocalBackend() - - model = art.TrainableModel( - name="advanced-email-agent", - project="email-agents", - base_model="Qwen/Qwen2.5-7B-Instruct", - ) - - await model.register(backend) - - # Training scenarios with expected answers - scenarios = [ - EmailScenario( - step=0, - query="Find emails from the finance team about Q4 budget", - user_email="manager@company.com", - expected_answer="Q4 budget is $2.5M with 15% increase from Q3..." +# Main training function +async def main(): + # Sample training scenarios (replace with real data) + training_scenarios = [ + Scenario( + id="1", + question="Find emails about the quarterly budget", + answer="Budget meeting scheduled for Q4 review", + inbox_address="user@company.com", + query_date="2024-01-20" ), - EmailScenario( - step=0, - query="Search for urgent emails from CEO this week", - user_email="manager@company.com", - expected_answer="CEO sent urgent email about board meeting on Friday..." + Scenario( + id="2", + question="Look for urgent project updates", + answer="Project deadline moved to next month", + inbox_address="user@company.com", + query_date="2024-01-20" ), ] - - # Generate trajectory groups - trajectory_groups = await art.gather_trajectory_groups( - ( - art.TrajectoryGroup( - email_rollout(model, scenario) - for _ in range(4) # Multiple attempts per scenario - ) - for scenario in scenarios - ), - pbar_desc="email agent training" - ) - - # Train with custom learning rate - await model.train( - trajectory_groups, - config=art.TrainConfig(learning_rate=5e-6) + + # Register model with backend + await model.register(backend) + + # Training configuration + training_config = { + "groups_per_step": 2, + "num_epochs": 3, + "rollouts_per_group": 4, + "learning_rate": 1e-5, + "max_steps": 5, + } + + # Training iterator + training_iterator = iterate_dataset( + training_scenarios, + groups_per_step=training_config["groups_per_step"], + num_epochs=training_config["num_epochs"], + initial_step=await model.get_step(), ) + + # Training loop + for batch in training_iterator: + print(f"Training step {batch.step}, epoch {batch.epoch}") + + # Create trajectory groups + groups = [] + for scenario in batch.items: + groups.append( + art.TrajectoryGroup([ + wrap_rollout(model, rollout)( + model, EmailScenario(step=batch.step, scenario=scenario) + ) + for _ in range(training_config["rollouts_per_group"]) + ]) + ) + + # Gather trajectories + finished_groups = await art.gather_trajectory_groups( + groups, + pbar_desc="gather", + max_exceptions=training_config["rollouts_per_group"] * len(batch.items), + ) + + # Train model + await model.train( + finished_groups, + config=art.TrainConfig(learning_rate=training_config["learning_rate"]), + ) + + print(f"Completed training step {batch.step}") + + if batch.step >= training_config["max_steps"]: + break if __name__ == "__main__": - import asyncio - asyncio.run(train_advanced_email_agent()) + asyncio.run(main()) ``` +This complete example shows how to: + +1. **Set up the environment** with model, backend, and data structures +2. **Define custom tools** for email search and retrieval +3. **Create a LangGraph ReAct agent** with proper configuration +4. **Implement trajectory tracking** with custom reward scoring +5. **Run the full training loop** with proper error handling +6. **Use wrap_rollout** to automatically capture agent interactions + +To use this example, simply replace the mock email functions (`search_emails`, `read_email`) with your actual email API integration, and provide real training scenarios in the `training_scenarios` list. + ## Troubleshooting ### Common Issues