From 5293077950a33ff38c18825ecde09ea349910bc3 Mon Sep 17 00:00:00 2001 From: Akshay Ballal Date: Sun, 28 Dec 2025 22:49:46 +0100 Subject: [PATCH 1/3] Add Dockerfile for CUDA support and enhance image embedding functionality - Introduced a new Dockerfile for building a server with CUDA development tools. - Updated the main server Dockerfile to improve the build process. - Added support for image embeddings, including new request and response structures for handling base64 images. - Enhanced the embedding logic to differentiate between text and image inputs, ensuring proper error handling for mixed input types. - Updated dependencies in Cargo.toml and Cargo.lock to include base64 and image libraries. --- Cargo.lock | 2 + server-cuda.Dockerfile | 88 ++++++++++ server.Dockerfile | 2 +- server/Cargo.toml | 2 + server/src/lib.rs | 361 +++++++++++++++++++++++++++++++++++++++-- 5 files changed, 438 insertions(+), 17 deletions(-) create mode 100644 server-cuda.Dockerfile diff --git a/Cargo.lock b/Cargo.lock index f723a718..07c9f972 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5576,8 +5576,10 @@ version = "0.7.0" dependencies = [ "actix-multipart", "actix-web", + "base64 0.22.1", "embed_anything", "futures-util", + "image", "serde", "tempfile", "tokio", diff --git a/server-cuda.Dockerfile b/server-cuda.Dockerfile new file mode 100644 index 00000000..5de13a2d --- /dev/null +++ b/server-cuda.Dockerfile @@ -0,0 +1,88 @@ +# Stage 1: Chef base with CUDA development tools +FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS chef +WORKDIR /app + +# Set non-interactive mode +ENV DEBIAN_FRONTEND=noninteractive +ENV TORCH_CUDA_ARCH_LIST=Turing +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES video,compute,utility +# Set CUDA compute capability for candle-kernels (Turing = 7.5, encoded as 75) +# This bypasses the need for nvidia-smi during build +# Format: major * 10 + minor (e.g., 7.5 -> 75, 8.0 -> 80, 8.6 -> 86) +ENV CUDA_COMPUTE_CAP=75 + +# Install build dependencies +RUN apt-get update && apt-get install -y \ + pkg-config \ + libssl-dev \ + python3 \ + python3-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +ENV PATH="/root/.cargo/bin:${PATH}" + +# Install cargo-chef (we use CUDA base image instead of lukemathwalker/cargo-chef +# because we need CUDA development tools) +RUN cargo install cargo-chef --locked + +# Create a mock nvidia-smi script as fallback (in case CUDA_COMPUTE_CAP doesn't work) +# This script will be used if nvidia-smi is not available +RUN echo '#!/bin/bash\n\ +# Mock nvidia-smi for Docker build\n\ +# Returns compute capability 75 (Turing 7.5) for build-time detection\n\ +if echo "$*" | grep -q "compute_cap"; then\n\ + echo "compute_cap"\n\ + echo "75"\n\ +elif echo "$*" | grep -q "query"; then\n\ + # Handle --query-gpu format\n\ + if echo "$*" | grep -q "csv"; then\n\ + echo "compute_cap"\n\ + echo "75"\n\ + else\n\ + echo "CUDA Version: 12.2"\n\ + echo "Driver Version: 535.00"\n\ + echo "Compute Capability: 7.5"\n\ + fi\n\ +else\n\ + # Default output\n\ + echo "NVIDIA-SMI 535.00"\n\ + echo "Driver Version: 535.00"\n\ + echo "CUDA Version: 12.2"\n\ +fi\n\ +exit 0' > /usr/local/bin/nvidia-smi && chmod +x /usr/local/bin/nvidia-smi + +# Stage 2: Planner - prepare recipe +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +# Stage 3: Builder - cook dependencies and build +FROM chef AS builder +COPY --from=planner /app/recipe.json recipe.json +# Build dependencies - this is the caching Docker layer! +RUN cargo chef cook --release --recipe-path recipe.json --package server --features cuda +# Build application +COPY . . +RUN cargo build --release -p server --features cuda +RUN strip target/release/server + +# Stage 4: Runtime - minimal CUDA runtime image +FROM nvidia/cuda:12.2.2-runtime-ubuntu22.04 AS runtime +WORKDIR /app + +# Install minimal runtime dependencies +RUN apt-get update && apt-get install -y \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Copy the stripped binary from builder +COPY --from=builder /app/target/release/server /usr/local/bin/server + +EXPOSE 8080 + +CMD ["server"] + diff --git a/server.Dockerfile b/server.Dockerfile index 850ded43..3e85d322 100644 --- a/server.Dockerfile +++ b/server.Dockerfile @@ -17,7 +17,7 @@ COPY --from=planner /app/recipe.json recipe.json RUN cargo chef cook --release --recipe-path recipe.json --package server # Build application COPY . . -RUN cargo build --release --package server +RUN cargo build --release -p server RUN strip target/release/server # We do not need the Rust toolchain to run the binary! diff --git a/server/Cargo.toml b/server/Cargo.toml index 471ba7fd..a5e9d85a 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,6 +14,8 @@ futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } tempfile = "3" tokio = { version = "1", features = ["fs", "io-util"] } +base64 = "0.22.1" +image = "0.25.6" [target.'cfg(not(target_os = "macos"))'.dependencies] embed_anything = {path = "../rust"} diff --git a/server/src/lib.rs b/server/src/lib.rs index 12511dc7..a37104ee 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -7,10 +7,11 @@ use std::time::{SystemTime, UNIX_EPOCH}; use actix_multipart::Multipart; use actix_web::dev::Server; use actix_web::{get, post, web, App, HttpResponse, HttpServer}; +use base64::Engine; use embed_anything::config::TextEmbedConfig; use embed_anything::{ embed_files_batch, embed_query, - embeddings::embed::{EmbedData, EmbedderBuilder, EmbeddingResult}, + embeddings::embed::{EmbedData, EmbedderBuilder, EmbeddingResult, EmbedImage}, }; use futures_util::StreamExt; use serde::{Deserialize, Serialize}; @@ -88,6 +89,29 @@ enum PdfEmbeddingVector { Multi(Vec>), } +// Image embedding request structure +#[derive(Deserialize)] +struct ImageEmbedRequest { + model: String, + images: Vec, // Base64 encoded images +} + +// Image embedding response structure +#[derive(Serialize)] +struct ImageEmbeddingResponse { + object: String, + data: Vec, + model: String, +} + +#[derive(Serialize)] +struct ImageEmbeddingData { + object: String, + index: usize, + embedding: Vec, + metadata: Option>, +} + fn pdf_embedding_response(model: String, embeddings: Vec) -> PdfEmbeddingResponse { let data = embeddings .into_iter() @@ -133,6 +157,21 @@ async fn create_embeddings(req: web::Json) -> HttpResponse { }); } + // Detect input type: check if all inputs are base64 images + let all_images = req.input.iter().all(|input| is_base64_image(input)); + let all_text = req.input.iter().all(|input| !is_base64_image(input)); + + // If mixed input types, return error + if !all_images && !all_text { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: "Mixed input types detected. Please provide either all text inputs or all base64 image inputs.".to_string(), + error_type: "invalid_request_error".to_string(), + code: Some("mixed_input_types".to_string()), + }, + }); + } + // Create embedder let embedder = match EmbedderBuilder::new() .model_id(Some(req.model.as_str())) @@ -150,21 +189,91 @@ async fn create_embeddings(req: web::Json) -> HttpResponse { } }; - // Convert input to string slices - let input_slices: Vec<&str> = req.input.iter().map(|s| s.as_str()).collect(); + // Route based on input type and model type + let embeddings = if all_images { + // Handle image embeddings + match embedder { + embed_anything::embeddings::embed::Embedder::Vision(vision_embedder) => { + // Create temp directory for image files + let temp_dir = match tempfile::tempdir() { + Ok(dir) => dir, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to create temp directory: {}", e), + error_type: "server_error".to_string(), + code: Some("temp_dir_creation_failed".to_string()), + }, + }); + } + }; - // Generate embeddings - let config = TextEmbedConfig::default(); - let embeddings = match embed_query(&input_slices, &embedder, Some(&config)).await { - Ok(embeddings) => embeddings, - Err(e) => { - return HttpResponse::InternalServerError().json(ErrorResponse { - error: ErrorDetail { - message: format!("Failed to generate embeddings: {}", e), - error_type: "server_error".to_string(), - code: Some("embedding_generation_failed".to_string()), - }, - }); + // Decode base64 images to temporary files + let mut image_paths = Vec::new(); + for (index, base64_image) in req.input.iter().enumerate() { + match decode_base64_to_temp_file(base64_image, index, &temp_dir).await { + Ok(path) => image_paths.push(path), + Err(e) => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to decode image at index {}: {}", index, e), + error_type: "invalid_request_error".to_string(), + code: Some("base64_decode_failed".to_string()), + }, + }); + } + } + } + + // Generate embeddings for images + match vision_embedder.embed_image_batch(&image_paths, None).await { + Ok(embeddings) => embeddings, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to generate image embeddings: {}", e), + error_type: "server_error".to_string(), + code: Some("embedding_generation_failed".to_string()), + }, + }); + } + } + } + _ => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!( + "Model '{}' does not support image embeddings. Please use a vision model (e.g., CLIP, SigLIP, DinoV2, ColPali).", + req.model + ), + error_type: "invalid_request_error".to_string(), + code: Some("unsupported_model".to_string()), + }, + }); + } + } + } else { + // Handle text embeddings + match embedder { + embed_anything::embeddings::embed::Embedder::Text(_) | embed_anything::embeddings::embed::Embedder::Vision(_) => { + // Convert input to string slices + let input_slices: Vec<&str> = req.input.iter().map(|s| s.as_str()).collect(); + + // Generate embeddings + let config = TextEmbedConfig::default(); + match embed_query(&input_slices, &embedder, Some(&config)).await { + Ok(embeddings) => embeddings, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to generate text embeddings: {}", e), + error_type: "server_error".to_string(), + code: Some("embedding_generation_failed".to_string()), + }, + }); + } + } + } } }; @@ -190,7 +299,11 @@ async fn create_embeddings(req: web::Json) -> HttpResponse { .collect(); // Calculate usage (simplified - you might want to implement proper token counting) - let total_tokens = req.input.iter().map(|s| s.split_whitespace().count()).sum(); + let total_tokens = if all_images { + req.input.len() // For images, count as 1 token per image (simplified) + } else { + req.input.iter().map(|s| s.split_whitespace().count()).sum() + }; let response = OpenAIEmbedResponse { object: "list".to_string(), @@ -507,6 +620,221 @@ async fn create_pdf_embeddings_upload(mut payload: Multipart) -> HttpResponse { HttpResponse::Ok().json(response) } +/// Detects if a string is likely a base64-encoded image +fn is_base64_image(input: &str) -> bool { + // Check if it starts with data URL prefix + if input.starts_with("data:image/") { + return true; + } + + // Check if it's valid base64 and try to decode and validate as image + let base64_data = input.trim(); + + // Base64 strings should be reasonably long and contain only base64 characters + if base64_data.len() < 100 { + return false; + } + + // Check if it contains only base64 characters (with optional padding) + if !base64_data.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=' + }) { + return false; + } + + // Try to decode and validate as image + if let Ok(image_bytes) = base64::engine::general_purpose::STANDARD.decode(base64_data) { + // Try to detect image format from bytes + if image::ImageReader::new(std::io::Cursor::new(&image_bytes)) + .with_guessed_format() + .is_ok() + { + return true; + } + } + + false +} + +/// Decodes a base64 image string and writes it to a temporary file +async fn decode_base64_to_temp_file( + base64_str: &str, + index: usize, + temp_dir: &tempfile::TempDir, +) -> Result { + // Remove data URL prefix if present (e.g., "data:image/png;base64,") + let base64_data = if base64_str.starts_with("data:") { + base64_str + .split(',') + .nth(1) + .ok_or_else(|| "Invalid data URL format".to_string())? + } else { + base64_str + }; + + // Decode base64 + let image_bytes = base64::engine::general_purpose::STANDARD + .decode(base64_data.trim()) + .map_err(|e| format!("Failed to decode base64: {}", e))?; + + // Try to load image from memory to validate and determine format + let image_format = image::ImageReader::new(std::io::Cursor::new(&image_bytes)) + .with_guessed_format() + .map_err(|e| format!("Failed to read image: {}", e))? + .format(); + + // Determine file extension from format + let extension = match image_format { + Some(image::ImageFormat::Png) => "png", + Some(image::ImageFormat::Jpeg) => "jpg", + Some(image::ImageFormat::WebP) => "webp", + Some(image::ImageFormat::Gif) => "gif", + Some(image::ImageFormat::Bmp) => "bmp", + _ => "png", // Default to PNG if format cannot be determined + }; + + // Create temporary file + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let filename = format!("image_{}_{}.{}", timestamp, index, extension); + let file_path = temp_dir.path().join(filename); + + // Write image bytes to file + tokio::fs::write(&file_path, image_bytes) + .await + .map_err(|e| format!("Failed to write temp file: {}", e))?; + + Ok(file_path) +} + +#[post("/v1/image_embeddings")] +async fn create_image_embeddings(req: web::Json) -> HttpResponse { + // Validate input + if req.images.is_empty() { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: "Images cannot be empty".to_string(), + error_type: "invalid_request_error".to_string(), + code: Some("empty_images".to_string()), + }, + }); + } + + // Create temp directory for image files + let temp_dir = match tempfile::tempdir() { + Ok(dir) => dir, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to create temp directory: {}", e), + error_type: "server_error".to_string(), + code: Some("temp_dir_creation_failed".to_string()), + }, + }); + } + }; + + // Decode base64 images to temporary files + let mut image_paths = Vec::new(); + for (index, base64_image) in req.images.iter().enumerate() { + match decode_base64_to_temp_file(base64_image, index, &temp_dir).await { + Ok(path) => image_paths.push(path), + Err(e) => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to decode image at index {}: {}", index, e), + error_type: "invalid_request_error".to_string(), + code: Some("base64_decode_failed".to_string()), + }, + }); + } + } + } + + // Create embedder - try to determine if it's a vision model + let embedder = match EmbedderBuilder::new() + .model_id(Some(req.model.as_str())) + .from_pretrained_hf() + { + Ok(embedder) => embedder, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to initialize embedder: {}", e), + error_type: "server_error".to_string(), + code: Some("embedder_init_failed".to_string()), + }, + }); + } + }; + + // Check if embedder supports vision + let vision_embedder = match embedder { + embed_anything::embeddings::embed::Embedder::Vision(embedder) => embedder, + _ => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!( + "Model '{}' does not support image embeddings. Please use a vision model (e.g., CLIP, SigLIP, DinoV2, ColPali).", + req.model + ), + error_type: "invalid_request_error".to_string(), + code: Some("unsupported_model".to_string()), + }, + }); + } + }; + + // Generate embeddings for images + let embeddings = match vision_embedder + .embed_image_batch(&image_paths, None) + .await + { + Ok(embeddings) => embeddings, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to generate embeddings: {}", e), + error_type: "server_error".to_string(), + code: Some("embedding_generation_failed".to_string()), + }, + }); + } + }; + + // Convert to response format + let embedding_data: Vec = embeddings + .into_iter() + .enumerate() + .map(|(index, embed_data)| { + let embedding_vector = match embed_data.embedding { + EmbeddingResult::DenseVector(vec) => vec, + EmbeddingResult::MultiVector(_) => { + // For multi-vector embeddings, return empty (or handle differently) + vec![] + } + }; + + ImageEmbeddingData { + object: "embedding".to_string(), + index, + embedding: embedding_vector, + metadata: embed_data.metadata, + } + }) + .collect(); + + let response = ImageEmbeddingResponse { + object: "list".to_string(), + data: embedding_data, + model: req.model.clone(), + }; + + HttpResponse::Ok().json(response) +} + pub fn run(listener: TcpListener) -> std::io::Result { let server = HttpServer::new(|| { App::new() @@ -514,6 +842,7 @@ pub fn run(listener: TcpListener) -> std::io::Result { .service(create_embeddings) .service(create_pdf_embeddings) .service(create_pdf_embeddings_upload) + .service(create_image_embeddings) }) .listen(listener)? .run(); From bb82b63aa63e9ee16686a87613c25a8543973223 Mon Sep 17 00:00:00 2001 From: Akshay Ballal Date: Thu, 1 Jan 2026 19:52:44 +0100 Subject: [PATCH 2/3] Refactor Dockerfile for improved CUDA build process and dependency management - Updated the Dockerfile to use a base image with CUDA 12.2.0 and streamlined the build stages. - Introduced sccache for caching Rust builds and improved the installation of Rust and cargo-chef. - Enhanced the build process by separating the planner and builder stages, ensuring better organization and efficiency. - Removed unnecessary mock scripts and optimized runtime dependencies for a cleaner image. --- server-cuda.Dockerfile | 129 ++++++++++++++++++++--------------------- 1 file changed, 63 insertions(+), 66 deletions(-) diff --git a/server-cuda.Dockerfile b/server-cuda.Dockerfile index 5de13a2d..c4ddc28b 100644 --- a/server-cuda.Dockerfile +++ b/server-cuda.Dockerfile @@ -1,88 +1,85 @@ -# Stage 1: Chef base with CUDA development tools -FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS chef -WORKDIR /app +FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 AS base-builder -# Set non-interactive mode -ENV DEBIAN_FRONTEND=noninteractive -ENV TORCH_CUDA_ARCH_LIST=Turing -ENV NVIDIA_VISIBLE_DEVICES all -ENV NVIDIA_DRIVER_CAPABILITIES video,compute,utility -# Set CUDA compute capability for candle-kernels (Turing = 7.5, encoded as 75) -# This bypasses the need for nvidia-smi during build -# Format: major * 10 + minor (e.g., 7.5 -> 75, 8.0 -> 80, 8.6 -> 86) +ENV SCCACHE=0.10.0 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" +# aligned with `cargo-chef` version in `lukemathwalker/cargo-chef:latest-rust-1.85-bookworm` +ENV CARGO_CHEF=0.1.71 ENV CUDA_COMPUTE_CAP=75 -# Install build dependencies -RUN apt-get update && apt-get install -y \ - pkg-config \ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ libssl-dev \ + pkg-config \ python3 \ python3-dev \ - curl \ && rm -rf /var/lib/apt/lists/* -# Install Rust -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -ENV PATH="/root/.cargo/bin:${PATH}" +# Download and configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --version $CARGO_CHEF --locked + +FROM base-builder AS planner + +WORKDIR /app + +COPY processors processors +COPY rust rust +COPY python python +COPY server server +COPY Cargo.toml ./ +COPY Cargo.lock ./ -# Install cargo-chef (we use CUDA base image instead of lukemathwalker/cargo-chef -# because we need CUDA development tools) -RUN cargo install cargo-chef --locked - -# Create a mock nvidia-smi script as fallback (in case CUDA_COMPUTE_CAP doesn't work) -# This script will be used if nvidia-smi is not available -RUN echo '#!/bin/bash\n\ -# Mock nvidia-smi for Docker build\n\ -# Returns compute capability 75 (Turing 7.5) for build-time detection\n\ -if echo "$*" | grep -q "compute_cap"; then\n\ - echo "compute_cap"\n\ - echo "75"\n\ -elif echo "$*" | grep -q "query"; then\n\ - # Handle --query-gpu format\n\ - if echo "$*" | grep -q "csv"; then\n\ - echo "compute_cap"\n\ - echo "75"\n\ - else\n\ - echo "CUDA Version: 12.2"\n\ - echo "Driver Version: 535.00"\n\ - echo "Compute Capability: 7.5"\n\ - fi\n\ -else\n\ - # Default output\n\ - echo "NVIDIA-SMI 535.00"\n\ - echo "Driver Version: 535.00"\n\ - echo "CUDA Version: 12.2"\n\ -fi\n\ -exit 0' > /usr/local/bin/nvidia-smi && chmod +x /usr/local/bin/nvidia-smi - -# Stage 2: Planner - prepare recipe -FROM chef AS planner -COPY . . RUN cargo chef prepare --recipe-path recipe.json -# Stage 3: Builder - cook dependencies and build -FROM chef AS builder +FROM base-builder AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG SCCACHE_GHA_ENABLED + +# Limit parallelism +ARG RAYON_NUM_THREADS=4 +ARG CARGO_BUILD_JOBS +ARG CARGO_BUILD_INCREMENTAL + +WORKDIR /app + COPY --from=planner /app/recipe.json recipe.json -# Build dependencies - this is the caching Docker layer! -RUN cargo chef cook --release --recipe-path recipe.json --package server --features cuda -# Build application -COPY . . -RUN cargo build --release -p server --features cuda -RUN strip target/release/server - -# Stage 4: Runtime - minimal CUDA runtime image -FROM nvidia/cuda:12.2.2-runtime-ubuntu22.04 AS runtime + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo chef cook --release --recipe-path recipe.json --package server --features cuda && sccache -s; + +COPY processors processors +COPY rust rust +COPY server server +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin server --features cuda && sccache -s; + +RUN strip /app/target/release/server + +FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 AS runtime + WORKDIR /app -# Install minimal runtime dependencies -RUN apt-get update && apt-get install -y \ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ ca-certificates \ + libssl-dev \ + curl \ && rm -rf /var/lib/apt/lists/* -# Copy the stripped binary from builder COPY --from=builder /app/target/release/server /usr/local/bin/server EXPOSE 8080 CMD ["server"] - From 62a01e1e59d7684282fc6d47597d0051758f34d3 Mon Sep 17 00:00:00 2001 From: prachi-kedar Date: Mon, 26 Jan 2026 15:35:50 +0100 Subject: [PATCH 3/3] feature: Added video embedding support and guide --- README.md | 8 +- docs/guides/video.md | 86 +++++++++ docs/index.md | 6 +- docs/roadmap/roadmap.md | 4 +- examples/video.py | 37 ++++ mkdocs.yml | 1 + processors/Cargo.toml | 2 + processors/src/lib.rs | 4 + processors/src/video_processor.rs | 145 +++++++++++++++ python/Cargo.toml | 1 + python/python/embed_anything/__init__.py | 5 +- .../python/embed_anything/_embed_anything.pyi | 97 +++++++++- python/src/config.rs | 41 +++++ python/src/lib.rs | 108 +++++++++++ rust/Cargo.toml | 1 + rust/src/config.rs | 35 ++++ rust/src/file_loader.rs | 45 +++++ rust/src/lib.rs | 174 +++++++++++++++++- 18 files changed, 788 insertions(+), 12 deletions(-) create mode 100644 docs/guides/video.md create mode 100644 examples/video.py create mode 100644 processors/src/video_processor.rs diff --git a/README.md b/README.md index acbf9869..b08ad7c5 100644 --- a/README.md +++ b/README.md @@ -86,12 +86,14 @@ EmbedAnything is a minimalist, yet highly performant, modular, lightning-fast, l - **Candle Backend** : Supports BERT, Jina, ColPali, Splade, ModernBERT, Reranker, Qwen - **ONNX Backend** : Supports BERT, Jina, ColPali, ColBERT Splade, Reranker, ModernBERT, Qwen - **Cloud Embedding Models:** : Supports OpenAI, Cohere, and Gemini. -- **MultiModality** : Works with text sources like PDFs, txt, md, Images JPG and Audio, .WAV +- **MultiModality** : Works with text sources like PDFs, txt, md, images, audio (.WAV), and videos (frame sampling; enable the `video` feature) - **GPU support** : Hardware acceleration on GPU as well. - **Chunking** : In-built chunking methods like semantic, late-chunking - **Vector Streaming:** : Separate file processing, Indexing and Inferencing on different threads, reduces latency. - **AWS S3 Bucket:** : Directly import AWS S3 bucket files. +- **Prebult Docker Image** : Just pull [it]( starlightsearch/embedanything-server) - **SearchAgent** : Example of how you can use index for Searchr1 reasoning. +- **Video guide** : Quick start for frame sampling: https://embed-anything.com/guides/video/ ## 💡What is Vector Streaming @@ -473,7 +475,7 @@ We’re excited to share that we've expanded our platform to support multiple mo - [x] Images -- [ ] Videos +- [x] Videos (frame sampling; enable the `video` feature) - [ ] Graph @@ -493,7 +495,7 @@ We now support both candle and Onnx backend
We had multimodality from day one for our infrastructure. We have already included it for websites, images and audios but we want to expand it further to. ➡️ Graph embedding -- build deepwalks embeddings depth first and word to vec
-➡️ Video Embedding
+➡️ Video embedding improvements (temporal + audio)
➡️ Yolo Clip
diff --git a/docs/guides/video.md b/docs/guides/video.md new file mode 100644 index 00000000..7059fab0 --- /dev/null +++ b/docs/guides/video.md @@ -0,0 +1,86 @@ +# Video Embeddings (Frame Sampling) + +EmbedAnything supports video by sampling frames and embedding them with a vision model +(CLIP/SigLIP). This is opt-in via the `video` feature flag and requires the `ffmpeg` +CLI to be available on your system. If `ffmpeg` is not on `PATH`, set `FFMPEG_BIN` +to the full path of the executable. + +## Recommended Config + +`VideoEmbedConfig` controls how frames are sampled: + +- `frame_step`: sample every Nth frame. Default `30`. +- `max_frames`: maximum frames per video. Default `300`. +- `batch_size`: frames per embedding batch. Default `32`. + +Suggested starting point: + +```python +from embed_anything import VideoEmbedConfig + +config = VideoEmbedConfig(frame_step=30, max_frames=300, batch_size=16) +``` + +## Python Usage + +```python +import embed_anything +from embed_anything import VideoEmbedConfig + +model = embed_anything.EmbeddingModel.from_pretrained_hf( + model_id="openai/clip-vit-base-patch16" +) + +config = VideoEmbedConfig(frame_step=30, max_frames=200, batch_size=16) + +data = embed_anything.embed_video_file("path/to/video.mp4", embedder=model, config=config) +``` + +## Build with Video Support + +You must enable the `video` feature and have the `ffmpeg` CLI installed. + +### macOS + +```bash +brew install ffmpeg +cargo build --features video +# Python (maturin) +maturin develop --features "extension-module,video" +``` + +### Linux (Debian/Ubuntu) + +```bash +sudo apt-get update +sudo apt-get install -y ffmpeg +cargo build --features video +# Python (maturin) +maturin develop --features "extension-module,video" +``` + +### Windows (prebuilt FFmpeg) + +```powershell +1. Download a static build from https://www.gyan.dev/ffmpeg/builds/ +2. Extract it and set: + +```powershell +$env:FFMPEG_BIN = "C:\path\to\ffmpeg.exe" +``` + +Then build: + +```powershell +cargo build --features video +# Python (maturin) +maturin develop --features "extension-module,video" +``` +``` + +## Output Metadata + +Each embedding includes: + +- `video_path`: the source video file +- `frame_index`: the sampled frame index (0-based) diff --git a/docs/index.md b/docs/index.md index 9b828a13..1dc73cd1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -74,7 +74,7 @@ EmbedAnything is a minimalist, yet highly performant, modular, lightning-fast, l - **Candle Backend** : Supports BERT, Jina, ColPali, Splade, ModernBERT, Reranker, Qwen - **ONNX Backend**: Supports BERT, Jina, ColPali, ColBERT Splade, Reranker, ModernBERT, Qwen - **Cloud Embedding Models:**: Supports OpenAI, Cohere, and Gemini. -- **MultiModality** : Works with text sources like PDFs, txt, md, Images JPG and Audio, .WAV +- **MultiModality** : Works with text sources like PDFs, txt, md, images, audio (.WAV), and videos (frame sampling; enable the `video` feature) - **GPU support** : Hardware acceleration on GPU as well. - **Chunking** : In-built chunking methods like semantic, late-chunking - **Vector Streaming:** Separate file processing, Indexing and Inferencing on different threads, reduces latency. @@ -339,7 +339,7 @@ We’re excited to share that we've expanded our platform to support multiple mo - [x] Images -- [ ] Videos +- [x] Videos (frame sampling; enable the `video` feature) - [ ] Graph @@ -359,7 +359,7 @@ We now support both candle and Onnx backend
We had multimodality from day one for our infrastructure. We have already included it for websites, images and audios but we want to expand it further to. ➡️ Graph embedding -- build deepwalks embeddings depth first and word to vec
-➡️ Video Embedding
+➡️ Video embedding improvements (temporal + audio)
➡️ Yolo Clip
diff --git a/docs/roadmap/roadmap.md b/docs/roadmap/roadmap.md index d3a76d6b..11beae2a 100644 --- a/docs/roadmap/roadmap.md +++ b/docs/roadmap/roadmap.md @@ -17,7 +17,7 @@ We’re excited to share that we've expanded our platform to support multiple mo - [x] Images -- [ ] Videos +- [x] Videos (frame sampling; enable the `video` feature) - [ ] Graph @@ -58,7 +58,7 @@ To address this, we’re excited to announce that we’re introducing Candle-ONN We had multimodality from day one for our infrastructure. We have already included it for websites, images and audios but we want to expand it further to. ☑️Graph embedding -- build deepwalks embeddings depth first and word to vec
-☑️Video Embedding
+☑️Video embedding improvements (temporal + audio)
☑️ Yolo Clip
diff --git a/examples/video.py b/examples/video.py new file mode 100644 index 00000000..cab27670 --- /dev/null +++ b/examples/video.py @@ -0,0 +1,37 @@ +import os +from pathlib import Path + +import embed_anything +from embed_anything import EmbedData, VideoEmbedConfig + +# Load a vision model (CLIP/SigLIP) for frame embeddings +model = embed_anything.EmbeddingModel.from_pretrained_hf( + model_id="openai/clip-vit-base-patch16" +) + +# Sample every 30th frame (~1 fps for 30 fps videos), cap to 200 frames +config = VideoEmbedConfig(frame_step=30, max_frames=200, batch_size=16) + +video_path = os.environ.get("VIDEO_PATH", "path/to/video.mp4") +if not Path(video_path).exists(): + raise FileNotFoundError( + f"Video not found: {video_path}. Set VIDEO_PATH env var to a valid file." + ) + +# Embed a single video +video_embeddings: list[EmbedData] = embed_anything.embed_video_file( + video_path, + embedder=model, + config=config, +) +print(f"Embedded {len(video_embeddings)} frames from video.") + +video_dir = os.environ.get("VIDEO_DIR") +if video_dir: + dir_embeddings = embed_anything.embed_video_directory( + video_dir, + embedder=model, + config=config, + ) + if dir_embeddings is not None: + print(f"Embedded {len(dir_embeddings)} total frames from directory.") diff --git a/mkdocs.yml b/mkdocs.yml index 9278f06e..e3118209 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -52,6 +52,7 @@ nav: - Guides: - guides/colpali.md - guides/images.md + - guides/video.md - guides/semantic.md - guides/adapters.md - guides/onnx_models.md diff --git a/processors/Cargo.toml b/processors/Cargo.toml index 397e655c..c23a7aa4 100644 --- a/processors/Cargo.toml +++ b/processors/Cargo.toml @@ -30,9 +30,11 @@ pdf2image = "0.1.3" image = "0.25.6" thiserror = "2.0.12" tempfile = "3.19.1" +# Video processing (uses external ffmpeg CLI) [dev-dependencies] tempdir = "0.3.7" [features] default = [] +video = [] \ No newline at end of file diff --git a/processors/src/lib.rs b/processors/src/lib.rs index 831daaea..8a91540a 100644 --- a/processors/src/lib.rs +++ b/processors/src/lib.rs @@ -15,3 +15,7 @@ pub mod html_processor; /// This module contains the file processor for DOCX files. pub mod docx_processor; + +/// This module contains the file processor for video files. +#[cfg(feature = "video")] +pub mod video_processor; diff --git a/processors/src/video_processor.rs b/processors/src/video_processor.rs new file mode 100644 index 00000000..6416a4ff --- /dev/null +++ b/processors/src/video_processor.rs @@ -0,0 +1,145 @@ +#![cfg(feature = "video")] + +use anyhow::{anyhow, Result}; +use std::env; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::{fs, path}; + +#[derive(Debug, Clone, Copy)] +pub enum VideoFrameFormat { + Jpeg, + Png, +} + +impl VideoFrameFormat { + fn extension(self) -> &'static str { + match self { + VideoFrameFormat::Jpeg => "jpg", + VideoFrameFormat::Png => "png", + } + } +} + +#[derive(Debug, Clone)] +pub struct VideoFrame { + pub index: usize, + pub path: PathBuf, +} + +#[derive(Debug, Clone)] +pub struct VideoProcessor { + frame_step: usize, + max_frames: Option, + output_format: VideoFrameFormat, + ffmpeg_bin: Option, +} + +impl VideoProcessor { + pub fn new(frame_step: usize) -> Self { + Self { + frame_step: frame_step.max(1), + max_frames: None, + output_format: VideoFrameFormat::Jpeg, + ffmpeg_bin: None, + } + } + + pub fn with_max_frames(mut self, max_frames: usize) -> Self { + self.max_frames = Some(max_frames); + self + } + + pub fn with_output_format(mut self, output_format: VideoFrameFormat) -> Self { + self.output_format = output_format; + self + } + + pub fn with_ffmpeg_bin>(mut self, ffmpeg_bin: P) -> Self { + self.ffmpeg_bin = Some(ffmpeg_bin.as_ref().to_path_buf()); + self + } + + fn resolve_ffmpeg_bin(&self) -> Result { + if let Some(bin) = &self.ffmpeg_bin { + return Ok(bin.clone()); + } + if let Ok(bin) = env::var("FFMPEG_BIN") { + return Ok(PathBuf::from(bin)); + } + Ok(PathBuf::from("ffmpeg")) + } + + pub fn extract_frames_to_dir, Q: AsRef>( + &self, + video_path: P, + output_dir: Q, + ) -> Result> { + let output_dir = output_dir.as_ref(); + fs::create_dir_all(output_dir)?; + + let ffmpeg_bin = self.resolve_ffmpeg_bin()?; + let frame_step = self.frame_step.max(1); + let filter = format!("select=not(mod(n\\,{}))", frame_step); + let output_pattern = output_dir.join(format!( + "frame_%06d.{}", + self.output_format.extension() + )); + + let mut command = Command::new(ffmpeg_bin); + command + .arg("-hide_banner") + .arg("-loglevel") + .arg("error") + .arg("-i") + .arg(video_path.as_ref()) + .arg("-vf") + .arg(filter) + .arg("-vsync") + .arg("vfr"); + + if let Some(max_frames) = self.max_frames { + command.arg("-vframes").arg(max_frames.to_string()); + } + + let status = command.arg(output_pattern).status()?; + if !status.success() { + return Err(anyhow!("ffmpeg failed with exit code {:?}", status.code())); + } + + let mut frame_paths = fs::read_dir(output_dir)? + .filter_map(|entry| entry.ok()) + .filter(|entry| entry.file_type().map(|t| t.is_file()).unwrap_or(false)) + .map(|entry| entry.path()) + .filter(|path| { + path.extension() + .and_then(|ext| ext.to_str()) + .map(|ext| ext.eq_ignore_ascii_case(self.output_format.extension())) + .unwrap_or(false) + }) + .collect::>(); + + frame_paths.sort(); + + if frame_paths.is_empty() { + return Err(anyhow!("No frames extracted from video")); + } + + let frames = frame_paths + .into_iter() + .enumerate() + .map(|(index, path)| VideoFrame { index, path }) + .collect(); + + Ok(frames) + } + + pub fn extract_frames_to_temp_dir>( + &self, + video_path: P, + ) -> Result<(tempfile::TempDir, Vec)> { + let temp_dir = tempfile::TempDir::new()?; + let frames = self.extract_frames_to_dir(video_path, temp_dir.path())?; + Ok((temp_dir, frames)) + } +} diff --git a/python/Cargo.toml b/python/Cargo.toml index 91328e46..0eb849ee 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -25,3 +25,4 @@ cudnn = ["embed_anything/cudnn"] metal = ["embed_anything/metal"] ort = ["embed_anything/ort"] audio = ["embed_anything/audio"] +video = ["embed_anything/video"] \ No newline at end of file diff --git a/python/python/embed_anything/__init__.py b/python/python/embed_anything/__init__.py index 29cd66e4..3267caac 100644 --- a/python/python/embed_anything/__init__.py +++ b/python/python/embed_anything/__init__.py @@ -2,7 +2,7 @@ This module provides functions and classes for embedding queries, files, and directories using different embedding models. It supports text, images, audio, -PDFs, and other media types with various embedding backends (Candle, ONNX, Cloud). +videos, PDFs, and other media types with various embedding backends (Candle, ONNX, Cloud). Main Functions: --------------- @@ -11,6 +11,8 @@ - `embed_directory`: Embeds all files in a directory and returns a list of EmbedData objects. - `embed_image_directory`: Embeds all images in a directory. - `embed_audio_file`: Embeds audio files using Whisper for transcription. +- `embed_video_file`: Embeds a video file by sampling frames. +- `embed_video_directory`: Embeds all videos in a directory. - `embed_webpage`: Embeds content from a webpage URL. Main Classes: @@ -18,6 +20,7 @@ - `EmbeddingModel`: Main class for loading and using embedding models. - `EmbedData`: Represents embedded data with text, embedding vector, and metadata. - `TextEmbedConfig`: Configuration for text embedding (chunking, batching, etc.). +- `VideoEmbedConfig`: Configuration for video embedding (frame sampling, batching). - `ColpaliModel`: Specialized model for document/image-text embedding. - `ColbertModel`: Model for late-interaction embeddings. - `Reranker`: Model for re-ranking search results. diff --git a/python/python/embed_anything/_embed_anything.pyi b/python/python/embed_anything/_embed_anything.pyi index 635713eb..7e2ecc2b 100644 --- a/python/python/embed_anything/_embed_anything.pyi +++ b/python/python/embed_anything/_embed_anything.pyi @@ -268,6 +268,42 @@ def embed_audio_file( ), ) -> list[EmbedData]: """ + +def embed_video_file( + file_path: str, + embedder: EmbeddingModel, + config: VideoEmbedConfig | None = None, +) -> list[EmbedData]: + """ + Embeds the given video file by sampling frames and returns a list of EmbedData objects. + + Args: + file_path: The path to the video file to embed. + embedder: The embedding model to use. + config: The configuration for video embedding. + + Returns: + A list of EmbedData objects. + """ + +def embed_video_directory( + file_path: str, + embedder: EmbeddingModel, + config: VideoEmbedConfig | None = None, + adapter: Adapter | None = None, +) -> list[EmbedData] | None: + """ + Embeds all videos in the given directory and returns a list of EmbedData objects. + + Args: + file_path: The path to the directory containing videos to embed. + embedder: The embedding model to use. + config: The configuration for video embedding. + adapter: The adapter to use for storing the embeddings in a vector database. + + Returns: + A list of EmbedData objects, or None if an adapter is used. + """ Embeds the given audio file and returns a list of EmbedData objects. Args: @@ -452,7 +488,7 @@ class Reranker: """ def rerank( - self, query: list[str], documents: list[str], top_k: int + self, query: list[str], documents: list[str], batch_size: int ) -> RerankerResult: """ Reranks the given documents for the query and returns a list of RerankerResult objects. @@ -460,7 +496,7 @@ class Reranker: Args: query: The query to rerank. documents: The list of documents to rerank. - top_k: The number of documents to return. + batch_size: The number of documents to process per batch. Returns: A list of RerankerResult objects. @@ -585,6 +621,29 @@ class ImageEmbedConfig: buffer_size: int | None batch_size: int | None +class VideoEmbedConfig: + """ + Represents the configuration for the Video Embedding model. + + Attributes: + frame_step: Sample every Nth frame. Default is 30. + max_frames: Maximum number of frames to embed. Default is 300. + batch_size: The batch size for processing frames. Default is 32. + """ + + def __init__( + self, + frame_step: int | None = None, + max_frames: int | None = None, + batch_size: int | None = None, + ): + self.frame_step = frame_step + self.max_frames = max_frames + self.batch_size = batch_size + frame_step: int | None + max_frames: int | None + batch_size: int | None + class EmbeddingModel: """ Represents an embedding model. @@ -760,6 +819,40 @@ class EmbeddingModel: A list of EmbedData objects. """ + def embed_video_file( + self, + video_file: str, + config: VideoEmbedConfig | None = None, + ) -> list[EmbedData]: + """ + Embeds the given video file and returns a list of EmbedData objects. + + Args: + video_file: The path to the video file to embed. + config: The configuration for video embedding. + + Returns: + A list of EmbedData objects. + """ + + def embed_video_directory( + self, + directory: str, + config: VideoEmbedConfig | None = None, + adapter: Adapter | None = None, + ) -> list[EmbedData] | None: + """ + Embeds videos in the given directory and returns a list of EmbedData objects. + + Args: + directory: The path to the directory to embed. + config: The configuration for video embedding. + adapter: The adapter for the embedding. + + Returns: + A list of EmbedData objects, or None if an adapter is used. + """ + def embed_query( self, query: list[str], diff --git a/python/src/config.rs b/python/src/config.rs index 258c4468..f8a749d1 100644 --- a/python/src/config.rs +++ b/python/src/config.rs @@ -92,3 +92,44 @@ impl ImageEmbedConfig { self.inner.batch_size } } + +#[pyclass] +#[derive(Clone, Default)] +pub struct VideoEmbedConfig { + pub inner: embed_anything::config::VideoEmbedConfig, +} + +#[pymethods] +impl VideoEmbedConfig { + #[new] + #[pyo3(signature = (frame_step=None, max_frames=None, batch_size=None))] + pub fn new( + frame_step: Option, + max_frames: Option, + batch_size: Option, + ) -> Self { + let default_config = embed_anything::config::VideoEmbedConfig::default(); + Self { + inner: embed_anything::config::VideoEmbedConfig { + frame_step: frame_step.or(default_config.frame_step), + max_frames: max_frames.or(default_config.max_frames), + batch_size: batch_size.or(default_config.batch_size), + }, + } + } + + #[getter] + pub fn frame_step(&self) -> Option { + self.inner.frame_step + } + + #[getter] + pub fn max_frames(&self) -> Option { + self.inner.max_frames + } + + #[getter] + pub fn batch_size(&self) -> Option { + self.inner.batch_size + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index 4f2648c5..701ceaed 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -6,6 +6,8 @@ use embed_anything::{ self, config::TextEmbedConfig, emb_audio, + embed_video_directory as embed_video_directory_rs, + embed_video_file as embed_video_file_rs, embeddings::embed::{Embedder, EmbeddingResult}, file_processor::audio::audio_processor, FileLoadingError, @@ -384,6 +386,25 @@ impl EmbeddingModel { ) -> PyResult>> { embed_audio_file(audio_file, audio_decoder, self, config) } + + #[pyo3(signature = (video_file, config=None))] + pub fn embed_video_file( + &self, + video_file: &str, + config: Option<&config::VideoEmbedConfig>, + ) -> PyResult> { + embed_video_file(video_file, self, config) + } + + #[pyo3(signature = (directory, config=None, adapter=None))] + pub fn embed_video_directory( + &self, + directory: PathBuf, + config: Option<&config::VideoEmbedConfig>, + adapter: Option, + ) -> PyResult>> { + embed_video_directory(directory, self, config, adapter) + } } #[pyclass] @@ -603,6 +624,90 @@ pub fn embed_audio_file( Ok(data) } +#[pyfunction] +#[pyo3(signature = (video_file, embedder, config=None))] +pub fn embed_video_file( + video_file: &str, + embedder: &EmbeddingModel, + config: Option<&config::VideoEmbedConfig>, +) -> PyResult> { + let config = config.map(|c| &c.inner); + let embedding_model = &embedder.inner; + let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + + if !Path::new(video_file).exists() { + return Err(PyFileNotFoundError::new_err(format!( + "File not found: {:?}", + video_file + ))); + }; + + let data = rt + .block_on(async { embed_video_file_rs(video_file, embedding_model.as_ref(), config).await }) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + + Ok(data.into_iter().map(|data| EmbedData { inner: data }).collect()) +} + +#[pyfunction] +#[pyo3(signature = (directory, embedder, config=None, adapter=None))] +pub fn embed_video_directory( + directory: PathBuf, + embedder: &EmbeddingModel, + config: Option<&config::VideoEmbedConfig>, + adapter: Option, +) -> PyResult>> { + let config = config.map(|c| &c.inner); + let embedding_model = &embedder.inner; + let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + + let adapter = match adapter { + Some(adapter) => { + let callback = move |data: Vec| { + Python::with_gil(|py| { + let upsert_fn = adapter.getattr(py, "upsert").unwrap(); + let converted_data = data + .into_iter() + .map(|data| EmbedData { inner: data }) + .collect::>(); + upsert_fn + .call1(py, (converted_data,)) + .map_err(|e| PyValueError::new_err(e.to_string())) + .unwrap(); + }); + }; + Some(callback) + } + None => None, + }; + + let data = rt.block_on(async { + embed_video_directory_rs( + directory, + embedding_model, + config, + adapter.map(|f| { + Box::new(f) + as Box< + dyn FnMut(Vec) + + Send + + Sync, + > + }), + ) + .await + .map_err(|e| PyValueError::new_err(e.to_string())) + .unwrap() + .map(|data| { + data.into_iter() + .map(|data| EmbedData { inner: data }) + .collect::>() + }) + }); + + Ok(data) +} + #[pyfunction] #[pyo3(signature = (directory, embedder, extensions=None, config=None, adapter = None))] pub fn embed_directory( @@ -779,6 +884,8 @@ fn _embed_anything(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(embed_query, m)?)?; m.add_function(wrap_pyfunction!(embed_webpage, m)?)?; m.add_function(wrap_pyfunction!(embed_audio_file, m)?)?; + m.add_function(wrap_pyfunction!(embed_video_file, m)?)?; + m.add_function(wrap_pyfunction!(embed_video_directory, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -786,6 +893,7 @@ fn _embed_anything(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/rust/Cargo.toml b/rust/Cargo.toml index c3a1e84b..2cab468e 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -96,6 +96,7 @@ cudnn = ["candle-core/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] metal = ["candle-core/metal", "candle-nn/metal"] audio = ["dep:symphonia"] +video = ["processors-rs/video"] ort = ["dep:ort",] rustls-tls = [ "reqwest/rustls-tls", diff --git a/rust/src/config.rs b/rust/src/config.rs index 4655519c..0d67e68c 100644 --- a/rust/src/config.rs +++ b/rust/src/config.rs @@ -185,3 +185,38 @@ impl ImageEmbedConfig { } } } + +pub const DEFAULT_VIDEO_FRAME_STEP: usize = 30; +pub const DEFAULT_VIDEO_MAX_FRAMES: usize = 300; +pub const DEFAULT_VIDEO_BATCH_SIZE: usize = 32; + +#[derive(Clone)] +pub struct VideoEmbedConfig { + pub frame_step: Option, + pub max_frames: Option, + pub batch_size: Option, +} + +impl Default for VideoEmbedConfig { + fn default() -> Self { + Self { + frame_step: Some(DEFAULT_VIDEO_FRAME_STEP), + max_frames: Some(DEFAULT_VIDEO_MAX_FRAMES), + batch_size: Some(DEFAULT_VIDEO_BATCH_SIZE), + } + } +} + +impl VideoEmbedConfig { + pub fn new( + frame_step: Option, + max_frames: Option, + batch_size: Option, + ) -> Self { + Self { + frame_step, + max_frames, + batch_size, + } + } +} diff --git a/rust/src/file_loader.rs b/rust/src/file_loader.rs index 132c5fa4..5aa5d818 100644 --- a/rust/src/file_loader.rs +++ b/rust/src/file_loader.rs @@ -94,6 +94,27 @@ impl FileParser { Ok(self.files.clone()) } + pub fn get_video_paths(&mut self, directory_path: &PathBuf) -> Result, Error> { + let video_regex = Regex::new(r".*\.(mp4|mov|avi|mkv|webm|m4v|flv|wmv)$").unwrap(); + + let video_paths: Vec = WalkDir::new(directory_path) + .into_iter() + .filter_map(|entry| entry.ok()) + .filter(|entry| entry.file_type().is_file()) + .filter(|entry| video_regex.is_match(entry.file_name().to_str().unwrap_or(""))) + .map(|entry| { + let absolute_path = entry + .path() + .canonicalize() + .unwrap_or_else(|_| entry.path().to_path_buf()); + absolute_path.to_string_lossy().to_string() + }) + .collect(); + + self.files = video_paths; + Ok(self.files.clone()) + } + pub fn get_files_to_index(&self, indexed_files: &HashSet) -> Vec { let files = self .files @@ -209,6 +230,30 @@ mod tests { assert_eq!(audio_files.len(), 2); } + #[test] + fn test_get_video_paths() { + let temp_dir = TempDir::new("example").unwrap(); + let video_file = temp_dir.path().join("clip.mp4"); + let _ignored_file = temp_dir.path().join("note.txt"); + File::create(&video_file).unwrap(); + File::create(&_ignored_file).unwrap(); + + let mut file_parser = FileParser::new(); + let video_files = file_parser + .get_video_paths(&PathBuf::from(temp_dir.path())) + .unwrap(); + + assert_eq!(video_files.len(), 1); + assert_eq!( + video_files[0], + video_file + .canonicalize() + .unwrap() + .to_string_lossy() + .to_string() + ); + } + #[test] fn test_get_files_to_index() { let temp_dir = TempDir::new("example").unwrap(); diff --git a/rust/src/lib.rs b/rust/src/lib.rs index ccd095a2..72fb2578 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -80,7 +80,10 @@ pub mod text_loader; pub mod s3_loader; use anyhow::{Error, Result}; -use config::{ImageEmbedConfig, TextEmbedConfig}; +use config::{ + ImageEmbedConfig, TextEmbedConfig, VideoEmbedConfig, DEFAULT_VIDEO_BATCH_SIZE, + DEFAULT_VIDEO_FRAME_STEP, +}; use embeddings::{ embed::{EmbedData, EmbedImage, Embedder, TextEmbedder, VisionEmbedder}, get_text_metadata, @@ -104,6 +107,8 @@ use processors_rs::{ processor::{Document, FileProcessor, UrlProcessor}, txt_processor::TxtProcessor, }; +#[cfg(feature = "video")] +use processors_rs::video_processor::VideoProcessor; /// Numerical precision types for model weights and computations. pub enum Dtype { @@ -176,6 +181,16 @@ pub async fn embed_query( Ok(embeddings) } +fn is_video_extension(extension: &std::ffi::OsStr) -> bool { + match extension.to_str().map(|ext| ext.to_ascii_lowercase()) { + Some(ext) => matches!( + ext.as_str(), + "mp4" | "mov" | "avi" | "mkv" | "webm" | "m4v" | "flv" | "wmv" + ), + None => false, + } +} + /// Embeds the text from a file using the specified embedding model. /// /// # Arguments @@ -228,6 +243,25 @@ pub async fn embed_file>( ); Ok(Some(embedder.embed_pdf(file_name, Some(batch_size)).await?)) } + Some(extension) if is_video_extension(extension) => { + #[cfg(feature = "video")] + { + let batch_size = config.and_then(|cfg| cfg.batch_size); + let video_config = VideoEmbedConfig::default(); + let video_config = VideoEmbedConfig { + batch_size, + ..video_config + }; + let embeddings = emb_video(file_name, embedder, &video_config).await?; + Ok(Some(embeddings)) + } + #[cfg(not(feature = "video"))] + { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Rebuild with --features video to embed videos." + )) + } + } _ => Ok(Some(vec![emb_image(file_name, embedder).await?])), }, } @@ -366,6 +400,144 @@ async fn emb_image>( Ok(embedding) } +#[cfg(feature = "video")] +pub async fn embed_video_file>( + file_name: T, + embedder: &Embedder, + config: Option<&VideoEmbedConfig>, +) -> Result> { + let default_config = VideoEmbedConfig::default(); + let config = config.unwrap_or(&default_config); + match embedder { + Embedder::Vision(embedder) => emb_video(file_name, embedder, config).await, + _ => Err(anyhow::anyhow!( + "Model not supported for video embedding" + )), + } +} + +#[cfg(not(feature = "video"))] +pub async fn embed_video_file>( + _file_name: T, + _embedder: &Embedder, + _config: Option<&VideoEmbedConfig>, +) -> Result> { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Please enable it to use video embedding." + )) +} + +#[cfg(feature = "video")] +pub async fn embed_video_directory( + directory: PathBuf, + embedding_model: &Arc, + config: Option<&VideoEmbedConfig>, + adapter: Option) + Send + Sync>>, +) -> Result>> { + let mut file_parser = FileParser::new(); + file_parser.get_video_paths(&directory)?; + + let default_config = VideoEmbedConfig::default(); + let config = config.unwrap_or(&default_config); + let files = file_parser.files.clone(); + let pb = indicatif::ProgressBar::new(files.len() as u64); + pb.set_style(indicatif::ProgressStyle::with_template( + "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})", + )?); + + let mut all_embeddings = Vec::new(); + let mut adapter = adapter; + for file in files { + let embeddings = embed_video_file(&file, embedding_model.as_ref(), Some(config)).await?; + pb.inc(1); + + if let Some(adapter) = adapter.as_mut() { + adapter(embeddings); + } else { + all_embeddings.extend(embeddings); + } + } + + if adapter.is_some() { + Ok(None) + } else { + Ok(Some(all_embeddings)) + } +} + +#[cfg(not(feature = "video"))] +pub async fn embed_video_directory( + _directory: PathBuf, + _embedding_model: &Arc, + _config: Option<&VideoEmbedConfig>, + _adapter: Option) + Send + Sync>>, +) -> Result>> { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Please enable it to use video embedding." + )) +} + +#[cfg(feature = "video")] +pub async fn emb_video>( + video_path: T, + embedding_model: &VisionEmbedder, + config: &VideoEmbedConfig, +) -> Result> { + let frame_step = config.frame_step.unwrap_or(DEFAULT_VIDEO_FRAME_STEP).max(1); + let max_frames = config.max_frames.filter(|value| *value > 0); + let batch_size = config + .batch_size + .unwrap_or(DEFAULT_VIDEO_BATCH_SIZE) + .max(1); + + let processor = VideoProcessor::new(frame_step); + let processor = match max_frames { + Some(limit) => processor.with_max_frames(limit), + None => processor, + }; + + let video_path = video_path.as_ref(); + let video_path_string = fs::canonicalize(video_path) + .unwrap_or_else(|_| video_path.to_path_buf()) + .to_string_lossy() + .to_string(); + + let (_temp_dir, frames) = processor.extract_frames_to_temp_dir(video_path)?; + if frames.is_empty() { + return Err(anyhow::anyhow!("No frames extracted from video")); + } + + let mut all_embeddings = Vec::new(); + for frame_chunk in frames.chunks(batch_size) { + let frame_paths: Vec<&std::path::Path> = + frame_chunk.iter().map(|frame| frame.path.as_path()).collect(); + let mut embeddings = embedding_model + .embed_image_batch(&frame_paths, Some(batch_size)) + .await?; + + for (embedding, frame) in embeddings.iter_mut().zip(frame_chunk.iter()) { + let metadata = embedding.metadata.get_or_insert_with(HashMap::new); + metadata.insert("video_path".to_string(), video_path_string.clone()); + metadata.insert("frame_index".to_string(), frame.index.to_string()); + } + + all_embeddings.extend(embeddings); + } + + Ok(all_embeddings) +} + +#[cfg(not(feature = "video"))] +pub async fn emb_video>( + _video_path: T, + _embedding_model: &VisionEmbedder, + _config: &VideoEmbedConfig, +) -> Result> { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Please enable it to use the emb_video function." + )) +} + /// Embeds audio content from a file using transcription and temporal segmentation. /// /// Processes audio files by first transcribing them to text and then creating