From 32a15bbd12559679cd94eba014eee7a327ff8aea Mon Sep 17 00:00:00 2001 From: Nathan LeRoy Date: Tue, 20 Jan 2026 10:50:09 -0500 Subject: [PATCH 1/2] add protein embedding model --- fastembed/__init__.py | 4 +- fastembed/bio/__init__.py | 3 + fastembed/bio/protein_embedding.py | 447 +++++++++++++++++++++++++++++ 3 files changed, 453 insertions(+), 1 deletion(-) create mode 100644 fastembed/bio/__init__.py create mode 100644 fastembed/bio/protein_embedding.py diff --git a/fastembed/__init__.py b/fastembed/__init__.py index 7a2e41340..c961e09e5 100644 --- a/fastembed/__init__.py +++ b/fastembed/__init__.py @@ -1,5 +1,6 @@ import importlib.metadata +from fastembed.bio import ProteinEmbedding from fastembed.image import ImageEmbedding from fastembed.late_interaction import LateInteractionTextEmbedding from fastembed.late_interaction_multimodal import LateInteractionMultimodalEmbedding @@ -19,4 +20,5 @@ "ImageEmbedding", "LateInteractionTextEmbedding", "LateInteractionMultimodalEmbedding", -] + "ProteinEmbedding", +] \ No newline at end of file diff --git a/fastembed/bio/__init__.py b/fastembed/bio/__init__.py new file mode 100644 index 000000000..73c9707df --- /dev/null +++ b/fastembed/bio/__init__.py @@ -0,0 +1,3 @@ +from fastembed.bio.protein_embedding import ProteinEmbedding + +__all__ = ["ProteinEmbedding"] \ No newline at end of file diff --git a/fastembed/bio/protein_embedding.py b/fastembed/bio/protein_embedding.py new file mode 100644 index 000000000..012eed95f --- /dev/null +++ b/fastembed/bio/protein_embedding.py @@ -0,0 +1,447 @@ +from dataclasses import asdict +from pathlib import Path +from typing import Any, Iterable, Sequence, Type + +import numpy as np + +from fastembed.common.model_description import DenseModelDescription, ModelSource +from fastembed.common.model_management import ModelManagement +from fastembed.common.onnx_model import OnnxModel, OnnxOutputContext, EmbeddingWorker +from fastembed.common.types import NumpyArray, OnnxProvider, Device +from fastembed.common.utils import define_cache_dir, iter_batch, normalize + + +supported_protein_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="facebook/esm2_t12_35M_UR50D", + dim=480, + description="Protein embeddings, ESM-2 35M parameters, 480 dimensions, 1024 max sequence length", + license="mit", + size_in_GB=0.13, + sources=ModelSource(hf="facebook/esm2_t12_35M_UR50D"), + model_file="model.onnx", + additional_files=["vocab.txt"], + ), +] + + +class ProteinTokenizer: + """ + Simple tokenizer for protein sequences using ESM-2 vocabulary. + """ + + def __init__(self, vocab_path: Path, max_length: int = 1024): + self.max_length = max_length + self.vocab: dict[str, int] = {} + self.id_to_token: dict[int, str] = {} + + with open(vocab_path) as f: + for idx, line in enumerate(f): + token = line.strip() + self.vocab[token] = idx + self.id_to_token[idx] = token + + self.cls_token_id = self.vocab.get("", 0) + self.eos_token_id = self.vocab.get("", 2) + self.pad_token_id = self.vocab.get("", 1) + self.unk_token_id = self.vocab.get("", 3) + + def encode(self, sequence: str) -> tuple[list[int], list[int]]: + """Encode a single protein sequence. + + Args: + sequence: Protein sequence (amino acid string) + + Returns: + Tuple of (input_ids, attention_mask) + """ + sequence = sequence.upper() + + input_ids = [self.cls_token_id] + for aa in sequence: + input_ids.append(self.vocab.get(aa, self.unk_token_id)) + input_ids.append(self.eos_token_id) + + if len(input_ids) > self.max_length: + input_ids = input_ids[: self.max_length] + + attention_mask = [1] * len(input_ids) + + return input_ids, attention_mask + + def encode_batch( + self, sequences: list[str] + ) -> tuple[list[list[int]], list[list[int]]]: + """Encode a batch of protein sequences with padding. + + Args: + sequences: List of protein sequences + + Returns: + Tuple of (input_ids, attention_masks) with padding + """ + all_input_ids = [] + all_attention_masks = [] + max_len = 0 + + for seq in sequences: + input_ids, attention_mask = self.encode(seq) + all_input_ids.append(input_ids) + all_attention_masks.append(attention_mask) + max_len = max(max_len, len(input_ids)) + + for i in range(len(all_input_ids)): + padding_length = max_len - len(all_input_ids[i]) + all_input_ids[i].extend([self.pad_token_id] * padding_length) + all_attention_masks[i].extend([0] * padding_length) + + return all_input_ids, all_attention_masks + + +class ProteinEmbeddingBase(ModelManagement[DenseModelDescription]): + def __init__( + self, + model_name: str, + cache_dir: str | None = None, + threads: int | None = None, + **kwargs: Any, + ): + self.model_name = model_name + self.cache_dir = cache_dir + self.threads = threads + self._local_files_only = kwargs.pop("local_files_only", False) + self._embedding_size: int | None = None + + def embed( + self, + sequences: str | Iterable[str], + batch_size: int = 32, + parallel: int | None = None, + **kwargs: Any, + ) -> Iterable[NumpyArray]: + """ + Embed protein sequences. + + Args: + sequences: Single protein sequence or iterable of sequences + batch_size: Batch size for encoding + parallel: Number of parallel workers (None for single-threaded) + + Yields: + Embeddings as numpy arrays + """ + raise NotImplementedError() + + @classmethod + def get_embedding_size(cls, model_name: str) -> int: + """ + Returns embedding size of the passed model. + + Args: + model_name: Name of the model + """ + descriptions = cls._list_supported_models() + for description in descriptions: + if description.model.lower() == model_name.lower(): + if description.dim is not None: + return description.dim + raise ValueError(f"Model {model_name} not found") + + @property + def embedding_size(self) -> int: + """ + Returns embedding size for the current model. + """ + if self._embedding_size is None: + self._embedding_size = self.get_embedding_size(self.model_name) + return self._embedding_size + + +class OnnxProteinModel(OnnxModel[NumpyArray]): + """ + ONNX model handler for protein embeddings. + """ + + ONNX_OUTPUT_NAMES: list[str] | None = None + + def __init__(self) -> None: + super().__init__() + self.tokenizer: ProteinTokenizer | None = None + + def _load_onnx_model( + self, + model_dir: Path, + model_file: str, + threads: int | None, + providers: Sequence[OnnxProvider] | None = None, + cuda: bool | Device = Device.AUTO, + device_id: int | None = None, + extra_session_options: dict[str, Any] | None = None, + ) -> None: + super()._load_onnx_model( + model_dir=model_dir, + model_file=model_file, + threads=threads, + providers=providers, + cuda=cuda, + device_id=device_id, + extra_session_options=extra_session_options, + ) + vocab_path = model_dir / "vocab.txt" + if not vocab_path.exists(): + raise ValueError(f"Could not find vocab.txt in {model_dir}") + self.tokenizer = ProteinTokenizer(vocab_path) + + def onnx_embed(self, sequences: list[str], **kwargs: Any) -> OnnxOutputContext: + """ + Run ONNX inference on protein sequences. + + Args: + sequences: List of protein sequences + Returns: + OnnxOutputContext containing model output and inputs + """ + assert self.tokenizer is not None + + input_ids, attention_masks = self.tokenizer.encode_batch(sequences) + + input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] + onnx_input: dict[str, NumpyArray] = { + "input_ids": np.array(input_ids, dtype=np.int64), + } + if "attention_mask" in input_names: + onnx_input["attention_mask"] = np.array(attention_masks, dtype=np.int64) + + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] + + return OnnxOutputContext( + model_output=model_output[0], + attention_mask=np.array(attention_masks, dtype=np.int64), + input_ids=np.array(input_ids, dtype=np.int64), + ) + + def _post_process_onnx_output( + self, output: OnnxOutputContext, **kwargs: Any + ) -> Iterable[NumpyArray]: + """Convert ONNX output to embeddings with mean pooling.""" + embeddings = output.model_output + attention_mask = output.attention_mask + + if attention_mask is None: + raise ValueError("attention_mask is required for mean pooling") + + mask_expanded = np.expand_dims(attention_mask, axis=-1) + sum_embeddings = np.sum(embeddings * mask_expanded, axis=1) + sum_mask = np.sum(mask_expanded, axis=1) + sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=None) + mean_embeddings = sum_embeddings / sum_mask + + return normalize(mean_embeddings) + + +class OnnxProteinEmbedding(ProteinEmbeddingBase, OnnxProteinModel): + """ + ONNX-based protein embedding implementation. + """ + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + return supported_protein_models + + def __init__( + self, + model_name: str = "facebook/esm2_t12_35M_UR50D", + cache_dir: str | None = None, + threads: int | None = None, + providers: Sequence[OnnxProvider] | None = None, + cuda: bool | Device = Device.AUTO, + device_ids: list[int] | None = None, + lazy_load: bool = False, + device_id: int | None = None, + specific_model_path: str | None = None, + **kwargs: Any, + ): + super().__init__(model_name, cache_dir, threads, **kwargs) + self.providers = providers + self.lazy_load = lazy_load + self._extra_session_options = self._select_exposed_session_options(kwargs) + self.device_ids = device_ids + self.cuda = cuda + + self.device_id: int | None = None + if device_id is not None: + self.device_id = device_id + elif self.device_ids is not None: + self.device_id = self.device_ids[0] + + self.model_description = self._get_model_description(model_name) + self.cache_dir = str(define_cache_dir(cache_dir)) + self._specific_model_path = specific_model_path + self._model_dir = self.download_model( + self.model_description, + self.cache_dir, + local_files_only=self._local_files_only, + specific_model_path=self._specific_model_path, + ) + + if not self.lazy_load: + self.load_onnx_model() + + def load_onnx_model(self) -> None: + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description.model_file, + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + extra_session_options=self._extra_session_options, + ) + + def embed( + self, + sequences: str | Iterable[str], + batch_size: int = 32, + parallel: int | None = None, + **kwargs: Any, + ) -> Iterable[NumpyArray]: + """ + Embed protein sequences. + + Args: + sequences: Single protein sequence or iterable of sequences (amino acid strings) + batch_size: Batch size for encoding + parallel: Number of parallel workers (not yet supported) + + Yields: + Embeddings as numpy arrays, one per sequence + """ + if isinstance(sequences, str): + sequences = [sequences] + + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + for batch in iter_batch(sequences, batch_size): + yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs), **kwargs) + + @classmethod + def _get_worker_class(cls) -> Type["ProteinEmbeddingWorker"]: + return ProteinEmbeddingWorker + + +class ProteinEmbeddingWorker(EmbeddingWorker[NumpyArray]): + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxProteinEmbedding: + return OnnxProteinEmbedding( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) + + def process( + self, items: Iterable[tuple[int, Any]] + ) -> Iterable[tuple[int, OnnxOutputContext]]: + for idx, batch in items: + onnx_output = self.model.onnx_embed(batch) + yield idx, onnx_output + + +class ProteinEmbedding(ProteinEmbeddingBase): + """ + Protein sequence embedding using ESM-2 and similar models. + + Example: + >>> from fastembed.bio import ProteinEmbedding + >>> model = ProteinEmbedding("facebook/esm2_t12_35M_UR50D") + >>> embeddings = list(model.embed(["MKTVRQERLKS", "GKGDPKKPRGKM"])) + >>> print(embeddings[0].shape) + (480,) + """ + + EMBEDDINGS_REGISTRY: list[Type[ProteinEmbeddingBase]] = [OnnxProteinEmbedding] + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + """ + return [asdict(model) for model in cls._list_supported_models()] + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + result: list[DenseModelDescription] = [] + for embedding in cls.EMBEDDINGS_REGISTRY: + result.extend(embedding._list_supported_models()) + return result + + def __init__( + self, + model_name: str = "facebook/esm2_t12_35M_UR50D", + cache_dir: str | None = None, + threads: int | None = None, + providers: Sequence[OnnxProvider] | None = None, + cuda: bool | Device = Device.AUTO, + device_ids: list[int] | None = None, + lazy_load: bool = False, + **kwargs: Any, + ): + """ + Initialize ProteinEmbedding. + + Args: + model_name: Name of the model to use + cache_dir: Path to cache directory + threads: Number of threads for ONNX runtime + providers: ONNX execution providers + cuda: Whether to use CUDA + device_ids: List of device IDs for multi-GPU + lazy_load: Whether to load model lazily + """ + super().__init__(model_name, cache_dir, threads, **kwargs) + + for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: + supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() + if any(model_name.lower() == model.model.lower() for model in supported_models): + self.model = EMBEDDING_MODEL_TYPE( + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + **kwargs, + ) + return + + raise ValueError( + f"Model {model_name} is not supported in ProteinEmbedding. " + "Please check the supported models using `ProteinEmbedding.list_supported_models()`" + ) + + def embed( + self, + sequences: str | Iterable[str], + batch_size: int = 32, + parallel: int | None = None, + **kwargs: Any, + ) -> Iterable[NumpyArray]: + """Embed protein sequences. + + Args: + sequences: Single protein sequence or iterable of sequences (amino acid strings) + batch_size: Batch size for encoding + parallel: Number of parallel workers + + Yields: + Embeddings as numpy arrays, one per sequence + """ + yield from self.model.embed(sequences, batch_size, parallel, **kwargs) \ No newline at end of file From 23dc2edf7bb1c2b3be30138a8d68cdab1f13b5d8 Mon Sep 17 00:00:00 2001 From: Nathan LeRoy Date: Tue, 20 Jan 2026 11:52:34 -0500 Subject: [PATCH 2/2] add esm2 fully --- README.md | 19 ++++ fastembed/bio/protein_embedding.py | 151 +++++++++++++++-------------- tests/test_protein_embeddings.py | 147 ++++++++++++++++++++++++++++ uv.lock | 3 + 4 files changed, 249 insertions(+), 71 deletions(-) create mode 100644 tests/test_protein_embeddings.py create mode 100644 uv.lock diff --git a/README.md b/README.md index d4c882615..b018df82d 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,25 @@ embeddings = list(model.embed(images)) # ] ``` +### 🧬 Protein embeddings + +```python +from fastembed import ProteinEmbedding + +sequences = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + "GKGDPKKPRGKMSSYAFFVQTSREEHKKKHPDASVNFSEFSKKCSERWKTMSAKEKGKFEDMAK", +] + +model = ProteinEmbedding(model_name="facebook/esm2_t12_35M_UR50D") +embeddings = list(model.embed(sequences)) + +# [ +# array([-0.0055, -0.0144, 0.0355, -0.0049, ...], dtype=float32), +# array([ 0.0114, 0.0020, -0.0247, 0.0060, ...], dtype=float32) +# ] +``` + ### Late interaction multimodal models (ColPali) ```python diff --git a/fastembed/bio/protein_embedding.py b/fastembed/bio/protein_embedding.py index 012eed95f..1b2781ddb 100644 --- a/fastembed/bio/protein_embedding.py +++ b/fastembed/bio/protein_embedding.py @@ -1,8 +1,11 @@ +import json from dataclasses import asdict from pathlib import Path from typing import Any, Iterable, Sequence, Type import numpy as np +from tokenizers import Tokenizer, pre_tokenizers, processors +from tokenizers.models import WordLevel from fastembed.common.model_description import DenseModelDescription, ModelSource from fastembed.common.model_management import ModelManagement @@ -18,84 +21,90 @@ description="Protein embeddings, ESM-2 35M parameters, 480 dimensions, 1024 max sequence length", license="mit", size_in_GB=0.13, - sources=ModelSource(hf="facebook/esm2_t12_35M_UR50D"), + sources=ModelSource(hf="nleroy917/esm2_t12_35M_UR50D-onnx"), model_file="model.onnx", - additional_files=["vocab.txt"], + additional_files=["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"], ), ] -class ProteinTokenizer: - """ - Simple tokenizer for protein sequences using ESM-2 vocabulary. - """ - - def __init__(self, vocab_path: Path, max_length: int = 1024): - self.max_length = max_length - self.vocab: dict[str, int] = {} - self.id_to_token: dict[int, str] = {} - - with open(vocab_path) as f: - for idx, line in enumerate(f): - token = line.strip() - self.vocab[token] = idx - self.id_to_token[idx] = token +def load_protein_tokenizer(model_dir: Path, max_length: int = 1024) -> Tokenizer: + """Load a protein tokenizer from model directory using HuggingFace tokenizers. - self.cls_token_id = self.vocab.get("", 0) - self.eos_token_id = self.vocab.get("", 2) - self.pad_token_id = self.vocab.get("", 1) - self.unk_token_id = self.vocab.get("", 3) + Attempts to load in order: + 1. tokenizer.json (standard HuggingFace fast tokenizer format) + 2. Build from vocab.txt (fallback for models without tokenizer.json) - def encode(self, sequence: str) -> tuple[list[int], list[int]]: - """Encode a single protein sequence. - - Args: - sequence: Protein sequence (amino acid string) - - Returns: - Tuple of (input_ids, attention_mask) - """ - sequence = sequence.upper() + Args: + model_dir: Path to model directory containing tokenizer files + max_length: Maximum sequence length (default, can be overridden by config) - input_ids = [self.cls_token_id] - for aa in sequence: - input_ids.append(self.vocab.get(aa, self.unk_token_id)) - input_ids.append(self.eos_token_id) + Returns: + Configured Tokenizer instance + """ + tokenizer_json_path = model_dir / "tokenizer.json" + tokenizer_config_path = model_dir / "tokenizer_config.json" + vocab_path = model_dir / "vocab.txt" + + # Try to load tokenizer.json directly (preferred) + if tokenizer_json_path.exists(): + tokenizer = Tokenizer.from_file(str(tokenizer_json_path)) + # Read max_length from config if available + if tokenizer_config_path.exists(): + with open(tokenizer_config_path) as f: + config = json.load(f) + config_max_length = config.get("model_max_length", max_length) + # Cap at reasonable value (transformers defaults can be huge) + if config_max_length <= max_length: + max_length = config_max_length + tokenizer.enable_truncation(max_length=max_length) + return tokenizer + + # Fall back to building from vocab.txt + if not vocab_path.exists(): + raise ValueError( + f"Could not find tokenizer.json or vocab.txt in {model_dir}" + ) - if len(input_ids) > self.max_length: - input_ids = input_ids[: self.max_length] + # Read max_length from config if available + if tokenizer_config_path.exists(): + with open(tokenizer_config_path) as f: + config = json.load(f) + max_length = config.get("model_max_length", max_length) - attention_mask = [1] * len(input_ids) + vocab: dict[str, int] = {} + with open(vocab_path) as f: + for idx, line in enumerate(f): + token = line.strip() + vocab[token] = idx - return input_ids, attention_mask + unk_token = "" + cls_token = "" + eos_token = "" + pad_token = "" - def encode_batch( - self, sequences: list[str] - ) -> tuple[list[list[int]], list[list[int]]]: - """Encode a batch of protein sequences with padding. + tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token=unk_token)) - Args: - sequences: List of protein sequences + tokenizer.pre_tokenizer = pre_tokenizers.Split( + pattern="", behavior="isolated", invert=False + ) - Returns: - Tuple of (input_ids, attention_masks) with padding - """ - all_input_ids = [] - all_attention_masks = [] - max_len = 0 + cls_token_id = vocab.get(cls_token, 0) + eos_token_id = vocab.get(eos_token, 2) - for seq in sequences: - input_ids, attention_mask = self.encode(seq) - all_input_ids.append(input_ids) - all_attention_masks.append(attention_mask) - max_len = max(max_len, len(input_ids)) + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls_token}:0 $A:0 {eos_token}:0", + special_tokens=[ + (cls_token, cls_token_id), + (eos_token, eos_token_id), + ], + ) - for i in range(len(all_input_ids)): - padding_length = max_len - len(all_input_ids[i]) - all_input_ids[i].extend([self.pad_token_id] * padding_length) - all_attention_masks[i].extend([0] * padding_length) + pad_token_id = vocab.get(pad_token, 1) + tokenizer.enable_padding(pad_id=pad_token_id, pad_token=pad_token) + tokenizer.enable_truncation(max_length=max_length) - return all_input_ids, all_attention_masks + return tokenizer class ProteinEmbeddingBase(ModelManagement[DenseModelDescription]): @@ -166,7 +175,7 @@ class OnnxProteinModel(OnnxModel[NumpyArray]): def __init__(self) -> None: super().__init__() - self.tokenizer: ProteinTokenizer | None = None + self.tokenizer: Tokenizer | None = None def _load_onnx_model( self, @@ -187,10 +196,7 @@ def _load_onnx_model( device_id=device_id, extra_session_options=extra_session_options, ) - vocab_path = model_dir / "vocab.txt" - if not vocab_path.exists(): - raise ValueError(f"Could not find vocab.txt in {model_dir}") - self.tokenizer = ProteinTokenizer(vocab_path) + self.tokenizer = load_protein_tokenizer(model_dir) def onnx_embed(self, sequences: list[str], **kwargs: Any) -> OnnxOutputContext: """ @@ -203,21 +209,24 @@ def onnx_embed(self, sequences: list[str], **kwargs: Any) -> OnnxOutputContext: """ assert self.tokenizer is not None - input_ids, attention_masks = self.tokenizer.encode_batch(sequences) + sequences = [seq.upper() for seq in sequences] + encoded = self.tokenizer.encode_batch(sequences) + input_ids = np.array([e.ids for e in encoded], dtype=np.int64) + attention_mask = np.array([e.attention_mask for e in encoded], dtype=np.int64) input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] onnx_input: dict[str, NumpyArray] = { - "input_ids": np.array(input_ids, dtype=np.int64), + "input_ids": input_ids, } if "attention_mask" in input_names: - onnx_input["attention_mask"] = np.array(attention_masks, dtype=np.int64) + onnx_input["attention_mask"] = attention_mask model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] return OnnxOutputContext( model_output=model_output[0], - attention_mask=np.array(attention_masks, dtype=np.int64), - input_ids=np.array(input_ids, dtype=np.int64), + attention_mask=attention_mask, + input_ids=input_ids, ) def _post_process_onnx_output( diff --git a/tests/test_protein_embeddings.py b/tests/test_protein_embeddings.py new file mode 100644 index 000000000..e638bed3d --- /dev/null +++ b/tests/test_protein_embeddings.py @@ -0,0 +1,147 @@ +import os + +import numpy as np +import pytest + +from fastembed.bio import ProteinEmbedding +from tests.utils import delete_model_cache + + +# Sample protein sequences for testing +SAMPLE_SEQUENCES = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + "GKGDPKKPRGKMSSYAFFVQTSREEHKKKHPDASVNFSEFSKKCSERWKTMSAKEKGKFEDMAKADKARYEREMKTY", +] + + +CANONICAL_VECTOR_VALUES = { + "facebook/esm2_t12_35M_UR50D": np.array( + [-0.0055, -0.0144, 0.0355, -0.0049, 0.0071] + ), +} + + +@pytest.fixture(scope="module") +def model_fixture(): + """ + Fixture that provides the protein embedding model and handles cleanup. + """ + is_ci = os.getenv("CI") + model = ProteinEmbedding("facebook/esm2_t12_35M_UR50D") + yield model + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_protein_embedding(model_fixture) -> None: + """Test basic protein embedding functionality.""" + model = model_fixture + dim = 480 # ESM2 t12 35M has 480 dimensions + + embeddings = list(model.embed(SAMPLE_SEQUENCES)) + embeddings = np.stack(embeddings, axis=0) + + assert embeddings.shape == (2, dim), f"Expected shape (2, {dim}), got {embeddings.shape}" + + # Check that embeddings are normalized (L2 norm close to 1) + norms = np.linalg.norm(embeddings, axis=1) + assert np.allclose(norms, 1.0, atol=1e-5), f"Embeddings should be normalized, got norms: {norms}" + + +def test_protein_embedding_canonical_values(model_fixture) -> None: + """ + Test that embeddings match expected canonical values. + """ + model = model_fixture + canonical_vector = CANONICAL_VECTOR_VALUES["facebook/esm2_t12_35M_UR50D"] + + embeddings = list(model.embed(SAMPLE_SEQUENCES[:1])) + embedding = embeddings[0] + + assert np.allclose( + embedding[: canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), f"First 5 values {embedding[:5]} don't match canonical {canonical_vector}" + + +def test_protein_embedding_single_sequence(model_fixture) -> None: + """ + Test embedding a single sequence passed as a string. + """ + model = model_fixture + dim = 480 + + # Single sequence as string + embedding = list(model.embed("MKTVRQERLKS")) + assert len(embedding) == 1 + assert embedding[0].shape == (dim,) + + +def test_protein_embedding_batch(model_fixture) -> None: + """ + Test batch embedding with different batch sizes. + """ + model = model_fixture + dim = 480 + + sequences = SAMPLE_SEQUENCES * 10 # 20 sequences + + # test with small batch size + embeddings_small_batch = list(model.embed(sequences, batch_size=4)) + embeddings_small_batch = np.stack(embeddings_small_batch, axis=0) + assert embeddings_small_batch.shape == (len(sequences), dim) + + # test with larger batch size + embeddings_large_batch = list(model.embed(sequences, batch_size=16)) + embeddings_large_batch = np.stack(embeddings_large_batch, axis=0) + assert embeddings_large_batch.shape == (len(sequences), dim) + + # results should be the same regardless of batch size + assert np.allclose(embeddings_small_batch, embeddings_large_batch, atol=1e-5) + + +def test_protein_embedding_size() -> None: + """ + Test get_embedding_size class method. + """ + assert ProteinEmbedding.get_embedding_size("facebook/esm2_t12_35M_UR50D") == 480 + + +def test_protein_embedding_lazy_load() -> None: + """ + Test lazy loading functionality. + """ + is_ci = os.getenv("CI") + + model = ProteinEmbedding("facebook/esm2_t12_35M_UR50D", lazy_load=True) + # model should not be loaded yet + assert not hasattr(model.model, "model") or model.model.model is None + + # after embedding, model should be loaded + list(model.embed(SAMPLE_SEQUENCES[:1])) + assert model.model.model is not None + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_list_supported_models() -> None: + """ + Test listing supported protein models. + """ + models = ProteinEmbedding.list_supported_models() + assert len(models) > 0 + assert any(m["model"] == "facebook/esm2_t12_35M_UR50D" for m in models) + + # check required fields + for model_info in models: + assert "model" in model_info + assert "dim" in model_info + assert "description" in model_info + + +def test_unsupported_model() -> None: + """ + Test that unsupported model raises ValueError. + """ + with pytest.raises(ValueError, match="not supported"): + ProteinEmbedding("nonexistent/model") \ No newline at end of file diff --git a/uv.lock b/uv.lock new file mode 100644 index 000000000..bda020730 --- /dev/null +++ b/uv.lock @@ -0,0 +1,3 @@ +version = 1 +revision = 3 +requires-python = ">=3.13"