diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bd941ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.venv/ +__pycache__/ +*.py[cod] +.pytest_cache/ + +outputs/events.* +samples/*.wav +*.tmp.wav +test.mp4 diff --git a/README.md b/README.md new file mode 100644 index 0000000..9430395 --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +# Intelligent CC Generation - Module 1 MVP + +This demo implements the first module from the PlanetRead Intelligent Closed Caption +Suggestion Tool pipeline: + +```text +video input -> audio extraction -> sound event detection -> JSON/CSV output +``` + +The current MVP does not decide whether a caption should be shown and does not check +speaker or scene reaction. It only detects candidate non-speech sound events with +timestamps and confidence scores. + +## What It Uses + +- Python 3.12 +- YAMNet from TensorFlow Hub for sound event classification +- `imageio-ffmpeg` for a pip-provided FFmpeg binary +- `soundfile` for loading WAV audio +- built-in `json` and `csv` for outputs + +## Setup + +```bash +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +``` + +No system FFmpeg install is required. The project uses FFmpeg from the +`imageio-ffmpeg` Python package. + +## Run On A Video + +```bash +python detect_sound_events.py \ + --input samples/sample_video.mp4 \ + --json outputs/sample_events.json \ + --csv outputs/sample_events.csv +``` + +The first run may take extra time because TensorFlow Hub downloads and caches the +YAMNet model. + +## Test Video Demo + +This PR also includes a short test video: + +```text +samples/test_video.mp4 +``` + +Run the detector on it with a slightly higher confidence threshold for cleaner +demo output: + +```bash +python detect_sound_events.py \ + --input samples/test_video.mp4 \ + --json outputs/test_video_events.json \ + --csv outputs/test_video_events.csv \ + --min-confidence 0.5 \ + --block-label Animal,Bird +``` + +Example detected events from this video include `Explosion`, `Gunshot, gunfire`, +and `Machine gun`. + +## Create A Small Sample Video + +```bash +python scripts/create_sample_video.py +``` + +This creates: + +```text +samples/sample_audio.wav +samples/sample_video.mp4 +``` + +Then run the detector command above. + +## Output Format + +JSON output: + +```json +[ + { + "label": "Busy signal", + "caption_label": "[busy signal]", + "start_time": 0.48, + "end_time": 1.44, + "confidence": 0.8768, + "start_timestamp": "00:00:00.480", + "end_timestamp": "00:00:01.440", + "duration": 0.96 + } +] +``` + +CSV output: + +```csv +label,caption_label,start_time,end_time,start_timestamp,end_timestamp,duration,confidence +Busy signal,[busy signal],0.48,1.44,00:00:00.480,00:00:01.440,0.96,0.8768 +``` + +## Useful Options + +- `--min-confidence 0.5` keeps only stronger detections. +- `--top-k 3` keeps more candidate labels per YAMNet frame. +- `--block-label Animal,Bird` suppresses noisy labels for a specific demo clip. +- `--keep-audio outputs/audio.wav` saves the extracted 16 kHz mono WAV. + +## Current Limitations + +- This is only Module 1, so it does not check visual reaction yet. +- YAMNet labels are generic AudioSet labels and may need mapping to cleaner CC text. +- The default threshold is conservative but not tuned on PlanetRead content yet. +- First run requires internet access to download the YAMNet model from TensorFlow Hub. diff --git a/detect_sound_events.py b/detect_sound_events.py new file mode 100644 index 0000000..edeabb8 --- /dev/null +++ b/detect_sound_events.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent / "src")) + +from cc_event_detector.cli import main + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/outputs/test_video_events.csv b/outputs/test_video_events.csv new file mode 100644 index 0000000..e762dbe --- /dev/null +++ b/outputs/test_video_events.csv @@ -0,0 +1,5 @@ +label,caption_label,start_time,end_time,start_timestamp,end_timestamp,duration,confidence +Explosion,[explosion],0.96,2.4,00:00:00.960,00:00:02.400,1.44,0.6279 +Fusillade,[rapid gunfire],1.92,2.88,00:00:01.920,00:00:02.880,0.96,0.7086 +"Gunshot, gunfire",[gunshot],9.6,12.48,00:00:09.600,00:00:12.480,2.88,0.9488 +Machine gun,[machine gun],12.48,13.44,00:00:12.480,00:00:13.440,0.96,0.8211 diff --git a/outputs/test_video_events.json b/outputs/test_video_events.json new file mode 100644 index 0000000..9f1a599 --- /dev/null +++ b/outputs/test_video_events.json @@ -0,0 +1,42 @@ +[ + { + "label": "Explosion", + "caption_label": "[explosion]", + "start_time": 0.96, + "end_time": 2.4, + "confidence": 0.6279, + "start_timestamp": "00:00:00.960", + "end_timestamp": "00:00:02.400", + "duration": 1.44 + }, + { + "label": "Fusillade", + "caption_label": "[rapid gunfire]", + "start_time": 1.92, + "end_time": 2.88, + "confidence": 0.7086, + "start_timestamp": "00:00:01.920", + "end_timestamp": "00:00:02.880", + "duration": 0.96 + }, + { + "label": "Gunshot, gunfire", + "caption_label": "[gunshot]", + "start_time": 9.6, + "end_time": 12.48, + "confidence": 0.9488, + "start_timestamp": "00:00:09.600", + "end_timestamp": "00:00:12.480", + "duration": 2.88 + }, + { + "label": "Machine gun", + "caption_label": "[machine gun]", + "start_time": 12.48, + "end_time": 13.44, + "confidence": 0.8211, + "start_timestamp": "00:00:12.480", + "end_timestamp": "00:00:13.440", + "duration": 0.96 + } +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ad4176f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +setuptools<81 +tensorflow==2.21.0 +tensorflow-hub==0.16.1 +numpy>=1.26,<3 +soundfile>=0.13.1 +imageio-ffmpeg>=0.6.0 diff --git a/samples/test_video.mp4 b/samples/test_video.mp4 new file mode 100644 index 0000000..b2bc562 Binary files /dev/null and b/samples/test_video.mp4 differ diff --git a/scripts/create_sample_video.py b/scripts/create_sample_video.py new file mode 100644 index 0000000..010ba50 --- /dev/null +++ b/scripts/create_sample_video.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import math +import subprocess +import wave +from pathlib import Path + +import imageio_ffmpeg + + +SAMPLE_RATE = 16_000 +DURATION_SECONDS = 6 + + +def tone(sample_index: int, frequency: float, amplitude: float) -> int: + value = amplitude * math.sin(2 * math.pi * frequency * sample_index / SAMPLE_RATE) + return int(max(-1.0, min(1.0, value)) * 32767) + + +def make_sample_wav(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + total_samples = SAMPLE_RATE * DURATION_SECONDS + frames = bytearray() + + for index in range(total_samples): + second = index / SAMPLE_RATE + if 1.0 <= second < 1.8: + sample = tone(index, 880.0, 0.55) + elif 3.2 <= second < 4.1: + sample = tone(index, 440.0, 0.45) + tone(index, 660.0, 0.25) + else: + sample = 0 + frames.extend(int(sample).to_bytes(2, byteorder="little", signed=True)) + + with wave.open(str(path), "wb") as handle: + handle.setnchannels(1) + handle.setsampwidth(2) + handle.setframerate(SAMPLE_RATE) + handle.writeframes(frames) + + +def make_sample_video(audio_path: Path, video_path: Path) -> None: + video_path.parent.mkdir(parents=True, exist_ok=True) + ffmpeg = imageio_ffmpeg.get_ffmpeg_exe() + command = [ + ffmpeg, + "-y", + "-f", + "lavfi", + "-i", + f"color=c=black:s=640x360:d={DURATION_SECONDS}", + "-i", + str(audio_path), + "-shortest", + "-c:v", + "libx264", + "-pix_fmt", + "yuv420p", + "-c:a", + "aac", + str(video_path), + ] + completed = subprocess.run(command, capture_output=True, text=True, check=False) + if completed.returncode != 0: + raise RuntimeError(completed.stderr.strip() or "Could not create sample video.") + + +def main() -> int: + audio_path = Path("samples/sample_audio.wav") + video_path = Path("samples/sample_video.mp4") + make_sample_wav(audio_path) + make_sample_video(audio_path, video_path) + print(f"Wrote {audio_path}") + print(f"Wrote {video_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/cc_event_detector/__init__.py b/src/cc_event_detector/__init__.py new file mode 100644 index 0000000..1abb581 --- /dev/null +++ b/src/cc_event_detector/__init__.py @@ -0,0 +1,3 @@ +"""Module 1 MVP for non-speech sound event detection.""" + +__version__ = "0.1.0" diff --git a/src/cc_event_detector/__main__.py b/src/cc_event_detector/__main__.py new file mode 100644 index 0000000..5417fda --- /dev/null +++ b/src/cc_event_detector/__main__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from .cli import main + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/cc_event_detector/audio.py b/src/cc_event_detector/audio.py new file mode 100644 index 0000000..8fd5877 --- /dev/null +++ b/src/cc_event_detector/audio.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import subprocess +from pathlib import Path + +import imageio_ffmpeg +import soundfile as sf + + +SUPPORTED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".mov", ".avi", ".webm"} +SUPPORTED_AUDIO_EXTENSIONS = {".wav", ".mp3", ".m4a", ".aac", ".flac", ".ogg"} +TARGET_SAMPLE_RATE = 16_000 + + +class AudioExtractionError(RuntimeError): + """Raised when audio extraction from a media file fails.""" + + +def is_video_file(path: Path) -> bool: + return path.suffix.lower() in SUPPORTED_VIDEO_EXTENSIONS + + +def is_audio_file(path: Path) -> bool: + return path.suffix.lower() in SUPPORTED_AUDIO_EXTENSIONS + + +def extract_audio_to_wav(input_path: Path, output_path: Path) -> Path: + """Extract 16 kHz mono WAV audio using the FFmpeg binary bundled by pip.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() + command = [ + ffmpeg_path, + "-y", + "-i", + str(input_path), + "-vn", + "-ac", + "1", + "-ar", + str(TARGET_SAMPLE_RATE), + "-f", + "wav", + str(output_path), + ] + completed = subprocess.run(command, capture_output=True, text=True, check=False) + if completed.returncode != 0: + message = completed.stderr.strip() or "FFmpeg failed while extracting audio." + raise AudioExtractionError(message) + return output_path + + +def load_wav_mono(path: Path) -> tuple[list[float], int]: + """Load WAV audio as mono float samples.""" + audio, sample_rate = sf.read(str(path), dtype="float32") + if getattr(audio, "ndim", 1) > 1: + audio = audio.mean(axis=1) + return audio.tolist(), int(sample_rate) diff --git a/src/cc_event_detector/cli.py b/src/cc_event_detector/cli.py new file mode 100644 index 0000000..57898c9 --- /dev/null +++ b/src/cc_event_detector/cli.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import argparse +from pathlib import Path +from tempfile import TemporaryDirectory + +from .audio import extract_audio_to_wav, is_audio_file, is_video_file +from .export import write_csv, write_json + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Detect non-speech sound events from a video or audio file.", + ) + parser.add_argument("--input", required=True, type=Path, help="Input video or audio file") + parser.add_argument("--json", type=Path, default=Path("outputs/events.json")) + parser.add_argument("--csv", type=Path, default=Path("outputs/events.csv")) + parser.add_argument("--min-confidence", type=float, default=0.25) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--merge-gap", type=float, default=0.6) + parser.add_argument( + "--block-label", + action="append", + default=[], + help="Extra YAMNet label to suppress. Can be repeated or comma-separated.", + ) + parser.add_argument( + "--keep-audio", + type=Path, + help="Optional path to save the extracted 16 kHz mono WAV file.", + ) + return parser + + +def parse_labels(values: list[str]) -> set[str]: + labels: set[str] = set() + for value in values: + labels.update(part.strip() for part in value.split(",") if part.strip()) + return labels + + +def main() -> int: + args = build_parser().parse_args() + input_path = args.input + if not input_path.exists(): + print(f"Input file not found: {input_path}") + return 1 + + try: + with TemporaryDirectory() as temp_dir: + if is_video_file(input_path) or is_audio_file(input_path): + wav_path = args.keep_audio or Path(temp_dir) / "extracted_audio.wav" + extract_audio_to_wav(input_path, wav_path) + else: + print("Unsupported input. Use a video/audio file such as .mp4, .mov, .wav, or .mp3.") + return 1 + + from .yamnet import detect_sound_events + from .yamnet import DEFAULT_BLOCKLIST + + events = detect_sound_events( + wav_path, + min_confidence=args.min_confidence, + top_k=args.top_k, + merge_gap_seconds=args.merge_gap, + blocklist=DEFAULT_BLOCKLIST | parse_labels(args.block_label), + ) + + write_json(events, args.json) + write_csv(events, args.csv) + except Exception as exc: + print(f"Detection failed: {exc}") + return 1 + + print(f"Detected {len(events)} sound event(s).") + print(f"JSON output: {args.json}") + print(f"CSV output: {args.csv}") + return 0 diff --git a/src/cc_event_detector/events.py b/src/cc_event_detector/events.py new file mode 100644 index 0000000..9dabc7d --- /dev/null +++ b/src/cc_event_detector/events.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass + + +def format_timestamp(seconds: float) -> str: + millis = int(round(seconds * 1000)) + hours, remainder = divmod(millis, 3_600_000) + minutes, remainder = divmod(remainder, 60_000) + secs, millis = divmod(remainder, 1000) + return f"{hours:02d}:{minutes:02d}:{secs:02d}.{millis:03d}" + + +@dataclass +class SoundEvent: + label: str + caption_label: str + start_time: float + end_time: float + confidence: float + + def to_dict(self) -> dict[str, float | str]: + data = asdict(self) + data["start_time"] = round(self.start_time, 3) + data["end_time"] = round(self.end_time, 3) + data["start_timestamp"] = format_timestamp(self.start_time) + data["end_timestamp"] = format_timestamp(self.end_time) + data["duration"] = round(max(0.0, self.end_time - self.start_time), 3) + data["confidence"] = round(self.confidence, 4) + return data diff --git a/src/cc_event_detector/export.py b/src/cc_event_detector/export.py new file mode 100644 index 0000000..de4c332 --- /dev/null +++ b/src/cc_event_detector/export.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import csv +import json +from pathlib import Path + +from .events import SoundEvent + + +def write_json(events: list[SoundEvent], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + payload = [event.to_dict() for event in events] + path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + +def write_csv(events: list[SoundEvent], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter( + handle, + fieldnames=[ + "label", + "caption_label", + "start_time", + "end_time", + "start_timestamp", + "end_timestamp", + "duration", + "confidence", + ], + ) + writer.writeheader() + for event in events: + writer.writerow(event.to_dict()) diff --git a/src/cc_event_detector/labels.py b/src/cc_event_detector/labels.py new file mode 100644 index 0000000..77d9f87 --- /dev/null +++ b/src/cc_event_detector/labels.py @@ -0,0 +1,28 @@ +from __future__ import annotations + + +CAPTION_LABELS = { + "Alarm": "[alarm]", + "Applause": "[applause]", + "Bird": "[bird chirping]", + "Car alarm": "[car alarm]", + "Cheering": "[cheering]", + "Door": "[door]", + "Explosion": "[explosion]", + "Fusillade": "[rapid gunfire]", + "Glass": "[glass breaking]", + "Glass breaking": "[glass breaking]", + "Gunshot, gunfire": "[gunshot]", + "Keys jangling": "[keys jangling]", + "Laughter": "[laughter]", + "Machine gun": "[machine gun]", + "Music": "[music]", + "Siren": "[siren]", + "Telephone": "[telephone ringing]", + "Vehicle horn, car horn, honking": "[horn honking]", +} + + +def caption_label_for(yamnet_label: str) -> str: + """Return a simple CC-style label for a YAMNet class.""" + return CAPTION_LABELS.get(yamnet_label, f"[{yamnet_label.lower()}]") diff --git a/src/cc_event_detector/yamnet.py b/src/cc_event_detector/yamnet.py new file mode 100644 index 0000000..f1aea31 --- /dev/null +++ b/src/cc_event_detector/yamnet.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import csv +import os +import warnings +from pathlib import Path + +os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") +warnings.filterwarnings( + "ignore", + message="pkg_resources is deprecated as an API.*", + category=UserWarning, +) + +import numpy as np +import tensorflow as tf +import tensorflow_hub as hub + +from .audio import TARGET_SAMPLE_RATE, load_wav_mono +from .events import SoundEvent +from .labels import caption_label_for + + +YAMNET_MODEL_URL = "https://tfhub.dev/google/yamnet/1" +YAMNET_PATCH_SECONDS = 0.96 +YAMNET_HOP_SECONDS = 0.48 + +DEFAULT_BLOCKLIST = { + "Silence", + "Speech", + "Inside, small room", + "Narration, monologue", + "Conversation", +} + + +class SoundDetectionError(RuntimeError): + """Raised when sound event detection cannot be completed.""" + + +def load_yamnet_model() -> object: + return hub.load(YAMNET_MODEL_URL) + + +def load_class_names(model: object) -> list[str]: + class_map_path = model.class_map_path().numpy().decode("utf-8") + names: list[str] = [] + with tf.io.gfile.GFile(class_map_path) as handle: + reader = csv.DictReader(handle) + for row in reader: + names.append(row["display_name"]) + return names + + +def detect_sound_events( + wav_path: Path, + min_confidence: float = 0.25, + top_k: int = 1, + merge_gap_seconds: float = 0.6, + blocklist: set[str] | None = None, +) -> list[SoundEvent]: + samples, sample_rate = load_wav_mono(wav_path) + if sample_rate != TARGET_SAMPLE_RATE: + raise SoundDetectionError( + f"Expected {TARGET_SAMPLE_RATE} Hz WAV audio, got {sample_rate} Hz. " + "Use the video input path or extract audio with this tool first." + ) + if not samples: + return [] + + model = load_yamnet_model() + class_names = load_class_names(model) + waveform = tf.convert_to_tensor(samples, dtype=tf.float32) + scores, _, _ = model(waveform) + score_matrix = scores.numpy() + blocked = blocklist or DEFAULT_BLOCKLIST + + candidates = frame_scores_to_events( + score_matrix, + class_names, + min_confidence=min_confidence, + top_k=top_k, + blocklist=blocked, + ) + return merge_adjacent_events(candidates, merge_gap_seconds=merge_gap_seconds) + + +def frame_scores_to_events( + score_matrix: np.ndarray, + class_names: list[str], + min_confidence: float, + top_k: int, + blocklist: set[str], +) -> list[SoundEvent]: + events: list[SoundEvent] = [] + for frame_index, frame_scores in enumerate(score_matrix): + top_indices = np.argsort(frame_scores)[::-1][:top_k] + for class_index in top_indices: + confidence = float(frame_scores[class_index]) + label = class_names[class_index] + if confidence < min_confidence or label in blocklist: + continue + start = frame_index * YAMNET_HOP_SECONDS + end = start + YAMNET_PATCH_SECONDS + events.append( + SoundEvent( + label=label, + caption_label=caption_label_for(label), + start_time=start, + end_time=end, + confidence=confidence, + ) + ) + return events + + +def merge_adjacent_events( + events: list[SoundEvent], + merge_gap_seconds: float, +) -> list[SoundEvent]: + if not events: + return [] + + sorted_events = sorted(events, key=lambda event: (event.label, event.start_time)) + merged: list[SoundEvent] = [] + for event in sorted_events: + if ( + merged + and merged[-1].label == event.label + and event.start_time - merged[-1].end_time <= merge_gap_seconds + ): + merged[-1].end_time = max(merged[-1].end_time, event.end_time) + merged[-1].confidence = max(merged[-1].confidence, event.confidence) + continue + merged.append(event) + + return sorted(merged, key=lambda event: event.start_time)