diff --git a/requirements.txt b/requirements.txt index 2e5ba58..b5042df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiohttp==3.13.3 requests==2.32.5 urllib3>=2.6.3 python-ffmpeg==2.0.12 @@ -11,6 +12,7 @@ selenium==4.24.0 webdriver-manager==4.0.2 google-genai==1.62.0 google-adk==1.24.1 +html2text==2025.4.15 psutil==6.0.0 sentry-sdk==2.14.0 pytest==8.3.3 diff --git a/src/processing_pipeline/stage_3/constants.py b/src/processing_pipeline/stage_3/constants.py new file mode 100644 index 0000000..ac17cb5 --- /dev/null +++ b/src/processing_pipeline/stage_3/constants.py @@ -0,0 +1,4 @@ +from processing_pipeline.constants import GeminiModel + +MAIN_MODEL = GeminiModel.GEMINI_2_5_PRO +FALLBACK_MODEL = GeminiModel.GEMINI_2_5_FLASH diff --git a/src/processing_pipeline/stage_3/executors.py b/src/processing_pipeline/stage_3/executors.py index a510d0b..27d9ac7 100644 --- a/src/processing_pipeline/stage_3/executors.py +++ b/src/processing_pipeline/stage_3/executors.py @@ -1,38 +1,30 @@ +import asyncio from datetime import datetime, timezone import json -import os -import subprocess -import tempfile -import time -from typing import Any from google import genai from google.genai.types import ( + AutomaticFunctionCallingConfig, File, FinishReason, GenerateContentConfig, - GoogleSearch, ThinkingConfig, - Tool, ) from pydantic import ValidationError -from processing_pipeline.constants import ( - GeminiCLIEventType, - GeminiModel, -) +from processing_pipeline.constants import GeminiModel from processing_pipeline.processing_utils import get_safety_settings from processing_pipeline.stage_3.models import Stage3Output -from utils import optional_task +from processing_pipeline.stage_3.web_tools import searxng_web_search, web_url_read class Stage3Executor: """Executor for Stage 3 in-depth analysis.""" @classmethod - def run( + async def run_async( cls, - gemini_key: str, + gemini_client: genai.Client, model_name: GeminiModel, audio_file: str, metadata: dict, @@ -41,13 +33,11 @@ def run( """ Main execution method for Stage 3 analysis. - Processing strategy: - 1. Step 1: Try Gemini CLI with custom search, fallback to Google Genai SDK with Google Search grounding if CLI fails - 2. Validate: Try to validate response with Pydantic model - 3. Step 2 (conditional): If validation fails, restructure with response_schema + Uses the Google GenAI SDK with web search tools (searxng_web_search, + web_url_read) for fact checking via automatic function calling. Args: - gemini_key: Google Gemini API key + gemini_client: Google GenAI client instance model_name: Name of the Gemini model to use audio_file: Path to the audio file metadata: Metadata dictionary for the audio clip @@ -56,206 +46,100 @@ def run( Returns: dict: Structured and validated analysis output """ - if not gemini_key: - raise ValueError("Google Gemini API key was not set!") - - client = genai.Client(api_key=gemini_key) - # Prepare the user prompt using the prompt version + # Prepare the user prompt user_prompt = ( f"{prompt_version['user_prompt']}\n\n" f"Here is the metadata of the attached audio clip:\n{json.dumps(metadata, indent=2)}\n\n" f"Here is the current date and time: {datetime.now(timezone.utc).strftime('%B %-d, %Y %-I:%M %p UTC')}\n\n" ) - # Strategy: Try CLI first, fallback to SDK - analysis_text = None - thought_summaries_from_api = None - uploaded_audio_file = None + # Upload audio file + uploaded_audio_file = gemini_client.files.upload(file=audio_file) try: - user_prompt_with_file = user_prompt + f"Here is the audio file attached: @{os.path.basename(audio_file)}" - analysis_text = cls.__analyze_with_custom_search( - model_name=model_name, - user_prompt=user_prompt_with_file, - system_instruction=prompt_version["system_instruction"], - ) - except RuntimeError as e: - print("Falling back to Google Search grounding with SDK...") - - uploaded_audio_file = client.files.upload(file=audio_file) while uploaded_audio_file.state.name == "PROCESSING": print("Processing the uploaded audio file...") - time.sleep(1) - uploaded_audio_file = client.files.get(name=uploaded_audio_file.name) - - sdk_result = cls.__analyze_with_google_search_grounding( - client, - model_name, - user_prompt, - uploaded_audio_file, + await asyncio.sleep(1) + uploaded_audio_file = gemini_client.files.get(name=uploaded_audio_file.name) + + # Analyze with web search tools + analysis_text, thought_summaries = await cls.__analyze_with_web_search( + gemini_client=gemini_client, + model_name=model_name, + uploaded_audio_file=uploaded_audio_file, + user_prompt=user_prompt, system_instruction=prompt_version["system_instruction"], ) - analysis_text = sdk_result["text"] - thought_summaries_from_api = sdk_result.get("thought_summaries") - try: - # Try to validate with Pydantic model first - validated_output = cls.__validate_with_pydantic(analysis_text) - - if validated_output: - thought_summaries = thought_summaries_from_api or validated_output.get("thought_summaries") - grounding_metadata = json.dumps(validated_output.get("verification_evidence"), indent=2) - return { - "response": validated_output, - "grounding_metadata": grounding_metadata, - "thought_summaries": thought_summaries, - } - - # Step 2: Structure with response_schema (if validation failed) - structured_output = cls.__structure_with_schema(client, analysis_text, prompt_version["output_schema"]) - thought_summaries = thought_summaries_from_api or structured_output.get("thought_summaries") - grounding_metadata = json.dumps(structured_output.get("verification_evidence"), indent=2) + # Validate with Pydantic, fall back to schema restructuring + output = cls.__validate_with_pydantic(analysis_text) + + if not output: + output = await cls.__structure_with_schema( + gemini_client, analysis_text, prompt_version["output_schema"] + ) + return { - "response": structured_output, - "grounding_metadata": grounding_metadata, - "thought_summaries": thought_summaries, + "response": output, + "grounding_metadata": json.dumps(output.get("verification_evidence"), indent=2), + "thought_summaries": thought_summaries or output.get("thought_summaries"), } finally: if uploaded_audio_file: - client.files.delete(name=uploaded_audio_file.name) + gemini_client.files.delete(name=uploaded_audio_file.name) - @optional_task(log_prints=True, retries=3) @classmethod - def __analyze_with_custom_search( + async def __analyze_with_web_search( cls, + gemini_client: genai.Client, model_name: GeminiModel, + uploaded_audio_file: File, user_prompt: str, system_instruction: str, ): """ - Analyze using Gemini CLI with custom search tools (MCP-based). + Analyze using the GenAI SDK with web search tools. - This method uses the Gemini CLI which provides: - - Custom search via MCP tools - - Streaming JSON output - - System instruction from file + Uses searxng_web_search and web_url_read as plain Python function tools + with the SDK's automatic function calling. Returns: - str: Final response text from Gemini CLI - - Raises: - RuntimeError: If CLI execution fails (for fallback to SDK method) + tuple: (analysis_text, thought_summaries) """ - print("Analyzing with Gemini CLI (custom search)...") - - events: list[dict[str, Any]] = [] - final_response = "" - timeout = 300 - - # Write system instruction to a temporary file for CLI - with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as tmp_file: - tmp_file.write(system_instruction) - system_instruction_path = tmp_file.name - - env = { - "PATH": os.environ.get("PATH", ""), - "HOME": os.environ.get("HOME", ""), - "GEMINI_API_KEY": os.environ["GOOGLE_GEMINI_KEY"], - "GEMINI_SYSTEM_MD": system_instruction_path, - "SEARXNG_URL": os.environ.get("SEARXNG_URL", ""), - } - - cmd = [ - "gemini", - "--model", - model_name, - "--output-format", - "stream-json", - user_prompt, - ] - - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - env=env, - timeout=timeout, - ) - - # Parse JSONL output - for line in result.stdout.strip().split("\n"): - if not line: - continue - try: - event = json.loads(line) - events.append(event) - - # Concatenate assistant message content - if event.get("type") == GeminiCLIEventType.MESSAGE and event.get("role") == "assistant": - content = event.get("content") - if content and isinstance(content, str): - final_response += content - except json.JSONDecodeError: - pass - - if result.returncode != 0: - raise RuntimeError(f"Gemini CLI exited with code {result.returncode}: {result.stderr}") - - if not final_response: - raise RuntimeError("Gemini CLI returned no response") - - return final_response - - except subprocess.TimeoutExpired as e: - raise RuntimeError(f"Gemini CLI timed out after {timeout} seconds") from e - finally: - if os.path.exists(system_instruction_path): - os.remove(system_instruction_path) - - @optional_task(log_prints=True, retries=3) - @classmethod - def __analyze_with_google_search_grounding( - cls, - client: genai.Client, - model_name: GeminiModel, - user_prompt: str, - uploaded_audio_file: File, - system_instruction: str, - ): - print("Analyzing audio with web search...") + print("Analyzing with SDK + web search tools...") - response = client.models.generate_content( + response = await gemini_client.aio.models.generate_content( model=model_name, contents=[user_prompt, uploaded_audio_file], config=GenerateContentConfig( system_instruction=system_instruction, - max_output_tokens=16384, - tools=[Tool(google_search=GoogleSearch())], + max_output_tokens=32768, + tools=[searxng_web_search, web_url_read], + automatic_function_calling=AutomaticFunctionCallingConfig( + maximum_remote_calls=20, + ), thinking_config=ThinkingConfig(thinking_budget=4096, include_thoughts=True), safety_settings=get_safety_settings(), ), ) thoughts = "" - for part in response.candidates[0].content.parts: - if part.thought and part.text: - thoughts += part.text + if response.candidates and response.candidates[0].content: + for part in response.candidates[0].content.parts: + if part.thought and part.text: + thoughts += part.text if not response.text: finish_reason = response.candidates[0].finish_reason if response.candidates else None if finish_reason == FinishReason.MAX_TOKENS: - raise ValueError("The response from Gemini was too long and was cut off in step 1.") + raise ValueError("The response from Gemini was too long and was cut off.") print(f"Response finish reason: {finish_reason}") - raise ValueError("No response from Gemini in step 1.") + raise ValueError("No response from Gemini.") - return { - "text": response.text, - "thought_summaries": thoughts, - } + return response.text, thoughts @classmethod def __validate_with_pydantic(cls, response_text: str): @@ -276,9 +160,9 @@ def __validate_with_pydantic(cls, response_text: str): return None @classmethod - def __structure_with_schema( + async def __structure_with_schema( cls, - client: genai.Client, + gemini_client: genai.Client, analysis_text: str, output_schema: dict, ): @@ -295,7 +179,7 @@ def __structure_with_schema( user_prompt = f"Please structure the following analysis text into the required JSON format:\n\n{analysis_text}" - response = client.models.generate_content( + response = await gemini_client.aio.models.generate_content( model=GeminiModel.GEMINI_2_5_FLASH, contents=[user_prompt], config=GenerateContentConfig( diff --git a/src/processing_pipeline/stage_3/flows.py b/src/processing_pipeline/stage_3/flows.py index 2ae6f74..004413f 100644 --- a/src/processing_pipeline/stage_3/flows.py +++ b/src/processing_pipeline/stage_3/flows.py @@ -1,7 +1,8 @@ +import asyncio import os -import time import boto3 +from google import genai from prefect.flows import Flow from prefect.client.schemas import FlowRun, State from prefect.task_runners import ConcurrentTaskRunner @@ -37,7 +38,7 @@ def reset_snippet_status_hook(flow: Flow, flow_run: FlowRun, state: State): on_crashed=[reset_snippet_status_hook], on_cancellation=[reset_snippet_status_hook], ) -def in_depth_analysis(snippet_ids, skip_review, repeat): +async def in_depth_analysis(snippet_ids, skip_review, repeat): # Setup S3 Client R2_BUCKET_NAME = os.getenv("R2_BUCKET_NAME") s3_client = boto3.client( @@ -47,8 +48,11 @@ def in_depth_analysis(snippet_ids, skip_review, repeat): aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"), ) - # Setup Gemini Key + # Setup Gemini client GEMINI_KEY = os.getenv("GOOGLE_GEMINI_KEY") + if not GEMINI_KEY: + raise ValueError("No Gemini API key set (GOOGLE_GEMINI_KEY)") + gemini_client = genai.Client(api_key=GEMINI_KEY) # Setup Supabase client supabase_client = SupabaseClient(supabase_url=os.getenv("SUPABASE_URL"), supabase_key=os.getenv("SUPABASE_KEY")) @@ -65,11 +69,11 @@ def in_depth_analysis(snippet_ids, skip_review, repeat): local_file = download_audio_file_from_s3(s3_client, R2_BUCKET_NAME, snippet["file_path"]) # Process the snippet - process_snippet( - supabase_client, - snippet, - local_file, - GEMINI_KEY, + await process_snippet( + supabase_client=supabase_client, + gemini_client=gemini_client, + snippet=snippet, + local_file=local_file, skip_review=skip_review, prompt_version=prompt_version, ) @@ -84,11 +88,11 @@ def in_depth_analysis(snippet_ids, skip_review, repeat): local_file = download_audio_file_from_s3(s3_client, R2_BUCKET_NAME, snippet["file_path"]) # Process the snippet - process_snippet( - supabase_client, - snippet, - local_file, - GEMINI_KEY, + await process_snippet( + supabase_client=supabase_client, + gemini_client=gemini_client, + snippet=snippet, + local_file=local_file, skip_review=skip_review, prompt_version=prompt_version, ) @@ -106,4 +110,4 @@ def in_depth_analysis(snippet_ids, skip_review, repeat): sleep_time = 60 print(f"Sleep for {sleep_time} seconds before the next iteration") - time.sleep(sleep_time) + await asyncio.sleep(sleep_time) diff --git a/src/processing_pipeline/stage_3/tasks.py b/src/processing_pipeline/stage_3/tasks.py index 058745e..927d860 100644 --- a/src/processing_pipeline/stage_3/tasks.py +++ b/src/processing_pipeline/stage_3/tasks.py @@ -3,15 +3,17 @@ import json import os +from google import genai from google.genai import errors from processing_pipeline.constants import ( CONFIDENCE_THRESHOLD, - GeminiModel, ProcessingStatus, ) from processing_pipeline.processing_utils import postprocess_snippet +from processing_pipeline.stage_3.constants import FALLBACK_MODEL, MAIN_MODEL from processing_pipeline.stage_3.executors import Stage3Executor +from processing_pipeline.supabase_utils import SupabaseClient from utils import optional_task @@ -131,67 +133,60 @@ def __get_metadata(snippet): @optional_task(log_prints=True) -def analyze_snippet(gemini_key, audio_file, metadata, prompt_version: dict): - main_model = GeminiModel.GEMINI_2_5_PRO - fallback_model = GeminiModel.GEMINI_2_5_FLASH +async def analyze_snippet(gemini_client, audio_file, metadata, prompt_version: dict): + model = MAIN_MODEL try: - print(f"Attempting analysis with {main_model}") - analyzing_response = Stage3Executor.run( - gemini_key=gemini_key, - model_name=main_model, + print(f"Attempting analysis with {model}") + analyzing_response = await Stage3Executor.run_async( + gemini_client=gemini_client, + model_name=model, audio_file=audio_file, metadata=metadata, prompt_version=prompt_version, ) - return { - **analyzing_response, - "analyzed_by": main_model, - } - except errors.ServerError as e: - print(f"Server error with {main_model} (code {e.code}): {e.message}") - print(f"Falling back to {fallback_model}") - analyzing_response = Stage3Executor.run( - gemini_key=gemini_key, - model_name=fallback_model, + except (errors.ServerError, errors.ClientError) as e: + if e.code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]: + print(f"Auth error with {model} (code {e.code}): {e.message}") + raise + + print( + f"{e.__class__.__name__} with {model} (code {e.code}): {e.message} " + f"Falling back to {FALLBACK_MODEL}" + ) + + model = FALLBACK_MODEL + analyzing_response = await Stage3Executor.run_async( + gemini_client=gemini_client, + model_name=model, audio_file=audio_file, metadata=metadata, prompt_version=prompt_version, ) - return { - **analyzing_response, - "analyzed_by": fallback_model, - } - except errors.ClientError as e: - if e.code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]: - print(f"Auth error with {main_model} (code {e.code}): {e.message}") - raise - else: - print(f"Client error with {main_model} (code {e.code}): {e.message}") - print(f"Falling back to {fallback_model}") - analyzing_response = Stage3Executor.run( - gemini_key=gemini_key, - model_name=fallback_model, - audio_file=audio_file, - metadata=metadata, - prompt_version=prompt_version, - ) - return { - **analyzing_response, - "analyzed_by": fallback_model, - } + + return { + **analyzing_response, + "analyzed_by": model, + } @optional_task(log_prints=True) -def process_snippet(supabase_client, snippet, local_file, gemini_key, skip_review: bool, prompt_version: dict): +async def process_snippet( + supabase_client: SupabaseClient, + gemini_client: genai.Client, + snippet: dict, + local_file: str, + skip_review: bool, + prompt_version: dict, +): print(f"Processing snippet: {local_file}") try: metadata = get_metadata(snippet) print(f"Metadata:\n{json.dumps(metadata, indent=2, ensure_ascii=False)}") - analyzing_response = analyze_snippet( - gemini_key=gemini_key, + analyzing_response = await analyze_snippet( + gemini_client=gemini_client, audio_file=local_file, metadata=metadata, prompt_version=prompt_version, @@ -222,5 +217,9 @@ def process_snippet(supabase_client, snippet, local_file, gemini_key, skip_revie print(f"Processing completed for audio file {local_file} - snippet ID: {snippet['id']}") except Exception as e: - print(f"Failed to process {local_file}: {e}") - supabase_client.set_snippet_status(snippet["id"], ProcessingStatus.ERROR, str(e)) + if isinstance(e, ExceptionGroup): + error_message = "\n".join(f"- {type(exc).__name__}: {exc}" for exc in e.exceptions) + else: + error_message = f"{type(e).__name__}: {e}" + print(f"Failed to process {local_file}:\n{error_message}") + supabase_client.set_snippet_status(snippet["id"], ProcessingStatus.ERROR, error_message) diff --git a/src/processing_pipeline/stage_3/web_tools.py b/src/processing_pipeline/stage_3/web_tools.py new file mode 100644 index 0000000..d792ba7 --- /dev/null +++ b/src/processing_pipeline/stage_3/web_tools.py @@ -0,0 +1,109 @@ +import os +import ssl + +import aiohttp +import certifi +import html2text + +SEARXNG_URL = os.environ.get("SEARXNG_URL", "") +HTTP_TIMEOUT = aiohttp.ClientTimeout(total=10) + +# SSL context using certifi's CA bundle for environments where the system +# certificate store may be incomplete (e.g., macOS Python without Homebrew certs) +_ssl_context = ssl.create_default_context(cafile=certifi.where()) + + +async def searxng_web_search( + query: str, + pageno: int = 1, + time_range: str | None = None, + language: str = "all", + safesearch: int = 0, +) -> dict: + """Performs a web search using the SearXNG API. + + Args: + query: The search query string. + pageno: Page number for pagination, starting from 1. + time_range: Time range filter. One of 'day', 'month', or 'year'. + language: Language code for search results, or 'all' for no filter. + safesearch: Safe search level. 0 for off, 1 for moderate, 2 for strict. + + Returns: + A dictionary with a list of search results, each containing + title, url, content snippet, and relevance score. + """ + if not SEARXNG_URL: + raise ValueError("SEARXNG_URL environment variable is not set") + + params = { + "q": query, + "format": "json", + "pageno": pageno, + } + if time_range in ("day", "month", "year"): + params["time_range"] = time_range + if language and language != "all": + params["language"] = language + if safesearch in (0, 1, 2): + params["safesearch"] = safesearch + + async with aiohttp.ClientSession(timeout=HTTP_TIMEOUT, connector=aiohttp.TCPConnector(ssl=_ssl_context)) as session: + async with session.get(f"{SEARXNG_URL}/search", params=params) as response: + response.raise_for_status() + data = await response.json() + + results = data.get("results", []) + return { + "query": query, + "results": [ + { + "title": r.get("title", ""), + "url": r.get("url", ""), + "content": r.get("content", ""), + "score": r.get("score"), + "publishedDate": r.get("publishedDate"), + "engines": r.get("engines", []), + } + for r in results + ], + } + + +async def web_url_read( + url: str, + start_char: int = 0, + max_length: int | None = None, +) -> dict: + """Read the content from a URL and convert it to markdown. + + Args: + url: The URL to read content from. + start_char: Starting character position for content extraction. + max_length: Maximum number of characters to return. + + Returns: + A dictionary with the URL and its content converted to markdown. + """ + async with aiohttp.ClientSession( + timeout=HTTP_TIMEOUT, connector=aiohttp.TCPConnector(ssl=_ssl_context) + ) as session: + async with session.get(url) as response: + response.raise_for_status() + html_content = await response.text() + + converter = html2text.HTML2Text() + converter.ignore_links = False + converter.ignore_images = True + converter.body_width = 0 + markdown = converter.handle(html_content) + + if start_char > 0: + markdown = markdown[start_char:] + if max_length is not None and max_length > 0: + markdown = markdown[:max_length] + + return { + "url": url, + "content": markdown, + }