diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py index f7664502f..3306e4cb6 100644 --- a/vectordb_bench/__init__.py +++ b/vectordb_bench/__init__.py @@ -54,6 +54,9 @@ class config: OPTIMIZE_TIMEOUT_1536D_500K = 15 * 60 # 15min OPTIMIZE_TIMEOUT_1536D_5M = 2.5 * 3600 # 2.5h + + CHURN_CYCLES_DEFAULT = 0 # Keeping this default to 0 as most clients do not support churn + CHURN_P_CHURN_DEFAULT = 10 def display(self) -> str: tmp = [ i for i in inspect.getmembers(self) diff --git a/vectordb_bench/backend/cases.py b/vectordb_bench/backend/cases.py index 6c43bb910..b93fe26cb 100644 --- a/vectordb_bench/backend/cases.py +++ b/vectordb_bench/backend/cases.py @@ -72,7 +72,7 @@ def case_description(self, custom_configs: dict | None = None) -> str: class CaseLabel(Enum): Load = auto() Performance = auto() - + Churn = auto() class Case(BaseModel): """Undefined case @@ -83,6 +83,8 @@ class Case(BaseModel): dataset(DataSet): dataset for this case runner. filter_rate(float | None): one of 99% | 1% | None filters(dict | None): filters for search + cycles(float | None): number of times to run churn cycles + p_churn(float | None): % of data to delete and reinsert """ case_id: CaseType @@ -95,6 +97,8 @@ class Case(BaseModel): optimize_timeout: float | int | None = None filter_rate: float | None = None + cycles: int | None = None + p_churn: float | int | None = None @property def filters(self) -> dict | None: diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index faa36712d..3e2590ec7 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -187,6 +187,22 @@ def search_embedding( """ raise NotImplementedError + @abstractmethod + def delete_embeddings( + self, + metadata: list[int], + **kwargs, + ) -> (int, Exception): + """Delete embeddings from the vector database based on metadata. + Args: + metadata (list[int]): List of metadata associated with the embeddings to delete. + **kwargs (Any): Vector database specific parameters. + Returns: + int: Number of deleted embeddings. + Exception: An exception if any error occurred during deletion. + """ + raise NotImplementedError + # TODO: remove @abstractmethod def optimize(self): diff --git a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py index a27eb01fc..51186c826 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +++ b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py @@ -1,7 +1,7 @@ import logging from contextlib import contextmanager import time -from typing import Iterable, Type +from typing import Any, Iterable, Optional, Tuple, Type from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine from opensearchpy import OpenSearch @@ -151,6 +151,13 @@ def search_embedding( except Exception as e: log.warning(f"Failed to search: {self.index_name} error: {str(e)}") raise e from None + + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass def optimize(self): """optimize will be called between insertion and search in performance cases.""" diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index 235cb595b..b1f575ba8 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -1,7 +1,7 @@ import chromadb import logging from contextlib import contextmanager -from typing import Any +from typing import Any, Optional, Tuple from ..api import VectorDB, DBCaseConfig log = logging.getLogger(__name__) @@ -63,6 +63,13 @@ def ready_to_load(self) -> bool: def optimize(self) -> None: pass + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def insert_embeddings( self, embeddings: list[list[float]], diff --git a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py index 64f27e490..0efb7cf52 100644 --- a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +++ b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py @@ -1,7 +1,7 @@ import logging import time from contextlib import contextmanager -from typing import Iterable +from typing import Any, Iterable, Optional, Tuple from ..api import VectorDB from .config import ElasticCloudIndexConfig from elasticsearch.helpers import bulk @@ -97,6 +97,13 @@ def insert_embeddings( log.warning(f"Failed to insert data: {self.indice} error: {str(e)}") return (0, e) + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/clients/memorydb/memorydb.py b/vectordb_bench/backend/clients/memorydb/memorydb.py index c5f80eb2a..6bb5ef4df 100644 --- a/vectordb_bench/backend/clients/memorydb/memorydb.py +++ b/vectordb_bench/backend/clients/memorydb/memorydb.py @@ -194,6 +194,13 @@ def insert_embeddings( return 0, e return result_len, None + + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass def _post_insert(self): """Wait for indexing to finish""" diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 4590265ae..7ea49f68e 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -3,7 +3,7 @@ import logging import time from contextlib import contextmanager -from typing import Iterable +from typing import Any, Iterable, Optional, Tuple from pymilvus import Collection, utility from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException @@ -196,6 +196,14 @@ def insert_embeddings( return (insert_count, e) return (insert_count, None) + + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py index bc042cc57..1e14d27fe 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py @@ -265,6 +265,13 @@ def insert_embeddings( ) return 0, e + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 8123acf18..9585b4c7c 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -396,6 +396,38 @@ def insert_embeddings( ) return 0, e + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + """Deletes embeddings from the pgvector table based on metadata (IDs). + Args: + metadata (list[int]): List of metadata (IDs) for the embeddings to delete. + **kwargs (Any): Additional vector database-specific parameters. + Returns: + int: Number of deleted embeddings. + Exception: An exception if an error occurs. + """ + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + # Construct SQL for deleting embeddings based on the metadata (IDs) + delete_sql = sql.SQL( + "DELETE FROM public.{table_name} WHERE id = ANY(%s)" + ).format(table_name=sql.Identifier(self.table_name)) + + # Execute the delete statement + self.cursor.execute(delete_sql, (metadata,)) + deleted_count = self.cursor.rowcount # Get the number of rows deleted + self.conn.commit() + + return deleted_count, None + except Exception as e: + log.warning(f"Failed to delete data from pgvector table ({self.table_name}), error: {e}") + return 0, e + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py index d8f26394c..98bf80631 100644 --- a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +++ b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py @@ -266,6 +266,13 @@ def insert_embeddings( ) return 0, e + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index c2653ee27..f3e2b43cc 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -2,7 +2,7 @@ import logging from contextlib import contextmanager -from typing import Type +from typing import Any, Optional, Tuple, Type from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType from .config import PineconeConfig @@ -95,6 +95,13 @@ def insert_embeddings( return (insert_count, e) return (len(embeddings), None) + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index a51632bc6..195de1880 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -3,6 +3,7 @@ import logging import time from contextlib import contextmanager +from typing import Any, Optional, Tuple from ..api import VectorDB, DBCaseConfig from qdrant_client.http.models import ( @@ -127,6 +128,13 @@ def insert_embeddings( else: return len(metadata), None + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/clients/redis/redis.py b/vectordb_bench/backend/clients/redis/redis.py index 8acf669d2..a2972002a 100644 --- a/vectordb_bench/backend/clients/redis/redis.py +++ b/vectordb_bench/backend/clients/redis/redis.py @@ -1,6 +1,6 @@ import logging from contextlib import contextmanager -from typing import Any, Type +from typing import Any, Optional, Tuple, Type from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType from .config import RedisConfig import redis @@ -123,6 +123,13 @@ def insert_embeddings( return 0, e return result_len, None + + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass def search_embedding( self, diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py index 4c8bd12da..e7f0bf136 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py @@ -1,7 +1,7 @@ """Wrapper around the Weaviate vector database over VectorDB""" import logging -from typing import Iterable +from typing import Any, Iterable, Optional, Tuple from contextlib import contextmanager import weaviate @@ -113,6 +113,13 @@ def insert_embeddings( log.warning(f"Failed to insert data, error: {str(e)}") return (insert_count, e) + def delete_embeddings( + self, + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + pass + def search_embedding( self, query: list[float], diff --git a/vectordb_bench/backend/runner/__init__.py b/vectordb_bench/backend/runner/__init__.py index 77bb25d67..9c602f12c 100644 --- a/vectordb_bench/backend/runner/__init__.py +++ b/vectordb_bench/backend/runner/__init__.py @@ -2,11 +2,16 @@ MultiProcessingSearchRunner, ) -from .serial_runner import SerialSearchRunner, SerialInsertRunner +from .serial_runner import ( + SerialSearchRunner, + SerialInsertRunner, + SerialChurnRunner, +) __all__ = [ 'MultiProcessingSearchRunner', 'SerialSearchRunner', 'SerialInsertRunner', + 'SerialChurnRunner', ] diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 9e6818443..52846a574 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -1,4 +1,5 @@ import time +import random import logging import traceback import concurrent @@ -226,3 +227,168 @@ def _run_in_subprocess(self) -> tuple[float, float]: def run(self) -> tuple[float, float]: return self._run_in_subprocess() + +class SerialChurnRunner: + def __init__(self, db: api.VectorDB, dataset: DatasetManager, test_data: list[list[float]], ground_truth: pd.DataFrame, + p_churn: float, cycles: int, normalize: bool = False, k: int = 100, timeout: float | None = None): + self.db = db + self.dataset = dataset + self.p_churn = p_churn / 100 + self.cycles = cycles + self.test_data = test_data + self.ground_truth = ground_truth + self.k = k + self.normalize = normalize + self.timeout = timeout if isinstance(timeout, (int, float)) else None + + def run_churn_cycle(self) -> list[dict]: + """Runs multiple churn cycles where embeddings are deleted and reinserted.""" + results = [] + total_embeddings = self.dataset.data.size # Use the size property from BaseDataset + churn_size = int(total_embeddings * self.p_churn) + + log.info(f"Starting churn process with total embeddings: {total_embeddings}, churn size: {churn_size}") + + # Initialize the database connection once + # Calculate recall before the first deletion/insertion cycle + log.info("Calculating initial metrics (recall, NDCG, p99 latency) before churn.") + initial_recall, initial_ndcg, initial_p99 = self._calculate_metrics() + results.append({ + 'cycle': 0, # Pre-churn cycle + 'recall': initial_recall, + 'ndcg': initial_ndcg, + 'p99': initial_p99 + }) + log.info(f"Initial metrics calculated: recall={initial_recall}, NDCG={initial_ndcg}, p99 latency={initial_p99}") + + # Perform the delete/insert churn for the defined number of cycles + for cycle in range(1, self.cycles + 1): + with self.db.init(): + log.info(f"Starting cycle {cycle}/{self.cycles}.") + + # Randomly select embeddings to delete + log.info(f"Selecting {churn_size} embeddings to delete for cycle {cycle}.") + deleted_embeddings, deleted_metadata = self._select_random_embeddings() + + # Delete selected embeddings in batches of 500 + log.info(f"Deleting {len(deleted_metadata)} embeddings in batches of 500 in cycle {cycle}.") + batch_size = 500 + deleted_count = 0 + for i in range(0, len(deleted_metadata), batch_size): + batch_metadata = deleted_metadata[i:i + batch_size] + count, delete_error = self.db.delete_embeddings(batch_metadata) + if delete_error: + log.error(f"Failed to delete embeddings in batch {i // batch_size + 1} of cycle {cycle}, error: {delete_error}") + break + else: + deleted_count += count + log.info(f"Successfully deleted batch {i // batch_size + 1} of {len(deleted_metadata) // batch_size + 1} in cycle {cycle}.") + + if deleted_count == len(deleted_metadata): + log.info(f"Successfully deleted all {deleted_count} embeddings in cycle {cycle}.") + else: + log.warning(f"Only {deleted_count} out of {len(deleted_metadata)} embeddings were deleted in cycle {cycle}.") + + # Re-insert deleted embeddings in batches of 500 + log.info(f"Re-inserting {len(deleted_embeddings)} embeddings in batches of 500 in cycle {cycle}.") + inserted_count = 0 + for i in range(0, len(deleted_embeddings), batch_size): + batch_embeddings = deleted_embeddings[i:i + batch_size] + batch_metadata = deleted_metadata[i:i + batch_size] + count, insert_error = self.db.insert_embeddings(batch_embeddings, batch_metadata) + if insert_error: + log.error(f"Failed to insert embeddings in batch {i // batch_size + 1} of cycle {cycle}, error: {insert_error}") + break + else: + inserted_count += count + log.info(f"Successfully inserted batch {i // batch_size + 1} of {len(deleted_embeddings) // batch_size + 1} in cycle {cycle}.") + + if inserted_count == len(deleted_embeddings): + log.info(f"Successfully inserted all {inserted_count} embeddings in cycle {cycle}.") + else: + log.warning(f"Only {inserted_count} out of {len(deleted_embeddings)} embeddings were inserted in cycle {cycle}.") + + # Perform a search to calculate metrics + log.info(f"Calculating metrics (recall, NDCG, p99 latency) after re-insertion in cycle {cycle}.") + recall, ndcg, p99 = self._calculate_metrics() + + + # Store results for the cycle + results.append({ + 'cycle': cycle, + 'recall': recall, + 'ndcg': ndcg, + 'p99': p99 + }) + log.info(f"Cycle {cycle} completed: recall={recall}, NDCG={ndcg}, p99 latency={p99}") + + log.info("Churn process completed.") + return results + + + def _select_random_embeddings(self) -> tuple[list[list[float]], list[int]]: + """Selects random embeddings and metadata for deletion based on self.p_churn.""" + selected_embeddings = [] + selected_metadata = [] + + # Calculate the total number of embeddings in the dataset + total_embeddings = self.dataset.data.size + churn_size = int(total_embeddings * self.p_churn) # Calculate churn size based on self.p_churn + + # Fetch embeddings from the dataset in a memory-efficient way + current_size = 0 + for data_df in self.dataset: + if current_size >= churn_size: + break + + all_metadata = data_df['id'].tolist() + emb_np = np.stack(data_df['emb']) + + # Normalize if necessary + if self.normalize: + all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist() + else: + all_embeddings = emb_np.tolist() + del emb_np + + # Calculate how many embeddings to take from this batch based on self.p_churn + embeddings_in_batch = len(all_metadata) + embeddings_to_take = int(embeddings_in_batch * self.p_churn) # Proportional selection + + # Randomly shuffle and select embeddings from this batch + combined = list(zip(all_embeddings, all_metadata)) + random.shuffle(combined) + + # Select the calculated number of embeddings, ensuring we don't exceed churn_size + embeddings_to_take = min(embeddings_to_take, churn_size - current_size) + selected_embeddings.extend([x[0] for x in combined[:embeddings_to_take]]) + selected_metadata.extend([x[1] for x in combined[:embeddings_to_take]]) + current_size += embeddings_to_take + + # Stop if we've selected enough embeddings + if current_size >= churn_size: + break + + log.info(f"Selected {len(selected_embeddings)} embeddings out of {total_embeddings} total embeddings, with a churn size of {churn_size}.") + + return selected_embeddings, selected_metadata + + def _calculate_metrics(self) -> tuple[float, float, float]: + """Calculates recall, NDCG, and latency metrics.""" + search_runner = SerialSearchRunner(self.db, self.test_data, self.ground_truth, self.k) + return search_runner.run() + + def run(self): + """ + Runs the churn process over multiple cycles. For each cycle, embeddings are deleted and then reinserted, + and metrics such as recall, NDCG, and p99 latency are calculated. + Returns: + list[dict]: A list of dictionaries, where each dictionary contains the following keys: + - 'cycle' (int): The cycle number (0 for initial recall before churn, 1+ for churn cycles). + - 'recall' (float): The average recall of the search queries after each cycle. + - 'ndcg' (float): The average NDCG (Normalized Discounted Cumulative Gain) after each cycle. + - 'p99' (float): The 99th percentile of search latency (in seconds) after each cycle. + """ + churn_results = self.run_churn_cycle() + log.info("Churn process completed") + return churn_results diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index a6d94f186..64a83b5bb 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -16,7 +16,8 @@ ) from ..metric import Metric from .runner import MultiProcessingSearchRunner -from .runner import SerialSearchRunner, SerialInsertRunner +from .runner import SerialSearchRunner, SerialInsertRunner, SerialChurnRunner + from .data_source import DatasetSource @@ -51,6 +52,7 @@ class CaseRunner(BaseModel): test_emb: list[list[float]] | None = None serial_search_runner: SerialSearchRunner | None = None search_runner: MultiProcessingSearchRunner | None = None + churn_runner: SerialChurnRunner | None = None final_search_runner: MultiProcessingSearchRunner | None = None def __eq__(self, obj): @@ -174,6 +176,7 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: if ( TaskStage.SEARCH_SERIAL in self.config.stages or TaskStage.SEARCH_CONCURRENT in self.config.stages + or TaskStage.CHURN in self.config.stages ): self._init_search_runner() if TaskStage.SEARCH_SERIAL in self.config.stages: @@ -186,7 +189,10 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: if TaskStage.SEARCH_CONCURRENT in self.config.stages: search_results = self._conc_search() m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = search_results - + if (TaskStage.CHURN in self.config.stages): + churn_results = self._churn_search() + m.churn_results = churn_results + except Exception as e: log.warning(f"Failed to run performance case, reason = {e}") traceback.print_exc() @@ -235,6 +241,25 @@ def _conc_search(self): finally: self.stop() + def _churn_search(self): + """ + Runs the churn process over multiple cycles. For each cycle, embeddings are deleted and then reinserted, + and metrics such as recall, NDCG, and p99 latency are calculated for each cycle. + Returns: + list[dict]: A list of dictionaries, where each dictionary contains the following keys: + - 'cycle' (int): The cycle number (0 for initial recall before churn, 1+ for churn cycles). + - 'recall' (float): The average recall of the search queries after each cycle. + - 'ndcg' (float): The average NDCG (Normalized Discounted Cumulative Gain) after each cycle. + - 'p99' (float): The 99th percentile of search latency (in seconds) after each cycle. + """ + try: + return self.churn_runner.run() + except Exception as e: + log.warning(f"search error: {str(e)}, {e}") + raise e from None + finally: + self.stop() + @utils.time_it def _task(self) -> None: with self.db.init(): @@ -279,6 +304,17 @@ def _init_search_runner(self): duration=self.config.case_config.concurrency_search_config.concurrency_duration, k=self.config.case_config.k, ) + if TaskStage.CHURN in self.config.stages: + self.churn_runner = SerialChurnRunner( + db=self.db, + dataset=self.ca.dataset, + test_data=self.test_emb, + ground_truth=gt_df, + p_churn=self.config.case_config.churn_search_config.p_churn, + cycles=self.config.case_config.churn_search_config.cycles, + normalize=self.normalize, + k=self.config.case_config.k, + ) def stop(self): if self.search_runner: diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index edce758e1..4c3307b48 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -25,6 +25,7 @@ from ..models import ( CaseConfig, CaseType, + ChurnSearchConfig, ConcurrencySearchConfig, DBCaseConfig, DBConfig, @@ -132,6 +133,7 @@ def parse_task_stages( load: bool, search_serial: bool, search_concurrent: bool, + search_churn: bool ) -> List[TaskStage]: stages = [] if load and not drop_old: @@ -146,6 +148,8 @@ def parse_task_stages( stages.append(TaskStage.SEARCH_SERIAL) if search_concurrent: stages.append(TaskStage.SEARCH_CONCURRENT) + if search_churn: + stages.append(TaskStage.CHURN) return stages @@ -233,6 +237,16 @@ class CommonTypedDict(TypedDict): show_default=True, ), ] + search_churn: Annotated[ + bool, + click.option( + "--search-churn/--skip-search-churn", + type=bool, + default=False, + help="Test index churn or skip", + show_default=True, + ), + ] case_type: Annotated[ str, click.option( @@ -394,6 +408,24 @@ class CommonTypedDict(TypedDict): show_default=True, ), ] + p_churn: Annotated[ + float, + click.option( + "--p-churn", + help="Percentage churn. From the original dataset, delete and re-insert this % of data", + default=10.0, # Default to 10% churn + show_default=True, + ), + ] + cycles: Annotated[ + int, + click.option( + "--cycles", + help="Number of churn cycles to perform", + default=1, # Default to 1 cycle + show_default=True, + ), + ] class HNSWBaseTypedDict(TypedDict): @@ -479,6 +511,10 @@ def run( concurrency_duration=parameters["concurrency_duration"], num_concurrency=[int(s) for s in parameters["num_concurrency"]], ), + churn_search_config=ChurnSearchConfig( + p_churn=parameters["p_churn"], + cycles=parameters["cycles"], + ), custom_case=parameters.get("custom_case", {}), ), stages=parse_task_stages( @@ -488,6 +524,7 @@ def run( parameters["load"], parameters["search_serial"], parameters["search_concurrent"], + parameters["search_churn"], ), ) diff --git a/vectordb_bench/metric.py b/vectordb_bench/metric.py index 5c23072e3..27112782f 100644 --- a/vectordb_bench/metric.py +++ b/vectordb_bench/metric.py @@ -23,6 +23,7 @@ class Metric: conc_num_list: list[int] = field(default_factory=list) conc_qps_list: list[float] = field(default_factory=list) conc_latency_p99_list: list[float] = field(default_factory=list) + churn_results: list[dict] = field(default_factory=list) QURIES_PER_DOLLAR_METRIC = "QP$ (Quries per Dollar)" diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 7968e3e26..2cbbb21ee 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -80,6 +80,10 @@ class ConcurrencySearchConfig(BaseModel): num_concurrency: List[int] = config.NUM_CONCURRENCY concurrency_duration: int = config.CONCURRENCY_DURATION +class ChurnSearchConfig(BaseModel): + p_churn: float = config.CHURN_P_CHURN_DEFAULT + cycles: int = config.CHURN_CYCLES_DEFAULT + class CaseConfig(BaseModel): """cases, dataset, test cases, filter rate, params""" @@ -88,6 +92,7 @@ class CaseConfig(BaseModel): custom_case: dict | None = None k: int | None = config.K_DEFAULT concurrency_search_config: ConcurrencySearchConfig = ConcurrencySearchConfig() + churn_search_config: ChurnSearchConfig = ChurnSearchConfig() ''' @property @@ -112,6 +117,7 @@ class TaskStage(StrEnum): LOAD = auto() SEARCH_SERIAL = auto() SEARCH_CONCURRENT = auto() + CHURN = auto() def __repr__(self) -> str: return str.__repr__(self.value) @@ -123,6 +129,7 @@ def __repr__(self) -> str: TaskStage.LOAD, TaskStage.SEARCH_SERIAL, TaskStage.SEARCH_CONCURRENT, + TaskStage.CHURN ]