diff --git a/prompts/Stage_3_output_schema.json b/prompts/Stage_3_output_schema.json index a8e86a6..39fb1b5 100644 --- a/prompts/Stage_3_output_schema.json +++ b/prompts/Stage_3_output_schema.json @@ -15,6 +15,10 @@ "political_leaning" ], "properties": { + "is_convertible": { + "type": "boolean", + "description": "Indicates whether the provided text can be converted into a valid JSON object according to the given schema. Only use when asked to convert provided text into a valid JSON object." + }, "transcription": { "type": "string", "description": "Transcription of the entire audio clip in the original language." diff --git a/requirements.txt b/requirements.txt index c2dffc1..b6f3fab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ prefect==3.0.1 boto3==1.35.15 selenium==4.24.0 webdriver-manager==4.0.2 -google-generativeai==0.8.3 google-genai==1.11.0 psutil==6.0.0 sentry-sdk==2.14.0 diff --git a/src/processing_pipeline/stage_3.py b/src/processing_pipeline/stage_3.py index d4a1053..475b2b6 100644 --- a/src/processing_pipeline/stage_3.py +++ b/src/processing_pipeline/stage_3.py @@ -4,15 +4,19 @@ from google import genai import json import boto3 +from pydantic import ValidationError from prefect.task_runners import ConcurrentTaskRunner from google.genai.types import ( + File, FinishReason, GenerateContentConfig, + GoogleSearch, HarmBlockThreshold, HarmCategory, SafetySetting, ThinkingConfig, + Tool, ) from processing_pipeline.supabase_utils import SupabaseClient from processing_pipeline.constants import ( @@ -21,6 +25,7 @@ get_output_schema_for_stage_3, get_user_prompt_for_stage_3, ) +from processing_pipeline.stage_3_models import Stage3Output from utils import optional_flow, optional_task @@ -79,6 +84,7 @@ def update_snippet_in_supabase( emotional_tone, context, political_leaning, + grounding_metadata, status, error_message, ): @@ -96,6 +102,7 @@ def update_snippet_in_supabase( emotional_tone=emotional_tone, context=context, political_leaning=political_leaning, + grounding_metadata=grounding_metadata, status=status, error_message=error_message, ) @@ -150,7 +157,7 @@ def process_snippet(supabase_client, snippet, local_file, gemini_key): metadata = get_metadata(snippet) print(f"Metadata:\n{json.dumps(metadata, indent=2)}") - pro_response = Stage3Executor.run( + response, grounding_metadata = Stage3Executor.run( gemini_key=gemini_key, model_name=GeminiModel.GEMINI_FLASH_LATEST, audio_file=local_file, @@ -160,18 +167,19 @@ def process_snippet(supabase_client, snippet, local_file, gemini_key): update_snippet_in_supabase( supabase_client=supabase_client, snippet_id=snippet["id"], - transcription=pro_response["transcription"], - translation=pro_response["translation"], - title=pro_response["title"], - summary=pro_response["summary"], - explanation=pro_response["explanation"], - disinformation_categories=pro_response["disinformation_categories"], - keywords_detected=pro_response["keywords_detected"], - language=pro_response["language"], - confidence_scores=pro_response["confidence_scores"], - emotional_tone=pro_response["emotional_tone"], - context=pro_response["context"], - political_leaning=pro_response["political_leaning"], + transcription=response["transcription"], + translation=response["translation"], + title=response["title"], + summary=response["summary"], + explanation=response["explanation"], + disinformation_categories=response["disinformation_categories"], + keywords_detected=response["keywords_detected"], + language=response["language"], + confidence_scores=response["confidence_scores"], + emotional_tone=response["emotional_tone"], + context=response["context"], + political_leaning=response["political_leaning"], + grounding_metadata=grounding_metadata, status="Ready for review", error_message=None, ) @@ -246,18 +254,41 @@ class Stage3Executor: OUTPUT_SCHEMA = get_output_schema_for_stage_3() @classmethod - def run(cls, gemini_key, model_name, audio_file, metadata): + def run( + cls, + gemini_key: str, + model_name: GeminiModel, + audio_file: str, + metadata: dict, + ): + """ + Main execution method for Stage 3 analysis. + + Performs two-stage processing with validation optimization: + 1. Step 1: Analyze audio with Google Search enabled + 2. Validate: Try to validate response with Pydantic model + 3. Step 2 (conditional): If validation fails, restructure with response_schema + + Args: + gemini_key: Google Gemini API key + model_name: Name of the Gemini model to use + audio_file: Path to the audio file + metadata: Metadata dictionary for the audio clip + + Returns: + dict: Structured and validated analysis output + """ if not gemini_key: raise ValueError("Google Gemini API key was not set!") client = genai.Client(api_key=gemini_key) # Upload the audio file and wait for it to finish processing - audio_file = client.files.upload(file=audio_file) - while audio_file.state.name == "PROCESSING": + uploaded_audio_file = client.files.upload(file=audio_file) + while uploaded_audio_file.state.name == "PROCESSING": print("Processing the uploaded audio file...") time.sleep(1) - audio_file = client.files.get(name=audio_file.name) + uploaded_audio_file = client.files.get(name=uploaded_audio_file.name) # Prepare the user prompt user_prompt = ( @@ -265,47 +296,161 @@ def run(cls, gemini_key, model_name, audio_file, metadata): ) try: - result = client.models.generate_content( - model=model_name, - contents=[audio_file, user_prompt], - config=GenerateContentConfig( - response_mime_type="application/json", - response_schema=cls.OUTPUT_SCHEMA, - system_instruction=cls.SYSTEM_INSTRUCTION, - max_output_tokens=16384, - thinking_config=ThinkingConfig(thinking_budget=4096), - safety_settings=[ - SafetySetting( - category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold=HarmBlockThreshold.BLOCK_NONE, - ), - SafetySetting( - category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold=HarmBlockThreshold.BLOCK_NONE, - ), - SafetySetting( - category=HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold=HarmBlockThreshold.BLOCK_NONE, - ), - SafetySetting( - category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=HarmBlockThreshold.BLOCK_NONE, - ), - SafetySetting( - category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, - threshold=HarmBlockThreshold.BLOCK_NONE, - ), - ], - ), + # Step 1: Analyze with Google Search + analysis_text, grounding_metadata = cls.__analyze_with_search( + client, + model_name, + user_prompt, + uploaded_audio_file, ) - if not result.parsed: - finish_reason = result.candidates[0].finish_reason - if finish_reason == FinishReason.MAX_TOKENS: - raise ValueError("The response from Gemini was too long and was cut off.") - print(f"Response finish reason: {finish_reason}") - raise ValueError("No response from Gemini.") + # Try to validate with Pydantic model first + validated_output = cls.__validate_with_pydantic(analysis_text) - return result.parsed + if validated_output: + return validated_output, grounding_metadata + + # Step 2: Structure with response_schema (if validation failed) + return cls.__structure_with_schema(client, analysis_text), grounding_metadata finally: - client.files.delete(name=audio_file.name) + client.files.delete(name=uploaded_audio_file.name) + + @classmethod + def __analyze_with_search( + cls, + client: genai.Client, + model_name: GeminiModel, + user_prompt: str, + audio_file: File, + ): + """ + Step 1: Analyze audio with Google Search tool enabled. + + Returns: + str: The response text from Gemini + """ + print("Analyzing audio with web search...") + + response = client.models.generate_content( + model=model_name, + contents=[user_prompt, audio_file], + config=GenerateContentConfig( + system_instruction=cls.SYSTEM_INSTRUCTION, + max_output_tokens=16384, + tools=[Tool(google_search=GoogleSearch())], + thinking_config=ThinkingConfig(thinking_budget=4096), + safety_settings=cls.__get_safety_settings(), + ), + ) + + grounding_metadata = str(response.candidates[0].grounding_metadata) if response.candidates else None + + if not response.text: + finish_reason = response.candidates[0].finish_reason + if finish_reason == FinishReason.MAX_TOKENS: + raise ValueError("The response from Gemini was too long and was cut off in step 1.") + print(f"Response finish reason: {finish_reason}") + raise ValueError("No response from Gemini in step 1.") + + return response.text, grounding_metadata + + @classmethod + def __validate_with_pydantic(cls, response_text: str): + """ + Attempts to validate the response text with the Pydantic model. + + Returns: + dict: Validated and structured output if successful + None: If validation fails + """ + try: + print("Attempting to validate response with Pydantic model...") + start_idx = response_text.find("{") + end_idx = response_text.rfind("}") + + if start_idx == -1 or end_idx == -1: + print("No JSON object found in the response.") + return None + + parsed = Stage3Output.model_validate_json(response_text[start_idx : end_idx + 1]) + print("Validation successful - returning structured output") + return parsed.model_dump() + except ValidationError as e: + print(f"Validation failed: {e}") + return None + + @classmethod + def __structure_with_schema( + cls, + client: genai.Client, + analysis_text: str, + ): + """ + Step 2: Structure the analysis results using response_schema. + + Returns: + dict: Structured and validated output + """ + print("Restructuring response with schema validation...") + + system_instruction = """You are a helpful assistant whose task is to convert provided text into a valid JSON object following a given schema. Your responsibilities are: + +1. **Validation**: Check if the provided text can be converted into a valid JSON object that adheres to the specified schema. +2. **Conversion**: + - If the text is convertible, convert it into a valid JSON object according to the schema. + - Set field `"is_convertible": true` in the JSON object. +3. **Error Handling**: + - If the text is not convertible (e.g., missing fields, incorrect data types), return a JSON object with the field `"is_convertible": false`.""" + + user_prompt = f"Please structure the following analysis text into the required JSON format:\n\n{analysis_text}" + + response = client.models.generate_content( + model=GeminiModel.GEMINI_FLASH_LATEST, + contents=[user_prompt], + config=GenerateContentConfig( + response_mime_type="application/json", + response_schema=cls.OUTPUT_SCHEMA, + system_instruction=system_instruction, + max_output_tokens=8192, + thinking_config=ThinkingConfig(thinking_budget=0), + safety_settings=cls.__get_safety_settings(), + ), + ) + + parsed_response = response.parsed + + if not parsed_response: + finish_reason = response.candidates[0].finish_reason if response.candidates else None + if finish_reason == FinishReason.MAX_TOKENS: + raise ValueError("The response from Gemini was too long and was cut off in step 2.") + raise ValueError(f"No response from Gemini in step 2. Response finished with reason: {finish_reason}") + + if not parsed_response.get("is_convertible"): + raise ValueError("[Stage 3] The response from Gemini could not be converted to the required schema.") + + return parsed_response + + @classmethod + def __get_safety_settings(cls): + return [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + ] diff --git a/src/processing_pipeline/stage_3_models.py b/src/processing_pipeline/stage_3_models.py new file mode 100644 index 0000000..c70671d --- /dev/null +++ b/src/processing_pipeline/stage_3_models.py @@ -0,0 +1,151 @@ +from typing import Literal +from pydantic import BaseModel, Field + + +class Title(BaseModel): + spanish: str = Field(description="Title of the snippet in Spanish") + english: str = Field(description="Title of the snippet in English") + + +class Summary(BaseModel): + spanish: str = Field(description="Summary of the snippet in Spanish") + english: str = Field(description="Summary of the snippet in English") + + +class Explanation(BaseModel): + spanish: str = Field(description="Explanation of the snippet in Spanish") + english: str = Field(description="Explanation of the snippet in English") + + +class DisinformationCategory(BaseModel): + spanish: str = Field(description="Disinformation category in Spanish") + english: str = Field(description="Disinformation category in English") + + +class Language(BaseModel): + primary_language: str = Field(description="Primary language of the audio (e.g., Spanish, Arabic)") + dialect: str = Field(description="Specific dialect or regional variation") + register_: str = Field(alias="register", description="Language register (formal, informal, colloquial, slang)") + + +class Context(BaseModel): + before: str = Field(description="Part of the audio clip transcription that precedes the snippet") + before_en: str = Field(description="Translation of the 'before' part into English") + after: str = Field(description="Part of the audio clip transcription that follows the snippet") + after_en: str = Field(description="Translation of the 'after' part into English") + main: str = Field(description="The transcription of the snippet itself") + main_en: str = Field(description="Translation of the 'main' part into English") + + +class Claim(BaseModel): + quote: str = Field(description="Direct quote of the false or misleading claim") + evidence: str = Field(description="Evidence demonstrating why the claim is false") + score: int = Field(description="Confidence score for this specific claim") + + +class ValidationChecklist(BaseModel): + specific_claims_quoted: bool + evidence_provided: bool + scoring_falsity: bool + defensible_to_factcheckers: bool + consistent_explanations: bool + + +class ScoreAdjustments(BaseModel): + initial_score: int + final_score: int + adjustment_reason: str + + +class Analysis(BaseModel): + claims: list[Claim] + validation_checklist: ValidationChecklist + score_adjustments: ScoreAdjustments + + +class CategoryScore(BaseModel): + category: str = Field(description="Name of the disinformation category") + score: int = Field(ge=0, le=100, description="Confidence score for this category, ranging from 0 to 100") + + +class ConfidenceScores(BaseModel): + overall: int = Field(ge=0, le=100, description="Overall confidence score of the analysis, ranging from 0 to 100") + analysis: Analysis + categories: list[CategoryScore] + + +class EmotionText(BaseModel): + spanish: str + english: str + + +class EmotionEvidence(BaseModel): + vocal_cues: list[str] = Field(description="Specific vocal characteristics observed") + phrases: list[str] = Field(description="Direct quotes demonstrating the emotion") + patterns: list[str] = Field(description="Recurring emotional patterns or themes") + + +class EmotionImpact(BaseModel): + credibility: str + audience_reception: str + cultural_context: str + + +class EmotionExplanation(BaseModel): + spanish: str + english: str + impact: EmotionImpact + + +class EmotionalToneItem(BaseModel): + emotion: EmotionText + intensity: int = Field(description="Intensity of the emotion, ranging from 0 to 100") + evidence: EmotionEvidence + explanation: EmotionExplanation + + +class PoliticalEvidence(BaseModel): + policy_positions: list[str] = Field(description="Explicit policy positions stated") + arguments: list[str] = Field(description="Specific arguments made") + rhetoric: list[str] = Field(description="Key phrases and rhetoric used") + sources: list[str] = Field(description="Sources or authorities cited") + solutions: list[str] = Field(description="Solutions proposed") + + +class PoliticalScoreAdjustments(BaseModel): + initial_score: float + final_score: float + reasoning: str + + +class PoliticalExplanation(BaseModel): + spanish: str + english: str + score_adjustments: PoliticalScoreAdjustments + + +class PoliticalLeaning(BaseModel): + score: float = Field(ge=-1.0, le=1.0, description="Political leaning score, ranging from -1.0 to 1.0") + evidence: PoliticalEvidence + explanation: PoliticalExplanation + + +class Stage3Output(BaseModel): + """Main model for Stage 3 output.""" + + transcription: str = Field(description="Transcription of the entire audio clip in the original language") + translation: str = Field(description="Translation of the transcription into English") + title: Title = Field(description="Descriptive title of the snippet") + summary: Summary = Field(description="Objective summary of the snippet") + explanation: Explanation = Field(description="Detailed explanation of why the snippet constitutes disinformation") + disinformation_categories: list[DisinformationCategory] = Field( + description="Disinformation categories that the snippet belongs to" + ) + keywords_detected: list[str] = Field( + description="Specific words or phrases that triggered the flag, in original language" + ) + language: Language + context: Context + confidence_scores: ConfidenceScores + emotional_tone: list[EmotionalToneItem] + political_leaning: PoliticalLeaning diff --git a/src/processing_pipeline/stage_4.py b/src/processing_pipeline/stage_4.py index 6543970..8ae0677 100644 --- a/src/processing_pipeline/stage_4.py +++ b/src/processing_pipeline/stage_4.py @@ -10,6 +10,7 @@ HarmBlockThreshold, HarmCategory, SafetySetting, + ThinkingConfig, Tool, ) from prefect.task_runners import ConcurrentTaskRunner @@ -277,7 +278,6 @@ def __ensure_json_format(cls, text): raise ValueError("Google Gemini API key was not set!") client = genai.Client(api_key=gemini_key) - model_id = GeminiModel.GEMINI_2_5_PRO # Prepare the user prompt user_prompt = ( @@ -296,12 +296,13 @@ def __ensure_json_format(cls, text): ) response = client.models.generate_content( - model=model_id, + model=GeminiModel.GEMINI_FLASH_LATEST, contents=user_prompt, config=GenerateContentConfig( response_mime_type="application/json", response_schema=cls.OUTPUT_SCHEMA, max_output_tokens=8192, + thinking_config=ThinkingConfig(thinking_budget=0), safety_settings=[ SafetySetting( category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, diff --git a/src/processing_pipeline/supabase_utils.py b/src/processing_pipeline/supabase_utils.py index 8673ef6..b907430 100644 --- a/src/processing_pipeline/supabase_utils.py +++ b/src/processing_pipeline/supabase_utils.py @@ -197,6 +197,7 @@ def update_snippet( emotional_tone, context, political_leaning, + grounding_metadata, status, error_message ): @@ -216,6 +217,7 @@ def update_snippet( "emotional_tone": emotional_tone, "context": context, "political_leaning": political_leaning, + "grounding_metadata": grounding_metadata, "previous_analysis": None, "status": status, "error_message": error_message, diff --git a/tests/processing_pipeline/test_stage_3.py b/tests/processing_pipeline/test_stage_3.py index 2b45c72..468bd01 100644 --- a/tests/processing_pipeline/test_stage_3.py +++ b/tests/processing_pipeline/test_stage_3.py @@ -143,6 +143,7 @@ def test_update_snippet(self, mock_supabase_client): emotional_tone="neutral", context="Test context", political_leaning="neutral", + grounding_metadata=None, status="Processed", error_message=None, ) @@ -160,26 +161,10 @@ def test_get_metadata(self, sample_snippet): assert result["end_time"] == "01:30" assert result["duration"] == "01:00" - @patch("google.genai.Client") - def test_process_snippet(self, mock_client_class, mock_supabase_client, sample_snippet, mock_gemini_response): + @patch("processing_pipeline.stage_3.Stage3Executor.run") + def test_process_snippet(self, mock_run, mock_supabase_client, sample_snippet, mock_gemini_response): """Test processing a snippet""" - # Configure mock audio file - mock_audio_file = Mock() - mock_audio_file.state.name = "PROCESSED" - mock_audio_file.name = "test-audio-file" - - # Configure mock client - mock_client = Mock() - mock_client.files.upload.return_value = mock_audio_file - mock_client.files.get.return_value = mock_audio_file - mock_client.files.delete = Mock() - - # Configure mock response - mock_result = Mock() - mock_result.parsed = mock_gemini_response - mock_client.models.generate_content.return_value = mock_result - - mock_client_class.return_value = mock_client + mock_run.return_value = (mock_gemini_response, "test_grounding_metadata") process_snippet(mock_supabase_client, sample_snippet, "test.mp3", "test-key") @@ -198,6 +183,7 @@ def test_process_snippet(self, mock_client_class, mock_supabase_client, sample_s emotional_tone=mock_gemini_response["emotional_tone"], context=mock_gemini_response["context"], political_leaning=mock_gemini_response["political_leaning"], + grounding_metadata="test_grounding_metadata", status="Ready for review", error_message=None, ) @@ -236,9 +222,17 @@ def test_stage_3_executor(self, mock_client_class): mock_client.files.get.return_value = mock_audio_file mock_client.files.delete = Mock() - mock_result = Mock() - mock_result.parsed = {"test": "response"} - mock_client.models.generate_content.return_value = mock_result + # Mock the analysis response with grounding metadata + mock_analysis_response = Mock() + mock_analysis_response.text = '{"test": "response"}' + mock_analysis_response.candidates = [Mock(grounding_metadata="test_grounding_metadata")] + + # Mock the structured response + mock_structured_response = Mock() + mock_structured_response.parsed = {"test": "response", "is_convertible": True} + + # Return different responses for different calls + mock_client.models.generate_content.side_effect = [mock_analysis_response, mock_structured_response] mock_client_class.return_value = mock_client @@ -249,7 +243,12 @@ def test_stage_3_executor(self, mock_client_class): metadata={"test": "metadata"}, ) - assert isinstance(result, dict) + # Result should be a tuple (response, grounding_metadata) + assert isinstance(result, tuple) + assert len(result) == 2 + response, grounding_metadata = result + assert isinstance(response, dict) + assert grounding_metadata is not None def test_stage_3_executor_without_api_key(self): """Test Stage3Executor without API key""" @@ -310,31 +309,18 @@ def test_in_depth_analysis_no_snippets(self, mock_supabase_client, mock_s3_clien mock_s3_client.download_file.assert_not_called() - @patch("google.genai.Client") + @patch("processing_pipeline.stage_3.Stage3Executor.run") def test_process_snippet_no_disinformation_categories( self, - mock_client_class, + mock_run, mock_supabase_client, sample_snippet, mock_gemini_response, ): """Test processing snippet without disinformation categories""" mock_gemini_response["disinformation_categories"] = [] - mock_audio_file = Mock() - mock_audio_file.state.name = "PROCESSED" - mock_audio_file.name = "test-audio-file" - - # Configure mock client - mock_client = Mock() - mock_client.files.upload.return_value = mock_audio_file - mock_client.files.get.return_value = mock_audio_file - mock_client.files.delete = Mock() - - mock_result = Mock() - mock_result.parsed = mock_gemini_response - mock_client.models.generate_content.return_value = mock_result - - mock_client_class.return_value = mock_client + # Mock Stage3Executor.run to return a tuple (response, grounding_metadata) + mock_run.return_value = (mock_gemini_response, "test_grounding_metadata") process_snippet(mock_supabase_client, sample_snippet, "test.mp3", "test-key")