diff --git a/Cargo.lock b/Cargo.lock index cdc28f73..40b310a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5788,8 +5788,10 @@ version = "0.7.0" dependencies = [ "actix-multipart", "actix-web", + "base64 0.22.1", "embed_anything", "futures-util", + "image", "serde", "tempfile", "tokio", diff --git a/README.md b/README.md index c2d34094..79768c6c 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ EmbedAnything is a minimalist, yet highly performant, modular, lightning-fast, l - **AWS S3 Bucket:** : Directly import AWS S3 bucket files. - **Prebult Docker Image** : Just pull it: starlightsearch/embedanything-server - **SearchAgent** : Example of how you can use index for Searchr1 reasoning. +- **Video guide** : Quick start for frame sampling: https://embed-anything.com/guides/video/ ## 💡What is Vector Streaming @@ -478,7 +479,7 @@ We’re excited to share that we've expanded our platform to support multiple mo - [x] Images -- [ ] Videos +- [x] Videos (frame sampling; enable the `video` feature) - [ ] Graph @@ -498,7 +499,7 @@ We now support both candle and Onnx backend
We had multimodality from day one for our infrastructure. We have already included it for websites, images and audios but we want to expand it further to. ➡️ Graph embedding -- build deepwalks embeddings depth first and word to vec
-➡️ Video Embedding
+➡️ Video embedding improvements (temporal + audio)
➡️ Yolo Clip
diff --git a/docs/guides/video.md b/docs/guides/video.md new file mode 100644 index 00000000..7059fab0 --- /dev/null +++ b/docs/guides/video.md @@ -0,0 +1,86 @@ +# Video Embeddings (Frame Sampling) + +EmbedAnything supports video by sampling frames and embedding them with a vision model +(CLIP/SigLIP). This is opt-in via the `video` feature flag and requires the `ffmpeg` +CLI to be available on your system. If `ffmpeg` is not on `PATH`, set `FFMPEG_BIN` +to the full path of the executable. + +## Recommended Config + +`VideoEmbedConfig` controls how frames are sampled: + +- `frame_step`: sample every Nth frame. Default `30`. +- `max_frames`: maximum frames per video. Default `300`. +- `batch_size`: frames per embedding batch. Default `32`. + +Suggested starting point: + +```python +from embed_anything import VideoEmbedConfig + +config = VideoEmbedConfig(frame_step=30, max_frames=300, batch_size=16) +``` + +## Python Usage + +```python +import embed_anything +from embed_anything import VideoEmbedConfig + +model = embed_anything.EmbeddingModel.from_pretrained_hf( + model_id="openai/clip-vit-base-patch16" +) + +config = VideoEmbedConfig(frame_step=30, max_frames=200, batch_size=16) + +data = embed_anything.embed_video_file("path/to/video.mp4", embedder=model, config=config) +``` + +## Build with Video Support + +You must enable the `video` feature and have the `ffmpeg` CLI installed. + +### macOS + +```bash +brew install ffmpeg +cargo build --features video +# Python (maturin) +maturin develop --features "extension-module,video" +``` + +### Linux (Debian/Ubuntu) + +```bash +sudo apt-get update +sudo apt-get install -y ffmpeg +cargo build --features video +# Python (maturin) +maturin develop --features "extension-module,video" +``` + +### Windows (prebuilt FFmpeg) + +```powershell +1. Download a static build from https://www.gyan.dev/ffmpeg/builds/ +2. Extract it and set: + +```powershell +$env:FFMPEG_BIN = "C:\path\to\ffmpeg.exe" +``` + +Then build: + +```powershell +cargo build --features video +# Python (maturin) +maturin develop --features "extension-module,video" +``` +``` + +## Output Metadata + +Each embedding includes: + +- `video_path`: the source video file +- `frame_index`: the sampled frame index (0-based) diff --git a/docs/index.md b/docs/index.md index 9b828a13..1dc73cd1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -74,7 +74,7 @@ EmbedAnything is a minimalist, yet highly performant, modular, lightning-fast, l - **Candle Backend** : Supports BERT, Jina, ColPali, Splade, ModernBERT, Reranker, Qwen - **ONNX Backend**: Supports BERT, Jina, ColPali, ColBERT Splade, Reranker, ModernBERT, Qwen - **Cloud Embedding Models:**: Supports OpenAI, Cohere, and Gemini. -- **MultiModality** : Works with text sources like PDFs, txt, md, Images JPG and Audio, .WAV +- **MultiModality** : Works with text sources like PDFs, txt, md, images, audio (.WAV), and videos (frame sampling; enable the `video` feature) - **GPU support** : Hardware acceleration on GPU as well. - **Chunking** : In-built chunking methods like semantic, late-chunking - **Vector Streaming:** Separate file processing, Indexing and Inferencing on different threads, reduces latency. @@ -339,7 +339,7 @@ We’re excited to share that we've expanded our platform to support multiple mo - [x] Images -- [ ] Videos +- [x] Videos (frame sampling; enable the `video` feature) - [ ] Graph @@ -359,7 +359,7 @@ We now support both candle and Onnx backend
We had multimodality from day one for our infrastructure. We have already included it for websites, images and audios but we want to expand it further to. ➡️ Graph embedding -- build deepwalks embeddings depth first and word to vec
-➡️ Video Embedding
+➡️ Video embedding improvements (temporal + audio)
➡️ Yolo Clip
diff --git a/docs/roadmap/roadmap.md b/docs/roadmap/roadmap.md index d3a76d6b..11beae2a 100644 --- a/docs/roadmap/roadmap.md +++ b/docs/roadmap/roadmap.md @@ -17,7 +17,7 @@ We’re excited to share that we've expanded our platform to support multiple mo - [x] Images -- [ ] Videos +- [x] Videos (frame sampling; enable the `video` feature) - [ ] Graph @@ -58,7 +58,7 @@ To address this, we’re excited to announce that we’re introducing Candle-ONN We had multimodality from day one for our infrastructure. We have already included it for websites, images and audios but we want to expand it further to. ☑️Graph embedding -- build deepwalks embeddings depth first and word to vec
-☑️Video Embedding
+☑️Video embedding improvements (temporal + audio)
☑️ Yolo Clip
diff --git a/examples/video.py b/examples/video.py new file mode 100644 index 00000000..cab27670 --- /dev/null +++ b/examples/video.py @@ -0,0 +1,37 @@ +import os +from pathlib import Path + +import embed_anything +from embed_anything import EmbedData, VideoEmbedConfig + +# Load a vision model (CLIP/SigLIP) for frame embeddings +model = embed_anything.EmbeddingModel.from_pretrained_hf( + model_id="openai/clip-vit-base-patch16" +) + +# Sample every 30th frame (~1 fps for 30 fps videos), cap to 200 frames +config = VideoEmbedConfig(frame_step=30, max_frames=200, batch_size=16) + +video_path = os.environ.get("VIDEO_PATH", "path/to/video.mp4") +if not Path(video_path).exists(): + raise FileNotFoundError( + f"Video not found: {video_path}. Set VIDEO_PATH env var to a valid file." + ) + +# Embed a single video +video_embeddings: list[EmbedData] = embed_anything.embed_video_file( + video_path, + embedder=model, + config=config, +) +print(f"Embedded {len(video_embeddings)} frames from video.") + +video_dir = os.environ.get("VIDEO_DIR") +if video_dir: + dir_embeddings = embed_anything.embed_video_directory( + video_dir, + embedder=model, + config=config, + ) + if dir_embeddings is not None: + print(f"Embedded {len(dir_embeddings)} total frames from directory.") diff --git a/mkdocs.yml b/mkdocs.yml index 9278f06e..e3118209 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -52,6 +52,7 @@ nav: - Guides: - guides/colpali.md - guides/images.md + - guides/video.md - guides/semantic.md - guides/adapters.md - guides/onnx_models.md diff --git a/processors/Cargo.toml b/processors/Cargo.toml index 397e655c..c23a7aa4 100644 --- a/processors/Cargo.toml +++ b/processors/Cargo.toml @@ -30,9 +30,11 @@ pdf2image = "0.1.3" image = "0.25.6" thiserror = "2.0.12" tempfile = "3.19.1" +# Video processing (uses external ffmpeg CLI) [dev-dependencies] tempdir = "0.3.7" [features] default = [] +video = [] \ No newline at end of file diff --git a/processors/src/lib.rs b/processors/src/lib.rs index 831daaea..8a91540a 100644 --- a/processors/src/lib.rs +++ b/processors/src/lib.rs @@ -15,3 +15,7 @@ pub mod html_processor; /// This module contains the file processor for DOCX files. pub mod docx_processor; + +/// This module contains the file processor for video files. +#[cfg(feature = "video")] +pub mod video_processor; diff --git a/processors/src/video_processor.rs b/processors/src/video_processor.rs new file mode 100644 index 00000000..6416a4ff --- /dev/null +++ b/processors/src/video_processor.rs @@ -0,0 +1,145 @@ +#![cfg(feature = "video")] + +use anyhow::{anyhow, Result}; +use std::env; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::{fs, path}; + +#[derive(Debug, Clone, Copy)] +pub enum VideoFrameFormat { + Jpeg, + Png, +} + +impl VideoFrameFormat { + fn extension(self) -> &'static str { + match self { + VideoFrameFormat::Jpeg => "jpg", + VideoFrameFormat::Png => "png", + } + } +} + +#[derive(Debug, Clone)] +pub struct VideoFrame { + pub index: usize, + pub path: PathBuf, +} + +#[derive(Debug, Clone)] +pub struct VideoProcessor { + frame_step: usize, + max_frames: Option, + output_format: VideoFrameFormat, + ffmpeg_bin: Option, +} + +impl VideoProcessor { + pub fn new(frame_step: usize) -> Self { + Self { + frame_step: frame_step.max(1), + max_frames: None, + output_format: VideoFrameFormat::Jpeg, + ffmpeg_bin: None, + } + } + + pub fn with_max_frames(mut self, max_frames: usize) -> Self { + self.max_frames = Some(max_frames); + self + } + + pub fn with_output_format(mut self, output_format: VideoFrameFormat) -> Self { + self.output_format = output_format; + self + } + + pub fn with_ffmpeg_bin>(mut self, ffmpeg_bin: P) -> Self { + self.ffmpeg_bin = Some(ffmpeg_bin.as_ref().to_path_buf()); + self + } + + fn resolve_ffmpeg_bin(&self) -> Result { + if let Some(bin) = &self.ffmpeg_bin { + return Ok(bin.clone()); + } + if let Ok(bin) = env::var("FFMPEG_BIN") { + return Ok(PathBuf::from(bin)); + } + Ok(PathBuf::from("ffmpeg")) + } + + pub fn extract_frames_to_dir, Q: AsRef>( + &self, + video_path: P, + output_dir: Q, + ) -> Result> { + let output_dir = output_dir.as_ref(); + fs::create_dir_all(output_dir)?; + + let ffmpeg_bin = self.resolve_ffmpeg_bin()?; + let frame_step = self.frame_step.max(1); + let filter = format!("select=not(mod(n\\,{}))", frame_step); + let output_pattern = output_dir.join(format!( + "frame_%06d.{}", + self.output_format.extension() + )); + + let mut command = Command::new(ffmpeg_bin); + command + .arg("-hide_banner") + .arg("-loglevel") + .arg("error") + .arg("-i") + .arg(video_path.as_ref()) + .arg("-vf") + .arg(filter) + .arg("-vsync") + .arg("vfr"); + + if let Some(max_frames) = self.max_frames { + command.arg("-vframes").arg(max_frames.to_string()); + } + + let status = command.arg(output_pattern).status()?; + if !status.success() { + return Err(anyhow!("ffmpeg failed with exit code {:?}", status.code())); + } + + let mut frame_paths = fs::read_dir(output_dir)? + .filter_map(|entry| entry.ok()) + .filter(|entry| entry.file_type().map(|t| t.is_file()).unwrap_or(false)) + .map(|entry| entry.path()) + .filter(|path| { + path.extension() + .and_then(|ext| ext.to_str()) + .map(|ext| ext.eq_ignore_ascii_case(self.output_format.extension())) + .unwrap_or(false) + }) + .collect::>(); + + frame_paths.sort(); + + if frame_paths.is_empty() { + return Err(anyhow!("No frames extracted from video")); + } + + let frames = frame_paths + .into_iter() + .enumerate() + .map(|(index, path)| VideoFrame { index, path }) + .collect(); + + Ok(frames) + } + + pub fn extract_frames_to_temp_dir>( + &self, + video_path: P, + ) -> Result<(tempfile::TempDir, Vec)> { + let temp_dir = tempfile::TempDir::new()?; + let frames = self.extract_frames_to_dir(video_path, temp_dir.path())?; + Ok((temp_dir, frames)) + } +} diff --git a/python/Cargo.toml b/python/Cargo.toml index e460a817..6f7e4a0c 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -26,4 +26,4 @@ cudnn = ["embed_anything/cudnn"] metal = ["embed_anything/metal"] ort = ["embed_anything/ort"] audio = ["embed_anything/audio"] -aws = ["embed_anything/aws"] \ No newline at end of file +aws = ["embed_anything/aws"] diff --git a/python/python/embed_anything/__init__.py b/python/python/embed_anything/__init__.py index 29cd66e4..3267caac 100644 --- a/python/python/embed_anything/__init__.py +++ b/python/python/embed_anything/__init__.py @@ -2,7 +2,7 @@ This module provides functions and classes for embedding queries, files, and directories using different embedding models. It supports text, images, audio, -PDFs, and other media types with various embedding backends (Candle, ONNX, Cloud). +videos, PDFs, and other media types with various embedding backends (Candle, ONNX, Cloud). Main Functions: --------------- @@ -11,6 +11,8 @@ - `embed_directory`: Embeds all files in a directory and returns a list of EmbedData objects. - `embed_image_directory`: Embeds all images in a directory. - `embed_audio_file`: Embeds audio files using Whisper for transcription. +- `embed_video_file`: Embeds a video file by sampling frames. +- `embed_video_directory`: Embeds all videos in a directory. - `embed_webpage`: Embeds content from a webpage URL. Main Classes: @@ -18,6 +20,7 @@ - `EmbeddingModel`: Main class for loading and using embedding models. - `EmbedData`: Represents embedded data with text, embedding vector, and metadata. - `TextEmbedConfig`: Configuration for text embedding (chunking, batching, etc.). +- `VideoEmbedConfig`: Configuration for video embedding (frame sampling, batching). - `ColpaliModel`: Specialized model for document/image-text embedding. - `ColbertModel`: Model for late-interaction embeddings. - `Reranker`: Model for re-ranking search results. diff --git a/python/python/embed_anything/_embed_anything.pyi b/python/python/embed_anything/_embed_anything.pyi index b1a47c16..7e2ecc2b 100644 --- a/python/python/embed_anything/_embed_anything.pyi +++ b/python/python/embed_anything/_embed_anything.pyi @@ -268,6 +268,42 @@ def embed_audio_file( ), ) -> list[EmbedData]: """ + +def embed_video_file( + file_path: str, + embedder: EmbeddingModel, + config: VideoEmbedConfig | None = None, +) -> list[EmbedData]: + """ + Embeds the given video file by sampling frames and returns a list of EmbedData objects. + + Args: + file_path: The path to the video file to embed. + embedder: The embedding model to use. + config: The configuration for video embedding. + + Returns: + A list of EmbedData objects. + """ + +def embed_video_directory( + file_path: str, + embedder: EmbeddingModel, + config: VideoEmbedConfig | None = None, + adapter: Adapter | None = None, +) -> list[EmbedData] | None: + """ + Embeds all videos in the given directory and returns a list of EmbedData objects. + + Args: + file_path: The path to the directory containing videos to embed. + embedder: The embedding model to use. + config: The configuration for video embedding. + adapter: The adapter to use for storing the embeddings in a vector database. + + Returns: + A list of EmbedData objects, or None if an adapter is used. + """ Embeds the given audio file and returns a list of EmbedData objects. Args: @@ -585,6 +621,29 @@ class ImageEmbedConfig: buffer_size: int | None batch_size: int | None +class VideoEmbedConfig: + """ + Represents the configuration for the Video Embedding model. + + Attributes: + frame_step: Sample every Nth frame. Default is 30. + max_frames: Maximum number of frames to embed. Default is 300. + batch_size: The batch size for processing frames. Default is 32. + """ + + def __init__( + self, + frame_step: int | None = None, + max_frames: int | None = None, + batch_size: int | None = None, + ): + self.frame_step = frame_step + self.max_frames = max_frames + self.batch_size = batch_size + frame_step: int | None + max_frames: int | None + batch_size: int | None + class EmbeddingModel: """ Represents an embedding model. @@ -760,6 +819,40 @@ class EmbeddingModel: A list of EmbedData objects. """ + def embed_video_file( + self, + video_file: str, + config: VideoEmbedConfig | None = None, + ) -> list[EmbedData]: + """ + Embeds the given video file and returns a list of EmbedData objects. + + Args: + video_file: The path to the video file to embed. + config: The configuration for video embedding. + + Returns: + A list of EmbedData objects. + """ + + def embed_video_directory( + self, + directory: str, + config: VideoEmbedConfig | None = None, + adapter: Adapter | None = None, + ) -> list[EmbedData] | None: + """ + Embeds videos in the given directory and returns a list of EmbedData objects. + + Args: + directory: The path to the directory to embed. + config: The configuration for video embedding. + adapter: The adapter for the embedding. + + Returns: + A list of EmbedData objects, or None if an adapter is used. + """ + def embed_query( self, query: list[str], diff --git a/python/src/config.rs b/python/src/config.rs index 258c4468..f8a749d1 100644 --- a/python/src/config.rs +++ b/python/src/config.rs @@ -92,3 +92,44 @@ impl ImageEmbedConfig { self.inner.batch_size } } + +#[pyclass] +#[derive(Clone, Default)] +pub struct VideoEmbedConfig { + pub inner: embed_anything::config::VideoEmbedConfig, +} + +#[pymethods] +impl VideoEmbedConfig { + #[new] + #[pyo3(signature = (frame_step=None, max_frames=None, batch_size=None))] + pub fn new( + frame_step: Option, + max_frames: Option, + batch_size: Option, + ) -> Self { + let default_config = embed_anything::config::VideoEmbedConfig::default(); + Self { + inner: embed_anything::config::VideoEmbedConfig { + frame_step: frame_step.or(default_config.frame_step), + max_frames: max_frames.or(default_config.max_frames), + batch_size: batch_size.or(default_config.batch_size), + }, + } + } + + #[getter] + pub fn frame_step(&self) -> Option { + self.inner.frame_step + } + + #[getter] + pub fn max_frames(&self) -> Option { + self.inner.max_frames + } + + #[getter] + pub fn batch_size(&self) -> Option { + self.inner.batch_size + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index f998750f..bef17ff7 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -7,6 +7,8 @@ use embed_anything::{ self, config::TextEmbedConfig, emb_audio, + embed_video_directory as embed_video_directory_rs, + embed_video_file as embed_video_file_rs, embeddings::embed::{Embedder, EmbeddingResult}, file_processor::audio::audio_processor, FileLoadingError, @@ -385,6 +387,25 @@ impl EmbeddingModel { ) -> PyResult>> { embed_audio_file(audio_file, audio_decoder, self, config) } + + #[pyo3(signature = (video_file, config=None))] + pub fn embed_video_file( + &self, + video_file: &str, + config: Option<&config::VideoEmbedConfig>, + ) -> PyResult> { + embed_video_file(video_file, self, config) + } + + #[pyo3(signature = (directory, config=None, adapter=None))] + pub fn embed_video_directory( + &self, + directory: PathBuf, + config: Option<&config::VideoEmbedConfig>, + adapter: Option, + ) -> PyResult>> { + embed_video_directory(directory, self, config, adapter) + } } #[pyclass] @@ -604,6 +625,90 @@ pub fn embed_audio_file( Ok(data) } +#[pyfunction] +#[pyo3(signature = (video_file, embedder, config=None))] +pub fn embed_video_file( + video_file: &str, + embedder: &EmbeddingModel, + config: Option<&config::VideoEmbedConfig>, +) -> PyResult> { + let config = config.map(|c| &c.inner); + let embedding_model = &embedder.inner; + let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + + if !Path::new(video_file).exists() { + return Err(PyFileNotFoundError::new_err(format!( + "File not found: {:?}", + video_file + ))); + }; + + let data = rt + .block_on(async { embed_video_file_rs(video_file, embedding_model.as_ref(), config).await }) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + + Ok(data.into_iter().map(|data| EmbedData { inner: data }).collect()) +} + +#[pyfunction] +#[pyo3(signature = (directory, embedder, config=None, adapter=None))] +pub fn embed_video_directory( + directory: PathBuf, + embedder: &EmbeddingModel, + config: Option<&config::VideoEmbedConfig>, + adapter: Option, +) -> PyResult>> { + let config = config.map(|c| &c.inner); + let embedding_model = &embedder.inner; + let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + + let adapter = match adapter { + Some(adapter) => { + let callback = move |data: Vec| { + Python::with_gil(|py| { + let upsert_fn = adapter.getattr(py, "upsert").unwrap(); + let converted_data = data + .into_iter() + .map(|data| EmbedData { inner: data }) + .collect::>(); + upsert_fn + .call1(py, (converted_data,)) + .map_err(|e| PyValueError::new_err(e.to_string())) + .unwrap(); + }); + }; + Some(callback) + } + None => None, + }; + + let data = rt.block_on(async { + embed_video_directory_rs( + directory, + embedding_model, + config, + adapter.map(|f| { + Box::new(f) + as Box< + dyn FnMut(Vec) + + Send + + Sync, + > + }), + ) + .await + .map_err(|e| PyValueError::new_err(e.to_string())) + .unwrap() + .map(|data| { + data.into_iter() + .map(|data| EmbedData { inner: data }) + .collect::>() + }) + }); + + Ok(data) +} + #[pyfunction] #[pyo3(signature = (directory, embedder, extensions=None, config=None, adapter = None))] pub fn embed_directory( @@ -780,6 +885,8 @@ fn _embed_anything(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(embed_query, m)?)?; m.add_function(wrap_pyfunction!(embed_webpage, m)?)?; m.add_function(wrap_pyfunction!(embed_audio_file, m)?)?; + m.add_function(wrap_pyfunction!(embed_video_file, m)?)?; + m.add_function(wrap_pyfunction!(embed_video_directory, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -787,6 +894,7 @@ fn _embed_anything(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 2958ea4b..4f01a502 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -96,6 +96,7 @@ cudnn = ["candle-core/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] metal = ["candle-core/metal", "candle-nn/metal"] audio = ["dep:symphonia"] +video = ["processors-rs/video"] ort = ["dep:ort",] aws = ["dep:aws-sdk-s3", "dep:aws-config", "dep:aws-credential-types"] rustls-tls = [ diff --git a/rust/src/config.rs b/rust/src/config.rs index 4655519c..0d67e68c 100644 --- a/rust/src/config.rs +++ b/rust/src/config.rs @@ -185,3 +185,38 @@ impl ImageEmbedConfig { } } } + +pub const DEFAULT_VIDEO_FRAME_STEP: usize = 30; +pub const DEFAULT_VIDEO_MAX_FRAMES: usize = 300; +pub const DEFAULT_VIDEO_BATCH_SIZE: usize = 32; + +#[derive(Clone)] +pub struct VideoEmbedConfig { + pub frame_step: Option, + pub max_frames: Option, + pub batch_size: Option, +} + +impl Default for VideoEmbedConfig { + fn default() -> Self { + Self { + frame_step: Some(DEFAULT_VIDEO_FRAME_STEP), + max_frames: Some(DEFAULT_VIDEO_MAX_FRAMES), + batch_size: Some(DEFAULT_VIDEO_BATCH_SIZE), + } + } +} + +impl VideoEmbedConfig { + pub fn new( + frame_step: Option, + max_frames: Option, + batch_size: Option, + ) -> Self { + Self { + frame_step, + max_frames, + batch_size, + } + } +} diff --git a/rust/src/file_loader.rs b/rust/src/file_loader.rs index 132c5fa4..5aa5d818 100644 --- a/rust/src/file_loader.rs +++ b/rust/src/file_loader.rs @@ -94,6 +94,27 @@ impl FileParser { Ok(self.files.clone()) } + pub fn get_video_paths(&mut self, directory_path: &PathBuf) -> Result, Error> { + let video_regex = Regex::new(r".*\.(mp4|mov|avi|mkv|webm|m4v|flv|wmv)$").unwrap(); + + let video_paths: Vec = WalkDir::new(directory_path) + .into_iter() + .filter_map(|entry| entry.ok()) + .filter(|entry| entry.file_type().is_file()) + .filter(|entry| video_regex.is_match(entry.file_name().to_str().unwrap_or(""))) + .map(|entry| { + let absolute_path = entry + .path() + .canonicalize() + .unwrap_or_else(|_| entry.path().to_path_buf()); + absolute_path.to_string_lossy().to_string() + }) + .collect(); + + self.files = video_paths; + Ok(self.files.clone()) + } + pub fn get_files_to_index(&self, indexed_files: &HashSet) -> Vec { let files = self .files @@ -209,6 +230,30 @@ mod tests { assert_eq!(audio_files.len(), 2); } + #[test] + fn test_get_video_paths() { + let temp_dir = TempDir::new("example").unwrap(); + let video_file = temp_dir.path().join("clip.mp4"); + let _ignored_file = temp_dir.path().join("note.txt"); + File::create(&video_file).unwrap(); + File::create(&_ignored_file).unwrap(); + + let mut file_parser = FileParser::new(); + let video_files = file_parser + .get_video_paths(&PathBuf::from(temp_dir.path())) + .unwrap(); + + assert_eq!(video_files.len(), 1); + assert_eq!( + video_files[0], + video_file + .canonicalize() + .unwrap() + .to_string_lossy() + .to_string() + ); + } + #[test] fn test_get_files_to_index() { let temp_dir = TempDir::new("example").unwrap(); diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 3a591f93..75eb8275 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -108,7 +108,10 @@ pub mod text_loader; pub mod s3_loader; use anyhow::{Error, Result}; -use config::{ImageEmbedConfig, TextEmbedConfig}; +use config::{ + ImageEmbedConfig, TextEmbedConfig, VideoEmbedConfig, DEFAULT_VIDEO_BATCH_SIZE, + DEFAULT_VIDEO_FRAME_STEP, +}; use embeddings::{ embed::{EmbedData, EmbedImage, Embedder, TextEmbedder, VisionEmbedder}, get_text_metadata, @@ -132,6 +135,8 @@ use processors_rs::{ processor::{Document, FileProcessor, UrlProcessor}, txt_processor::TxtProcessor, }; +#[cfg(feature = "video")] +use processors_rs::video_processor::VideoProcessor; /// Numerical precision types for model weights and computations. pub enum Dtype { @@ -204,6 +209,16 @@ pub async fn embed_query( Ok(embeddings) } +fn is_video_extension(extension: &std::ffi::OsStr) -> bool { + match extension.to_str().map(|ext| ext.to_ascii_lowercase()) { + Some(ext) => matches!( + ext.as_str(), + "mp4" | "mov" | "avi" | "mkv" | "webm" | "m4v" | "flv" | "wmv" + ), + None => false, + } +} + /// Embeds the text from a file using the specified embedding model. /// /// # Arguments @@ -256,6 +271,25 @@ pub async fn embed_file>( ); Ok(Some(embedder.embed_pdf(file_name, Some(batch_size)).await?)) } + Some(extension) if is_video_extension(extension) => { + #[cfg(feature = "video")] + { + let batch_size = config.and_then(|cfg| cfg.batch_size); + let video_config = VideoEmbedConfig::default(); + let video_config = VideoEmbedConfig { + batch_size, + ..video_config + }; + let embeddings = emb_video(file_name, embedder, &video_config).await?; + Ok(Some(embeddings)) + } + #[cfg(not(feature = "video"))] + { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Rebuild with --features video to embed videos." + )) + } + } _ => Ok(Some(vec![emb_image(file_name, embedder).await?])), }, } @@ -394,6 +428,144 @@ async fn emb_image>( Ok(embedding) } +#[cfg(feature = "video")] +pub async fn embed_video_file>( + file_name: T, + embedder: &Embedder, + config: Option<&VideoEmbedConfig>, +) -> Result> { + let default_config = VideoEmbedConfig::default(); + let config = config.unwrap_or(&default_config); + match embedder { + Embedder::Vision(embedder) => emb_video(file_name, embedder, config).await, + _ => Err(anyhow::anyhow!( + "Model not supported for video embedding" + )), + } +} + +#[cfg(not(feature = "video"))] +pub async fn embed_video_file>( + _file_name: T, + _embedder: &Embedder, + _config: Option<&VideoEmbedConfig>, +) -> Result> { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Please enable it to use video embedding." + )) +} + +#[cfg(feature = "video")] +pub async fn embed_video_directory( + directory: PathBuf, + embedding_model: &Arc, + config: Option<&VideoEmbedConfig>, + adapter: Option) + Send + Sync>>, +) -> Result>> { + let mut file_parser = FileParser::new(); + file_parser.get_video_paths(&directory)?; + + let default_config = VideoEmbedConfig::default(); + let config = config.unwrap_or(&default_config); + let files = file_parser.files.clone(); + let pb = indicatif::ProgressBar::new(files.len() as u64); + pb.set_style(indicatif::ProgressStyle::with_template( + "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})", + )?); + + let mut all_embeddings = Vec::new(); + let mut adapter = adapter; + for file in files { + let embeddings = embed_video_file(&file, embedding_model.as_ref(), Some(config)).await?; + pb.inc(1); + + if let Some(adapter) = adapter.as_mut() { + adapter(embeddings); + } else { + all_embeddings.extend(embeddings); + } + } + + if adapter.is_some() { + Ok(None) + } else { + Ok(Some(all_embeddings)) + } +} + +#[cfg(not(feature = "video"))] +pub async fn embed_video_directory( + _directory: PathBuf, + _embedding_model: &Arc, + _config: Option<&VideoEmbedConfig>, + _adapter: Option) + Send + Sync>>, +) -> Result>> { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Please enable it to use video embedding." + )) +} + +#[cfg(feature = "video")] +pub async fn emb_video>( + video_path: T, + embedding_model: &VisionEmbedder, + config: &VideoEmbedConfig, +) -> Result> { + let frame_step = config.frame_step.unwrap_or(DEFAULT_VIDEO_FRAME_STEP).max(1); + let max_frames = config.max_frames.filter(|value| *value > 0); + let batch_size = config + .batch_size + .unwrap_or(DEFAULT_VIDEO_BATCH_SIZE) + .max(1); + + let processor = VideoProcessor::new(frame_step); + let processor = match max_frames { + Some(limit) => processor.with_max_frames(limit), + None => processor, + }; + + let video_path = video_path.as_ref(); + let video_path_string = fs::canonicalize(video_path) + .unwrap_or_else(|_| video_path.to_path_buf()) + .to_string_lossy() + .to_string(); + + let (_temp_dir, frames) = processor.extract_frames_to_temp_dir(video_path)?; + if frames.is_empty() { + return Err(anyhow::anyhow!("No frames extracted from video")); + } + + let mut all_embeddings = Vec::new(); + for frame_chunk in frames.chunks(batch_size) { + let frame_paths: Vec<&std::path::Path> = + frame_chunk.iter().map(|frame| frame.path.as_path()).collect(); + let mut embeddings = embedding_model + .embed_image_batch(&frame_paths, Some(batch_size)) + .await?; + + for (embedding, frame) in embeddings.iter_mut().zip(frame_chunk.iter()) { + let metadata = embedding.metadata.get_or_insert_with(HashMap::new); + metadata.insert("video_path".to_string(), video_path_string.clone()); + metadata.insert("frame_index".to_string(), frame.index.to_string()); + } + + all_embeddings.extend(embeddings); + } + + Ok(all_embeddings) +} + +#[cfg(not(feature = "video"))] +pub async fn emb_video>( + _video_path: T, + _embedding_model: &VisionEmbedder, + _config: &VideoEmbedConfig, +) -> Result> { + Err(anyhow::anyhow!( + "The 'video' feature is not enabled. Please enable it to use the emb_video function." + )) +} + /// Embeds audio content from a file using transcription and temporal segmentation. /// /// Processes audio files by first transcribing them to text and then creating diff --git a/server-cuda.Dockerfile b/server-cuda.Dockerfile new file mode 100644 index 00000000..c4ddc28b --- /dev/null +++ b/server-cuda.Dockerfile @@ -0,0 +1,85 @@ +FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 AS base-builder + +ENV SCCACHE=0.10.0 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" +# aligned with `cargo-chef` version in `lukemathwalker/cargo-chef:latest-rust-1.85-bookworm` +ENV CARGO_CHEF=0.1.71 +ENV CUDA_COMPUTE_CAP=75 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + python3 \ + python3-dev \ + && rm -rf /var/lib/apt/lists/* + +# Download and configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --version $CARGO_CHEF --locked + +FROM base-builder AS planner + +WORKDIR /app + +COPY processors processors +COPY rust rust +COPY python python +COPY server server +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM base-builder AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG SCCACHE_GHA_ENABLED + +# Limit parallelism +ARG RAYON_NUM_THREADS=4 +ARG CARGO_BUILD_JOBS +ARG CARGO_BUILD_INCREMENTAL + +WORKDIR /app + +COPY --from=planner /app/recipe.json recipe.json + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo chef cook --release --recipe-path recipe.json --package server --features cuda && sccache -s; + +COPY processors processors +COPY rust rust +COPY server server +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin server --features cuda && sccache -s; + +RUN strip /app/target/release/server + +FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 AS runtime + +WORKDIR /app + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + libssl-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=builder /app/target/release/server /usr/local/bin/server + +EXPOSE 8080 + +CMD ["server"] diff --git a/server.Dockerfile b/server.Dockerfile index 850ded43..3e85d322 100644 --- a/server.Dockerfile +++ b/server.Dockerfile @@ -17,7 +17,7 @@ COPY --from=planner /app/recipe.json recipe.json RUN cargo chef cook --release --recipe-path recipe.json --package server # Build application COPY . . -RUN cargo build --release --package server +RUN cargo build --release -p server RUN strip target/release/server # We do not need the Rust toolchain to run the binary! diff --git a/server/Cargo.toml b/server/Cargo.toml index 471ba7fd..a5e9d85a 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,6 +14,8 @@ futures-util = "0.3" serde = { version = "1.0", features = ["derive"] } tempfile = "3" tokio = { version = "1", features = ["fs", "io-util"] } +base64 = "0.22.1" +image = "0.25.6" [target.'cfg(not(target_os = "macos"))'.dependencies] embed_anything = {path = "../rust"} diff --git a/server/src/lib.rs b/server/src/lib.rs index 12511dc7..a37104ee 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -7,10 +7,11 @@ use std::time::{SystemTime, UNIX_EPOCH}; use actix_multipart::Multipart; use actix_web::dev::Server; use actix_web::{get, post, web, App, HttpResponse, HttpServer}; +use base64::Engine; use embed_anything::config::TextEmbedConfig; use embed_anything::{ embed_files_batch, embed_query, - embeddings::embed::{EmbedData, EmbedderBuilder, EmbeddingResult}, + embeddings::embed::{EmbedData, EmbedderBuilder, EmbeddingResult, EmbedImage}, }; use futures_util::StreamExt; use serde::{Deserialize, Serialize}; @@ -88,6 +89,29 @@ enum PdfEmbeddingVector { Multi(Vec>), } +// Image embedding request structure +#[derive(Deserialize)] +struct ImageEmbedRequest { + model: String, + images: Vec, // Base64 encoded images +} + +// Image embedding response structure +#[derive(Serialize)] +struct ImageEmbeddingResponse { + object: String, + data: Vec, + model: String, +} + +#[derive(Serialize)] +struct ImageEmbeddingData { + object: String, + index: usize, + embedding: Vec, + metadata: Option>, +} + fn pdf_embedding_response(model: String, embeddings: Vec) -> PdfEmbeddingResponse { let data = embeddings .into_iter() @@ -133,6 +157,21 @@ async fn create_embeddings(req: web::Json) -> HttpResponse { }); } + // Detect input type: check if all inputs are base64 images + let all_images = req.input.iter().all(|input| is_base64_image(input)); + let all_text = req.input.iter().all(|input| !is_base64_image(input)); + + // If mixed input types, return error + if !all_images && !all_text { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: "Mixed input types detected. Please provide either all text inputs or all base64 image inputs.".to_string(), + error_type: "invalid_request_error".to_string(), + code: Some("mixed_input_types".to_string()), + }, + }); + } + // Create embedder let embedder = match EmbedderBuilder::new() .model_id(Some(req.model.as_str())) @@ -150,21 +189,91 @@ async fn create_embeddings(req: web::Json) -> HttpResponse { } }; - // Convert input to string slices - let input_slices: Vec<&str> = req.input.iter().map(|s| s.as_str()).collect(); + // Route based on input type and model type + let embeddings = if all_images { + // Handle image embeddings + match embedder { + embed_anything::embeddings::embed::Embedder::Vision(vision_embedder) => { + // Create temp directory for image files + let temp_dir = match tempfile::tempdir() { + Ok(dir) => dir, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to create temp directory: {}", e), + error_type: "server_error".to_string(), + code: Some("temp_dir_creation_failed".to_string()), + }, + }); + } + }; - // Generate embeddings - let config = TextEmbedConfig::default(); - let embeddings = match embed_query(&input_slices, &embedder, Some(&config)).await { - Ok(embeddings) => embeddings, - Err(e) => { - return HttpResponse::InternalServerError().json(ErrorResponse { - error: ErrorDetail { - message: format!("Failed to generate embeddings: {}", e), - error_type: "server_error".to_string(), - code: Some("embedding_generation_failed".to_string()), - }, - }); + // Decode base64 images to temporary files + let mut image_paths = Vec::new(); + for (index, base64_image) in req.input.iter().enumerate() { + match decode_base64_to_temp_file(base64_image, index, &temp_dir).await { + Ok(path) => image_paths.push(path), + Err(e) => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to decode image at index {}: {}", index, e), + error_type: "invalid_request_error".to_string(), + code: Some("base64_decode_failed".to_string()), + }, + }); + } + } + } + + // Generate embeddings for images + match vision_embedder.embed_image_batch(&image_paths, None).await { + Ok(embeddings) => embeddings, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to generate image embeddings: {}", e), + error_type: "server_error".to_string(), + code: Some("embedding_generation_failed".to_string()), + }, + }); + } + } + } + _ => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!( + "Model '{}' does not support image embeddings. Please use a vision model (e.g., CLIP, SigLIP, DinoV2, ColPali).", + req.model + ), + error_type: "invalid_request_error".to_string(), + code: Some("unsupported_model".to_string()), + }, + }); + } + } + } else { + // Handle text embeddings + match embedder { + embed_anything::embeddings::embed::Embedder::Text(_) | embed_anything::embeddings::embed::Embedder::Vision(_) => { + // Convert input to string slices + let input_slices: Vec<&str> = req.input.iter().map(|s| s.as_str()).collect(); + + // Generate embeddings + let config = TextEmbedConfig::default(); + match embed_query(&input_slices, &embedder, Some(&config)).await { + Ok(embeddings) => embeddings, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to generate text embeddings: {}", e), + error_type: "server_error".to_string(), + code: Some("embedding_generation_failed".to_string()), + }, + }); + } + } + } } }; @@ -190,7 +299,11 @@ async fn create_embeddings(req: web::Json) -> HttpResponse { .collect(); // Calculate usage (simplified - you might want to implement proper token counting) - let total_tokens = req.input.iter().map(|s| s.split_whitespace().count()).sum(); + let total_tokens = if all_images { + req.input.len() // For images, count as 1 token per image (simplified) + } else { + req.input.iter().map(|s| s.split_whitespace().count()).sum() + }; let response = OpenAIEmbedResponse { object: "list".to_string(), @@ -507,6 +620,221 @@ async fn create_pdf_embeddings_upload(mut payload: Multipart) -> HttpResponse { HttpResponse::Ok().json(response) } +/// Detects if a string is likely a base64-encoded image +fn is_base64_image(input: &str) -> bool { + // Check if it starts with data URL prefix + if input.starts_with("data:image/") { + return true; + } + + // Check if it's valid base64 and try to decode and validate as image + let base64_data = input.trim(); + + // Base64 strings should be reasonably long and contain only base64 characters + if base64_data.len() < 100 { + return false; + } + + // Check if it contains only base64 characters (with optional padding) + if !base64_data.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=' + }) { + return false; + } + + // Try to decode and validate as image + if let Ok(image_bytes) = base64::engine::general_purpose::STANDARD.decode(base64_data) { + // Try to detect image format from bytes + if image::ImageReader::new(std::io::Cursor::new(&image_bytes)) + .with_guessed_format() + .is_ok() + { + return true; + } + } + + false +} + +/// Decodes a base64 image string and writes it to a temporary file +async fn decode_base64_to_temp_file( + base64_str: &str, + index: usize, + temp_dir: &tempfile::TempDir, +) -> Result { + // Remove data URL prefix if present (e.g., "data:image/png;base64,") + let base64_data = if base64_str.starts_with("data:") { + base64_str + .split(',') + .nth(1) + .ok_or_else(|| "Invalid data URL format".to_string())? + } else { + base64_str + }; + + // Decode base64 + let image_bytes = base64::engine::general_purpose::STANDARD + .decode(base64_data.trim()) + .map_err(|e| format!("Failed to decode base64: {}", e))?; + + // Try to load image from memory to validate and determine format + let image_format = image::ImageReader::new(std::io::Cursor::new(&image_bytes)) + .with_guessed_format() + .map_err(|e| format!("Failed to read image: {}", e))? + .format(); + + // Determine file extension from format + let extension = match image_format { + Some(image::ImageFormat::Png) => "png", + Some(image::ImageFormat::Jpeg) => "jpg", + Some(image::ImageFormat::WebP) => "webp", + Some(image::ImageFormat::Gif) => "gif", + Some(image::ImageFormat::Bmp) => "bmp", + _ => "png", // Default to PNG if format cannot be determined + }; + + // Create temporary file + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let filename = format!("image_{}_{}.{}", timestamp, index, extension); + let file_path = temp_dir.path().join(filename); + + // Write image bytes to file + tokio::fs::write(&file_path, image_bytes) + .await + .map_err(|e| format!("Failed to write temp file: {}", e))?; + + Ok(file_path) +} + +#[post("/v1/image_embeddings")] +async fn create_image_embeddings(req: web::Json) -> HttpResponse { + // Validate input + if req.images.is_empty() { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: "Images cannot be empty".to_string(), + error_type: "invalid_request_error".to_string(), + code: Some("empty_images".to_string()), + }, + }); + } + + // Create temp directory for image files + let temp_dir = match tempfile::tempdir() { + Ok(dir) => dir, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to create temp directory: {}", e), + error_type: "server_error".to_string(), + code: Some("temp_dir_creation_failed".to_string()), + }, + }); + } + }; + + // Decode base64 images to temporary files + let mut image_paths = Vec::new(); + for (index, base64_image) in req.images.iter().enumerate() { + match decode_base64_to_temp_file(base64_image, index, &temp_dir).await { + Ok(path) => image_paths.push(path), + Err(e) => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to decode image at index {}: {}", index, e), + error_type: "invalid_request_error".to_string(), + code: Some("base64_decode_failed".to_string()), + }, + }); + } + } + } + + // Create embedder - try to determine if it's a vision model + let embedder = match EmbedderBuilder::new() + .model_id(Some(req.model.as_str())) + .from_pretrained_hf() + { + Ok(embedder) => embedder, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to initialize embedder: {}", e), + error_type: "server_error".to_string(), + code: Some("embedder_init_failed".to_string()), + }, + }); + } + }; + + // Check if embedder supports vision + let vision_embedder = match embedder { + embed_anything::embeddings::embed::Embedder::Vision(embedder) => embedder, + _ => { + return HttpResponse::BadRequest().json(ErrorResponse { + error: ErrorDetail { + message: format!( + "Model '{}' does not support image embeddings. Please use a vision model (e.g., CLIP, SigLIP, DinoV2, ColPali).", + req.model + ), + error_type: "invalid_request_error".to_string(), + code: Some("unsupported_model".to_string()), + }, + }); + } + }; + + // Generate embeddings for images + let embeddings = match vision_embedder + .embed_image_batch(&image_paths, None) + .await + { + Ok(embeddings) => embeddings, + Err(e) => { + return HttpResponse::InternalServerError().json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to generate embeddings: {}", e), + error_type: "server_error".to_string(), + code: Some("embedding_generation_failed".to_string()), + }, + }); + } + }; + + // Convert to response format + let embedding_data: Vec = embeddings + .into_iter() + .enumerate() + .map(|(index, embed_data)| { + let embedding_vector = match embed_data.embedding { + EmbeddingResult::DenseVector(vec) => vec, + EmbeddingResult::MultiVector(_) => { + // For multi-vector embeddings, return empty (or handle differently) + vec![] + } + }; + + ImageEmbeddingData { + object: "embedding".to_string(), + index, + embedding: embedding_vector, + metadata: embed_data.metadata, + } + }) + .collect(); + + let response = ImageEmbeddingResponse { + object: "list".to_string(), + data: embedding_data, + model: req.model.clone(), + }; + + HttpResponse::Ok().json(response) +} + pub fn run(listener: TcpListener) -> std::io::Result { let server = HttpServer::new(|| { App::new() @@ -514,6 +842,7 @@ pub fn run(listener: TcpListener) -> std::io::Result { .service(create_embeddings) .service(create_pdf_embeddings) .service(create_pdf_embeddings_upload) + .service(create_image_embeddings) }) .listen(listener)? .run();