Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
816 changes: 3 additions & 813 deletions prompts/Stage_4_review_prompt.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ boto3==1.35.15
selenium==4.24.0
webdriver-manager==4.0.2
google-generativeai==0.8.3
google-genai==1.11.0
psutil==6.0.0
sentry-sdk==2.14.0
pytest==8.3.3
supabase==2.9.0
supabase==2.15.0
pydub==0.25.1
openai==1.57.1
tiktoken==0.8.0
Expand Down
75 changes: 28 additions & 47 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,41 @@
import json
import os
import boto3

from dotenv import load_dotenv

from processing_pipeline.timestamped_transcription_generator import TimestampedTranscriptionGenerator
from processing_pipeline.stage_4 import Stage4Executor, prepare_snippet_for_review
from processing_pipeline.supabase_utils import SupabaseClient

load_dotenv()

# Setup S3 Client
R2_BUCKET_NAME = os.getenv("R2_BUCKET_NAME")
s3_client = boto3.client(
"s3",
endpoint_url=os.getenv("R2_ENDPOINT_URL"),
aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"),
)

# Setup Gemini Key
GEMINI_KEY = os.getenv("GOOGLE_GEMINI_KEY")


def test_timestamped_transcription_generator():
# Download the audio file from R2
audio_file = "radio_1853b3_20241127_102353.mp3"
try:
s3_client.download_file(
R2_BUCKET_NAME, "radio_1853b3/radio_1853b3_20241127_102353.mp3", "radio_1853b3_20241127_102353.mp3"
)

if os.path.exists(audio_file):
result = TimestampedTranscriptionGenerator.run(audio_file, GEMINI_KEY, 10)
print(result)
else:
print(f"File {audio_file} does not exist")
finally:
# Delete the local file if it exists
if os.path.exists(audio_file):
os.remove(audio_file)

# def test_stage_4():
# supabase_client = SupabaseClient(supabase_url=os.getenv("SUPABASE_URL"), supabase_key=os.getenv("SUPABASE_KEY"))
# snippet = supabase_client.get_snippet_by_id(id="d8a5b222-0b0a-4503-8e6c-ad51e0d784d7")
# previous_analysis = snippet["previous_analysis"]
# transcription, metadata, analysis_json = prepare_snippet_for_review(previous_analysis)
# print(
# f"TRANSCRIPTION:\n{transcription}\n\n"
# f"METADATA:\n{json.dumps(metadata, indent=2)}\n\n"
# f"ANALYSIS_JSON:\n{json.dumps(analysis_json, indent=2)}"
# )
# response, grounding_metadata = Stage4Executor.run(
# transcription=transcription,
# metadata=metadata,
# analysis_json=analysis_json,
# )
# print(response)
# print(grounding_metadata)
def test_stage_4():
supabase_client = SupabaseClient(supabase_url=os.getenv("SUPABASE_URL"), supabase_key=os.getenv("SUPABASE_KEY"))
snippet = supabase_client.get_snippet_by_id(id="3b39f536-7466-44da-9772-b10dcf72c6be")
previous_analysis = snippet["previous_analysis"]
transcription, disinformation_snippet, metadata, analysis_json = prepare_snippet_for_review(previous_analysis)
print(
f"TRANSCRIPTION:\n{transcription}\n\n"
f"DISINFORMATION_SNIPPET:\n{disinformation_snippet}\n\n"
f"METADATA:\n{json.dumps(metadata, indent=2)}\n\n"
f"ANALYSIS_JSON:\n{json.dumps(analysis_json, indent=2)}"
)
response, grounding_metadata = Stage4Executor.run(
transcription=transcription,
disinformation_snippet=disinformation_snippet,
metadata=metadata,
analysis_json=analysis_json,
)
print("RESULT:")
print(json.dumps(response, indent=2))
print("\nGROUNDING_METADATA:")
print(grounding_metadata)

# We need to change the paid key to the free key in stage 4


if __name__ == "__main__":
test_timestamped_transcription_generator()
test_stage_4()
14 changes: 9 additions & 5 deletions src/processing_pipeline/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

GEMINI_1_5_PRO = "gemini-1.5-pro-002"
GEMINI_1_5_FLASH = "gemini-1.5-flash"
GEMINI_2_5_FLASH = "gemini-2.5-flash"
GEMINI_2_5_PRO = "gemini-2.5-pro"


def get_transcription_prompt_for_stage_1_preprocess():
Expand Down Expand Up @@ -43,9 +45,11 @@ def get_system_instruction_for_stage_3():
def get_output_schema_for_stage_3():
return json.load(open("prompts/Stage_3_output_schema.json", "r"))


def get_timestamped_transcription_generation_prompt():
return open("prompts/Timestamped_transcription_generation_prompt.md", "r").read()


def get_timestamped_transcription_generation_output_schema():
return json.load(open("prompts/Timestamped_transcription_generation_output_schema.json", "r"))

Expand All @@ -62,8 +66,8 @@ def get_output_schema_for_stage_4():
return json.load(open("prompts/Stage_4_output_schema.json", "r"))


def get_gemini_1206_transcription_generation_prompt():
return open("prompts/Gemini_1206_transcription_generation_prompt.md", "r").read()
def get_gemini_2_5_pro_transcription_generation_prompt():
return open("prompts/Gemini_2_5_pro_transcription_generation_prompt.md", "r").read()


if __name__ == "__main__":
Expand Down Expand Up @@ -127,6 +131,6 @@ def get_gemini_1206_transcription_generation_prompt():
# output_schema_for_stage_4 = get_output_schema_for_stage_4()
# print(json.dumps(output_schema_for_stage_4, indent=2))

# Print gemini 1206 transcription generation prompt
gemini_1206_transcription_generation_prompt = get_gemini_1206_transcription_generation_prompt()
print(gemini_1206_transcription_generation_prompt)
# Print gemini 2.5 pro transcription generation prompt
gemini_2_5_pro_transcription_generation_prompt = get_gemini_2_5_pro_transcription_generation_prompt()
print(gemini_2_5_pro_transcription_generation_prompt)
41 changes: 0 additions & 41 deletions src/processing_pipeline/gemini_1206_transcription_generator.py

This file was deleted.

59 changes: 59 additions & 0 deletions src/processing_pipeline/gemini_2_5_pro_transcription_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import time
from google import genai
from google.genai.types import SafetySetting, HarmCategory, HarmBlockThreshold, GenerateContentConfig, ThinkingConfig
from processing_pipeline.constants import (
GEMINI_2_5_PRO,
get_gemini_2_5_pro_transcription_generation_prompt,
)


class Gemini25ProTranscriptionGenerator:

USER_PROMPT = get_gemini_2_5_pro_transcription_generation_prompt()

@classmethod
def run(cls, audio_file, gemini_key):
if not gemini_key:
raise ValueError("Google Gemini API key was not set!")

client = genai.Client(api_key=gemini_key)

# Upload the audio file and wait for it to finish processing
audio_file = client.files.upload(file=audio_file, config={"mime_type": "audio/mp3"})
while audio_file.state.name == "PROCESSING":
print("Processing the uploaded audio file...")
time.sleep(1)
audio_file = client.files.get(name=audio_file.name)

try:
result = client.models.generate_content(
model=GEMINI_2_5_PRO,
contents=[cls.USER_PROMPT, audio_file],
config=GenerateContentConfig(
max_output_tokens=8192,
thinking_config=ThinkingConfig(include_thoughts=False),
safety_settings=[
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=HarmBlockThreshold.BLOCK_NONE,
),
],
),
)

if not result.text:
raise ValueError("No content in response - likely truncated or blocked")

return result.text
finally:
client.files.delete(name=audio_file.name)
Loading