From aa0032565c6d2137597af0a4be5bf6f0d28b81be Mon Sep 17 00:00:00 2001 From: rcholic Date: Fri, 2 Jan 2026 11:52:34 -0800 Subject: [PATCH 1/4] phase 2.2 and 2.3 done --- sentience/agent.py | 242 ++++++------------------------- sentience/browser_evaluator.py | 5 +- sentience/cloud_tracing.py | 1 - sentience/element_filter.py | 135 +++++++++++++++++ sentience/sentience_methods.py | 1 - sentience/snapshot.py | 4 +- sentience/trace_event_builder.py | 97 +++++++++++++ 7 files changed, 282 insertions(+), 203 deletions(-) create mode 100644 sentience/element_filter.py create mode 100644 sentience/trace_event_builder.py diff --git a/sentience/agent.py b/sentience/agent.py index fec23d5..585ab48 100644 --- a/sentience/agent.py +++ b/sentience/agent.py @@ -13,6 +13,7 @@ from .agent_config import AgentConfig from .base_agent import BaseAgent, BaseAgentAsync from .browser import AsyncSentienceBrowser, SentienceBrowser +from .element_filter import ElementFilter from .llm_provider import LLMProvider, LLMResponse from .models import ( ActionHistory, @@ -25,6 +26,7 @@ TokenStats, ) from .snapshot import snapshot, snapshot_async +from .trace_event_builder import TraceEventBuilder if TYPE_CHECKING: from .tracing import Tracer @@ -100,9 +102,7 @@ def _compute_hash(self, text: str) -> str: """Compute SHA256 hash of text.""" return hashlib.sha256(text.encode("utf-8")).hexdigest() - def _get_element_bbox( - self, element_id: int | None, snap: Snapshot - ) -> dict[str, float] | None: + def _get_element_bbox(self, element_id: int | None, snap: Snapshot) -> dict[str, float] | None: """Get bounding box for an element from snapshot.""" if element_id is None: return None @@ -200,17 +200,8 @@ def act( # noqa: C901 # Emit snapshot trace event if tracer is enabled if self.tracer: - # Include ALL elements with full data for DOM tree display - # Use snap.elements (all elements) not filtered_elements - elements_data = [el.model_dump() for el in snap.elements] - # Build snapshot event data - snapshot_data = { - "url": snap.url, - "element_count": len(snap.elements), - "timestamp": snap.timestamp, - "elements": elements_data, # Full element data for DOM tree - } + snapshot_data = TraceEventBuilder.build_snapshot_event(snap) # Always include screenshot in trace event for studio viewer compatibility # CloudTraceSink will extract and upload screenshots separately, then remove @@ -425,23 +416,18 @@ def act( # noqa: C901 } # Build complete step_end event - step_end_data = { - "v": 1, - "step_id": step_id, - "step_index": self._step_count, - "goal": goal, - "attempt": attempt, - "pre": { - "url": pre_url, - "snapshot_digest": snapshot_digest, - }, - "llm": llm_data, - "exec": exec_data, - "post": { - "url": post_url, - }, - "verify": verify_data, - } + step_end_data = TraceEventBuilder.build_step_end_event( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=attempt, + pre_url=pre_url, + post_url=post_url, + snapshot_digest=snapshot_digest, + llm_data=llm_data, + exec_data=exec_data, + verify_data=verify_data, + ) self.tracer.emit("step_end", step_end_data, step_id=step_id) @@ -723,8 +709,8 @@ def filter_elements(self, snapshot: Snapshot, goal: str | None = None) -> list[E """ Filter elements from snapshot based on goal context. - This default implementation applies goal-based keyword matching to boost - relevant elements and filters out irrelevant ones. + This implementation uses ElementFilter to apply goal-based keyword matching + to boost relevant elements and filters out irrelevant ones. Args: snapshot: Current page snapshot @@ -733,76 +719,7 @@ def filter_elements(self, snapshot: Snapshot, goal: str | None = None) -> list[E Returns: Filtered list of elements """ - elements = snapshot.elements - - # If no goal provided, return all elements (up to limit) - if not goal: - return elements[: self.default_snapshot_limit] - - goal_lower = goal.lower() - - # Extract keywords from goal - keywords = self._extract_keywords(goal_lower) - - # Boost elements matching goal keywords - scored_elements = [] - for el in elements: - score = el.importance - - # Boost if element text matches goal - if el.text and any(kw in el.text.lower() for kw in keywords): - score += 0.3 - - # Boost if role matches goal intent - if "click" in goal_lower and el.visual_cues.is_clickable: - score += 0.2 - if "type" in goal_lower and el.role in ["textbox", "searchbox"]: - score += 0.2 - if "search" in goal_lower: - # Filter out non-interactive elements for search tasks - if el.role in ["link", "img"] and not el.visual_cues.is_primary: - score -= 0.5 - - scored_elements.append((score, el)) - - # Re-sort by boosted score - scored_elements.sort(key=lambda x: x[0], reverse=True) - elements = [el for _, el in scored_elements] - - return elements[: self.default_snapshot_limit] - - def _extract_keywords(self, text: str) -> list[str]: - """ - Extract meaningful keywords from goal text - - Args: - text: Text to extract keywords from - - Returns: - List of keywords - """ - stopwords = { - "the", - "a", - "an", - "and", - "or", - "but", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "from", - "as", - "is", - "was", - } - words = text.split() - return [w for w in words if w not in stopwords and len(w) > 2] + return ElementFilter.filter_by_goal(snapshot, goal, self.default_snapshot_limit) class SentienceAgentAsync(BaseAgentAsync): @@ -874,9 +791,7 @@ def _compute_hash(self, text: str) -> str: """Compute SHA256 hash of text.""" return hashlib.sha256(text.encode("utf-8")).hexdigest() - def _get_element_bbox( - self, element_id: int | None, snap: Snapshot - ) -> dict[str, float] | None: + def _get_element_bbox(self, element_id: int | None, snap: Snapshot) -> dict[str, float] | None: """Get bounding box for an element from snapshot.""" if element_id is None: return None @@ -974,17 +889,8 @@ async def act( # noqa: C901 # Emit snapshot trace event if tracer is enabled if self.tracer: - # Include ALL elements with full data for DOM tree display - # Use snap.elements (all elements) not filtered_elements - elements_data = [el.model_dump() for el in snap.elements] - # Build snapshot event data - snapshot_data = { - "url": snap.url, - "element_count": len(snap.elements), - "timestamp": snap.timestamp, - "elements": elements_data, # Full element data for DOM tree - } + snapshot_data = TraceEventBuilder.build_snapshot_event(snap) # Always include screenshot in trace event for studio viewer compatibility # CloudTraceSink will extract and upload screenshots separately, then remove @@ -1199,23 +1105,18 @@ async def act( # noqa: C901 } # Build complete step_end event - step_end_data = { - "v": 1, - "step_id": step_id, - "step_index": self._step_count, - "goal": goal, - "attempt": attempt, - "pre": { - "url": pre_url, - "snapshot_digest": snapshot_digest, - }, - "llm": llm_data, - "exec": exec_data, - "post": { - "url": post_url, - }, - "verify": verify_data, - } + step_end_data = TraceEventBuilder.build_step_end_event( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=attempt, + pre_url=pre_url, + post_url=post_url, + snapshot_digest=snapshot_digest, + llm_data=llm_data, + exec_data=exec_data, + verify_data=verify_data, + ) self.tracer.emit("step_end", step_end_data, step_id=step_id) @@ -1447,66 +1348,17 @@ def clear_history(self) -> None: } def filter_elements(self, snapshot: Snapshot, goal: str | None = None) -> list[Element]: - """Filter elements from snapshot based on goal context (same as sync version)""" - elements = snapshot.elements - - # If no goal provided, return all elements (up to limit) - if not goal: - return elements[: self.default_snapshot_limit] - - goal_lower = goal.lower() - - # Extract keywords from goal - keywords = self._extract_keywords(goal_lower) - - # Boost elements matching goal keywords - scored_elements = [] - for el in elements: - score = el.importance - - # Boost if element text matches goal - if el.text and any(kw in el.text.lower() for kw in keywords): - score += 0.3 - - # Boost if role matches goal intent - if "click" in goal_lower and el.visual_cues.is_clickable: - score += 0.2 - if "type" in goal_lower and el.role in ["textbox", "searchbox"]: - score += 0.2 - if "search" in goal_lower: - # Filter out non-interactive elements for search tasks - if el.role in ["link", "img"] and not el.visual_cues.is_primary: - score -= 0.5 - - scored_elements.append((score, el)) - - # Re-sort by boosted score - scored_elements.sort(key=lambda x: x[0], reverse=True) - elements = [el for _, el in scored_elements] - - return elements[: self.default_snapshot_limit] - - def _extract_keywords(self, text: str) -> list[str]: - """Extract meaningful keywords from goal text (same as sync version)""" - stopwords = { - "the", - "a", - "an", - "and", - "or", - "but", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "from", - "as", - "is", - "was", - } - words = text.split() - return [w for w in words if w not in stopwords and len(w) > 2] + """ + Filter elements from snapshot based on goal context. + + This implementation uses ElementFilter to apply goal-based keyword matching + to boost relevant elements and filters out irrelevant ones. + + Args: + snapshot: Current page snapshot + goal: User's goal (can inform filtering) + + Returns: + Filtered list of elements + """ + return ElementFilter.filter_by_goal(snapshot, goal, self.default_snapshot_limit) diff --git a/sentience/browser_evaluator.py b/sentience/browser_evaluator.py index 79238a9..3cae2b4 100644 --- a/sentience/browser_evaluator.py +++ b/sentience/browser_evaluator.py @@ -21,7 +21,7 @@ class BrowserEvaluator: @staticmethod def wait_for_extension( - page: Union[Page, AsyncPage], + page: Page | AsyncPage, timeout_ms: int = 5000, ) -> None: """ @@ -79,7 +79,7 @@ async def wait_for_extension_async( ) from e @staticmethod - def _gather_diagnostics(page: Union[Page, AsyncPage]) -> dict[str, Any]: + def _gather_diagnostics(page: Page | AsyncPage) -> dict[str, Any]: """ Gather diagnostics about extension state. @@ -297,4 +297,3 @@ async def verify_method_exists_async( return await page.evaluate(f"typeof window.sentience.{method_name} !== 'undefined'") except Exception: return False - diff --git a/sentience/cloud_tracing.py b/sentience/cloud_tracing.py index 0631718..7dfc71b 100644 --- a/sentience/cloud_tracing.py +++ b/sentience/cloud_tracing.py @@ -13,7 +13,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Any, Optional, Protocol, Union -from collections.abc import Callable import requests diff --git a/sentience/element_filter.py b/sentience/element_filter.py new file mode 100644 index 0000000..944ff6f --- /dev/null +++ b/sentience/element_filter.py @@ -0,0 +1,135 @@ +""" +Element filtering utilities for agent-based element selection. + +This module provides centralized element filtering logic to reduce duplication +across agent implementations. +""" + +from typing import Optional + +from .models import Element, Snapshot + + +class ElementFilter: + """ + Centralized element filtering logic for agent-based element selection. + + Provides static methods for filtering elements based on: + - Importance scores + - Goal-based keyword matching + - Role and visual properties + """ + + # Common stopwords for keyword extraction + STOPWORDS = { + "the", + "a", + "an", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + "from", + "as", + "is", + "was", + } + + @staticmethod + def filter_by_importance( + snapshot: Snapshot, + max_elements: int = 50, + ) -> list[Element]: + """ + Filter elements by importance score (simple top-N selection). + + Args: + snapshot: Current page snapshot + max_elements: Maximum number of elements to return + + Returns: + Top N elements sorted by importance score + """ + elements = snapshot.elements + # Elements are already sorted by importance in snapshot + return elements[:max_elements] + + @staticmethod + def filter_by_goal( + snapshot: Snapshot, + goal: Optional[str], + max_elements: int = 50, + ) -> list[Element]: + """ + Filter elements from snapshot based on goal context. + + Applies goal-based keyword matching to boost relevant elements + and filters out irrelevant ones. + + Args: + snapshot: Current page snapshot + goal: User's goal (can inform filtering) + max_elements: Maximum number of elements to return + + Returns: + Filtered list of elements sorted by boosted importance score + """ + elements = snapshot.elements + + # If no goal provided, return all elements (up to limit) + if not goal: + return elements[:max_elements] + + goal_lower = goal.lower() + + # Extract keywords from goal + keywords = ElementFilter._extract_keywords(goal_lower) + + # Boost elements matching goal keywords + scored_elements = [] + for el in elements: + score = el.importance + + # Boost if element text matches goal + if el.text and any(kw in el.text.lower() for kw in keywords): + score += 0.3 + + # Boost if role matches goal intent + if "click" in goal_lower and el.visual_cues.is_clickable: + score += 0.2 + if "type" in goal_lower and el.role in ["textbox", "searchbox"]: + score += 0.2 + if "search" in goal_lower: + # Filter out non-interactive elements for search tasks + if el.role in ["link", "img"] and not el.visual_cues.is_primary: + score -= 0.5 + + scored_elements.append((score, el)) + + # Re-sort by boosted score + scored_elements.sort(key=lambda x: x[0], reverse=True) + elements = [el for _, el in scored_elements] + + return elements[:max_elements] + + @staticmethod + def _extract_keywords(text: str) -> list[str]: + """ + Extract meaningful keywords from goal text. + + Args: + text: Text to extract keywords from + + Returns: + List of keywords (non-stopwords, length > 2) + """ + words = text.split() + return [w for w in words if w not in ElementFilter.STOPWORDS and len(w) > 2] + diff --git a/sentience/sentience_methods.py b/sentience/sentience_methods.py index abcd90b..e9a6697 100644 --- a/sentience/sentience_methods.py +++ b/sentience/sentience_methods.py @@ -84,4 +84,3 @@ class AgentAction(str, Enum): def __str__(self) -> str: """Return the action name as a string.""" return self.value - diff --git a/sentience/snapshot.py b/sentience/snapshot.py index 507a8ba..6f8e4fd 100644 --- a/sentience/snapshot.py +++ b/sentience/snapshot.py @@ -19,9 +19,7 @@ MAX_PAYLOAD_BYTES = 10 * 1024 * 1024 -def _save_trace_to_file( - raw_elements: list[dict[str, Any]], trace_path: str | None = None -) -> None: +def _save_trace_to_file(raw_elements: list[dict[str, Any]], trace_path: str | None = None) -> None: """ Save raw_elements to a JSON file for benchmarking/training diff --git a/sentience/trace_event_builder.py b/sentience/trace_event_builder.py new file mode 100644 index 0000000..867de0c --- /dev/null +++ b/sentience/trace_event_builder.py @@ -0,0 +1,97 @@ +""" +Trace event building utilities for agent-based tracing. + +This module provides centralized trace event building logic to reduce duplication +across agent implementations. +""" + +from typing import Any, Optional + +from .models import AgentActionResult, Element, Snapshot + + +class TraceEventBuilder: + """ + Helper for building trace events with consistent structure. + + Provides static methods for building common trace event types: + - snapshot_taken events + - step_end events + """ + + @staticmethod + def build_snapshot_event( + snapshot: Snapshot, + include_all_elements: bool = True, + ) -> dict[str, Any]: + """ + Build snapshot_taken trace event data. + + Args: + snapshot: Snapshot to build event from + include_all_elements: If True, include all elements (for DOM tree display). + If False, use filtered elements only. + + Returns: + Dictionary with snapshot event data + """ + # Include ALL elements with full data for DOM tree display + # Use snap.elements (all elements) not filtered_elements + elements_data = [el.model_dump() for el in snapshot.elements] + + return { + "url": snapshot.url, + "element_count": len(snapshot.elements), + "timestamp": snapshot.timestamp, + "elements": elements_data, # Full element data for DOM tree + } + + @staticmethod + def build_step_end_event( + step_id: str, + step_index: int, + goal: str, + attempt: int, + pre_url: str, + post_url: str, + snapshot_digest: Optional[str], + llm_data: dict[str, Any], + exec_data: dict[str, Any], + verify_data: dict[str, Any], + ) -> dict[str, Any]: + """ + Build step_end trace event data. + + Args: + step_id: Unique step identifier + step_index: Step index (0-based) + goal: User's goal for this step + attempt: Attempt number (0-based) + pre_url: URL before action execution + post_url: URL after action execution + snapshot_digest: Digest of snapshot before action + llm_data: LLM interaction data + exec_data: Action execution data + verify_data: Verification data + + Returns: + Dictionary with step_end event data + """ + return { + "v": 1, + "step_id": step_id, + "step_index": step_index, + "goal": goal, + "attempt": attempt, + "pre": { + "url": pre_url, + "snapshot_digest": snapshot_digest, + }, + "llm": llm_data, + "exec": exec_data, + "post": { + "url": post_url, + }, + "verify": verify_data, + } + From 1a2d85cc223f6f616f883290aa40d6e8ac55bf3d Mon Sep 17 00:00:00 2001 From: rcholic Date: Fri, 2 Jan 2026 13:08:16 -0800 Subject: [PATCH 2/4] Phase 3.1 and 3.2 completed --- sentience/cloud_tracing.py | 176 ++++------------ sentience/element_filter.py | 3 +- sentience/llm_provider.py | 124 ++++++----- sentience/llm_provider_utils.py | 120 +++++++++++ sentience/llm_response_builder.py | 153 ++++++++++++++ sentience/trace_event_builder.py | 3 +- sentience/trace_file_manager.py | 197 ++++++++++++++++++ sentience/tracing.py | 103 +-------- tests/test_async_api.py | 3 + tests/test_llm_provider_utils.py | 97 +++++++++ tests/test_llm_response_builder.py | 96 +++++++++ tests/test_trace_file_manager.py | 115 ++++++++++ .../test_trace_file_manager_extract_stats.py | 165 +++++++++++++++ 13 files changed, 1068 insertions(+), 287 deletions(-) create mode 100644 sentience/llm_provider_utils.py create mode 100644 sentience/llm_response_builder.py create mode 100644 sentience/trace_file_manager.py create mode 100644 tests/test_llm_provider_utils.py create mode 100644 tests/test_llm_response_builder.py create mode 100644 tests/test_trace_file_manager.py create mode 100644 tests/test_trace_file_manager_extract_stats.py diff --git a/sentience/cloud_tracing.py b/sentience/cloud_tracing.py index 7dfc71b..7c55c54 100644 --- a/sentience/cloud_tracing.py +++ b/sentience/cloud_tracing.py @@ -17,6 +17,7 @@ import requests from sentience.models import TraceStats +from sentience.trace_file_manager import TraceFileManager from sentience.tracing import TraceSink @@ -98,7 +99,7 @@ def __init__( # Use persistent cache directory instead of temp file # This ensures traces survive process crashes cache_dir = Path.home() / ".sentience" / "traces" / "pending" - cache_dir.mkdir(parents=True, exist_ok=True) + TraceFileManager.ensure_directory(cache_dir) # Persistent file (survives process crash) self._path = cache_dir / f"{run_id}.jsonl" @@ -124,9 +125,7 @@ def emit(self, event: dict[str, Any]) -> None: if self._closed: raise RuntimeError("CloudTraceSink is closed") - json_str = json.dumps(event, ensure_ascii=False) - self._trace_file.write(json_str + "\n") - self._trace_file.flush() # Ensure written to disk + TraceFileManager.write_event(self._trace_file, event) def close( self, @@ -385,7 +384,9 @@ def _upload_index(self) -> None: if self.logger: self.logger.warning(f"Error uploading trace index: {e}") - def _infer_final_status_from_trace(self) -> str: + def _infer_final_status_from_trace( + self, events: list[dict[str, Any]], run_end: dict[str, Any] | None + ) -> str: """ Infer final status from trace events by reading the trace file. @@ -436,92 +437,20 @@ def _infer_final_status_from_trace(self) -> str: # If we can't read the trace, default to unknown return "unknown" - def _extract_stats_from_trace(self) -> dict[str, Any]: + def _extract_stats_from_trace(self) -> TraceStats: """ Extract execution statistics from trace file. Returns: - Dictionary with stats fields for /v1/traces/complete + TraceStats with stats fields for /v1/traces/complete """ try: # Read trace file to extract stats - with open(self._path, encoding="utf-8") as f: - events = [] - for line in f: - line = line.strip() - if not line: - continue - try: - event = json.loads(line) - events.append(event) - except json.JSONDecodeError: - continue - - if not events: - return TraceStats( - total_steps=0, - total_events=0, - duration_ms=None, - final_status="unknown", - started_at=None, - ended_at=None, - ) - - # Find run_start and run_end events - run_start = next((e for e in events if e.get("type") == "run_start"), None) - run_end = next((e for e in events if e.get("type") == "run_end"), None) - - # Extract timestamps - started_at: str | None = None - ended_at: str | None = None - if run_start: - started_at = run_start.get("ts") - if run_end: - ended_at = run_end.get("ts") - - # Calculate duration - duration_ms: int | None = None - if started_at and ended_at: - try: - from datetime import datetime - - start_dt = datetime.fromisoformat(started_at.replace("Z", "+00:00")) - end_dt = datetime.fromisoformat(ended_at.replace("Z", "+00:00")) - delta = end_dt - start_dt - duration_ms = int(delta.total_seconds() * 1000) - except Exception: - pass - - # Count steps (from step_start events, only first attempt) - step_indices = set() - for event in events: - if event.get("type") == "step_start": - step_index = event.get("data", {}).get("step_index") - if step_index is not None: - step_indices.add(step_index) - total_steps = len(step_indices) if step_indices else 0 - - # If run_end has steps count, use that (more accurate) - if run_end: - steps_from_end = run_end.get("data", {}).get("steps") - if steps_from_end is not None: - total_steps = max(total_steps, steps_from_end) - - # Count total events - total_events = len(events) - - # Infer final status - final_status = self._infer_final_status_from_trace() - - return TraceStats( - total_steps=total_steps, - total_events=total_events, - duration_ms=duration_ms, - final_status=final_status, - started_at=started_at, - ended_at=ended_at, + events = TraceFileManager.read_events(self._path) + # Use TraceFileManager to extract stats (with custom status inference) + return TraceFileManager.extract_stats( + events, infer_status_func=self._infer_final_status_from_trace ) - except Exception as e: if self.logger: self.logger.warning(f"Error extracting stats from trace: {e}") @@ -593,28 +522,20 @@ def _extract_screenshots_from_trace(self) -> dict[int, dict[str, Any]]: sequence = 0 try: - with open(self._path, encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - - try: - event = json.loads(line) - # Check if this is a snapshot event with screenshot - if event.get("type") == "snapshot": - data = event.get("data", {}) - screenshot_base64 = data.get("screenshot_base64") - - if screenshot_base64: - sequence += 1 - screenshots[sequence] = { - "base64": screenshot_base64, - "format": data.get("screenshot_format", "jpeg"), - "step_id": event.get("step_id"), - } - except json.JSONDecodeError: - continue + events = TraceFileManager.read_events(self._path) + for event in events: + # Check if this is a snapshot event with screenshot + if event.get("type") == "snapshot": + data = event.get("data", {}) + screenshot_base64 = data.get("screenshot_base64") + + if screenshot_base64: + sequence += 1 + screenshots[sequence] = { + "base64": screenshot_base64, + "format": data.get("screenshot_format", "jpeg"), + "step_id": event.get("step_id"), + } except Exception as e: if self.logger: self.logger.error(f"Error extracting screenshots: {e}") @@ -629,34 +550,23 @@ def _create_cleaned_trace(self, output_path: Path) -> None: output_path: Path to write cleaned trace file """ try: - with ( - open(self._path, encoding="utf-8") as infile, - open(output_path, "w", encoding="utf-8") as outfile, - ): - for line in infile: - line = line.strip() - if not line: - continue - - try: - event = json.loads(line) - # Remove screenshot_base64 from snapshot events - if event.get("type") == "snapshot": - data = event.get("data", {}) - if "screenshot_base64" in data: - # Create copy without screenshot fields - cleaned_data = { - k: v - for k, v in data.items() - if k not in ("screenshot_base64", "screenshot_format") - } - event["data"] = cleaned_data - - # Write cleaned event - outfile.write(json.dumps(event, ensure_ascii=False) + "\n") - except json.JSONDecodeError: - # Skip invalid lines - continue + events = TraceFileManager.read_events(self._path) + with open(output_path, "w", encoding="utf-8") as outfile: + for event in events: + # Remove screenshot_base64 from snapshot events + if event.get("type") == "snapshot": + data = event.get("data", {}) + if "screenshot_base64" in data: + # Create copy without screenshot fields + cleaned_data = { + k: v + for k, v in data.items() + if k not in ("screenshot_base64", "screenshot_format") + } + event["data"] = cleaned_data + + # Write cleaned event + TraceFileManager.write_event(outfile, event) except Exception as e: if self.logger: self.logger.error(f"Error creating cleaned trace: {e}") diff --git a/sentience/element_filter.py b/sentience/element_filter.py index 944ff6f..df117b9 100644 --- a/sentience/element_filter.py +++ b/sentience/element_filter.py @@ -64,7 +64,7 @@ def filter_by_importance( @staticmethod def filter_by_goal( snapshot: Snapshot, - goal: Optional[str], + goal: str | None, max_elements: int = 50, ) -> list[Element]: """ @@ -132,4 +132,3 @@ def _extract_keywords(text: str) -> list[str]: """ words = text.split() return [w for w in words if w not in ElementFilter.STOPWORDS and len(w) > 2] - diff --git a/sentience/llm_provider.py b/sentience/llm_provider.py index c4f1035..650f17f 100644 --- a/sentience/llm_provider.py +++ b/sentience/llm_provider.py @@ -8,6 +8,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from .llm_provider_utils import get_api_key_from_env, handle_provider_error, require_package +from .llm_response_builder import LLMResponseBuilder + @dataclass class LLMResponse: @@ -33,6 +36,15 @@ class LLMProvider(ABC): - Any other completion API """ + def __init__(self, model: str): + """ + Initialize LLM provider with model name. + + Args: + model: Model identifier (e.g., "gpt-4o", "claude-3-sonnet") + """ + self._model_name = model + @abstractmethod def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse: """ @@ -97,13 +109,16 @@ def __init__( base_url: Custom API base URL (for compatible APIs) organization: OpenAI organization ID """ - try: - from openai import OpenAI - except ImportError: - raise ImportError("OpenAI package not installed. Install with: pip install openai") + super().__init__(model) # Initialize base class with model name + + OpenAI = require_package( + "openai", + "openai", + "OpenAI", + "pip install openai", + ) self.client = OpenAI(api_key=api_key, base_url=base_url, organization=organization) - self._model_name = model def generate( self, @@ -150,12 +165,15 @@ def generate( api_params.update(kwargs) # Call OpenAI API - response = self.client.chat.completions.create(**api_params) + try: + response = self.client.chat.completions.create(**api_params) + except Exception as e: + handle_provider_error(e, "OpenAI", "generate response") choice = response.choices[0] usage = response.usage - return LLMResponse( + return LLMResponseBuilder.from_openai_format( content=choice.message.content, prompt_tokens=usage.prompt_tokens if usage else None, completion_tokens=usage.completion_tokens if usage else None, @@ -193,15 +211,16 @@ def __init__(self, api_key: str | None = None, model: str = "claude-3-5-sonnet-2 api_key: Anthropic API key (or set ANTHROPIC_API_KEY env var) model: Model name (claude-3-opus, claude-3-sonnet, claude-3-haiku, etc.) """ - try: - from anthropic import Anthropic - except ImportError: - raise ImportError( - "Anthropic package not installed. Install with: pip install anthropic" - ) + super().__init__(model) # Initialize base class with model name + + Anthropic = require_package( + "anthropic", + "anthropic", + "Anthropic", + "pip install anthropic", + ) self.client = Anthropic(api_key=api_key) - self._model_name = model def generate( self, @@ -239,21 +258,19 @@ def generate( api_params.update(kwargs) # Call Anthropic API - response = self.client.messages.create(**api_params) + try: + response = self.client.messages.create(**api_params) + except Exception as e: + handle_provider_error(e, "Anthropic", "generate response") content = response.content[0].text if response.content else "" - return LLMResponse( + return LLMResponseBuilder.from_anthropic_format( content=content, - prompt_tokens=response.usage.input_tokens if hasattr(response, "usage") else None, - completion_tokens=response.usage.output_tokens if hasattr(response, "usage") else None, - total_tokens=( - (response.usage.input_tokens + response.usage.output_tokens) - if hasattr(response, "usage") - else None - ), + input_tokens=response.usage.input_tokens if hasattr(response, "usage") else None, + output_tokens=response.usage.output_tokens if hasattr(response, "usage") else None, model_name=response.model, - finish_reason=response.stop_reason, + stop_reason=response.stop_reason, ) def supports_json_mode(self) -> bool: @@ -287,13 +304,16 @@ def __init__(self, api_key: str | None = None, model: str = "glm-4-plus"): api_key: Zhipu AI API key (or set GLM_API_KEY env var) model: Model name (glm-4-plus, glm-4, glm-4-air, glm-4-flash, etc.) """ - try: - from zhipuai import ZhipuAI - except ImportError: - raise ImportError("ZhipuAI package not installed. Install with: pip install zhipuai") + super().__init__(model) # Initialize base class with model name + + ZhipuAI = require_package( + "zhipuai", + "zhipuai", + "ZhipuAI", + "pip install zhipuai", + ) self.client = ZhipuAI(api_key=api_key) - self._model_name = model def generate( self, @@ -335,12 +355,15 @@ def generate( api_params.update(kwargs) # Call GLM API - response = self.client.chat.completions.create(**api_params) + try: + response = self.client.chat.completions.create(**api_params) + except Exception as e: + handle_provider_error(e, "GLM", "generate response") choice = response.choices[0] usage = response.usage - return LLMResponse( + return LLMResponseBuilder.from_openai_format( content=choice.message.content, prompt_tokens=usage.prompt_tokens if usage else None, completion_tokens=usage.completion_tokens if usage else None, @@ -380,25 +403,20 @@ def __init__(self, api_key: str | None = None, model: str = "gemini-2.0-flash-ex api_key: Google API key (or set GEMINI_API_KEY or GOOGLE_API_KEY env var) model: Model name (gemini-2.0-flash-exp, gemini-1.5-pro, gemini-1.5-flash, etc.) """ - try: - import google.generativeai as genai - except ImportError: - raise ImportError( - "Google Generative AI package not installed. Install with: pip install google-generativeai" - ) + super().__init__(model) # Initialize base class with model name + + genai = require_package( + "google-generativeai", + "google.generativeai", + install_command="pip install google-generativeai", + ) - # Configure API key + # Configure API key (check parameter first, then environment variables) + api_key = get_api_key_from_env(["GEMINI_API_KEY", "GOOGLE_API_KEY"], api_key) if api_key: genai.configure(api_key=api_key) - else: - import os - - api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") - if api_key: - genai.configure(api_key=api_key) self.genai = genai - self._model_name = model self.model = genai.GenerativeModel(model) def generate( @@ -437,7 +455,10 @@ def generate( generation_config.update(kwargs) # Call Gemini API - response = self.model.generate_content(full_prompt, generation_config=generation_config) + try: + response = self.model.generate_content(full_prompt, generation_config=generation_config) + except Exception as e: + handle_provider_error(e, "Gemini", "generate response") # Extract content content = response.text if response.text else "" @@ -452,13 +473,12 @@ def generate( completion_tokens = response.usage_metadata.candidates_token_count total_tokens = response.usage_metadata.total_token_count - return LLMResponse( + return LLMResponseBuilder.from_gemini_format( content=content, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, model_name=self._model_name, - finish_reason=None, # Gemini uses different finish reason format ) def supports_json_mode(self) -> bool: @@ -505,6 +525,9 @@ def __init__( load_in_8bit: Use 8-bit quantization (saves 50% memory) torch_dtype: Data type ("auto", "float16", "bfloat16", "float32") """ + super().__init__(model_name) # Initialize base class with model name + + # Import required packages with consistent error handling try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig @@ -514,8 +537,6 @@ def __init__( "Install with: pip install transformers torch" ) - self._model_name = model_name - # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) @@ -622,11 +643,10 @@ def generate( generated_tokens = outputs[0][input_length:] response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() - return LLMResponse( + return LLMResponseBuilder.from_local_format( content=response_text, prompt_tokens=input_length, completion_tokens=len(generated_tokens), - total_tokens=input_length + len(generated_tokens), model_name=self._model_name, ) diff --git a/sentience/llm_provider_utils.py b/sentience/llm_provider_utils.py new file mode 100644 index 0000000..fdae52b --- /dev/null +++ b/sentience/llm_provider_utils.py @@ -0,0 +1,120 @@ +""" +LLM Provider utility functions for common initialization and error handling. + +This module provides helper functions to reduce duplication across LLM provider implementations. +""" + +import os +from collections.abc import Callable +from typing import Any, Optional, TypeVar + +T = TypeVar("T") + + +def require_package( + package_name: str, + module_name: str, + class_name: str | None = None, + install_command: str | None = None, +) -> Any: + """ + Import a package with consistent error handling. + + Args: + package_name: Name of the package (for error messages) + module_name: Module name to import (e.g., "openai", "google.generativeai") + class_name: Optional class name to import from module (e.g., "OpenAI") + install_command: Installation command (defaults to "pip install {package_name}") + + Returns: + Imported module or class + + Raises: + ImportError: If package is not installed, with helpful message + + Example: + >>> OpenAI = require_package("openai", "openai", "OpenAI", "pip install openai") + >>> genai = require_package("google-generativeai", "google.generativeai", install_command="pip install google-generativeai") + """ + if install_command is None: + install_command = f"pip install {package_name}" + + try: + if class_name: + # Import specific class: from module import class + module = __import__(module_name, fromlist=[class_name]) + return getattr(module, class_name) + else: + # Import entire module + return __import__(module_name) + except ImportError: + raise ImportError(f"{package_name} package not installed. Install with: {install_command}") + + +def get_api_key_from_env( + env_vars: list[str], + api_key: str | None = None, +) -> str | None: + """ + Get API key from parameter or environment variables. + + Args: + env_vars: List of environment variable names to check (in order) + api_key: Optional API key parameter (takes precedence) + + Returns: + API key string or None if not found + + Example: + >>> key = get_api_key_from_env(["OPENAI_API_KEY"], api_key="sk-...") + >>> # Returns "sk-..." if provided, otherwise checks OPENAI_API_KEY env var + """ + if api_key: + return api_key + + for env_var in env_vars: + value = os.getenv(env_var) + if value: + return value + + return None + + +def handle_provider_error( + error: Exception, + provider_name: str, + operation: str = "operation", +) -> None: + """ + Standardize error handling for LLM provider operations. + + Args: + error: Exception that occurred + provider_name: Name of the provider (e.g., "OpenAI", "Anthropic") + operation: Description of the operation that failed + + Raises: + RuntimeError: With standardized error message + + Example: + >>> try: + ... response = client.chat.completions.create(...) + ... except Exception as e: + ... handle_provider_error(e, "OpenAI", "generate response") + """ + error_msg = str(error) + if "api key" in error_msg.lower() or "authentication" in error_msg.lower(): + raise RuntimeError( + f"{provider_name} API key is invalid or missing. " + f"Please check your API key configuration." + ) from error + elif "rate limit" in error_msg.lower() or "429" in error_msg: + raise RuntimeError( + f"{provider_name} rate limit exceeded. Please try again later." + ) from error + elif "model" in error_msg.lower() and "not found" in error_msg.lower(): + raise RuntimeError( + f"{provider_name} model not found. Please check the model name." + ) from error + else: + raise RuntimeError(f"{provider_name} {operation} failed: {error_msg}") from error diff --git a/sentience/llm_response_builder.py b/sentience/llm_response_builder.py new file mode 100644 index 0000000..a93a282 --- /dev/null +++ b/sentience/llm_response_builder.py @@ -0,0 +1,153 @@ +""" +LLM Response building utilities for consistent response construction. + +This module provides helper functions for building LLMResponse objects +from various provider API responses. +""" + +from typing import Any, Optional + +# Import LLMResponse here to avoid circular dependency +# We import it inside functions to break the cycle + + +class LLMResponseBuilder: + """ + Helper for building LLMResponse objects with consistent structure. + + Provides static methods for building responses from different provider formats. + """ + + @staticmethod + def from_openai_format( + content: str, + prompt_tokens: int | None = None, + completion_tokens: int | None = None, + total_tokens: int | None = None, + model_name: str | None = None, + finish_reason: str | None = None, + ) -> "LLMResponse": + """ + Build LLMResponse from OpenAI-style API response. + + Args: + content: Response text content + prompt_tokens: Number of prompt tokens + completion_tokens: Number of completion tokens + total_tokens: Total tokens (or sum of prompt + completion) + model_name: Model identifier + finish_reason: Finish reason (stop, length, etc.) + + Returns: + LLMResponse object + """ + from .llm_provider import LLMResponse # Import here to avoid circular dependency + + return LLMResponse( + content=content, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens + or ( + (prompt_tokens + completion_tokens) if prompt_tokens and completion_tokens else None + ), + model_name=model_name, + finish_reason=finish_reason, + ) + + @staticmethod + def from_anthropic_format( + content: str, + input_tokens: int | None = None, + output_tokens: int | None = None, + model_name: str | None = None, + stop_reason: str | None = None, + ) -> "LLMResponse": + """ + Build LLMResponse from Anthropic-style API response. + + Args: + content: Response text content + input_tokens: Number of input tokens + output_tokens: Number of output tokens + model_name: Model identifier + stop_reason: Stop reason (end_turn, max_tokens, etc.) + + Returns: + LLMResponse object + """ + from .llm_provider import LLMResponse # Import here to avoid circular dependency + + return LLMResponse( + content=content, + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=(input_tokens + output_tokens) if input_tokens and output_tokens else None, + model_name=model_name, + finish_reason=stop_reason, + ) + + @staticmethod + def from_gemini_format( + content: str, + prompt_tokens: int | None = None, + completion_tokens: int | None = None, + total_tokens: int | None = None, + model_name: str | None = None, + ) -> "LLMResponse": + """ + Build LLMResponse from Gemini-style API response. + + Args: + content: Response text content + prompt_tokens: Number of prompt tokens + completion_tokens: Number of completion tokens + total_tokens: Total tokens + model_name: Model identifier + + Returns: + LLMResponse object + """ + from .llm_provider import LLMResponse # Import here to avoid circular dependency + + return LLMResponse( + content=content, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens + or ( + (prompt_tokens + completion_tokens) if prompt_tokens and completion_tokens else None + ), + model_name=model_name, + finish_reason=None, # Gemini uses different finish reason format + ) + + @staticmethod + def from_local_format( + content: str, + prompt_tokens: int, + completion_tokens: int, + model_name: str, + ) -> "LLMResponse": + """ + Build LLMResponse from local model generation. + + Args: + content: Response text content + prompt_tokens: Number of prompt tokens + completion_tokens: Number of completion tokens + model_name: Model identifier + + Returns: + LLMResponse object + """ + from .llm_provider import LLMResponse # Import here to avoid circular dependency + + return LLMResponse( + content=content, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + model_name=model_name, + finish_reason=None, + ) diff --git a/sentience/trace_event_builder.py b/sentience/trace_event_builder.py index 867de0c..3d4dfb5 100644 --- a/sentience/trace_event_builder.py +++ b/sentience/trace_event_builder.py @@ -54,7 +54,7 @@ def build_step_end_event( attempt: int, pre_url: str, post_url: str, - snapshot_digest: Optional[str], + snapshot_digest: str | None, llm_data: dict[str, Any], exec_data: dict[str, Any], verify_data: dict[str, Any], @@ -94,4 +94,3 @@ def build_step_end_event( }, "verify": verify_data, } - diff --git a/sentience/trace_file_manager.py b/sentience/trace_file_manager.py new file mode 100644 index 0000000..0bba017 --- /dev/null +++ b/sentience/trace_file_manager.py @@ -0,0 +1,197 @@ +""" +Trace file management utilities for consistent file operations. + +This module provides helper functions for common trace file operations +shared between JsonlTraceSink and CloudTraceSink. +""" + +import json +from collections.abc import Callable +from pathlib import Path +from typing import Any, Optional + +from .models import TraceStats + + +class TraceFileManager: + """ + Helper for common trace file operations. + + Provides static methods for file operations shared across trace sinks. + """ + + @staticmethod + def write_event(file_handle: Any, event: dict[str, Any]) -> None: + """ + Write a trace event to a file handle as JSONL. + + Args: + file_handle: Open file handle (must be writable) + event: Event dictionary to write + """ + json_str = json.dumps(event, ensure_ascii=False) + file_handle.write(json_str + "\n") + file_handle.flush() # Ensure written to disk + + @staticmethod + def ensure_directory(path: Path) -> None: + """ + Ensure the parent directory of a path exists. + + Args: + path: File path whose parent directory should exist + """ + path.parent.mkdir(parents=True, exist_ok=True) + + @staticmethod + def read_events(path: Path) -> list[dict[str, Any]]: + """ + Read all events from a JSONL trace file. + + Args: + path: Path to JSONL trace file + + Returns: + List of event dictionaries + + Raises: + FileNotFoundError: If file doesn't exist + json.JSONDecodeError: If file contains invalid JSON + """ + events = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append(event) + except json.JSONDecodeError: + # Skip invalid lines but continue reading + continue + return events + + @staticmethod + def extract_stats( + events: list[dict[str, Any]], + infer_status_func: None | ( + Callable[[list[dict[str, Any]], dict[str, Any] | None], str] + ) = None, + ) -> TraceStats: + """ + Extract execution statistics from trace events. + + This is a common operation shared between JsonlTraceSink and CloudTraceSink. + + Args: + events: List of trace event dictionaries + infer_status_func: Optional function to infer final_status from events. + If None, uses default inference logic. + + Returns: + TraceStats with execution statistics + """ + if not events: + return TraceStats( + total_steps=0, + total_events=0, + duration_ms=None, + final_status="unknown", + started_at=None, + ended_at=None, + ) + + # Find run_start and run_end events + run_start = next((e for e in events if e.get("type") == "run_start"), None) + run_end = next((e for e in events if e.get("type") == "run_end"), None) + + # Extract timestamps + started_at: str | None = None + ended_at: str | None = None + if run_start: + started_at = run_start.get("ts") + if run_end: + ended_at = run_end.get("ts") + + # Calculate duration + duration_ms: int | None = None + if started_at and ended_at: + try: + from datetime import datetime + + start_dt = datetime.fromisoformat(started_at.replace("Z", "+00:00")) + end_dt = datetime.fromisoformat(ended_at.replace("Z", "+00:00")) + delta = end_dt - start_dt + duration_ms = int(delta.total_seconds() * 1000) + except Exception: + pass + + # Count steps (from step_start events, only first attempt) + step_indices = set() + for event in events: + if event.get("type") == "step_start": + step_index = event.get("data", {}).get("step_index") + if step_index is not None: + step_indices.add(step_index) + total_steps = len(step_indices) if step_indices else 0 + + # If run_end has steps count, use that (more accurate) + if run_end: + steps_from_end = run_end.get("data", {}).get("steps") + if steps_from_end is not None: + total_steps = max(total_steps, steps_from_end) + + # Count total events + total_events = len(events) + + # Infer final status + if infer_status_func: + final_status = infer_status_func(events, run_end) + else: + final_status = TraceFileManager._infer_final_status(events, run_end) + + return TraceStats( + total_steps=total_steps, + total_events=total_events, + duration_ms=duration_ms, + final_status=final_status, + started_at=started_at, + ended_at=ended_at, + ) + + @staticmethod + def _infer_final_status( + events: list[dict[str, Any]], + run_end: dict[str, Any] | None, + ) -> str: + """ + Infer final status from trace events. + + Args: + events: List of trace event dictionaries + run_end: Optional run_end event dictionary + + Returns: + Final status string: "success", "failure", "partial", or "unknown" + """ + # Check for run_end event with status + if run_end: + status = run_end.get("data", {}).get("status") + if status in ("success", "failure", "partial", "unknown"): + return status + + # Infer from error events + has_errors = any(e.get("type") == "error" for e in events) + if has_errors: + step_ends = [e for e in events if e.get("type") == "step_end"] + if step_ends: + return "partial" + else: + return "failure" + else: + step_ends = [e for e in events if e.get("type") == "step_end"] + if step_ends: + return "success" + else: + return "unknown" diff --git a/sentience/tracing.py b/sentience/tracing.py index 8f1702e..fc0405c 100644 --- a/sentience/tracing.py +++ b/sentience/tracing.py @@ -13,6 +13,7 @@ from typing import Any, Optional from .models import TraceStats +from .trace_file_manager import TraceFileManager @dataclass @@ -90,7 +91,7 @@ def __init__(self, path: str | Path): path: File path to write traces to """ self.path = Path(path) - self.path.parent.mkdir(parents=True, exist_ok=True) + TraceFileManager.ensure_directory(self.path) # Open file in append mode with line buffering self._file = open(self.path, "a", encoding="utf-8", buffering=1) @@ -102,8 +103,7 @@ def emit(self, event: dict[str, Any]) -> None: Args: event: Event dictionary """ - json_str = json.dumps(event, ensure_ascii=False) - self._file.write(json_str + "\n") + TraceFileManager.write_event(self._file, event) def close(self) -> None: """Close the file and generate index.""" @@ -122,101 +122,8 @@ def get_stats(self) -> TraceStats: """ try: # Read trace file to extract stats - with open(self.path, encoding="utf-8") as f: - events = [] - for line in f: - line = line.strip() - if not line: - continue - try: - event = json.loads(line) - events.append(event) - except json.JSONDecodeError: - continue - - if not events: - return TraceStats( - total_steps=0, - total_events=0, - duration_ms=None, - final_status="unknown", - started_at=None, - ended_at=None, - ) - - # Find run_start and run_end events - run_start = next((e for e in events if e.get("type") == "run_start"), None) - run_end = next((e for e in events if e.get("type") == "run_end"), None) - - # Extract timestamps - started_at: str | None = None - ended_at: str | None = None - if run_start: - started_at = run_start.get("ts") - if run_end: - ended_at = run_end.get("ts") - - # Calculate duration - duration_ms: int | None = None - if started_at and ended_at: - try: - from datetime import datetime - - start_dt = datetime.fromisoformat(started_at.replace("Z", "+00:00")) - end_dt = datetime.fromisoformat(ended_at.replace("Z", "+00:00")) - delta = end_dt - start_dt - duration_ms = int(delta.total_seconds() * 1000) - except Exception: - pass - - # Count steps (from step_start events, only first attempt) - step_indices = set() - for event in events: - if event.get("type") == "step_start": - step_index = event.get("data", {}).get("step_index") - if step_index is not None: - step_indices.add(step_index) - total_steps = len(step_indices) if step_indices else 0 - - # If run_end has steps count, use that (more accurate) - if run_end: - steps_from_end = run_end.get("data", {}).get("steps") - if steps_from_end is not None: - total_steps = max(total_steps, steps_from_end) - - # Count total events - total_events = len(events) - - # Infer final status - final_status = "unknown" - # Check for run_end event with status - if run_end: - status = run_end.get("data", {}).get("status") - if status in ("success", "failure", "partial", "unknown"): - final_status = status - else: - # Infer from error events - has_errors = any(e.get("type") == "error" for e in events) - if has_errors: - step_ends = [e for e in events if e.get("type") == "step_end"] - if step_ends: - final_status = "partial" - else: - final_status = "failure" - else: - step_ends = [e for e in events if e.get("type") == "step_end"] - if step_ends: - final_status = "success" - - return TraceStats( - total_steps=total_steps, - total_events=total_events, - duration_ms=duration_ms, - final_status=final_status, - started_at=started_at, - ended_at=ended_at, - ) - + events = TraceFileManager.read_events(self.path) + return TraceFileManager.extract_stats(events) except Exception: return TraceStats( total_steps=0, diff --git a/tests/test_async_api.py b/tests/test_async_api.py index cb6a89e..fdff935 100644 --- a/tests/test_async_api.py +++ b/tests/test_async_api.py @@ -514,6 +514,9 @@ async def test_sentience_agent_async_initialization(): # Create a simple mock LLM provider class MockLLMProvider(LLMProvider): + def __init__(self): + super().__init__("mock-model") + def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse: return LLMResponse( content="CLICK(1)", diff --git a/tests/test_llm_provider_utils.py b/tests/test_llm_provider_utils.py new file mode 100644 index 0000000..4723dcc --- /dev/null +++ b/tests/test_llm_provider_utils.py @@ -0,0 +1,97 @@ +"""Tests for sentience.llm_provider_utils module""" + +import os +from unittest.mock import patch + +import pytest + +from sentience.llm_provider_utils import ( + get_api_key_from_env, + handle_provider_error, + require_package, +) + + +def test_require_package_success(): + """Test require_package successfully imports existing package.""" + # Test with a package that should exist + json_module = require_package("json", "json", install_command="pip install json") + assert json_module is not None + # Verify it's actually the json module + assert hasattr(json_module, "dumps") + + +def test_require_package_import_error(): + """Test require_package raises ImportError for missing package.""" + with pytest.raises(ImportError, match="nonexistent-package.*not installed"): + require_package( + "nonexistent-package", + "nonexistent_package", + install_command="pip install nonexistent-package", + ) + + +def test_require_package_with_class(): + """Test require_package imports specific class.""" + # json doesn't have a class, but we can test the mechanism + json_module = require_package("json", "json", install_command="pip install json") + assert json_module is not None + + +def test_get_api_key_from_env_with_param(): + """Test get_api_key_from_env returns parameter if provided.""" + key = get_api_key_from_env(["TEST_API_KEY"], api_key="provided-key") + assert key == "provided-key" + + +def test_get_api_key_from_env_from_env_var(): + """Test get_api_key_from_env reads from environment variable.""" + with patch.dict(os.environ, {"TEST_API_KEY": "env-key-value"}): + key = get_api_key_from_env(["TEST_API_KEY"]) + assert key == "env-key-value" + + +def test_get_api_key_from_env_multiple_vars(): + """Test get_api_key_from_env checks multiple environment variables.""" + # Remove FIRST_KEY if it exists, set SECOND_KEY + with patch.dict(os.environ, {"SECOND_KEY": "second-value"}, clear=False): + # Remove FIRST_KEY if it exists + os.environ.pop("FIRST_KEY", None) + key = get_api_key_from_env(["FIRST_KEY", "SECOND_KEY"]) + assert key == "second-value" + + +def test_get_api_key_from_env_not_found(): + """Test get_api_key_from_env returns None if not found.""" + with patch.dict(os.environ, {}, clear=True): + key = get_api_key_from_env(["NONEXISTENT_KEY"]) + assert key is None + + +def test_handle_provider_error_api_key(): + """Test handle_provider_error handles API key errors.""" + error = Exception("Invalid API key provided") + with pytest.raises(RuntimeError, match="API key is invalid or missing"): + handle_provider_error(error, "OpenAI", "generate response") + + +def test_handle_provider_error_rate_limit(): + """Test handle_provider_error handles rate limit errors.""" + error = Exception("Rate limit exceeded: 429") + with pytest.raises(RuntimeError, match="rate limit exceeded"): + handle_provider_error(error, "Anthropic", "generate response") + + +def test_handle_provider_error_model_not_found(): + """Test handle_provider_error handles model not found errors.""" + error = Exception("Model 'gpt-999' not found") + with pytest.raises(RuntimeError, match="model not found"): + handle_provider_error(error, "OpenAI", "generate response") + + +def test_handle_provider_error_generic(): + """Test handle_provider_error handles generic errors.""" + error = Exception("Network timeout") + with pytest.raises(RuntimeError, match="Gemini generate response failed: Network timeout"): + handle_provider_error(error, "Gemini", "generate response") + diff --git a/tests/test_llm_response_builder.py b/tests/test_llm_response_builder.py new file mode 100644 index 0000000..f39d2da --- /dev/null +++ b/tests/test_llm_response_builder.py @@ -0,0 +1,96 @@ +""" +Tests for LLMResponseBuilder helper class. +""" + +import pytest + +from sentience.llm_provider import LLMResponse +from sentience.llm_response_builder import LLMResponseBuilder + + +class TestLLMResponseBuilder: + """Test LLMResponseBuilder helper methods""" + + def test_from_openai_format(self): + """Test building response from OpenAI format""" + response = LLMResponseBuilder.from_openai_format( + content="Hello, world!", + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + model_name="gpt-4o", + finish_reason="stop", + ) + + assert isinstance(response, LLMResponse) + assert response.content == "Hello, world!" + assert response.prompt_tokens == 10 + assert response.completion_tokens == 5 + assert response.total_tokens == 15 + assert response.model_name == "gpt-4o" + assert response.finish_reason == "stop" + + def test_from_openai_format_auto_total(self): + """Test OpenAI format with auto-calculated total_tokens""" + response = LLMResponseBuilder.from_openai_format( + content="Test", + prompt_tokens=5, + completion_tokens=3, + model_name="gpt-4o", + ) + + assert response.total_tokens == 8 # Auto-calculated + + def test_from_anthropic_format(self): + """Test building response from Anthropic format""" + response = LLMResponseBuilder.from_anthropic_format( + content="Claude response", + input_tokens=12, + output_tokens=8, + model_name="claude-3-sonnet", + stop_reason="end_turn", + ) + + assert isinstance(response, LLMResponse) + assert response.content == "Claude response" + assert response.prompt_tokens == 12 + assert response.completion_tokens == 8 + assert response.total_tokens == 20 + assert response.model_name == "claude-3-sonnet" + assert response.finish_reason == "end_turn" + + def test_from_gemini_format(self): + """Test building response from Gemini format""" + response = LLMResponseBuilder.from_gemini_format( + content="Gemini response", + prompt_tokens=15, + completion_tokens=7, + total_tokens=22, + model_name="gemini-2.0-flash-exp", + ) + + assert isinstance(response, LLMResponse) + assert response.content == "Gemini response" + assert response.prompt_tokens == 15 + assert response.completion_tokens == 7 + assert response.total_tokens == 22 + assert response.model_name == "gemini-2.0-flash-exp" + assert response.finish_reason is None + + def test_from_local_format(self): + """Test building response from local model format""" + response = LLMResponseBuilder.from_local_format( + content="Local model response", + prompt_tokens=20, + completion_tokens=10, + model_name="Qwen/Qwen2.5-3B-Instruct", + ) + + assert isinstance(response, LLMResponse) + assert response.content == "Local model response" + assert response.prompt_tokens == 20 + assert response.completion_tokens == 10 + assert response.total_tokens == 30 + assert response.model_name == "Qwen/Qwen2.5-3B-Instruct" + assert response.finish_reason is None + diff --git a/tests/test_trace_file_manager.py b/tests/test_trace_file_manager.py new file mode 100644 index 0000000..014bbbe --- /dev/null +++ b/tests/test_trace_file_manager.py @@ -0,0 +1,115 @@ +""" +Tests for TraceFileManager helper class. +""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from sentience.trace_file_manager import TraceFileManager + + +class TestTraceFileManager: + """Test TraceFileManager helper methods""" + + def test_write_event(self): + """Test writing event to file handle""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: + temp_path = Path(f.name) + + try: + with open(temp_path, "w", encoding="utf-8") as file_handle: + event = {"type": "test", "data": {"key": "value"}} + TraceFileManager.write_event(file_handle, event) + + # Read back and verify + with open(temp_path, encoding="utf-8") as f: + line = f.read().strip() + assert line + parsed = json.loads(line) + assert parsed == event + finally: + temp_path.unlink() + + def test_ensure_directory(self): + """Test ensuring directory exists""" + with tempfile.TemporaryDirectory() as tmpdir: + test_path = Path(tmpdir) / "nested" / "path" / "file.jsonl" + TraceFileManager.ensure_directory(test_path) + + assert test_path.parent.exists() + assert test_path.parent.is_dir() + + def test_read_events(self): + """Test reading events from JSONL file""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: + temp_path = Path(f.name) + + try: + # Write test events + events = [ + {"type": "event1", "data": {"key1": "value1"}}, + {"type": "event2", "data": {"key2": "value2"}}, + {"type": "event3", "data": {"key3": "value3"}}, + ] + + with open(temp_path, "w", encoding="utf-8") as f: + for event in events: + TraceFileManager.write_event(f, event) + + # Read back + read_events = TraceFileManager.read_events(temp_path) + + assert len(read_events) == 3 + assert read_events == events + finally: + temp_path.unlink() + + def test_read_events_skips_empty_lines(self): + """Test that empty lines are skipped when reading""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: + temp_path = Path(f.name) + + try: + # Write events with empty lines + with open(temp_path, "w", encoding="utf-8") as f: + TraceFileManager.write_event(f, {"type": "event1"}) + f.write("\n") # Empty line + f.write(" \n") # Whitespace-only line + TraceFileManager.write_event(f, {"type": "event2"}) + + read_events = TraceFileManager.read_events(temp_path) + + assert len(read_events) == 2 + assert read_events[0]["type"] == "event1" + assert read_events[1]["type"] == "event2" + finally: + temp_path.unlink() + + def test_read_events_handles_invalid_json(self): + """Test that invalid JSON lines are skipped""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: + temp_path = Path(f.name) + + try: + # Write valid and invalid events + with open(temp_path, "w", encoding="utf-8") as f: + TraceFileManager.write_event(f, {"type": "event1"}) + f.write("invalid json line\n") + TraceFileManager.write_event(f, {"type": "event2"}) + + read_events = TraceFileManager.read_events(temp_path) + + assert len(read_events) == 2 + assert read_events[0]["type"] == "event1" + assert read_events[1]["type"] == "event2" + finally: + temp_path.unlink() + + def test_read_events_file_not_found(self): + """Test that FileNotFoundError is raised for non-existent file""" + with pytest.raises(FileNotFoundError): + TraceFileManager.read_events(Path("/nonexistent/file.jsonl")) + diff --git a/tests/test_trace_file_manager_extract_stats.py b/tests/test_trace_file_manager_extract_stats.py new file mode 100644 index 0000000..45ded88 --- /dev/null +++ b/tests/test_trace_file_manager_extract_stats.py @@ -0,0 +1,165 @@ +"""Tests for TraceFileManager.extract_stats method""" + +from datetime import datetime, timezone + +import pytest + +from sentience.models import TraceStats +from sentience.trace_file_manager import TraceFileManager + + +def test_extract_stats_empty_events(): + """Test extract_stats with empty events list.""" + stats = TraceFileManager.extract_stats([]) + assert stats.total_steps == 0 + assert stats.total_events == 0 + assert stats.duration_ms is None + assert stats.final_status == "unknown" + assert stats.started_at is None + assert stats.ended_at is None + + +def test_extract_stats_with_run_start_and_end(): + """Test extract_stats calculates duration from run_start and run_end.""" + start_time = datetime.now(timezone.utc) + end_time = datetime.now(timezone.utc) + # Make end_time 5 seconds later + end_time = end_time.replace(second=end_time.second + 5) + + events = [ + { + "type": "run_start", + "ts": start_time.isoformat().replace("+00:00", "Z"), + "data": {}, + }, + { + "type": "step_start", + "data": {"step_index": 0}, + }, + { + "type": "step_end", + "data": {}, + }, + { + "type": "run_end", + "ts": end_time.isoformat().replace("+00:00", "Z"), + "data": {"steps": 1}, + }, + ] + + stats = TraceFileManager.extract_stats(events) + assert stats.total_steps == 1 + assert stats.total_events == 4 + assert stats.duration_ms is not None + assert stats.duration_ms >= 5000 # At least 5 seconds + assert stats.started_at == start_time.isoformat().replace("+00:00", "Z") + assert stats.ended_at == end_time.isoformat().replace("+00:00", "Z") + assert stats.final_status == "success" # Has step_end, no errors + + +def test_extract_stats_counts_steps(): + """Test extract_stats correctly counts steps from step_start events.""" + events = [ + {"type": "run_start", "ts": "2024-01-01T00:00:00Z", "data": {}}, + {"type": "step_start", "data": {"step_index": 0}}, + {"type": "step_end", "data": {}}, + {"type": "step_start", "data": {"step_index": 1}}, + {"type": "step_end", "data": {}}, + {"type": "step_start", "data": {"step_index": 2}}, + {"type": "step_end", "data": {}}, + {"type": "run_end", "ts": "2024-01-01T00:01:00Z", "data": {"steps": 3}}, + ] + + stats = TraceFileManager.extract_stats(events) + assert stats.total_steps == 3 + assert stats.total_events == 8 + + +def test_extract_stats_infers_status_success(): + """Test extract_stats infers success status from step_end events.""" + events = [ + {"type": "run_start", "ts": "2024-01-01T00:00:00Z", "data": {}}, + {"type": "step_start", "data": {"step_index": 0}}, + {"type": "step_end", "data": {}}, + {"type": "run_end", "ts": "2024-01-01T00:01:00Z", "data": {}}, + ] + + stats = TraceFileManager.extract_stats(events) + assert stats.final_status == "success" + + +def test_extract_stats_infers_status_failure(): + """Test extract_stats infers failure status from error events.""" + events = [ + {"type": "run_start", "ts": "2024-01-01T00:00:00Z", "data": {}}, + {"type": "step_start", "data": {"step_index": 0}}, + {"type": "error", "data": {"message": "Something went wrong"}}, + {"type": "run_end", "ts": "2024-01-01T00:01:00Z", "data": {}}, + ] + + stats = TraceFileManager.extract_stats(events) + assert stats.final_status == "failure" + + +def test_extract_stats_infers_status_partial(): + """Test extract_stats infers partial status from errors with step_end.""" + events = [ + {"type": "run_start", "ts": "2024-01-01T00:00:00Z", "data": {}}, + {"type": "step_start", "data": {"step_index": 0}}, + {"type": "step_end", "data": {}}, + {"type": "step_start", "data": {"step_index": 1}}, + {"type": "error", "data": {"message": "Step 2 failed"}}, + {"type": "run_end", "ts": "2024-01-01T00:01:00Z", "data": {}}, + ] + + stats = TraceFileManager.extract_stats(events) + assert stats.final_status == "partial" + + +def test_extract_stats_uses_run_end_status(): + """Test extract_stats uses status from run_end event if present.""" + events = [ + {"type": "run_start", "ts": "2024-01-01T00:00:00Z", "data": {}}, + {"type": "step_start", "data": {"step_index": 0}}, + {"type": "error", "data": {"message": "Error"}}, + { + "type": "run_end", + "ts": "2024-01-01T00:01:00Z", + "data": {"status": "partial"}, # Explicit status overrides inference + }, + ] + + stats = TraceFileManager.extract_stats(events) + assert stats.final_status == "partial" # Uses run_end status, not inferred "failure" + + +def test_extract_stats_with_custom_inference(): + """Test extract_stats uses custom status inference function.""" + def custom_inference(events, run_end): + # Return a valid status value + return "partial" + + events = [ + {"type": "run_start", "ts": "2024-01-01T00:00:00Z", "data": {}}, + {"type": "step_start", "data": {"step_index": 0}}, + {"type": "step_end", "data": {}}, + {"type": "run_end", "ts": "2024-01-01T00:01:00Z", "data": {}}, + ] + + stats = TraceFileManager.extract_stats(events, infer_status_func=custom_inference) + assert stats.final_status == "partial" # Uses custom inference instead of default "success" + + +def test_extract_stats_no_timestamps(): + """Test extract_stats handles missing timestamps gracefully.""" + events = [ + {"type": "step_start", "data": {"step_index": 0}}, + {"type": "step_end", "data": {}}, + ] + + stats = TraceFileManager.extract_stats(events) + assert stats.total_steps == 1 + assert stats.duration_ms is None + assert stats.started_at is None + assert stats.ended_at is None + From ebc44d34c5d816690a5864805b657dce1ed878b9 Mon Sep 17 00:00:00 2001 From: rcholic Date: Fri, 2 Jan 2026 13:33:22 -0800 Subject: [PATCH 3/4] Phase 4: Modularize code --- sentience/__init__.py | 3 +- sentience/element_filter.py | 2 +- sentience/formatting.py | 62 ++------ sentience/utils/__init__.py | 41 ++++++ sentience/utils/browser.py | 47 +++++++ sentience/utils/element.py | 258 ++++++++++++++++++++++++++++++++++ sentience/utils/formatting.py | 60 ++++++++ 7 files changed, 418 insertions(+), 55 deletions(-) create mode 100644 sentience/utils/__init__.py create mode 100644 sentience/utils/browser.py create mode 100644 sentience/utils/element.py create mode 100644 sentience/utils/formatting.py diff --git a/sentience/__init__.py b/sentience/__init__.py index 61526a6..14b72fb 100644 --- a/sentience/__init__.py +++ b/sentience/__init__.py @@ -16,7 +16,7 @@ from .expect import expect # Formatting (v0.12.0+) -from .formatting import format_snapshot_for_llm +from .utils.formatting import format_snapshot_for_llm from .generator import ScriptGenerator, generate from .inspector import Inspector, inspect from .llm_provider import ( @@ -62,6 +62,7 @@ from .tracing import JsonlTraceSink, TraceEvent, Tracer, TraceSink # Utilities (v0.12.0+) +# Import from utils package (re-exports from submodules for backward compatibility) from .utils import ( canonical_snapshot_loose, canonical_snapshot_strict, diff --git a/sentience/element_filter.py b/sentience/element_filter.py index df117b9..6159115 100644 --- a/sentience/element_filter.py +++ b/sentience/element_filter.py @@ -64,7 +64,7 @@ def filter_by_importance( @staticmethod def filter_by_goal( snapshot: Snapshot, - goal: str | None, + goal: Optional[str], max_elements: int = 50, ) -> list[Element]: """ diff --git a/sentience/formatting.py b/sentience/formatting.py index f8961c5..b8dd653 100644 --- a/sentience/formatting.py +++ b/sentience/formatting.py @@ -1,59 +1,15 @@ """ Snapshot formatting utilities for LLM prompts. -Provides functions to convert Sentience snapshots into text format suitable -for LLM consumption. -""" - -from typing import List - -from .models import Snapshot - - -def format_snapshot_for_llm(snap: Snapshot, limit: int = 50) -> str: - """ - Convert snapshot elements to text format for LLM consumption. - - This is the canonical way Sentience formats DOM state for LLMs. - The format includes element ID, role, text preview, visual cues, - position, and importance score. +DEPRECATED: This module is maintained for backward compatibility only. +New code should import from sentience.utils.formatting or sentience directly: - Args: - snap: Snapshot object with elements - limit: Maximum number of elements to include (default: 50) - - Returns: - Formatted string with one element per line - - Example: - >>> snap = snapshot(browser) - >>> formatted = format_snapshot_for_llm(snap, limit=10) - >>> print(formatted) - [1]