diff --git a/src/omop_emb/backends/base_backend.py b/src/omop_emb/backends/base_backend.py index d72ca5e..d761048 100644 --- a/src/omop_emb/backends/base_backend.py +++ b/src/omop_emb/backends/base_backend.py @@ -586,6 +586,7 @@ def bulk_upsert_embeddings( model_name: str, metric_type: MetricType, batches: Iterable[Tuple[Sequence[ConceptEmbeddingRecord], ndarray]], + total_n_batches: Optional[int] = None, ) -> None: """Upsert embeddings in multiple batches, delegating to ``upsert_embeddings``. @@ -597,8 +598,13 @@ def bulk_upsert_embeddings( Validated once per batch via ``upsert_embeddings``. batches : Iterable[tuple[Sequence[ConceptEmbeddingRecord], ndarray]] Iterable of ``(records, embeddings)`` pairs. + total_n_batches : Optional[int] + Total number of batches for the progress bar. """ - for records, embeddings in batches: + import tqdm + + pbar = tqdm.tqdm(batches, desc=f"Upserting embeddings into {model_name} ({metric_type.value})", total=total_n_batches) + for records, embeddings in pbar: self.upsert_embeddings( model_name=model_name, metric_type=metric_type, diff --git a/src/omop_emb/backends/index_config.py b/src/omop_emb/backends/index_config.py index f2e53f6..127fa88 100644 --- a/src/omop_emb/backends/index_config.py +++ b/src/omop_emb/backends/index_config.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Any, Callable, Mapping, Optional, Self, cast, get_type_hints -from omop_emb.config import IndexType, MetricType +from omop_emb.config import IndexType, MetricType, parse_index_type # ---------------------------------------------------------------------------- # Resevered metadata keys @@ -245,6 +245,7 @@ def index_config_from_index_type(index_type: IndexType, **kwargs: Any) -> IndexC ValueError If ``index_type`` has no registered ``IndexConfig`` subclass. """ + index_type = parse_index_type(index_type) if index_type == IndexType.FLAT: return FlatIndexConfig() if index_type == IndexType.HNSW: diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index ee1bd06..66f7aee 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -556,12 +556,14 @@ def upsert_concept_embeddings( def bulk_upsert_concept_embeddings( self, batches: Iterable[Tuple[Sequence[ConceptEmbeddingRecord], ndarray]], + total_n_batches: Optional[int] = None, ) -> None: """Upsert from a lazy ``(records, embeddings)`` iterable.""" self._backend.bulk_upsert_embeddings( model_name=self.canonical_model_name, metric_type=self._metric_type, batches=batches, + total_n_batches=total_n_batches, ) def embed_and_upsert_concepts( diff --git a/src/omop_emb/storage/faiss/faiss_cache.py b/src/omop_emb/storage/faiss/faiss_cache.py index bd55dcb..0ac1b94 100644 --- a/src/omop_emb/storage/faiss/faiss_cache.py +++ b/src/omop_emb/storage/faiss/faiss_cache.py @@ -56,7 +56,7 @@ from omop_emb.config import MetricType, IndexType, ProviderType from omop_emb.utils.embedding_utils import EmbeddingConceptFilter, NearestConceptMatch -from omop_emb.backends.index_config import IndexConfig +from omop_emb.backends.index_config import IndexConfig, index_config_from_index_type from omop_emb.backends.base_backend import EmbeddingBackend from omop_emb.model_registry.model_registry_types import EmbeddingModelRecord @@ -191,11 +191,14 @@ def from_json(cls, text: str) -> "CacheMetadata": If the JSON is malformed or contains an unknown enum value. """ d = json.loads(text) + if not "index_config" in d: + raise ValueError("Missing 'index_config' field in cache metadata JSON.") + return cls( model_name=d.get("model_name", ""), dimensions=int(d.get("dimensions", 0)), metric_type=MetricType(d["metric_type"]), - index_config=IndexConfig.from_dict(d.get("index_config", {})), + index_config=index_config_from_index_type(**d["index_config"]), row_count=int(d.get("row_count", -1)), exported_at=d.get("exported_at", ""), model_updated_at=d.get("model_updated_at"), @@ -834,6 +837,7 @@ def _batches(): model_name=self._model_name, metric_type=metric_type, batches=_batches(), + total_n_batches=n//batch_size + (1 if n % batch_size else 0) ) logger.info( "Imported %d vectors for '%s' (metric=%s) into backend.",