Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/omop_emb/backends/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def bulk_upsert_embeddings(
model_name: str,
metric_type: MetricType,
batches: Iterable[Tuple[Sequence[ConceptEmbeddingRecord], ndarray]],
total_n_batches: Optional[int] = None,
) -> None:
"""Upsert embeddings in multiple batches, delegating to ``upsert_embeddings``.

Expand All @@ -597,8 +598,13 @@ def bulk_upsert_embeddings(
Validated once per batch via ``upsert_embeddings``.
batches : Iterable[tuple[Sequence[ConceptEmbeddingRecord], ndarray]]
Iterable of ``(records, embeddings)`` pairs.
total_n_batches : Optional[int]
Total number of batches for the progress bar.
"""
for records, embeddings in batches:
import tqdm

pbar = tqdm.tqdm(batches, desc=f"Upserting embeddings into {model_name} ({metric_type.value})", total=total_n_batches)
for records, embeddings in pbar:
self.upsert_embeddings(
model_name=model_name,
metric_type=metric_type,
Expand Down
3 changes: 2 additions & 1 deletion src/omop_emb/backends/index_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Any, Callable, Mapping, Optional, Self, cast, get_type_hints

from omop_emb.config import IndexType, MetricType
from omop_emb.config import IndexType, MetricType, parse_index_type

# ----------------------------------------------------------------------------
# Resevered metadata keys
Expand Down Expand Up @@ -245,6 +245,7 @@ def index_config_from_index_type(index_type: IndexType, **kwargs: Any) -> IndexC
ValueError
If ``index_type`` has no registered ``IndexConfig`` subclass.
"""
index_type = parse_index_type(index_type)
if index_type == IndexType.FLAT:
return FlatIndexConfig()
if index_type == IndexType.HNSW:
Expand Down
2 changes: 2 additions & 0 deletions src/omop_emb/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,14 @@ def upsert_concept_embeddings(
def bulk_upsert_concept_embeddings(
self,
batches: Iterable[Tuple[Sequence[ConceptEmbeddingRecord], ndarray]],
total_n_batches: Optional[int] = None,
) -> None:
"""Upsert from a lazy ``(records, embeddings)`` iterable."""
self._backend.bulk_upsert_embeddings(
model_name=self.canonical_model_name,
metric_type=self._metric_type,
batches=batches,
total_n_batches=total_n_batches,
)

def embed_and_upsert_concepts(
Expand Down
8 changes: 6 additions & 2 deletions src/omop_emb/storage/faiss/faiss_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

from omop_emb.config import MetricType, IndexType, ProviderType
from omop_emb.utils.embedding_utils import EmbeddingConceptFilter, NearestConceptMatch
from omop_emb.backends.index_config import IndexConfig
from omop_emb.backends.index_config import IndexConfig, index_config_from_index_type
from omop_emb.backends.base_backend import EmbeddingBackend
from omop_emb.model_registry.model_registry_types import EmbeddingModelRecord

Expand Down Expand Up @@ -191,11 +191,14 @@ def from_json(cls, text: str) -> "CacheMetadata":
If the JSON is malformed or contains an unknown enum value.
"""
d = json.loads(text)
if not "index_config" in d:
raise ValueError("Missing 'index_config' field in cache metadata JSON.")

return cls(
model_name=d.get("model_name", ""),
dimensions=int(d.get("dimensions", 0)),
metric_type=MetricType(d["metric_type"]),
index_config=IndexConfig.from_dict(d.get("index_config", {})),
index_config=index_config_from_index_type(**d["index_config"]),
row_count=int(d.get("row_count", -1)),
exported_at=d.get("exported_at", ""),
model_updated_at=d.get("model_updated_at"),
Expand Down Expand Up @@ -834,6 +837,7 @@ def _batches():
model_name=self._model_name,
metric_type=metric_type,
batches=_batches(),
total_n_batches=n//batch_size + (1 if n % batch_size else 0)
)
logger.info(
"Imported %d vectors for '%s' (metric=%s) into backend.",
Expand Down
Loading