Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# PR: Intelligent CC Suggestion Tool — Module 1 Complete

## Summary

This PR delivers a fully working **Module 1** (Sound Event Detection → SRT/SLS output) and lays the architectural groundwork for Module 2 (Visual Reaction Detection). The pipeline accepts any video or audio file and produces closed-caption suggestions for meaningful non-speech audio events — without over-captioning ambient sounds.

**What this PR includes:**
- Full YAMNet-based detection pipeline with a transient-aware 3-path filter (not just a confidence threshold)
- English + Hindi SRT/SLS/JSON/CSV export — no translation API, fully offline
- Silero VAD speech suppression so speech frames never become false CC events
- librosa onset pass to catch short transients (<0.2s) YAMNet's window misses
- Architectural groundwork for Module 2 visual reaction scoring

---

## Pipeline Architecture

```
INPUT VIDEO
├──▶ AUDIO EXTRACTION (imageio-ffmpeg, no system install needed)
│ │
│ ┌──────┴──────┐
│ │ Silero VAD │──▶ speech intervals (suppressed from detection)
│ └─────────────┘
│ │
│ ┌──────▼──────────────────────────────────┐
│ │ YAMNet · RMS gate · Blocklist │
│ │ │
│ │ Transient? ──YES──▶ accept immediately │
│ │ │ (dog bark, gunshot,│
│ │ NO door slam, glass…) │
│ │ ▼ │
│ │ Consensus voting (2/3 frames) │
│ │ + onset check (engine, rain, crowd) │
│ └──────────────────┬──────────────────────┘
│ │
│ librosa onset pass (catches <0.2s events)
│ │
│ Merge · Deduplicate · Sort
└──▶ SRT (EN + HI) · SLS · JSON · CSV
```

---

## Run

```bash
python detect.py --input video.mp4 --srt outputs/cc_en.srt --srt-hi outputs/cc_hi.srt
```

📎 **Colab links:** https://colab.research.google.com/drive/1aAbBrZBw1xg8ASqS98lyCewVWRSZb_Bj?usp=sharing,
https://colab.research.google.com/drive/15kpMJkWYWQO0sBoJZhYFMqcRBbLLzVMy?usp=sharing


---

## Research: Benchmark Across 5 Model Families

Before settling on YAMNet as the production solution, we benchmarked five model families. Here's what we found.

---

### WAV2CLIP + CLAP — Not viable

Both models embed audio into CLIP/text space and score against text prompts via cosine similarity. In theory, free-form labels; in practice:

- **CLAP (HTSAT-base)** had repeated checkpoint/architecture mismatches — `laion_clap`'s `load_ckpt()` silently builds a different model width depending on `enable_fusion`, causing a `RuntimeError` on every load attempt across three configurations
- **WAV2CLIP** loaded but produced inconsistent, low-confidence labelling — it lives in CLIP's *visual* embedding space, which wasn't built for diverse audio
- **Verdict:** The cosine-similarity approach is brittle without a dedicated audio backbone. Not worth pursuing further.

---

### PANNs CNN14 — Better mAP, but wrong fit for CC

PANNs CNN14 (mAP 0.385 vs YAMNet's 0.306) is technically a stronger AudioSet model. A fair benchmark was run with blocklist off and thresholds matched to YAMNet sensitivity.

**The problem:** PANNs is trained on AudioSet's full 527-class hierarchy including very broad meta-classes — `"Music"`, `"Animal"`, `"Sound"`. On a real video, over 1500 frames fired on these broad labels. They're not wrong, but they're not CC-worthy. With blocklist on, too many real events get suppressed as collateral; with it off, the output is noisy.

PANNs' higher mAP comes from scoring those broad categories well. For CC specifically — where you want narrow, specific, actionable events — the broad-label training is a liability, not an asset. **YAMNet's narrower 521-class set, which looks like a weakness on paper, is actually an advantage here.**

---

### Qwen2-Audio-7B — Most promising, fine-tune path forward

Qwen2-Audio is a 7B audio-language model (Whisper-large-v2 encoder + LLM). Instead of cosine similarity or fixed class indices, it reasons about audio in natural language and returns structured JSON.

**What stood out:**
- Contextual descriptions, not bare labels: `"glass breaking, likely a fight scene"` — directly useful for CC editors reviewing output
- Native Hindi/multilingual support — no separate translation step needed
- Zero-shot on any category; the label bank is a prompt, not a fixed classifier
- Self-reported confidence calibrated better than embedding-similarity scores

**Limitation:** 7B params needs ~14GB VRAM at full precision; we ran 4-bit quantized on a T4 (fits in 15GB). Inference is ~3–5× slower than YAMNet.

**Fine-tuning is the path forward.** Qwen2-Audio can be adapted to Indian content without starting from scratch:
1. Start from AudioSet classes YAMNet is trained on as the base — strong prior already exists
2. Augment with clips for underrepresented India-specific sounds: dhol, shehnai, auto-rickshaw horn, switch/click sounds, crowd chanting
3. Map new classes to existing AudioSet parents where possible (dhol → `Drum`, shehnai → `Wind instrument`) so existing weights transfer
4. Fine-tune only the output mapping / last few layers — audio encoder is already strong
5. A ~1000-clip augmented dataset fine-tunes in 2–3 hours on a T4

---

### Full Comparison Table

| | YAMNet | PANNs CNN14 | CLAP | WAV2CLIP | Qwen2-Audio |
|---|---|---|---|---|---|
| AudioSet mAP | 0.306 | 0.385 | ~0.47 | ~0.40 | LLM-based |
| Parameters | 3.7M | 81M | 87M | ~60M | 7B |
| Label type | Fixed 521 | Fixed 527 | Free text | Free text | Free text + reasoning |
| Hindi support | manual map | manual map | via prompt | via prompt | native |
| India-specific sounds | poor | poor | moderate | poor | best (zero-shot) |
| False positive control | gates + blocklist | blocklist too aggressive | mAP ceiling | inconsistent | LLM reasoning |
| Speed | fastest | fast | fast | moderate | slow |
| Offline | ✅ | ✅ | ✅ | ✅ | ✅ (quantized) |
| **Verdict** | ✅ **production** | ❌ noisy for CC | ❌ load errors | ❌ low quality | 🔬 fine-tune target |



cc @abinash-sketch @keerthiseelan-planetread
20 changes: 20 additions & 0 deletions detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Root-level entry point.

python detect.py --input video.mp4 [options]

This simply re-exports the CLI main() so the tool can be run from
the project root without installing the package.
"""
from __future__ import annotations

import sys
from pathlib import Path

# Allow running from project root without pip install
sys.path.insert(0, str(Path(__file__).resolve().parent / "src"))

from cc_detector.cli import main

if __name__ == "__main__":
raise SystemExit(main())
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
tensorflow>=2.13.0
tensorflow-hub>=0.14.0
torch>=2.1.0
torchaudio>=2.1.0
soundfile>=0.13.1
librosa>=0.10.0
numpy>=1.26,<2.0
imageio-ffmpeg>=0.6.0
16 changes: 16 additions & 0 deletions src/cc_detector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
cc_detector — Intelligent Closed Caption Suggestion Tool.
Module 1 MVP: YAMNet-based non-speech sound event detection.

Improvements over baseline YAMNet approaches:
- 5-gate filtering pipeline (speech, RMS, spectral, harmonic, blocklist)
- Top-K consensus voting across frames before accepting an event
- Temporal smoothing with a sliding-window majority vote
- Librosa spectral harmonic gating suppresses tonal music artefacts
- Librosa onset detection catches short transients (<0.2s) YAMNet misses
- Dual SRT output: English + Hindi with tuple-based keyword matching
- JSON + CSV + SRT export with per-event metadata debug fields
"""

__version__ = "0.2.0"

6 changes: 6 additions & 0 deletions src/cc_detector/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import annotations
from .cli import main
import sys

if __name__ == "__main__":
raise SystemExit(main())
80 changes: 80 additions & 0 deletions src/cc_detector/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Audio extraction and loading.

Uses imageio-ffmpeg's bundled binary — no system FFmpeg required.
All downstream models (YAMNet, Silero VAD, librosa) expect 16 kHz mono.
"""

from __future__ import annotations

import subprocess
from pathlib import Path

import imageio_ffmpeg
import librosa
import numpy as np
import soundfile as sf

TARGET_SR = 16_000 # YAMNet + Silero VAD both expect 16 kHz mono
_FFMPEG = imageio_ffmpeg.get_ffmpeg_exe()

SUPPORTED_VIDEO: frozenset[str] = frozenset(
{".mp4", ".mkv", ".mov", ".avi", ".webm", ".flv", ".ts", ".m2ts"}
)
SUPPORTED_AUDIO: frozenset[str] = frozenset(
{".wav", ".mp3", ".m4a", ".aac", ".flac", ".ogg", ".opus"}
)


class MediaError(RuntimeError):
"""Raised when media extraction or audio loading fails."""


def is_video(path: Path) -> bool:
return path.suffix.lower() in SUPPORTED_VIDEO


def is_audio(path: Path) -> bool:
return path.suffix.lower() in SUPPORTED_AUDIO


def extract_audio(media_path: Path, out_wav: Path) -> Path:
"""
Extract 16 kHz mono WAV from any video or audio file.

Uses the imageio-ffmpeg bundled binary — callers need not install
system FFmpeg. Raises MediaError on failure.
"""
out_wav.parent.mkdir(parents=True, exist_ok=True)
cmd = [
_FFMPEG, "-y",
"-i", str(media_path),
"-vn",
"-ac", "1",
"-ar", str(TARGET_SR),
"-f", "wav",
str(out_wav),
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise MediaError(
f"FFmpeg failed on {media_path.name}:\n"
+ (result.stderr[-600:] or "(no stderr)")
)
return out_wav


def load_mono_f32(wav_path: Path) -> tuple[np.ndarray, int]:
"""
Load a WAV file as float32 mono numpy array.
Returns (samples, sample_rate).
Normalises amplitude to [-1, 1] if needed (YAMNet expects this range).
"""
audio, sr = sf.read(str(wav_path), dtype="float32", always_2d=False)
if audio.ndim > 1:
audio = audio.mean(axis=1)
# Normalise if raw PCM was decoded outside [-1, 1]
peak = np.abs(audio).max()
if peak > 1.0:
audio = audio / peak
return audio, int(sr)
140 changes: 140 additions & 0 deletions src/cc_detector/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Command-line interface for the Intelligent CC Suggestion Tool."""
from __future__ import annotations

import argparse
import sys
import time
from pathlib import Path
from tempfile import TemporaryDirectory

from .audio import extract_audio, is_video, is_audio, MediaError
from .export import write_srt, write_sls, write_json, write_csv
from .vad import get_speech_intervals
from .yamnet import detect


def _build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="cc_detector",
description=(
"Intelligent CC Suggestion Tool — Module 1\n"
"YAMNet-based non-speech sound event detection → SRT/SLS/JSON/CSV"
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--input", "-i", required=True, type=Path, metavar="FILE",
help="Input video or audio file")
p.add_argument("--json", type=Path, default=Path("outputs/events.json"))
p.add_argument("--csv", type=Path, default=Path("outputs/events.csv"))
p.add_argument("--srt", type=Path, default=Path("outputs/cc_english.srt"))
p.add_argument("--srt-hi", type=Path, default=Path("outputs/cc_hindi.srt"))
p.add_argument("--sls", type=Path, default=None)
p.add_argument("--sls-hi", type=Path, default=None)
p.add_argument("--keep-audio", type=Path, default=None)

p.add_argument("--min-confidence", type=float, default=0.25,
help="YAMNet confidence threshold (default: 0.25)")
p.add_argument("--rms-threshold", type=float, default=0.010,
help="Silence gate (default: 0.010)")
p.add_argument("--merge-gap", type=float, default=1.5,
help="Merge gap seconds (default: 1.5)")
p.add_argument("--top-k", type=int, default=5,
help="Top-K labels per frame (default: 5)")
p.add_argument("--vad-threshold", type=float, default=0.50,
help="Silero VAD threshold (default: 0.50)")
p.add_argument("--consensus-window", type=int, default=3)
p.add_argument("--consensus-k", type=int, default=2)
p.add_argument("--no-onset-pass", action="store_true")
p.add_argument("--block-label", action="append", default=[], metavar="LABEL")
return p


def main(argv=None) -> int:
args = _build_parser().parse_args(argv)

if not args.input.exists():
print(f"[ERROR] Input not found: {args.input}", file=sys.stderr)
return 1
if not (is_video(args.input) or is_audio(args.input)):
print(f"[ERROR] Unsupported format: {args.input.suffix}", file=sys.stderr)
return 1

if args.block_label:
from . import labels as lbl_module
extra = frozenset(
part.strip().lower()
for val in args.block_label
for part in val.split(",") if part.strip()
)
lbl_module.BLOCKLIST = lbl_module.BLOCKLIST | extra

t_start = time.time()
try:
with TemporaryDirectory() as tmpdir:
wav_path = args.keep_audio or Path(tmpdir) / "audio.wav"

print(f"[1/4] Extracting audio from: {args.input.name}")
extract_audio(args.input, wav_path)

print("[2/4] Running Silero VAD (speech suppression)...")
speech_intervals = get_speech_intervals(
str(wav_path), threshold=args.vad_threshold
)
print(f" {len(speech_intervals)} speech segment(s) found")

print("[3/4] Running YAMNet sound event detection...")
events, stats, infer_time = detect(
wav_path, speech_intervals,
conf_thresh = args.min_confidence,
rms_thresh = args.rms_threshold,
merge_gap = args.merge_gap,
top_k = args.top_k,
use_onset_pass = not args.no_onset_pass,
consensus_window = args.consensus_window,
consensus_k = args.consensus_k,
vad_tolerance = 0.35,
)
print(f" YAMNet inference: {infer_time:.2f}s")
print(f" Gate stats: {stats}")
print(f" {len(events)} CC event(s) detected")

print("[4/4] Exporting outputs...")
write_json(events, args.json)
write_csv(events, args.csv)
write_srt(events, args.srt, hindi=False)
write_srt(events, args.srt_hi, hindi=True)
if args.sls: write_sls(events, args.sls, hindi=False)
if args.sls_hi: write_sls(events, args.sls_hi, hindi=True)

except MediaError as exc:
print(f"[ERROR] {exc}", file=sys.stderr); return 1
except Exception as exc:
print(f"[ERROR] {exc}", file=sys.stderr)
import traceback; traceback.print_exc(); return 1

elapsed = time.time() - t_start
print()
print("=" * 60)
print(" CC DETECTION COMPLETE")
print("=" * 60)
print(f" Events detected : {len(events)}")
print(f" Total wall time : {elapsed:.1f}s")
print(f" YAMNet inference : {infer_time:.1f}s")
print()
print(" Output files:")
print(f" JSON : {args.json}")
print(f" CSV : {args.csv}")
print(f" SRT : {args.srt}")
print(f" SRT : {args.srt_hi} (Hindi)")
if args.sls: print(f" SLS : {args.sls}")
if args.sls_hi: print(f" SLS : {args.sls_hi} (Hindi)")
print()

if events:
print(f" {'#':<4} {'Start':>8} {'End':>8} {'Label':<22} {'Conf':>6} {'Frames':>6} {'Src':<8} Hindi CC")
print(" " + "─" * 85)
for i, ev in enumerate(events, 1):
print(f" {i:<4} {ev.start_time:>7.2f}s {ev.end_time:>7.2f}s "
f"{ev.label:<22} {ev.confidence:>6.3f} {ev.frame_count:>6} "
f"{ev.onset_source:<8} {ev.caption_hi}")
return 0
Loading