From 86bcd13a37adbfe96b3fc69c77502ca42358cc91 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Thu, 28 Aug 2025 17:44:45 -0700 Subject: [PATCH] Add RULER scoring to LG integration example --- docs/integrations/langgraph-integration.mdx | 66 +++++++++++---------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/docs/integrations/langgraph-integration.mdx b/docs/integrations/langgraph-integration.mdx index 7911dbb95..79c661f93 100644 --- a/docs/integrations/langgraph-integration.mdx +++ b/docs/integrations/langgraph-integration.mdx @@ -137,7 +137,7 @@ if final_answer: correctness_judge_response = await judge_correctness( scenario, traj.final_answer.answer ) - traj.metrics["correct"] = float(correctness_judge_response.accept) + traj.metrics["correct"] = correctness_judge_response.accept ``` ### Training Loop with LangGraph Integration @@ -179,7 +179,7 @@ for batch in training_iterator: # Apply RULER scoring judged_groups = [] for group in finished_groups: - judged_group = await ruler_score_group(group, "openai/o4-mini", debug=True) + judged_group = await ruler_score_group(group, "openai/o4-mini") judged_groups.append(judged_group) # Train the model @@ -326,7 +326,7 @@ async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudge You are given a question, the reference answer, and an answer generated by an AI assistant. Your task is to decide whether the AI answer is correct and should be accepted. """) - + messages = [ {"role": "system", "content": system_prompt}, { @@ -338,13 +338,13 @@ async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudge ), }, ] - + response = await acompletion( model="openai/gpt-4o-mini", messages=messages, response_format=CorrectnessJudgeResponse, ) - + return CorrectnessJudgeResponse.model_validate_json( response.choices[0].message.content or "{}" ) @@ -354,7 +354,7 @@ async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudge async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory: scenario = email_scenario.scenario MAX_TURNS = 10 - + traj = ProjectTrajectory( reward=0.0, messages_and_choices=[], @@ -363,62 +363,62 @@ async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTra "step": email_scenario.step, }, ) - + system_prompt = dedent(f""" You are an email search agent. Use the tools to search emails and find answers. User's email address: {scenario.inbox_address} Today's date: {scenario.query_date} - + When you find the answer, use return_final_answer_tool with the answer and source message IDs. """) - + final_answer = None - + @tool def search_inbox_tool(keywords: List[str]) -> List[dict]: """Search inbox for emails matching keywords""" results = search_emails(scenario.inbox_address, keywords, scenario.query_date) return [asdict(result) for result in results] - + @tool def read_email_tool(message_id: str) -> dict | None: """Read a specific email by message ID""" email = read_email(message_id) return email.model_dump() if email else None - + @tool def return_final_answer_tool(answer: str, reference_message_ids: List[str]) -> dict: """Return final answer with source message IDs""" nonlocal final_answer final_answer = FinalAnswer(answer=answer, source_ids=reference_message_ids) return final_answer.model_dump() - + tools = [search_inbox_tool, read_email_tool, return_final_answer_tool] chat_model = init_chat_model(model.name, temperature=1.0) react_agent = create_react_agent(chat_model, tools) - + try: config = { "configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": MAX_TURNS, } - + await react_agent.ainvoke({ "messages": [ SystemMessage(content=system_prompt), HumanMessage(content=scenario.question), ] }, config=config) - + if final_answer: traj.final_answer = final_answer correctness_judge_response = await judge_correctness(scenario, final_answer.answer) traj.metrics["correct"] = float(correctness_judge_response.accept) - + except Exception as e: print(f"Error running agent: {e}") traj.messages_and_choices.append({"role": "assistant", "content": f"Error: {str(e)}"}) - + return traj # Main training function @@ -433,17 +433,17 @@ async def main(): query_date="2024-01-20" ), Scenario( - id="2", + id="2", question="Look for urgent project updates", answer="Project deadline moved to next month", inbox_address="user@company.com", query_date="2024-01-20" ), ] - + # Register model with backend await model.register(backend) - + # Training configuration training_config = { "groups_per_step": 2, @@ -452,7 +452,7 @@ async def main(): "learning_rate": 1e-5, "max_steps": 5, } - + # Training iterator training_iterator = iterate_dataset( training_scenarios, @@ -460,11 +460,11 @@ async def main(): num_epochs=training_config["num_epochs"], initial_step=await model.get_step(), ) - + # Training loop for batch in training_iterator: print(f"Training step {batch.step}, epoch {batch.epoch}") - + # Create trajectory groups groups = [] for scenario in batch.items: @@ -476,22 +476,28 @@ async def main(): for _ in range(training_config["rollouts_per_group"]) ]) ) - + # Gather trajectories finished_groups = await art.gather_trajectory_groups( groups, pbar_desc="gather", max_exceptions=training_config["rollouts_per_group"] * len(batch.items), ) - + + # Apply RULER scoring + judged_groups = [] + for group in finished_groups: + judged_group = await ruler_score_group(group, "openai/o4-mini") + judged_groups.append(judged_group) + # Train model await model.train( - finished_groups, + judged_groups, config=art.TrainConfig(learning_rate=training_config["learning_rate"]), ) - + print(f"Completed training step {batch.step}") - + if batch.step >= training_config["max_steps"]: break @@ -503,7 +509,7 @@ This complete example shows how to: 1. **Set up the environment** with model, backend, and data structures 2. **Define custom tools** for email search and retrieval -3. **Create a LangGraph ReAct agent** with proper configuration +3. **Create a LangGraph ReAct agent** with proper configuration 4. **Implement trajectory tracking** with custom reward scoring 5. **Run the full training loop** with proper error handling 6. **Use wrap_rollout** to automatically capture agent interactions