Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/processing_pipeline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
deployment = in_depth_analysis.to_deployment(
name="Stage 3: In-Depth Analysis",
concurrency_limit=100,
parameters=dict(snippet_ids=[], repeat=True),
parameters=dict(snippet_ids=[], skip_review=True, repeat=True),
)
serve(deployment, limit=100)
case "analysis_review":
Expand Down
29 changes: 29 additions & 0 deletions src/processing_pipeline/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
SafetySetting,
)

from utils import optional_task


def get_safety_settings():
return [
Expand All @@ -28,3 +30,30 @@ def get_safety_settings():
threshold=HarmBlockThreshold.BLOCK_NONE,
),
]


@optional_task(log_prints=True, retries=3)
def create_new_label_and_assign_to_snippet(supabase_client, snippet_id, label):
english_label_text = label["english"]
spanish_label_text = label["spanish"]

# Create the label
label = supabase_client.create_new_label(english_label_text, spanish_label_text)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming the local variable in create_new_label_and_assign_to_snippet (line 41) to avoid shadowing the input parameter label. For example, use created_label instead.


# Assign the label to the snippet
supabase_client.assign_label_to_snippet(label_id=label["id"], snippet_id=snippet_id)


@optional_task(log_prints=True, retries=3)
def delete_vector_embedding_of_snippet(supabase_client, snippet_id):
supabase_client.delete_vector_embedding_of_snippet(snippet_id)


@optional_task(log_prints=True, retries=3)
def postprocess_snippet(supabase_client, snippet_id, disinformation_categories):
# Create new labels based on the response and assign them to the snippet
for category in disinformation_categories:
create_new_label_and_assign_to_snippet(supabase_client, snippet_id, category)

# Delete the vector embedding of the old snippet (if any) to trigger a new embedding
delete_vector_embedding_of_snippet(supabase_client, snippet_id)
117 changes: 40 additions & 77 deletions src/processing_pipeline/stage_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
FinishReason,
GenerateContentConfig,
GoogleSearch,
HarmBlockThreshold,
HarmCategory,
SafetySetting,
ThinkingConfig,
Tool,
)
from processing_pipeline.supabase_utils import SupabaseClient
from processing_pipeline.processing_utils import (
get_safety_settings,
postprocess_snippet,
)
from processing_pipeline.constants import (
GeminiModel,
get_system_instruction_for_stage_3,
Expand Down Expand Up @@ -72,36 +73,25 @@ def __download_audio_file_from_s3(s3_client, r2_bucket_name, file_path):
def update_snippet_in_supabase(
supabase_client,
snippet_id,
transcription,
translation,
title,
summary,
explanation,
disinformation_categories,
keywords_detected,
language,
confidence_scores,
emotional_tone,
context,
political_leaning,
gemini_response,
grounding_metadata,
status,
error_message,
):
supabase_client.update_snippet(
id=snippet_id,
transcription=transcription,
translation=translation,
title=title,
summary=summary,
explanation=explanation,
disinformation_categories=disinformation_categories,
keywords_detected=keywords_detected,
language=language,
confidence_scores=confidence_scores,
emotional_tone=emotional_tone,
context=context,
political_leaning=political_leaning,
transcription=gemini_response["transcription"],
translation=gemini_response["translation"],
title=gemini_response["title"],
summary=gemini_response["summary"],
explanation=gemini_response["explanation"],
disinformation_categories=gemini_response["disinformation_categories"],
keywords_detected=gemini_response["keywords_detected"],
language=gemini_response["language"],
confidence_scores=gemini_response["confidence_scores"],
emotional_tone=gemini_response["emotional_tone"],
context=gemini_response["context"],
political_leaning=gemini_response["political_leaning"],
grounding_metadata=grounding_metadata,
status=status,
error_message=error_message,
Expand Down Expand Up @@ -150,49 +140,42 @@ def __get_metadata(snippet):


@optional_task(log_prints=True)
def process_snippet(supabase_client, snippet, local_file, gemini_key):
def process_snippet(supabase_client, snippet, local_file, gemini_key, skip_review: bool):
try:
print(f"Processing snippet: {local_file} with Gemini 2.5 Flash")

metadata = get_metadata(snippet)
print(f"Metadata:\n{json.dumps(metadata, indent=2)}")
print(f"Metadata:\n{json.dumps(metadata, indent=2, ensure_ascii=False)}")

response, grounding_metadata = Stage3Executor.run(
gemini_key=gemini_key,
model_name=GeminiModel.GEMINI_FLASH_LATEST,
model_name=GeminiModel.GEMINI_2_5_PRO,
audio_file=local_file,
metadata=metadata,
)

status = "Processed" if skip_review else "Ready for review"
update_snippet_in_supabase(
supabase_client=supabase_client,
snippet_id=snippet["id"],
transcription=response["transcription"],
translation=response["translation"],
title=response["title"],
summary=response["summary"],
explanation=response["explanation"],
disinformation_categories=response["disinformation_categories"],
keywords_detected=response["keywords_detected"],
language=response["language"],
confidence_scores=response["confidence_scores"],
emotional_tone=response["emotional_tone"],
context=response["context"],
political_leaning=response["political_leaning"],
gemini_response=response,
grounding_metadata=grounding_metadata,
status="Ready for review",
status=status,
error_message=None,
)

print(f"Processing completed for {local_file}")
if skip_review:
postprocess_snippet(supabase_client, snippet["id"], response["disinformation_categories"])

print(f"Processing completed for audio file {local_file} - snippet ID: {snippet['id']}")

except Exception as e:
print(f"Failed to process {local_file}: {e}")
supabase_client.set_snippet_status(snippet["id"], "Error", str(e))


@optional_flow(name="Stage 3: In-depth Analysis", log_prints=True, task_runner=ConcurrentTaskRunner)
def in_depth_analysis(snippet_ids, repeat):
def in_depth_analysis(snippet_ids, skip_review, repeat):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter order changed in in_depth_analysis: now defined as (snippet_ids, skip_review, repeat) instead of (snippet_ids, repeat, skip_review=True). This could break positional calls; consider retaining default values and order for backward compatibility.

# Setup S3 Client
R2_BUCKET_NAME = os.getenv("R2_BUCKET_NAME")
s3_client = boto3.client(
Expand All @@ -217,7 +200,7 @@ def in_depth_analysis(snippet_ids, repeat):
local_file = download_audio_file_from_s3(s3_client, R2_BUCKET_NAME, snippet["file_path"])

# Process the snippet
process_snippet(supabase_client, snippet, local_file, GEMINI_KEY)
process_snippet(supabase_client, snippet, local_file, GEMINI_KEY, skip_review=skip_review)

print(f"Delete the downloaded snippet clip: {local_file}")
os.remove(local_file)
Expand All @@ -229,7 +212,7 @@ def in_depth_analysis(snippet_ids, repeat):
local_file = download_audio_file_from_s3(s3_client, R2_BUCKET_NAME, snippet["file_path"])

# Process the snippet
process_snippet(supabase_client, snippet, local_file, GEMINI_KEY)
process_snippet(supabase_client, snippet, local_file, GEMINI_KEY, skip_review=skip_review)

print(f"Delete the downloaded snippet clip: {local_file}")
os.remove(local_file)
Expand Down Expand Up @@ -315,13 +298,14 @@ def run(
finally:
client.files.delete(name=uploaded_audio_file.name)

@optional_task(log_prints=True, retries=3)
@classmethod
def __analyze_with_search(
cls,
client: genai.Client,
model_name: GeminiModel,
user_prompt: str,
audio_file: File,
uploaded_audio_file: File,
):
"""
Step 1: Analyze audio with Google Search tool enabled.
Expand All @@ -333,22 +317,24 @@ def __analyze_with_search(

response = client.models.generate_content(
model=model_name,
contents=[user_prompt, audio_file],
contents=[user_prompt, uploaded_audio_file],
config=GenerateContentConfig(
system_instruction=cls.SYSTEM_INSTRUCTION,
max_output_tokens=16384,
tools=[Tool(google_search=GoogleSearch())],
thinking_config=ThinkingConfig(thinking_budget=4096),
safety_settings=cls.__get_safety_settings(),
safety_settings=get_safety_settings(),
),
)

grounding_metadata = str(response.candidates[0].grounding_metadata) if response.candidates else None

if not response.text:
finish_reason = response.candidates[0].finish_reason
finish_reason = response.candidates[0].finish_reason if response.candidates else None

if finish_reason == FinishReason.MAX_TOKENS:
raise ValueError("The response from Gemini was too long and was cut off in step 1.")

print(f"Response finish reason: {finish_reason}")
raise ValueError("No response from Gemini in step 1.")

Expand Down Expand Up @@ -413,44 +399,21 @@ def __structure_with_schema(
system_instruction=system_instruction,
max_output_tokens=8192,
thinking_config=ThinkingConfig(thinking_budget=0),
safety_settings=cls.__get_safety_settings(),
safety_settings=get_safety_settings(),
),
)

parsed_response = response.parsed

if not parsed_response:
finish_reason = response.candidates[0].finish_reason if response.candidates else None

if finish_reason == FinishReason.MAX_TOKENS:
raise ValueError("The response from Gemini was too long and was cut off in step 2.")

raise ValueError(f"No response from Gemini in step 2. Response finished with reason: {finish_reason}")

if not parsed_response.get("is_convertible"):
raise ValueError("[Stage 3] The response from Gemini could not be converted to the required schema.")

return parsed_response

@classmethod
def __get_safety_settings(cls):
return [
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
]
Loading