diff --git a/.gitignore b/.gitignore index f8602c11a..aa3a72f44 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,7 @@ __MACOSX .DS_Store build/ venv/ +.venv/ .idea/ -results/ \ No newline at end of file +results/ +logs/ \ No newline at end of file diff --git a/README.md b/README.md index 1bf95dc39..3d38d9444 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ Closely mimicking real-world production environments, we've set up diverse testi Prepare to delve into the world of VectorDBBench, and let it guide you in uncovering your perfect vector database match. +VectorDBBench is sponsered by Zilliz,the leading opensource vectorDB company behind Milvus. Choose smarter with VectorDBBench- start your free test on [zilliz cloud](https://zilliz.com/) today! + **Leaderboard:** https://zilliz.com/benchmark ## Quick Start ### Prerequirement @@ -53,6 +55,8 @@ All the database client supported | awsopensearch | `pip install vectordb-bench[opensearch]` | | aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` | | mongodb | `pip install vectordb-bench[mongodb]` | +| tidb | `pip install vectordb-bench[tidb]` | +| vespa | `pip install vectordb-bench[vespa]` | ### Run @@ -110,6 +114,10 @@ Options: --num-concurrency TEXT Comma-separated list of concurrency values to test during concurrent search [default: 1,10,20] + --concurrency-timeout INTEGER Timeout (in seconds) to wait for a + concurrency slot before failing. Set to a + negative value to wait indefinitely. + [default: 3600] --user-name TEXT Db username [required] --password TEXT Db password [required] --host TEXT Db host [required] @@ -129,7 +137,11 @@ Options: --ef-construction INTEGER hnsw ef-construction --ef-search INTEGER hnsw ef-search --quantization-type [none|bit|halfvec] - quantization type for vectors + quantization type for vectors (in index) + --table-quantization-type [none|bit|halfvec] + quantization type for vectors (in table). If + equal to bit, the parameter + quantization_type will be set to bit too. --custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K --custom-case-description TEXT Custom name description --custom-case-load-timeout INTEGER @@ -153,6 +165,48 @@ Options: with-gt] --help Show this message and exit. ``` + +### Run awsopensearch from command line + +```shell +vectordbbench awsopensearch --db-label awsopensearch \ +--m 16 --ef-construction 256 \ +--host search-vector-db-prod-h4f6m4of6x7yp2rz7gdmots7w4.us-west-2.es.amazonaws.com --port 443 \ +--user vector --password '' \ +--case-type Performance1536D5M --num-insert-workers 10 \ +--skip-load --num-concurrency 75 +``` + +To list the options for awsopensearch, execute `vectordbbench awsopensearch --help` + +```text +$ vectordbbench awsopensearch --help +Usage: vectordbbench awsopensearch [OPTIONS] + +Options: + # Sharding and Replication + --number-of-shards INTEGER Number of primary shards for the index + --number-of-replicas INTEGER Number of replica copies for each primary + shard + # Indexing Performance + --index-thread-qty INTEGER Thread count for native engine indexing + --index-thread-qty-during-force-merge INTEGER + Thread count during force merge operations + --number-of-indexing-clients INTEGER + Number of concurrent indexing clients + # Index Management + --number-of-segments INTEGER Target number of segments after merging + --refresh-interval TEXT How often to make new data available for + search + --force-merge-enabled BOOLEAN Whether to perform force merge operation + --flush-threshold-size TEXT Size threshold for flushing the transaction + log + # Memory Management + --cb-threshold TEXT k-NN Memory circuit breaker threshold + + --help Show this message and exit. + ``` + #### Using a configuration file. The vectordbbench command can optionally read some or all the options from a yaml formatted configuration file. @@ -218,13 +272,13 @@ pip install -e '.[pinecone]' ``` ### Run test server ``` -$ python -m vectordb_bench +python -m vectordb_bench ``` OR: ```shell -$ init_bench +init_bench ``` OR: @@ -241,13 +295,13 @@ After reopen the repository in container, run `python -m vectordb_bench` in the ### Check coding styles ```shell -$ make lint +make lint ``` To fix the coding styles automatically ```shell -$ make format +make format ``` ## How does it work? @@ -319,6 +373,13 @@ We have strict requirements for the data set format, please follow them. - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format. - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. + - We recommend limiting the number of test query vectors, like 1,000. + When conducting concurrent query tests, Vdbbench creates a large number of processes. + To minimize additional communication overhead during testing, + we prepare a complete set of test queries for each process, allowing them to run independently. + However, this means that as the number of concurrent processes increases, + the number of copied query vectors also increases significantly, + which can place substantial pressure on memory resources. - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files. diff --git a/install.py b/install.py index f683a37b2..5807485fd 100644 --- a/install.py +++ b/install.py @@ -1,7 +1,8 @@ -import os import argparse +import os import subprocess + def docker_tag_base(): return 'vdbbench' diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index c3a3bbbda..86958ada2 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -1,4 +1,4 @@ -grpcio==1.53.0 +grpcio==1.53.2 grpcio-tools==1.53.0 qdrant-client pinecone-client @@ -22,3 +22,5 @@ environs pydantic type[VectorDB]: # noqa: PLR0911, PLR0912, C901 + def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 """Import while in use""" if self == DB.Milvus: from .milvus.milvus import Milvus @@ -115,6 +120,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return AWSOpenSearch + if self == DB.Clickhouse: + from .clickhouse.clickhouse import Clickhouse + + return Clickhouse + if self == DB.AlloyDB: from .alloydb.alloydb import AlloyDB @@ -135,16 +145,36 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return MongoDB + if self == DB.MariaDB: + from .mariadb.mariadb import MariaDB + + return MariaDB + + if self == DB.TiDB: + from .tidb.tidb import TiDB + + return TiDB + if self == DB.Test: from .test.test import Test return Test + if self == DB.Vespa: + from .vespa.vespa import Vespa + + return Vespa + + if self == DB.LanceDB: + from .lancedb.lancedb import LanceDB + + return LanceDB + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @property - def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 + def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 """Import while in use""" if self == DB.Milvus: from .milvus.config import MilvusConfig @@ -216,6 +246,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 return AWSOpenSearchConfig + if self == DB.Clickhouse: + from .clickhouse.config import ClickhouseConfig + + return ClickhouseConfig + if self == DB.AlloyDB: from .alloydb.config import AlloyDBConfig @@ -236,15 +271,35 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 return MongoDBConfig + if self == DB.MariaDB: + from .mariadb.config import MariaDBConfig + + return MariaDBConfig + + if self == DB.TiDB: + from .tidb.config import TiDBConfig + + return TiDBConfig + if self == DB.Test: from .test.config import TestConfig return TestConfig + if self == DB.Vespa: + from .vespa.config import VespaConfig + + return VespaConfig + + if self == DB.LanceDB: + from .lancedb.config import LanceDBConfig + + return LanceDBConfig + msg = f"Unknown DB: {self.name}" raise ValueError(msg) - def case_config_cls( # noqa: PLR0911 + def case_config_cls( # noqa: C901, PLR0911, PLR0912 self, index_type: IndexType | None = None, ) -> type[DBCaseConfig]: @@ -288,6 +343,11 @@ def case_config_cls( # noqa: PLR0911 return AWSOpenSearchIndexConfig + if self == DB.Clickhouse: + from .clickhouse.config import ClickhouseHNSWConfig + + return ClickhouseHNSWConfig + if self == DB.PgVectorScale: from .pgvectorscale.config import _pgvectorscale_case_config @@ -318,6 +378,26 @@ def case_config_cls( # noqa: PLR0911 return MongoDBIndexConfig + if self == DB.MariaDB: + from .mariadb.config import _mariadb_case_config + + return _mariadb_case_config.get(index_type) + + if self == DB.TiDB: + from .tidb.config import TiDBIndexConfig + + return TiDBIndexConfig + + if self == DB.Vespa: + from .vespa.config import VespaHNSWConfig + + return VespaHNSWConfig + + if self == DB.LanceDB: + from .lancedb.config import _lancedb_case_config + + return _lancedb_case_config.get(index_type) + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index a86849e96..ff7b378a7 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -16,18 +16,33 @@ class MetricType(str, Enum): class IndexType(str, Enum): HNSW = "HNSW" + HNSW_SQ = "HNSW_SQ" + HNSW_PQ = "HNSW_PQ" + HNSW_PRQ = "HNSW_PRQ" DISKANN = "DISKANN" STREAMING_DISKANN = "DISKANN" IVFFlat = "IVF_FLAT" + IVFPQ = "IVF_PQ" IVFSQ8 = "IVF_SQ8" + IVF_RABITQ = "IVF_RABITQ" Flat = "FLAT" AUTOINDEX = "AUTOINDEX" ES_HNSW = "hnsw" ES_IVFFlat = "ivfflat" GPU_IVF_FLAT = "GPU_IVF_FLAT" + GPU_BRUTE_FORCE = "GPU_BRUTE_FORCE" GPU_IVF_PQ = "GPU_IVF_PQ" GPU_CAGRA = "GPU_CAGRA" SCANN = "scann" + NONE = "NONE" + + +class SQType(str, Enum): + SQ6 = "SQ6" + SQ8 = "SQ8" + BF16 = "BF16" + FP16 = "FP16" + FP32 = "FP32" class DBConfig(ABC, BaseModel): @@ -161,7 +176,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert the embeddings to the vector database. The default number of embeddings for each insert_embeddings is 5000. diff --git a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py index 234014f19..adb766300 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +++ b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py @@ -12,6 +12,7 @@ WAITING_FOR_REFRESH_SEC = 30 WAITING_FOR_FORCE_MERGE_SEC = 30 +SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC = 30 class AWSOpenSearch(VectorDB): @@ -52,10 +53,27 @@ def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIn return AWSOpenSearchIndexConfig def _create_index(self, client: OpenSearch): + cluster_settings_body = { + "persistent": { + "knn.algo_param.index_thread_qty": self.case_config.index_thread_qty, + "knn.memory.circuit_breaker.limit": self.case_config.cb_threshold, + } + } + client.cluster.put_settings(cluster_settings_body) settings = { "index": { "knn": True, + "number_of_shards": self.case_config.number_of_shards, + "number_of_replicas": 0, + "translog.flush_threshold_size": self.case_config.flush_threshold_size, + # Setting trans log threshold to 5GB + **( + {"knn.algo_param.ef_search": self.case_config.ef_search} + if self.case_config.engine == AWSOS_Engine.nmslib + else {} + ), }, + "refresh_interval": self.case_config.refresh_interval, } mappings = { "properties": { @@ -145,9 +163,9 @@ def search_embedding( docvalue_fields=[self.id_col_name], stored_fields="_none_", ) - log.info(f"Search took: {resp['took']}") - log.info(f"Search shards: {resp['_shards']}") - log.info(f"Search hits total: {resp['hits']['total']}") + log.debug(f"Search took: {resp['took']}") + log.debug(f"Search shards: {resp['_shards']}") + log.debug(f"Search hits total: {resp['hits']['total']}") return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]] except Exception as e: log.warning(f"Failed to search: {self.index_name} error: {e!s}") @@ -157,12 +175,37 @@ def optimize(self, data_size: int | None = None): """optimize will be called between insertion and search in performance cases.""" # Call refresh first to ensure that all segments are created self._refresh_index() - self._do_force_merge() + if self.case_config.force_merge_enabled: + self._do_force_merge() + self._refresh_index() + self._update_replicas() # Call refresh again to ensure that the index is ready after force merge. self._refresh_index() # ensure that all graphs are loaded in memory and ready for search self._load_graphs_to_memory() + def _update_replicas(self): + index_settings = self.client.indices.get_settings(index=self.index_name) + current_number_of_replicas = int(index_settings[self.index_name]["settings"]["index"]["number_of_replicas"]) + log.info( + f"Current Number of replicas are {current_number_of_replicas}" + f" and changing the replicas to {self.case_config.number_of_replicas}" + ) + settings_body = {"index": {"number_of_replicas": self.case_config.number_of_replicas}} + self.client.indices.put_settings(index=self.index_name, body=settings_body) + self._wait_till_green() + + def _wait_till_green(self): + log.info("Wait for index to become green..") + while True: + res = self.client.cat.indices(index=self.index_name, h="health", format="json") + health = res[0]["health"] + if health != "green": + break + log.info(f"The index {self.index_name} has health : {health} and is not green. Retrying") + time.sleep(SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC) + log.info(f"Index {self.index_name} is green..") + def _refresh_index(self): log.debug(f"Starting refresh for index {self.index_name}") while True: @@ -179,6 +222,12 @@ def _refresh_index(self): log.debug(f"Completed refresh for index {self.index_name}") def _do_force_merge(self): + log.info(f"Updating the Index thread qty to {self.case_config.index_thread_qty_during_force_merge}.") + + cluster_settings_body = { + "persistent": {"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty_during_force_merge} + } + self.client.cluster.put_settings(cluster_settings_body) log.debug(f"Starting force merge for index {self.index_name}") force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false" force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"] diff --git a/vectordb_bench/backend/clients/aws_opensearch/cli.py b/vectordb_bench/backend/clients/aws_opensearch/cli.py index bb0c2450d..fa457154d 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/cli.py +++ b/vectordb_bench/backend/clients/aws_opensearch/cli.py @@ -18,6 +18,79 @@ class AWSOpenSearchTypedDict(TypedDict): port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")] user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")] password: Annotated[str, click.option("--password", type=str, help="Db password")] + number_of_shards: Annotated[ + int, + click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1), + ] + number_of_replicas: Annotated[ + int, + click.option( + "--number-of-replicas", type=int, help="Number of replica copies for each primary shard", default=1 + ), + ] + index_thread_qty: Annotated[ + int, + click.option( + "--index-thread-qty", + type=int, + help="Thread count for native engine indexing", + default=4, + ), + ] + + index_thread_qty_during_force_merge: Annotated[ + int, + click.option( + "--index-thread-qty-during-force-merge", + type=int, + help="Thread count during force merge operations", + default=4, + ), + ] + + number_of_indexing_clients: Annotated[ + int, + click.option( + "--number-of-indexing-clients", + type=int, + help="Number of concurrent indexing clients", + default=1, + ), + ] + + number_of_segments: Annotated[ + int, + click.option("--number-of-segments", type=int, help="Target number of segments after merging", default=1), + ] + + refresh_interval: Annotated[ + int, + click.option( + "--refresh-interval", type=str, help="How often to make new data available for search", default="60s" + ), + ] + + force_merge_enabled: Annotated[ + int, + click.option("--force-merge-enabled", type=bool, help="Whether to perform force merge operation", default=True), + ] + + flush_threshold_size: Annotated[ + int, + click.option( + "--flush-threshold-size", type=str, help="Size threshold for flushing the transaction log", default="5120mb" + ), + ] + + cb_threshold: Annotated[ + int, + click.option( + "--cb-threshold", + type=str, + help="k-NN Memory circuit breaker threshold", + default="50%", + ), + ] class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ... @@ -36,6 +109,17 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]): user=parameters["user"], password=SecretStr(parameters["password"]), ), - db_case_config=AWSOpenSearchIndexConfig(), + db_case_config=AWSOpenSearchIndexConfig( + number_of_shards=parameters["number_of_shards"], + number_of_replicas=parameters["number_of_replicas"], + index_thread_qty=parameters["index_thread_qty"], + number_of_segments=parameters["number_of_segments"], + refresh_interval=parameters["refresh_interval"], + force_merge_enabled=parameters["force_merge_enabled"], + flush_threshold_size=parameters["flush_threshold_size"], + number_of_indexing_clients=parameters["number_of_indexing_clients"], + index_thread_qty_during_force_merge=parameters["index_thread_qty_during_force_merge"], + cb_threshold=parameters["cb_threshold"], + ), **parameters, ) diff --git a/vectordb_bench/backend/clients/aws_opensearch/config.py b/vectordb_bench/backend/clients/aws_opensearch/config.py index e9ccc7277..dd51b266d 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/config.py +++ b/vectordb_bench/backend/clients/aws_opensearch/config.py @@ -39,6 +39,16 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig): efConstruction: int = 256 efSearch: int = 256 M: int = 16 + index_thread_qty: int | None = 4 + number_of_shards: int | None = 1 + number_of_replicas: int | None = 0 + number_of_segments: int | None = 1 + refresh_interval: str | None = "60s" + force_merge_enabled: bool | None = True + flush_threshold_size: str | None = "5120mb" + number_of_indexing_clients: int | None = 1 + index_thread_qty_during_force_merge: int + cb_threshold: str | None = "50%" def parse_metric(self) -> str: if self.metric_type == MetricType.IP: diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index 76c810263..26a810065 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -65,7 +65,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into the database. Args: @@ -74,7 +74,7 @@ def insert_embeddings( kwargs: other arguments Returns: - (int, Exception): number of embeddings inserted and exception if any + tuple[int, Exception]: number of embeddings inserted and exception if any """ ids = [str(i) for i in metadata] metadata = [{"id": int(i)} for i in metadata] diff --git a/vectordb_bench/backend/clients/clickhouse/cli.py b/vectordb_bench/backend/clients/clickhouse/cli.py new file mode 100644 index 000000000..4b50bc55b --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/cli.py @@ -0,0 +1,67 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor2, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from .config import ClickhouseHNSWConfig + + +class ClickhouseTypedDict(TypedDict): + password: Annotated[str, click.option("--password", type=str, help="DB password")] + host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)] + port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")] + user: Annotated[int, click.option("--user", type=str, default="clickhouse", help="DB user")] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + is_flag=True, + show_default=True, + default=True, + help="Enable or disable SSL for Clickhouse", + ), + ] + ssl_ca_certs: Annotated[ + str, + click.option( + "--ssl-ca-certs", + show_default=True, + help="Path to certificate authority file to use for SSL", + ), + ] + + +class ClickhouseHNSWTypedDict(CommonTypedDict, ClickhouseTypedDict, HNSWFlavor2): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(ClickhouseHNSWTypedDict) +def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]): + from .config import ClickhouseConfig + + run( + db=DB.Clickhouse, + db_config=ClickhouseConfig( + db_label=parameters["db_label"], + user=parameters["user"], + password=SecretStr(parameters["password"]) if parameters["password"] else None, + host=parameters["host"], + port=parameters["port"], + ssl=parameters["ssl"], + ssl_ca_certs=parameters["ssl_ca_certs"], + ), + db_case_config=ClickhouseHNSWConfig( + M=parameters["m"], + efConstruction=parameters["ef_construction"], + ef=parameters["ef_runtime"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/clickhouse/clickhouse.py b/vectordb_bench/backend/clients/clickhouse/clickhouse.py new file mode 100644 index 000000000..de09895a8 --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/clickhouse.py @@ -0,0 +1,232 @@ +"""Wrapper around the Clickhouse vector database over VectorDB""" + +import logging +from contextlib import contextmanager +from typing import Any + +import clickhouse_connect +from clickhouse_connect.driver import Client + +from .. import IndexType +from ..api import VectorDB +from .config import ClickhouseConfigDict, ClickhouseIndexConfig + +log = logging.getLogger(__name__) + + +class Clickhouse(VectorDB): + """Use SQLAlchemy instructions""" + + def __init__( + self, + dim: int, + db_config: ClickhouseConfigDict, + db_case_config: ClickhouseIndexConfig, + collection_name: str = "CHVectorCollection", + drop_old: bool = False, + **kwargs, + ): + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + + self.index_param = self.case_config.index_param() + self.search_param = self.case_config.search_param() + self.session_param = self.case_config.session_param() + + self._index_name = "clickhouse_index" + self._primary_field = "id" + self._vector_field = "embedding" + + # construct basic units + self.conn = self._create_connection(**self.db_config, settings=self.session_param) + + if drop_old: + log.info(f"Clickhouse client drop table : {self.table_name}") + self._drop_table() + self._create_table(dim) + if self.case_config.create_index_before_load: + self._create_index() + + self.conn.close() + self.conn = None + + @contextmanager + def init(self) -> None: + """ + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + >>> self.search_embedding() + """ + + self.conn = self._create_connection(**self.db_config, settings=self.session_param) + + try: + yield + finally: + self.conn.close() + self.conn = None + + def _create_connection(self, settings: dict | None, **kwargs) -> Client: + return clickhouse_connect.get_client(**self.db_config, settings=settings) + + def _drop_index(self): + assert self.conn is not None, "Connection is not initialized" + try: + self.conn.command( + f'ALTER TABLE {self.db_config["database"]}.{self.table_name} DROP INDEX {self._index_name}' + ) + except Exception as e: + log.warning(f"Failed to drop index on table {self.db_config['database']}.{self.table_name}: {e}") + raise e from None + + def _drop_table(self): + assert self.conn is not None, "Connection is not initialized" + + try: + self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["database"]}.{self.table_name}') + except Exception as e: + log.warning(f"Failed to drop table {self.db_config['database']}.{self.table_name}: {e}") + raise e from None + + def _perfomance_tuning(self): + self.conn.command("SET materialize_skip_indexes_on_insert = 1") + + def _create_index(self): + assert self.conn is not None, "Connection is not initialized" + try: + if self.index_param["index_type"] == IndexType.HNSW.value: + if ( + self.index_param["quantization"] + and self.index_param["params"]["M"] + and self.index_param["params"]["efConstruction"] + ): + query = f""" + ALTER TABLE {self.db_config["database"]}.{self.table_name} + ADD INDEX {self._index_name} {self._vector_field} + TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}',{self.dim}, + '{self.index_param["quantization"]}', + {self.index_param["params"]["M"]}, {self.index_param["params"]["efConstruction"]}) + GRANULARITY {self.index_param["granularity"]} + """ + else: + query = f""" + ALTER TABLE {self.db_config["database"]}.{self.table_name} + ADD INDEX {self._index_name} {self._vector_field} + TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}', {self.dim}) + GRANULARITY {self.index_param["granularity"]} + """ + self.conn.command(cmd=query) + else: + log.warning("HNSW is only avaliable method in clickhouse now") + except Exception as e: + log.warning(f"Failed to create Clickhouse vector index on table: {self.table_name} error: {e}") + raise e from None + + def _create_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + + try: + # create table + self.conn.command( + f'CREATE TABLE IF NOT EXISTS {self.db_config["database"]}.{self.table_name} ' + f"({self._primary_field} UInt32, " + f'{self._vector_field} Array({self.index_param["vector_data_type"]}) CODEC(NONE), ' + f"CONSTRAINT same_length CHECK length(embedding) = {dim}) " + f"ENGINE = MergeTree() " + f"ORDER BY {self._primary_field}" + ) + + except Exception as e: + log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}") + raise e from None + + def optimize(self, data_size: int | None = None): + pass + + def _post_insert(self): + pass + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> (int, Exception): + assert self.conn is not None, "Connection is not initialized" + + try: + # do not iterate for bulk insert + items = [metadata, embeddings] + + self.conn.insert( + table=self.table_name, + data=items, + column_names=["id", "embedding"], + column_type_names=["UInt32", f'Array({self.index_param["vector_data_type"]})'], + column_oriented=True, + ) + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + parameters = { + "primary_field": self._primary_field, + "vector_field": self._vector_field, + "schema": self.db_config["database"], + "table": self.table_name, + "gt": 0 if filters is None else filters.get("id", 0), + "k": k, + "metric_type": self.search_param["metric_type"], + "query": query, + } + if self.case_config.metric_type == "COSINE": + if filters: + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "WHERE {primary_field:Identifier} > {gt:UInt32} " + "ORDER BY cosineDistance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows + return [int(row[0]) for row in result] + + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "ORDER BY cosineDistance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows + return [int(row[0]) for row in result] + if filters: + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "WHERE {primary_field:Identifier} > {gt:UInt32} " + "ORDER BY L2Distance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows + return [int(row[0]) for row in result] + + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "ORDER BY L2Distance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows + return [int(row[0]) for row in result] diff --git a/vectordb_bench/backend/clients/clickhouse/config.py b/vectordb_bench/backend/clients/clickhouse/config.py new file mode 100644 index 000000000..f9e09812b --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/config.py @@ -0,0 +1,89 @@ +from abc import abstractmethod +from typing import TypedDict + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class ClickhouseConfigDict(TypedDict): + user: str + password: str + host: str + port: int + database: str + secure: bool + + +class ClickhouseConfig(DBConfig): + user: str = "clickhouse" + password: SecretStr + host: str = "localhost" + port: int = 8123 + db_name: str = "default" + secure: bool = False + + def to_dict(self) -> ClickhouseConfigDict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "database": self.db_name, + "user": self.user, + "password": pwd_str, + "secure": self.secure, + } + + +class ClickhouseIndexConfig(BaseModel, DBCaseConfig): + + metric_type: MetricType | None = None + vector_data_type: str | None = "Float32" # Data type of vectors. Can be Float32 or Float64 or BFloat16 + create_index_before_load: bool = True + create_index_after_load: bool = False + + def parse_metric(self) -> str: + if not self.metric_type: + return "" + return self.metric_type.value + + def parse_metric_str(self) -> str: + if self.metric_type == MetricType.L2: + return "L2Distance" + if self.metric_type == MetricType.COSINE: + return "cosineDistance" + return "cosineDistance" + + @abstractmethod + def session_param(self): + pass + + +class ClickhouseHNSWConfig(ClickhouseIndexConfig): + M: int | None # Default in clickhouse in 32 + efConstruction: int | None # Default in clickhouse in 128 + ef: int | None = None + index: IndexType = IndexType.HNSW + quantization: str | None = "bf16" # Default is bf16. Possible values are f64, f32, f16, bf16, or i8 + granularity: int | None = 10_000_000 # Size of the index granules. By default, in CH it's equal 10.000.000 + + def index_param(self) -> dict: + return { + "vector_data_type": self.vector_data_type, + "metric_type": self.parse_metric_str(), + "index_type": self.index.value, + "quantization": self.quantization, + "granularity": self.granularity, + "params": {"M": self.M, "efConstruction": self.efConstruction}, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric_str(), + "params": {"ef": self.ef}, + } + + def session_param(self) -> dict: + return { + "allow_experimental_vector_similarity_index": 1, + } diff --git a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py index ea038c587..7d201729f 100644 --- a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +++ b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py @@ -81,7 +81,7 @@ def insert_embeddings( embeddings: Iterable[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert the embeddings to the elasticsearch.""" assert self.client is not None, "should self.init() first" diff --git a/vectordb_bench/backend/clients/lancedb/cli.py b/vectordb_bench/backend/clients/lancedb/cli.py new file mode 100644 index 000000000..573c64d05 --- /dev/null +++ b/vectordb_bench/backend/clients/lancedb/cli.py @@ -0,0 +1,92 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from ..api import IndexType + + +class LanceDBTypedDict(CommonTypedDict): + uri: Annotated[ + str, + click.option("--uri", type=str, help="URI connection string", required=True), + ] + token: Annotated[ + str | None, + click.option("--token", type=str, help="Authentication token", required=False), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDB(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get("NONE")(), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDBAutoIndex(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get(IndexType.AUTOINDEX)(), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get(IndexType.IVFPQ)(), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get(IndexType.HNSW)(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/lancedb/config.py b/vectordb_bench/backend/clients/lancedb/config.py new file mode 100644 index 000000000..0bbdfc4c9 --- /dev/null +++ b/vectordb_bench/backend/clients/lancedb/config.py @@ -0,0 +1,103 @@ +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class LanceDBConfig(DBConfig): + """LanceDB connection configuration.""" + + db_label: str + uri: str + token: SecretStr | None = None + + def to_dict(self) -> dict: + return { + "uri": self.uri, + "token": self.token.get_secret_value() if self.token else None, + } + + +class LanceDBIndexConfig(BaseModel, DBCaseConfig): + index: IndexType = IndexType.IVFPQ + metric_type: MetricType = MetricType.L2 + num_partitions: int = 0 + num_sub_vectors: int = 0 + nbits: int = 8 # Must be 4 or 8 + sample_rate: int = 256 + max_iterations: int = 50 + + def index_param(self) -> dict: + if self.index not in [ + IndexType.IVFPQ, + IndexType.HNSW, + IndexType.AUTOINDEX, + IndexType.NONE, + ]: + msg = f"Index type {self.index} is not supported for LanceDB!" + raise ValueError(msg) + + # See https://lancedb.github.io/lancedb/python/python/#lancedb.table.Table.create_index + params = { + "metric": self.parse_metric(), + "num_bits": self.nbits, + "sample_rate": self.sample_rate, + "max_iterations": self.max_iterations, + } + + if self.num_partitions > 0: + params["num_partitions"] = self.num_partitions + if self.num_sub_vectors > 0: + params["num_sub_vectors"] = self.num_sub_vectors + + return params + + def search_param(self) -> dict: + pass + + def parse_metric(self) -> str: + if self.metric_type in [MetricType.L2, MetricType.COSINE]: + return self.metric_type.value.lower() + if self.metric_type in [MetricType.IP, MetricType.DP]: + return "dot" + msg = f"Metric type {self.metric_type} is not supported for LanceDB!" + raise ValueError(msg) + + +class LanceDBNoIndexConfig(LanceDBIndexConfig): + index: IndexType = IndexType.NONE + + def index_param(self) -> dict: + return {} + + +class LanceDBAutoIndexConfig(LanceDBIndexConfig): + index: IndexType = IndexType.AUTOINDEX + + def index_param(self) -> dict: + return {} + + +class LanceDBHNSWIndexConfig(LanceDBIndexConfig): + index: IndexType = IndexType.HNSW + m: int = 0 + ef_construction: int = 0 + + def index_param(self) -> dict: + params = LanceDBIndexConfig.index_param(self) + + # See https://lancedb.github.io/lancedb/python/python/#lancedb.index.HnswSq + params["index_type"] = "IVF_HNSW_SQ" + if self.m > 0: + params["m"] = self.m + if self.ef_construction > 0: + params["ef_construction"] = self.ef_construction + + return params + + +_lancedb_case_config = { + IndexType.IVFPQ: LanceDBIndexConfig, + IndexType.AUTOINDEX: LanceDBAutoIndexConfig, + IndexType.HNSW: LanceDBHNSWIndexConfig, + IndexType.NONE: LanceDBNoIndexConfig, +} diff --git a/vectordb_bench/backend/clients/lancedb/lancedb.py b/vectordb_bench/backend/clients/lancedb/lancedb.py new file mode 100644 index 000000000..d93871de2 --- /dev/null +++ b/vectordb_bench/backend/clients/lancedb/lancedb.py @@ -0,0 +1,91 @@ +import logging +from contextlib import contextmanager + +import lancedb +import pyarrow as pa +from lancedb.pydantic import LanceModel + +from ..api import IndexType, VectorDB +from .config import LanceDBConfig, LanceDBIndexConfig + +log = logging.getLogger(__name__) + + +class VectorModel(LanceModel): + id: int + vector: list[float] + + +class LanceDB(VectorDB): + def __init__( + self, + dim: int, + db_config: LanceDBConfig, + db_case_config: LanceDBIndexConfig, + collection_name: str = "vector_bench_test", + drop_old: bool = False, + **kwargs, + ): + self.name = "LanceDB" + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + self.uri = db_config["uri"] + + db = lancedb.connect(self.uri) + + if drop_old: + try: + db.drop_table(self.table_name) + except Exception as e: + log.warning(f"Failed to drop table {self.table_name}: {e}") + + try: + db.open_table(self.table_name) + except Exception: + schema = pa.schema( + [pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.float64(), list_size=self.dim))] + ) + db.create_table(self.table_name, schema=schema, mode="overwrite") + + @contextmanager + def init(self): + self.db = lancedb.connect(self.uri) + self.table = self.db.open_table(self.table_name) + yield + self.db = None + self.table = None + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + ) -> tuple[int, Exception | None]: + try: + data = [{"id": meta, "vector": emb} for meta, emb in zip(metadata, embeddings, strict=False)] + self.table.add(data) + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into LanceDB table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + ) -> list[int]: + if filters: + results = self.table.search(query).where(f"id >= {filters['id']}", prefilter=True).limit(k).to_list() + else: + results = self.table.search(query).limit(k).to_list() + return [int(result["id"]) for result in results] + + def optimize(self, data_size: int | None = None): + if self.table and hasattr(self, "case_config") and self.case_config.index != IndexType.NONE: + log.info(f"Creating index for LanceDB table ({self.table_name})") + self.table.create_index(**self.case_config.index_param()) + # Better recall with IVF_PQ (though still bad) but breaks HNSW: https://github.com/lancedb/lancedb/issues/2369 + if self.case_config.index in (IndexType.IVFPQ, IndexType.AUTOINDEX): + self.table.optimize() diff --git a/vectordb_bench/backend/clients/mariadb/cli.py b/vectordb_bench/backend/clients/mariadb/cli.py new file mode 100644 index 000000000..969247271 --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/cli.py @@ -0,0 +1,122 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + + +class MariaDBTypedDict(CommonTypedDict): + user_name: Annotated[ + str, + click.option( + "--username", + type=str, + help="Username", + required=True, + ), + ] + password: Annotated[ + str, + click.option( + "--password", + type=str, + help="Password", + required=True, + ), + ] + + host: Annotated[ + str, + click.option( + "--host", + type=str, + help="Db host", + default="127.0.0.1", + ), + ] + + port: Annotated[ + int, + click.option( + "--port", + type=int, + default=3306, + help="Db Port", + ), + ] + + storage_engine: Annotated[ + int, + click.option( + "--storage-engine", + type=click.Choice(["InnoDB", "MyISAM"]), + help="DB storage engine", + required=True, + ), + ] + + +class MariaDBHNSWTypedDict(MariaDBTypedDict): + m: Annotated[ + int | None, + click.option( + "--m", + type=int, + help="M parameter in MHNSW vector indexing", + required=False, + ), + ] + + ef_search: Annotated[ + int | None, + click.option( + "--ef-search", + type=int, + help="MariaDB system variable mhnsw_min_limit", + required=False, + ), + ] + + max_cache_size: Annotated[ + int | None, + click.option( + "--max-cache-size", + type=int, + help="MariaDB system variable mhnsw_max_cache_size", + required=False, + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict) +def MariaDBHNSW( + **parameters: Unpack[MariaDBHNSWTypedDict], +): + from .config import MariaDBConfig, MariaDBHNSWConfig + + run( + db=DB.MariaDB, + db_config=MariaDBConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + ), + db_case_config=MariaDBHNSWConfig( + M=parameters["m"], + ef_search=parameters["ef_search"], + storage_engine=parameters["storage_engine"], + max_cache_size=parameters["max_cache_size"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/mariadb/config.py b/vectordb_bench/backend/clients/mariadb/config.py new file mode 100644 index 000000000..d183adc76 --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/config.py @@ -0,0 +1,73 @@ +from typing import TypedDict + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class MariaDBConfigDict(TypedDict): + """These keys will be directly used as kwargs in mariadb connection string, + so the names must match exactly mariadb API""" + + user: str + password: str + host: str + port: int + + +class MariaDBConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 3306 + + def to_dict(self) -> MariaDBConfigDict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + } + + +class MariaDBIndexConfig(BaseModel): + """Base config for MariaDB""" + + metric_type: MetricType | None = None + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + if self.metric_type == MetricType.COSINE: + return "cosine" + msg = f"Metric type {self.metric_type} is not supported!" + raise ValueError(msg) + + +class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): + M: int | None + ef_search: int | None + index: IndexType = IndexType.HNSW + storage_engine: str = "InnoDB" + max_cache_size: int | None + + def index_param(self) -> dict: + return { + "storage_engine": self.storage_engine, + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "M": self.M, + "max_cache_size": self.max_cache_size, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "ef_search": self.ef_search, + } + + +_mariadb_case_config = { + IndexType.HNSW: MariaDBHNSWConfig, +} diff --git a/vectordb_bench/backend/clients/mariadb/mariadb.py b/vectordb_bench/backend/clients/mariadb/mariadb.py new file mode 100644 index 000000000..5ccddfe7a --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/mariadb.py @@ -0,0 +1,208 @@ +import logging +from contextlib import contextmanager + +import mariadb +import numpy as np + +from ..api import VectorDB +from .config import MariaDBConfigDict, MariaDBIndexConfig + +log = logging.getLogger(__name__) + + +class MariaDB(VectorDB): + def __init__( + self, + dim: int, + db_config: MariaDBConfigDict, + db_case_config: MariaDBIndexConfig, + collection_name: str = "vec_collection", + drop_old: bool = False, + **kwargs, + ): + self.name = "MariaDB" + self.db_config = db_config + self.case_config = db_case_config + self.db_name = "vectordbbench" + self.table_name = collection_name + self.dim = dim + + # construct basic units + self.conn, self.cursor = self._create_connection(**self.db_config) + + if drop_old: + self._drop_db() + self._create_db_table(dim) + + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + @staticmethod + def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]: + conn = mariadb.connect(**kwargs) + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + + def _drop_db(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop db : {self.db_name}") + + # flush tables before dropping database to avoid some locking issue + self.cursor.execute("FLUSH TABLES") + self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}") + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + def _create_db_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + log.info(f"{self.name} client create database : {self.db_name}") + self.cursor.execute(f"CREATE DATABASE {self.db_name}") + + log.info(f"{self.name} client create table : {self.table_name}") + self.cursor.execute(f"USE {self.db_name}") + + self.cursor.execute( + f""" + CREATE TABLE {self.table_name} ( + id INT PRIMARY KEY, + v VECTOR({self.dim}) NOT NULL + ) ENGINE={index_param["storage_engine"]} + """ + ) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning(f"Failed to create table: {self.table_name} error: {e}") + raise e from None + + @contextmanager + def init(self): + """create and destory connections to database. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.conn, self.cursor = self._create_connection(**self.db_config) + + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() + + # maximize allowed package size + self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824") + + if index_param["index_type"] == "HNSW": + if index_param["max_cache_size"] is not None: + self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param['max_cache_size']}") + if search_param["ef_search"] is not None: + self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}") + self.cursor.execute("COMMIT") + + self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608 + self.select_sql = ( + f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608 + f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d" + ) + self.select_sql_with_filter = ( + f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608 + f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d" + ) + + try: + yield + finally: + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + def ready_to_load(self) -> bool: + pass + + def optimize(self) -> None: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + index_options = f"DISTANCE={index_param['metric_type']}" + if index_param["index_type"] == "HNSW" and index_param["M"] is not None: + index_options += f" M={index_param['M']}" + + self.cursor.execute( + f""" + ALTER TABLE {self.db_name}.{self.table_name} + ADD VECTOR KEY v(v) {index_options} + """ + ) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning(f"Failed to create index: {self.table_name} error: {e}") + raise e from None + + @staticmethod + def vector_to_hex(v): # noqa: ANN001 + return np.array(v, "float32").tobytes() + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> tuple[int, Exception]: + """Insert embeddings into the database. + Should call self.init() first. + """ + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + batch_data = [] + for i, row in enumerate(metadata_arr): + batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i]))) + + self.cursor.executemany(self.insert_sql, batch_data) + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + search_param = self.case_config.search_param() # noqa: F841 + + if filters: + self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k)) + else: + self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k)) + + return [id for (id,) in self.cursor.fetchall()] # noqa: A001 diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 51ea82eff..24a61566f 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -195,6 +195,38 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): ) +class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict): + metric_type: Annotated[ + str, + click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"), + ] + limit: Annotated[ + int, + click.option("--limit", type=int, required=True, help="Top-k limit for search"), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict) +def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): + from .config import GPUBruteForceConfig, MilvusConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), + ), + db_case_config=GPUBruteForceConfig( + metric_type=parameters["metric_type"], + limit=parameters["limit"], # top-k for search + ), + **parameters, + ) + + class MilvusGPUIVFPQTypedDict( CommonTypedDict, MilvusTypedDict, diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 7d0df803a..672becf1b 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, SecretStr, validator -from ..api import DBCaseConfig, DBConfig, IndexType, MetricType +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType class MilvusConfig(DBConfig): @@ -40,6 +40,7 @@ def is_gpu_index(self) -> bool: IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ, + IndexType.GPU_BRUTE_FORCE, ] def parse_metric(self) -> str: @@ -87,6 +88,88 @@ def search_param(self) -> dict: } +class HNSWSQConfig(HNSWConfig, DBCaseConfig): + index: IndexType = IndexType.HNSW_SQ + sq_type: SQType = SQType.SQ8 + refine: bool = True + refine_type: SQType = SQType.FP32 + refine_k: float = 1 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "M": self.M, + "efConstruction": self.efConstruction, + "sq_type": self.sq_type.value, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"ef": self.ef, "refine_k": self.refine_k}, + } + + +class HNSWPQConfig(HNSWConfig): + index: IndexType = IndexType.HNSW_PQ + m: int = 32 + nbits: int = 8 + refine: bool = True + refine_type: SQType = SQType.FP32 + refine_k: float = 1 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "M": self.M, + "efConstruction": self.efConstruction, + "m": self.m, + "nbits": self.nbits, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"ef": self.ef, "refine_k": self.refine_k}, + } + + +class HNSWPRQConfig(HNSWPQConfig): + index: IndexType = IndexType.HNSW_PRQ + nrq: int = 2 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "M": self.M, + "efConstruction": self.efConstruction, + "m": self.m, + "nbits": self.nbits, + "nrq": self.nrq, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"ef": self.ef, "refine_k": self.refine_k}, + } + + class DISKANNConfig(MilvusIndexConfig, DBCaseConfig): search_list: int | None = None index: IndexType = IndexType.DISKANN @@ -124,6 +207,27 @@ def search_param(self) -> dict: } +class IVFPQConfig(MilvusIndexConfig, DBCaseConfig): + nlist: int + nprobe: int | None = None + m: int = 32 + nbits: int = 8 + index: IndexType = IndexType.IVFPQ + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": {"nlist": self.nlist, "m": self.m, "nbits": self.nbits}, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"nprobe": self.nprobe}, + } + + class IVFSQ8Config(MilvusIndexConfig, DBCaseConfig): nlist: int nprobe: int | None = None @@ -143,6 +247,31 @@ def search_param(self) -> dict: } +class IVFRABITQConfig(IVFSQ8Config): + index: IndexType = IndexType.IVF_RABITQ + rbq_bits_query: int = 0 # 0, 1, 2, ..., 8 + refine: bool = True + refine_type: SQType = SQType.FP32 + refine_k: float = 1 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "nlist": self.nlist, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"nprobe": self.nprobe, "rbq_bits_query": self.rbq_bits_query, "refine_k": self.refine_k}, + } + + class FLATConfig(MilvusIndexConfig, DBCaseConfig): index: IndexType = IndexType.Flat @@ -184,6 +313,36 @@ def search_param(self) -> dict: } +class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig): + limit: int = 10 # Default top-k for search + metric_type: str # Metric type (e.g., 'L2', 'IP', etc.) + index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE + + def index_param(self) -> dict: + """ + Returns the parameters for creating the GPU_BRUTE_FORCE index. + No additional parameters required for index building. + """ + return { + "metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.) + "index_type": self.index.value, # GPU_BRUTE_FORCE index type + "params": {}, # No additional parameters for GPU_BRUTE_FORCE + } + + def search_param(self) -> dict: + """ + Returns the parameters for performing a search on the GPU_BRUTE_FORCE index. + Only metric_type and top-k (limit) are needed for search. + """ + return { + "metric_type": self.parse_metric(), # Metric type for search + "params": { + "nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search) + "limit": self.limit, # Top-k for search + }, + } + + class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig): nlist: int = 1024 m: int = 0 @@ -254,11 +413,17 @@ def search_param(self) -> dict: _milvus_case_config = { IndexType.AUTOINDEX: AutoIndexConfig, IndexType.HNSW: HNSWConfig, + IndexType.HNSW_SQ: HNSWSQConfig, + IndexType.HNSW_PQ: HNSWPQConfig, + IndexType.HNSW_PRQ: HNSWPRQConfig, IndexType.DISKANN: DISKANNConfig, IndexType.IVFFlat: IVFFlatConfig, + IndexType.IVFPQ: IVFPQConfig, IndexType.IVFSQ8: IVFSQ8Config, + IndexType.IVF_RABITQ: IVFRABITQConfig, IndexType.Flat: FLATConfig, IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig, IndexType.GPU_IVF_PQ: GPUIVFPQConfig, IndexType.GPU_CAGRA: GPUCAGRAConfig, + IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig, } diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 4015eb1f3..465c51179 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -61,6 +61,7 @@ def __init__( consistency_level="Session", ) + log.info(f"{self.name} create index: index_params: {self.case_config.index_param()}") col.create_index( self._vector_field, self.case_config.index_param(), @@ -71,7 +72,7 @@ def __init__( connections.disconnect("default") @contextmanager - def init(self) -> None: + def init(self): """ Examples: >>> with self.init(): @@ -126,6 +127,7 @@ def wait_index(): try: self.col.compact() self.col.wait_for_compaction_completed() + log.info("compactation completed. waiting for the rest of index buliding.") except Exception as e: log.warning(f"{self.name} compact error: {e}") if hasattr(e, "code"): @@ -155,7 +157,7 @@ def insert_embeddings( embeddings: Iterable[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into Milvus. should call self.init() first""" # use the first insert_embeddings to init collection assert self.col is not None diff --git a/vectordb_bench/backend/clients/mongodb/config.py b/vectordb_bench/backend/clients/mongodb/config.py index cc09471a4..a2d8ca57a 100644 --- a/vectordb_bench/backend/clients/mongodb/config.py +++ b/vectordb_bench/backend/clients/mongodb/config.py @@ -1,8 +1,16 @@ +from enum import Enum + from pydantic import BaseModel, SecretStr from ..api import DBCaseConfig, DBConfig, IndexType, MetricType +class QuantizationType(Enum): + NONE = "none" + BINARY = "binary" + SCALAR = "scalar" + + class MongoDBConfig(DBConfig, BaseModel): connection_string: SecretStr = "mongodb+srv://:@.heatl.mongodb.net" database: str = "vdb_bench" @@ -16,9 +24,9 @@ def to_dict(self) -> dict: class MongoDBIndexConfig(BaseModel, DBCaseConfig): index: IndexType = IndexType.HNSW # MongoDB uses HNSW for vector search - metric_type: MetricType | None = None - num_candidates: int | None = 1500 # Default numCandidates for vector search - exact_search: bool = False # Whether to use exact (ENN) search + metric_type: MetricType = MetricType.COSINE + num_candidates_ratio: int = 10 # Default numCandidates ratio for vector search + quantization: QuantizationType = QuantizationType.NONE # Quantization type if applicable def parse_metric(self) -> str: if self.metric_type == MetricType.L2: @@ -36,9 +44,10 @@ def index_param(self) -> dict: "similarity": self.parse_metric(), "numDimensions": None, # Will be set in MongoDB class "path": "vector", # Vector field name + "quantization": self.quantization.value, } ], } def search_param(self) -> dict: - return {"numCandidates": self.num_candidates if not self.exact_search else None, "exact": self.exact_search} + return {"num_candidates_ratio": self.num_candidates_ratio} diff --git a/vectordb_bench/backend/clients/mongodb/mongodb.py b/vectordb_bench/backend/clients/mongodb/mongodb.py index dddcc9a4c..0bbfd5d9c 100644 --- a/vectordb_bench/backend/clients/mongodb/mongodb.py +++ b/vectordb_bench/backend/clients/mongodb/mongodb.py @@ -90,7 +90,7 @@ def _create_index(self) -> None: break log.info(f"index deleting {indices}") except Exception: - log.exception("Error dropping index") + log.exception(f"Error dropping index {index_name}") try: # Create vector search index search_index = SearchIndexModel(definition=index_params, name=index_name, type="vectorSearch") @@ -104,7 +104,7 @@ def _create_index(self) -> None: log.info(f"Created index on {self.id_field} field") except Exception: - log.exception("Error creating index") + log.exception(f"Error creating index {index_name}") raise def _wait_for_index_ready(self, index_name: str, check_interval: int = 5) -> None: @@ -167,16 +167,15 @@ def search_embedding( else: # Set numCandidates based on k value and data size # For 50K dataset, use higher multiplier for better recall - num_candidates = min(10000, max(k * 20, search_params["numCandidates"] or 0)) + num_candidates = min(10000, k * search_params["num_candidates_ratio"]) vector_search["numCandidates"] = num_candidates # Add filter if specified if filters: log.info(f"Applying filter: {filters}") vector_search["filter"] = { - "id": {"gt": filters["id"]}, + "id": {"gte": filters["id"]}, } - pipeline = [ {"$vectorSearch": vector_search}, { diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index 55a462055..f8c138802 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -18,8 +18,7 @@ ) -# ruff: noqa -def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): +def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): # noqa: ARG001 if ctx.params.get("reranking") and value is None: # ef_search is the default value for quantized_fetch_limit as it's bound by ef_search. # 100 is default value for quantized_fetch_limit for IVFFlat. @@ -82,7 +81,17 @@ class PgVectorTypedDict(CommonTypedDict): click.option( "--quantization-type", type=click.Choice(["none", "bit", "halfvec"]), - help="quantization type for vectors", + help="quantization type for vectors (in index)", + required=False, + ), + ] + table_quantization_type: Annotated[ + str | None, + click.option( + "--table-quantization-type", + type=click.Choice(["none", "bit", "halfvec"]), + help="quantization type for vectors (in table). " + "If equal to bit, the parameter quantization_type will be set to bit too.", required=False, ), ] @@ -146,6 +155,7 @@ def PgVectorIVFFlat( lists=parameters["lists"], probes=parameters["probes"], quantization_type=parameters["quantization_type"], + table_quantization_type=parameters["table_quantization_type"], reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], @@ -182,6 +192,7 @@ def PgVectorHNSW( maintenance_work_mem=parameters["maintenance_work_mem"], max_parallel_workers=parameters["max_parallel_workers"], quantization_type=parameters["quantization_type"], + table_quantization_type=parameters["table_quantization_type"], reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index c386d75ef..abfddc0cf 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -80,7 +80,12 @@ def parse_metric(self) -> str: if d.get(self.quantization_type) is None: return d.get("_fallback").get(self.metric_type) - return d.get(self.quantization_type).get(self.metric_type) + metric = d.get(self.quantization_type).get(self.metric_type) + # If using binary quantization for the index, use a bit metric + # no matter what metric was selected for vector or halfvec data + if self.quantization_type == "bit" and metric is None: + return "bit_hamming_ops" + return metric def parse_metric_fun_op(self) -> LiteralString: if self.quantization_type == "bit": @@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig): maintenance_work_mem: str | None = None max_parallel_workers: int | None = None quantization_type: str | None = None + table_quantization_type: str | None reranking: bool | None = None quantized_fetch_limit: int | None = None reranking_metric: str | None = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"lists": self.lists} - if self.quantization_type == "none": - self.quantization_type = None + if self.quantization_type == "none" or self.quantization_type is None: + self.quantization_type = "vector" + if self.table_quantization_type == "none" or self.table_quantization_type is None: + self.table_quantization_type = "vector" + if self.table_quantization_type == "bit": + self.quantization_type = "bit" return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -183,6 +193,7 @@ def index_param(self) -> PgVectorIndexParam: "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, + "table_quantization_type": self.table_quantization_type, } def search_param(self) -> PgVectorSearchParam: @@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): maintenance_work_mem: str | None = None max_parallel_workers: int | None = None quantization_type: str | None = None + table_quantization_type: str | None reranking: bool | None = None quantized_fetch_limit: int | None = None reranking_metric: str | None = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} - if self.quantization_type == "none": - self.quantization_type = None + if self.quantization_type == "none" or self.quantization_type is None: + self.quantization_type = "vector" + if self.table_quantization_type == "none" or self.table_quantization_type is None: + self.table_quantization_type = "vector" + if self.table_quantization_type == "bit": + self.quantization_type = "bit" return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -227,6 +243,7 @@ def index_param(self) -> PgVectorIndexParam: "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, + "table_quantization_type": self.table_quantization_type, } def search_param(self) -> PgVectorSearchParam: diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 4164461fb..7d06a2ba5 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -94,7 +94,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: reranking = self.case_config.search_param()["reranking"] column_name = ( sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding")) - if index_param["quantization_type"] == "bit" + if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit" else sql.SQL("embedding") ) search_vector = ( @@ -104,7 +104,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: ) # The following sections assume that the quantization_type value matches the quantization function name - if index_param["quantization_type"] is not None: + if index_param["quantization_type"] != index_param["table_quantization_type"]: + # Reranking makes sense only if table quantization is not "bit" if index_param["quantization_type"] == "bit" and reranking: # Embeddings needs to be passed to binary_quantize function if quantization_type is bit search_query = sql.Composed( @@ -113,7 +114,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: """ SELECT i.id FROM ( - SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance + SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) """, @@ -123,6 +124,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: reranking_metric_fun_op=sql.SQL( self.case_config.search_param()["reranking_metric_fun_op"], ), + search_vector=search_vector, + table_quantization_type=sql.SQL(index_param["table_quantization_type"]), quantization_type=sql.SQL(index_param["quantization_type"]), dim=sql.Literal(self.dim), where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), @@ -130,7 +133,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL( """ - {search_vector} + {search_vector}::{quantization_type}({dim}) LIMIT {quantized_fetch_limit} ) i ORDER BY i.distance @@ -138,6 +141,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: """, ).format( search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), quantized_fetch_limit=sql.Literal( self.case_config.search_param()["quantized_fetch_limit"], ), @@ -160,10 +165,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" {search_vector} LIMIT %s::int").format( + sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format( search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), ), - ], + ] ) else: search_query = sql.Composed( @@ -175,8 +182,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ], + sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format( + search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), + ] ) return search_query @@ -323,7 +334,7 @@ def _create_index(self): ) with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) - if index_param["quantization_type"] is not None: + if index_param["quantization_type"] != index_param["table_quantization_type"]: index_create_sql = sql.SQL( """ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} @@ -365,14 +376,23 @@ def _create_table(self, dim: int): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + index_param = self.case_config.index_param() + try: log.info(f"{self.name} client create table : {self.table_name}") # create table self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", - ).format(table_name=sql.Identifier(self.table_name), dim=dim), + """ + CREATE TABLE IF NOT EXISTS public.{table_name} + (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim})); + """ + ).format( + table_name=sql.Identifier(self.table_name), + table_quantization_type=sql.SQL(index_param["table_quantization_type"]), + dim=dim, + ) ) self.cursor.execute( sql.SQL( @@ -393,18 +413,41 @@ def insert_embeddings( assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + index_param = self.case_config.index_param() + try: metadata_arr = np.array(metadata) embeddings_arr = np.array(embeddings) - with self.cursor.copy( - sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name), - ), - ) as copy: - copy.set_types(["bigint", "vector"]) - for i, row in enumerate(metadata_arr): - copy.write_row((row, embeddings_arr[i])) + if index_param["table_quantization_type"] == "bit": + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + # Same logic as pgvector binary_quantize + for i, row in enumerate(metadata_arr): + embeddings_bit = "" + for embedding in embeddings_arr[i]: + if embedding > 0: + embeddings_bit += "1" + else: + embeddings_bit += "0" + copy.write_row((str(row), embeddings_bit)) + else: + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + if index_param["table_quantization_type"] == "halfvec": + copy.set_types(["bigint", "halfvec"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, np.float16(embeddings_arr[i]))) + else: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) self.conn.commit() if kwargs.get("last_batch"): diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index 1a681b33f..9fd0afc7c 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -67,7 +67,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: assert len(embeddings) == len(metadata) insert_count = 0 try: diff --git a/vectordb_bench/backend/clients/qdrant_cloud/cli.py b/vectordb_bench/backend/clients/qdrant_cloud/cli.py new file mode 100644 index 000000000..32353472b --- /dev/null +++ b/vectordb_bench/backend/clients/qdrant_cloud/cli.py @@ -0,0 +1,43 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB + + +class QdrantTypedDict(CommonTypedDict): + url: Annotated[ + str, + click.option("--url", type=str, help="URL connection string", required=True), + ] + api_key: Annotated[ + str | None, + click.option("--api-key", type=str, help="API key for authentication", required=False), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(QdrantTypedDict) +def QdrantCloud(**parameters: Unpack[QdrantTypedDict]): + from .config import QdrantConfig, QdrantIndexConfig + + config_params = { + "db_label": parameters["db_label"], + "url": SecretStr(parameters["url"]), + } + + config_params["api_key"] = SecretStr(parameters["api_key"]) if parameters["api_key"] else None + + run( + db=DB.QdrantCloud, + db_config=QdrantConfig(**config_params), + db_case_config=QdrantIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/qdrant_cloud/config.py b/vectordb_bench/backend/clients/qdrant_cloud/config.py index c1d6882c0..b60733bc3 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/config.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/config.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, SecretStr, validator +from pydantic import BaseModel, SecretStr from ..api import DBCaseConfig, DBConfig, MetricType @@ -6,28 +6,20 @@ # Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant. class QdrantConfig(DBConfig): url: SecretStr - api_key: SecretStr + api_key: SecretStr | None = None def to_dict(self) -> dict: - api_key = self.api_key.get_secret_value() - if len(api_key) > 0: + api_key_value = self.api_key.get_secret_value() if self.api_key else None + if api_key_value: return { "url": self.url.get_secret_value(), - "api_key": self.api_key.get_secret_value(), + "api_key": api_key_value, "prefer_grpc": True, } return { "url": self.url.get_secret_value(), } - @validator("*") - def not_empty_field(cls, v: any, field: any): - if field.name in ["api_key", "db_label"]: - return v - if isinstance(v, str | SecretStr) and len(v) == 0: - raise ValueError("Empty string!") - return v - class QdrantIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index 5de72798b..f618c3ba4 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -111,7 +111,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into Milvus. should call self.init() first""" assert self.qdrant_client is not None try: diff --git a/vectordb_bench/backend/clients/tidb/cli.py b/vectordb_bench/backend/clients/tidb/cli.py new file mode 100644 index 000000000..cdfcbe432 --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/cli.py @@ -0,0 +1,98 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB + +from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run + + +class TiDBTypedDict(CommonTypedDict): + user_name: Annotated[ + str, + click.option( + "--username", + type=str, + help="Username", + default="root", + show_default=True, + required=True, + ), + ] + password: Annotated[ + str, + click.option( + "--password", + type=str, + default="", + show_default=True, + help="Password", + ), + ] + host: Annotated[ + str, + click.option( + "--host", + type=str, + default="127.0.0.1", + show_default=True, + required=True, + help="Db host", + ), + ] + port: Annotated[ + int, + click.option( + "--port", + type=int, + default=4000, + show_default=True, + required=True, + help="Db Port", + ), + ] + db_name: Annotated[ + str, + click.option( + "--db-name", + type=str, + default="test", + show_default=True, + required=True, + help="Db name", + ), + ] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + default=False, + show_default=True, + is_flag=True, + help="Enable or disable SSL, for TiDB Serverless SSL must be enabled", + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(TiDBTypedDict) +def TiDB( + **parameters: Unpack[TiDBTypedDict], +): + from .config import TiDBConfig, TiDBIndexConfig + + run( + db=DB.TiDB, + db_config=TiDBConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + db_name=parameters["db_name"], + ssl=parameters["ssl"], + ), + db_case_config=TiDBIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/tidb/config.py b/vectordb_bench/backend/clients/tidb/config.py new file mode 100644 index 000000000..693551045 --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/config.py @@ -0,0 +1,46 @@ +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, MetricType + + +class TiDBConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 4000 + db_name: str = "test" + ssl: bool = False + + def to_dict(self) -> dict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + "database": self.db_name, + "ssl_verify_cert": self.ssl, + "ssl_verify_identity": self.ssl, + } + + +class TiDBIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + + def get_metric_fn(self) -> str: + if self.metric_type == MetricType.L2: + return "vec_l2_distance" + if self.metric_type == MetricType.COSINE: + return "vec_cosine_distance" + msg = f"Unsupported metric type: {self.metric_type}" + raise ValueError(msg) + + def index_param(self) -> dict: + return { + "metric_fn": self.get_metric_fn(), + } + + def search_param(self) -> dict: + return { + "metric_fn": self.get_metric_fn(), + } diff --git a/vectordb_bench/backend/clients/tidb/tidb.py b/vectordb_bench/backend/clients/tidb/tidb.py new file mode 100644 index 000000000..b75605eda --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/tidb.py @@ -0,0 +1,233 @@ +import concurrent.futures +import io +import logging +import time +from contextlib import contextmanager +from typing import Any + +import pymysql + +from ..api import VectorDB +from .config import TiDBIndexConfig + +log = logging.getLogger(__name__) + + +class TiDB(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: TiDBIndexConfig, + collection_name: str = "vector_bench_test", + drop_old: bool = False, + **kwargs, + ): + self.name = "TiDB" + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + self.conn = None # To be inited by init() + self.cursor = None # To be inited by init() + + self.search_fn = db_case_config.search_param()["metric_fn"] + + if drop_old: + self._drop_table() + self._create_table() + + @contextmanager + def init(self): + with self._get_connection() as (conn, cursor): + self.conn = conn + self.cursor = cursor + try: + yield + finally: + self.conn = None + self.cursor = None + + @contextmanager + def _get_connection(self): + with pymysql.connect(**self.db_config) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + yield conn, cursor + + def _drop_table(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") + conn.commit() + except Exception as e: + log.warning("Failed to drop table: %s error: %s", self.table_name, e) + raise + + def _create_table(self): + try: + index_param = self.case_config.index_param() + with self._get_connection() as (conn, cursor): + cursor.execute( + f""" + CREATE TABLE {self.table_name} ( + id BIGINT PRIMARY KEY, + embedding VECTOR({self.dim}) NOT NULL, + VECTOR INDEX (({index_param["metric_fn"]}(embedding))) + ); + """ + ) + conn.commit() + except Exception as e: + log.warning("Failed to create table: %s error: %s", self.table_name, e) + raise + + def ready_to_load(self) -> bool: + pass + + def optimize(self, data_size: int | None = None) -> None: + while True: + progress = self._optimize_check_tiflash_replica_progress() + if progress != 1: + log.info("Data replication not ready, progress: %d", progress) + time.sleep(2) + else: + break + + log.info("Waiting TiFlash to catch up...") + self._optimize_wait_tiflash_catch_up() + + log.info("Start compacting TiFlash replica...") + self._optimize_compact_tiflash() + + log.info("Waiting index build to finish...") + log_reduce_seq = 0 + while True: + pending_rows = self._optimize_get_tiflash_index_pending_rows() + if pending_rows > 0: + if log_reduce_seq % 15 == 0: + log.info("Index not fully built, pending rows: %d", pending_rows) + log_reduce_seq += 1 + time.sleep(2) + else: + break + + log.info("Index build finished successfully.") + + def _optimize_check_tiflash_replica_progress(self): + try: + database = self.db_config["database"] + with self._get_connection() as (_, cursor): + cursor.execute( + f""" + SELECT PROGRESS FROM information_schema.tiflash_replica + WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}" + """ # noqa: S608 + ) + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to check TiFlash replica progress: %s", e) + raise + + def _optimize_wait_tiflash_catch_up(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"') + conn.commit() + cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608 + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to wait TiFlash to catch up: %s", e) + raise + + def _optimize_compact_tiflash(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute(f"ALTER TABLE {self.table_name} COMPACT") + conn.commit() + except Exception as e: + log.warning("Failed to compact table: %s", e) + raise + + def _optimize_get_tiflash_index_pending_rows(self): + try: + database = self.db_config["database"] + with self._get_connection() as (_, cursor): + cursor.execute( + f""" + SELECT SUM(ROWS_STABLE_NOT_INDEXED) + FROM information_schema.tiflash_indexes + WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}" + """ # noqa: S608 + ) + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to read TiFlash index pending rows: %s", e) + raise + + def _insert_embeddings_serial( + self, + embeddings: list[list[float]], + metadata: list[int], + offset: int, + size: int, + ) -> Exception: + try: + with self._get_connection() as (conn, cursor): + buf = io.StringIO() + buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608 + for i in range(offset, offset + size): + if i > offset: + buf.write(",") + buf.write(f'({metadata[i]}, "{embeddings[i]!s}")') + cursor.execute(buf.getvalue()) + conn.commit() + except Exception as e: + log.warning("Failed to insert data into table: %s", e) + raise + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> tuple[int, Exception]: + workers = 10 + # Avoid exceeding MAX_ALLOWED_PACKET (default=64MB) + max_batch_size = 64 * 1024 * 1024 // 24 // self.dim + batch_size = len(embeddings) // workers + batch_size = min(batch_size, max_batch_size) + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [] + for i in range(0, len(embeddings), batch_size): + offset = i + size = min(batch_size, len(embeddings) - i) + future = executor.submit(self._insert_embeddings_serial, embeddings, metadata, offset, size) + futures.append(future) + done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + executor.shutdown(wait=False) + for future in done: + future.result() + for future in pending: + future.cancel() + return len(metadata), None + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + self.cursor.execute( + f""" + SELECT id FROM {self.table_name} + ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k}; + """ # noqa: S608 + ) + result = self.cursor.fetchall() + return [int(i[0]) for i in result] diff --git a/vectordb_bench/backend/clients/vespa/cli.py b/vectordb_bench/backend/clients/vespa/cli.py new file mode 100644 index 000000000..616d46067 --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/cli.py @@ -0,0 +1,47 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB +from vectordb_bench.cli.cli import ( + CommonTypedDict, + HNSWFlavor1, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + + +class VespaTypedDict(CommonTypedDict, HNSWFlavor1): + uri: Annotated[ + str, + click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"), + ] + port: Annotated[ + int, + click.option("--port", "-p", type=int, help="connection port", default=8080), + ] + quantization: Annotated[ + str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none") + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(VespaTypedDict) +def Vespa(**params: Unpack[VespaTypedDict]): + from .config import VespaConfig, VespaHNSWConfig + + case_params = { + "quantization_type": params["quantization"], + "M": params["m"], + "efConstruction": params["ef_construction"], + "ef": params["ef_search"], + } + + run( + db=DB.Vespa, + db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]), + db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}), + **params, + ) diff --git a/vectordb_bench/backend/clients/vespa/config.py b/vectordb_bench/backend/clients/vespa/config.py new file mode 100644 index 000000000..3d4a1deaf --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/config.py @@ -0,0 +1,51 @@ +from typing import Literal, TypeAlias + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, MetricType + +VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"] + +VespaQuantizationType: TypeAlias = Literal["none", "binary"] + + +class VespaConfig(DBConfig): + url: SecretStr = "http://127.0.0.1" + port: int = 8080 + + def to_dict(self): + return { + "url": self.url.get_secret_value(), + "port": self.port, + } + + +class VespaHNSWConfig(BaseModel, DBCaseConfig): + metric_type: MetricType = MetricType.COSINE + quantization_type: VespaQuantizationType = "none" + M: int = 16 + efConstruction: int = 200 + ef: int = 100 + + def index_param(self) -> dict: + return { + "distance_metric": self.parse_metric(self.metric_type), + "max_links_per_node": self.M, + "neighbors_to_explore_at_insert": self.efConstruction, + } + + def search_param(self) -> dict: + return {} + + def parse_metric(self, metric_type: MetricType) -> VespaMetric: + match metric_type: + case MetricType.COSINE: + return "angular" + case MetricType.L2: + return "euclidean" + case MetricType.DP | MetricType.IP: + return "dotproduct" + case MetricType.HAMMING: + return "hamming" + case _: + raise NotImplementedError diff --git a/vectordb_bench/backend/clients/vespa/util.py b/vectordb_bench/backend/clients/vespa/util.py new file mode 100644 index 000000000..7a64cc30d --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/util.py @@ -0,0 +1,15 @@ +"""Utility functions for supporting binary quantization + +From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8 +""" + +import numpy as np + + +def binarize_tensor(tensor: list[float]) -> list[int]: + """ + Binarize a floating-point list by thresholding at zero + and packing the bits into bytes. + """ + tensor = np.array(tensor) + return np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist() diff --git a/vectordb_bench/backend/clients/vespa/vespa.py b/vectordb_bench/backend/clients/vespa/vespa.py new file mode 100644 index 000000000..5288bc04c --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/vespa.py @@ -0,0 +1,249 @@ +import datetime +import logging +import math +from collections.abc import Generator +from contextlib import contextmanager + +from vespa import application + +from ..api import VectorDB +from . import util +from .config import VespaHNSWConfig + +log = logging.getLogger(__name__) + + +class Vespa(VectorDB): + def __init__( + self, + dim: int, + db_config: dict[str, str], + db_case_config: VespaHNSWConfig | None = None, + collection_name: str = "VectorDBBenchCollection", + drop_old: bool = False, + **kwargs, + ) -> None: + self.dim = dim + self.db_config = db_config + self.case_config = db_case_config or VespaHNSWConfig() + self.schema_name = collection_name + + client = self.deploy_http() + client.wait_for_application_up() + + if drop_old: + try: + client.delete_all_docs("vectordbbench_content", self.schema_name) + except Exception: + drop_old = False + log.exception(f"Vespa client drop_old schema: {self.schema_name}") + + @contextmanager + def init(self) -> Generator[None, None, None]: + """create and destory connections to database. + Why contextmanager: + + In multiprocessing search tasks, vectordbbench might init + totally hundreds of thousands of connections with DB server. + + Too many connections may drain local FDs or server connection resources. + If the DB client doesn't have `close()` method, just set the object to None. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.client = application.Vespa(self.db_config["url"], port=self.db_config["port"]) + yield + self.client = None + + def need_normalize_cosine(self) -> bool: + """Wheather this database need to normalize dataset to support COSINE""" + return False + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> tuple[int, Exception | None]: + """Insert the embeddings to the vector database. The default number of embeddings for + each insert_embeddings is 5000. + + Args: + embeddings(list[list[float]]): list of embedding to add to the vector database. + metadatas(list[int]): metadata associated with the embeddings, for filtering. + **kwargs(Any): vector database specific parameters. + + Returns: + int: inserted data count + """ + assert self.client is not None + + data = ({"id": str(i), "fields": {"id": i, "embedding": e}} for i, e in zip(metadata, embeddings, strict=True)) + self.client.feed_iterable(data, self.schema_name) + return len(embeddings), None + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + ) -> list[int]: + """Get k most similar embeddings to query vector. + + Args: + query(list[float]): query embedding to look up documents similar to. + k(int): Number of most similar embeddings to return. Defaults to 100. + filters(dict, optional): filtering expression to filter the data while searching. + + Returns: + list[int]: list of k most similar embeddings IDs to the query embedding. + """ + assert self.client is not None + + ef = self.case_config.ef + extra_ef = max(0, ef - k) + embedding_field = "embedding" if self.case_config.quantization_type == "none" else "embedding_binary" + + yql = ( + f"select id from {self.schema_name} where " # noqa: S608 + f"{{targetHits: {k}, hnsw.exploreAdditionalHits: {extra_ef}}}" + f"nearestNeighbor({embedding_field}, query_embedding)" + ) + + if filters: + id_filter = filters.get("id") + yql += f" and id >= {id_filter}" + + query_embedding = query if self.case_config.quantization_type == "none" else util.binarize_tensor(query) + + ranking = self.case_config.quantization_type + + result = self.client.query({"yql": yql, "input.query(query_embedding)": query_embedding, "ranking": ranking}) + return [child["fields"]["id"] for child in result.get_json()["root"]["children"]] + + def optimize(self, data_size: int | None = None): + """optimize will be called between insertion and search in performance cases. + + Should be blocked until the vectorDB is ready to be tested on + heavy performance cases. + + Time(insert the dataset) + Time(optimize) will be recorded as "load_duration" metric + Optimize's execution time is limited, the limited time is based on cases. + """ + + @property + def application_package(self): + if getattr(self, "_application_package", None) is None: + self._application_package = self._create_application_package() + return self._application_package + + def _create_application_package(self): + from vespa.package import ( + HNSW, + ApplicationPackage, + Document, + Field, + RankProfile, + Schema, + Validation, + ValidationID, + ) + + fields = [ + Field( + "id", + "int", + indexing=["summary", "attribute"], + ), + Field( + "embedding", + f"tensor(x[{self.dim}])", + indexing=["summary", "attribute", "index"], + ann=HNSW(**self.case_config.index_param()), + ), + ] + + if self.case_config.quantization_type == "binary": + fields.append( + Field( + "embedding_binary", + f"tensor(x[{math.ceil(self.dim / 8)}])", + indexing=[ + "input embedding", + # convert 32 bit float to 1 bit + "binarize", + # pack 8 bits into one int8 + "pack_bits", + "summary", + "attribute", + "index", + ], + ann=HNSW(**{**self.case_config.index_param(), "distance_metric": "hamming"}), + is_document_field=False, + ) + ) + + tomorrow = datetime.date.today() + datetime.timedelta(days=1) + + return ApplicationPackage( + "vectordbbench", + [ + Schema( + self.schema_name, + Document( + fields, + ), + rank_profiles=[ + RankProfile( + name="none", + first_phase="", + inherits="default", + inputs=[("query(query_embedding)", f"tensor(x[{self.dim}])")], + ), + RankProfile( + name="binary", + first_phase="", + inherits="default", + inputs=[("query(query_embedding)", f"tensor(x[{math.ceil(self.dim / 8)}])")], + ), + ], + ) + ], + validations=[ + Validation(ValidationID.tensorTypeChange, until=tomorrow), + Validation(ValidationID.fieldTypeChange, until=tomorrow), + ], + ) + + def deploy_http(self) -> application.Vespa: + """ + Deploy a Vespa application package via HTTP REST API. + + Returns: + application.Vespa: The deployed Vespa application instance + """ + import requests + + url = self.db_config["url"] + ":19071/application/v2/tenant/default/prepareandactivate" + package_data = self.application_package.to_zip() + headers = {"Content-Type": "application/zip"} + + try: + response = requests.post(url=url, data=package_data, headers=headers, timeout=10) + + response.raise_for_status() + result = response.json() + return application.Vespa( + url=self.db_config["url"], + port=self.db_config["port"], + deployment_message=result.get("message"), + application_package=self.application_package, + ) + + except requests.exceptions.RequestException as e: + error_msg = f"Failed to deploy Vespa application: {e!s}" + if hasattr(e, "response") and e.response is not None: + error_msg += f" - Response: {e.response.text}" + raise RuntimeError(error_msg) from e diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py index aa4368bb7..c31104d8b 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py @@ -99,7 +99,7 @@ def insert_embeddings( embeddings: Iterable[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into Weaviate""" assert self.client.schema.exists(self.collection_name) insert_count = 0 diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index 62700b0fa..f90580dc6 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -220,10 +220,12 @@ def prepare( train_files = utils.compose_train_files(file_count, use_shuffled) all_files = train_files - gt_file, test_file = None, None + test_file = "test.parquet" + all_files.extend([test_file]) + gt_file = None if self.data.with_gt: - gt_file, test_file = utils.compose_gt_file(filters), "test.parquet" - all_files.extend([gt_file, test_file]) + gt_file = utils.compose_gt_file(filters) + all_files.extend([gt_file]) if not self.data.is_custom: source.reader().read( @@ -232,8 +234,10 @@ def prepare( local_ds_root=self.data_dir, ) - if gt_file is not None and test_file is not None: + if test_file is not None: self.test_data = self._read_file(test_file) + + if gt_file is not None: self.gt_data = self._read_file(gt_file) prefix = "shuffle_train" if use_shuffled else "train" diff --git a/vectordb_bench/backend/runner/mp_runner.py b/vectordb_bench/backend/runner/mp_runner.py index 687a0ecd7..fd87d7ece 100644 --- a/vectordb_bench/backend/runner/mp_runner.py +++ b/vectordb_bench/backend/runner/mp_runner.py @@ -5,10 +5,12 @@ import time import traceback from collections.abc import Iterable +from multiprocessing.queues import Queue import numpy as np from ... import config +from ...models import ConcurrencySlotTimeoutError from ..clients import api NUM_PER_BATCH = config.NUM_PER_BATCH @@ -28,16 +30,18 @@ def __init__( self, db: api.VectorDB, test_data: list[list[float]], - k: int = 100, + k: int = config.K_DEFAULT, filters: dict | None = None, concurrencies: Iterable[int] = config.NUM_CONCURRENCY, - duration: int = 30, + duration: int = config.CONCURRENCY_DURATION, + concurrency_timeout: int = config.CONCURRENCY_TIMEOUT, ): self.db = db self.k = k self.filters = filters self.concurrencies = concurrencies self.duration = duration + self.concurrency_timeout = concurrency_timeout self.test_data = test_data log.debug(f"test dataset columns: {len(test_data)}") @@ -114,9 +118,7 @@ def _run_all_concurrencies_mem_efficient(self): log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}") future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)] # Sync all processes - while q.qsize() < conc: - sleep_t = conc if conc < 10 else 10 - time.sleep(sleep_t) + self._wait_for_queue_fill(q, size=conc) with cond: cond.notify_all() @@ -160,6 +162,15 @@ def _run_all_concurrencies_mem_efficient(self): conc_latency_avg_list, ) + def _wait_for_queue_fill(self, q: Queue, size: int): + wait_t = 0 + while q.qsize() < size: + sleep_t = size if size < 10 else 10 + wait_t += sleep_t + if wait_t > self.concurrency_timeout > 0: + raise ConcurrencySlotTimeoutError + time.sleep(sleep_t) + def run(self) -> float: """ Returns: diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 365641132..5b418c886 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -209,7 +209,8 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: ideal_dcg = get_ideal_dcg(self.k) log.debug(f"test dataset size: {len(test_data)}") - log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}") + if ground_truth is not None: + log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}") latencies, recalls, ndcgs = [], [], [] for idx, emb in enumerate(test_data): @@ -228,9 +229,13 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: latencies.append(time.perf_counter() - s) - gt = ground_truth["neighbors_id"][idx] - recalls.append(calc_recall(self.k, gt[: self.k], results)) - ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg)) + if ground_truth is not None: + gt = ground_truth["neighbors_id"][idx] + recalls.append(calc_recall(self.k, gt[: self.k], results)) + ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg)) + else: + recalls.append(0) + ndcgs.append(0) if len(latencies) % 100 == 0: log.debug( diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 2a583b4f5..1da3bb4cd 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -275,6 +275,7 @@ def _init_search_runner(self): filters=self.ca.filters, concurrencies=self.config.case_config.concurrency_search_config.num_concurrency, duration=self.config.case_config.concurrency_search_config.concurrency_duration, + concurrency_timeout=self.config.case_config.concurrency_search_config.concurrency_timeout, k=self.config.case_config.k, ) diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 3bb7763d8..1b0eb295b 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -1,9 +1,9 @@ import logging -import os import time from collections.abc import Callable from concurrent.futures import wait from datetime import datetime +from pathlib import Path from pprint import pformat from typing import ( Annotated, @@ -17,10 +17,9 @@ import click from yaml import load -from vectordb_bench.backend.clients.api import MetricType - from .. import config from ..backend.clients import DB +from ..backend.clients.api import MetricType from ..interface import benchmark_runner, global_result_future from ..models import ( CaseConfig, @@ -38,18 +37,17 @@ from yaml import Loader -def click_get_defaults_from_file(ctx, param, value): +def click_get_defaults_from_file(ctx, param, value): # noqa: ANN001, ARG001 if value: - if os.path.exists(value): - input_file = value - else: - input_file = os.path.join(config.CONFIG_LOCAL_DIR, value) + path = Path(value) + input_file = path if path.exists() else Path(config.CONFIG_LOCAL_DIR, path) try: - with open(input_file) as f: - _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader) + with input_file.open() as f: + _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader) # noqa: S506 ctx.default_map = _config.get(ctx.command.name, {}) except Exception as e: - raise click.BadParameter(f"Failed to load config file: {e}") + msg = f"Failed to load config file: {e}" + raise click.BadParameter(msg) from e return value @@ -68,12 +66,16 @@ def click_parameter_decorators_from_typed_dict( For clarity, the key names of the TypedDict will be used to determine the type hints for the input parameters. - The actual function parameters are controlled by the click.option definitions. You must manually ensure these are aligned in a sensible way! + The actual function parameters are controlled by the click.option definitions. + You must manually ensure these are aligned in a sensible way! Example: ``` class CommonTypedDict(TypedDict): - z: Annotated[int, click.option("--z/--no-z", is_flag=True, type=bool, help="help z", default=True, show_default=True)] + z: Annotated[ + int, + click.option("--z/--no-z", is_flag=True, type=bool, help="help z", default=True, show_default=True) + ] name: Annotated[str, click.argument("name", required=False, default="Jeff")] class FooTypedDict(CommonTypedDict): @@ -91,14 +93,16 @@ def foo(**parameters: Unpack[FooTypedDict]): for _, t in get_type_hints(typed_dict, include_extras=True).items(): assert get_origin(t) is Annotated if len(t.__metadata__) == 1 and t.__metadata__[0].__module__ == "click.decorators": - # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1) + # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] + # with no additional metadata defined (len=1) decorators.append(t.__metadata__[0]) else: raise RuntimeError( - "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring", + "Click-TypedDict decorator parsing must only contain root type " + "and a click decorator like click.option. See docstring", ) - def deco(f): + def deco(f): # noqa: ANN001 for dec in reversed(decorators): f = dec(f) return f @@ -106,7 +110,7 @@ def deco(f): return deco -def click_arg_split(ctx: click.Context, param: click.core.Option, value): +def click_arg_split(ctx: click.Context, param: click.core.Option, value): # noqa: ANN001, ARG001 """Will split a comma-separated list input into an actual list. Args: @@ -145,8 +149,7 @@ def parse_task_stages( return stages -# ruff: noqa -def check_custom_case_parameters(ctx: any, param: any, value: any): +def check_custom_case_parameters(ctx: any, param: any, value: any): # noqa: ARG001 if ctx.params.get("case_type") == "PerformanceCustomDataset" and value is None: raise click.BadParameter( """ Custom case parameters @@ -299,6 +302,17 @@ class CommonTypedDict(TypedDict): callback=lambda *args: list(map(int, click_arg_split(*args))), ), ] + concurrency_timeout: Annotated[ + int, + click.option( + "--concurrency-timeout", + type=int, + default=config.CONCURRENCY_TIMEOUT, + show_default=True, + help="Timeout (in seconds) to wait for a concurrency slot before failing. " + "Set to a negative value to wait indefinitely.", + ), + ] custom_case_name: Annotated[ str, click.option( @@ -401,6 +415,7 @@ class CommonTypedDict(TypedDict): show_default=True, ), ] + task_label: Annotated[str, click.option("--task-label", help="Task label")] class HNSWBaseTypedDict(TypedDict): @@ -485,6 +500,7 @@ def run( concurrency_search_config=ConcurrencySearchConfig( concurrency_duration=parameters["concurrency_duration"], num_concurrency=[int(s) for s in parameters["num_concurrency"]], + concurrency_timeout=parameters["concurrency_timeout"], ), custom_case=get_custom_case_config(parameters), ), @@ -495,10 +511,11 @@ def run( parameters["search_concurrent"], ), ) + task_label = parameters["task_label"] log.info(f"Task:\n{pformat(task)}\n") if not parameters["dry_run"]: - benchmark_runner.run([task]) + benchmark_runner.run([task], task_label) time.sleep(5) if global_result_future: wait([global_result_future]) diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 5e3798691..d4153bc1e 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,13 +1,19 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.clickhouse.cli import Clickhouse +from ..backend.clients.lancedb.cli import LanceDB +from ..backend.clients.mariadb.cli import MariaDBHNSW from ..backend.clients.memorydb.cli import MemoryDB from ..backend.clients.milvus.cli import MilvusAutoIndex from ..backend.clients.pgdiskann.cli import PgDiskAnn from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat from ..backend.clients.pgvector.cli import PgVectorHNSW from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn +from ..backend.clients.qdrant_cloud.cli import QdrantCloud from ..backend.clients.redis.cli import Redis from ..backend.clients.test.cli import Test +from ..backend.clients.tidb.cli import TiDB +from ..backend.clients.vespa.cli import Vespa from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex from .cli import cli @@ -25,6 +31,12 @@ cli.add_command(PgVectorScaleDiskAnn) cli.add_command(PgDiskAnn) cli.add_command(AlloyDBScaNN) +cli.add_command(MariaDBHNSW) +cli.add_command(TiDB) +cli.add_command(Clickhouse) +cli.add_command(Vespa) +cli.add_command(LanceDB) +cli.add_command(QdrantCloud) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/components/custom/displaypPrams.py b/vectordb_bench/frontend/components/custom/displaypPrams.py index b677e5909..712b57b00 100644 --- a/vectordb_bench/frontend/components/custom/displaypPrams.py +++ b/vectordb_bench/frontend/components/custom/displaypPrams.py @@ -3,7 +3,7 @@ def displayParams(st): """ - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format. - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. - - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. + - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files. @@ -11,3 +11,14 @@ def displayParams(st): - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order. """ ) + st.caption( + """We recommend limiting the number of test query vectors, like 1,000.""", + help=""" +When conducting concurrent query tests, Vdbbench creates a large number of processes. +To minimize additional communication overhead during testing, +we prepare a complete set of test queries for each process, allowing them to run independently.\n +However, this means that as the number of concurrent processes increases, +the number of copied query vectors also increases significantly, +which can place substantial pressure on memory resources. +""", + ) diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py index 800e6dede..a2d2de77f 100644 --- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py +++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py @@ -36,21 +36,27 @@ def dbConfigSettingItem(st, activeDb: DB): columns = st.columns(DB_CONFIG_SETTING_COLUMNS) dbConfigClass = activeDb.config_cls - properties = dbConfigClass.schema().get("properties") + schema = dbConfigClass.schema() + property_items = schema.get("properties").items() + required_fields = set(schema.get("required", [])) dbConfig = {} idx = 0 # db config (unique) - for key, property in properties.items(): + for key, property in property_items: if key not in dbConfigClass.common_short_configs() and key not in dbConfigClass.common_long_configs(): column = columns[idx % DB_CONFIG_SETTING_COLUMNS] idx += 1 - dbConfig[key] = column.text_input( + input_value = column.text_input( key, - key="%s-%s" % (activeDb.name, key), + key=f"{activeDb.name}-{key}", value=property.get("default", ""), type="password" if inputIsPassword(key) else "default", + placeholder="optional" if key not in required_fields else None, ) + if key in required_fields or input_value: + dbConfig[key] = input_value + # db config (common short labels) for key in dbConfigClass.common_short_configs(): column = columns[idx % DB_CONFIG_SETTING_COLUMNS] diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py index 426095397..5827efeb4 100644 --- a/vectordb_bench/frontend/components/run_test/submitTask.py +++ b/vectordb_bench/frontend/components/run_test/submitTask.py @@ -1,6 +1,8 @@ from datetime import datetime +from vectordb_bench import config from vectordb_bench.frontend.config import styles from vectordb_bench.interface import benchmark_runner +from vectordb_bench.models import TaskConfig def submitTask(st, tasks, isAllValid): @@ -47,16 +49,31 @@ def advancedSettings(st): k = container[0].number_input("k", min_value=1, value=100, label_visibility="collapsed") container[1].caption("K value for number of nearest neighbors to search") - return index_already_exists, use_aliyun, k + container = st.columns([1, 2]) + defaultconcurrentInput = ",".join(map(str, config.NUM_CONCURRENCY)) + concurrentInput = container[0].text_input( + "Concurrent Input", value=defaultconcurrentInput, label_visibility="collapsed" + ) + container[1].caption("num of concurrencies for search tests to get max-qps") + return index_already_exists, use_aliyun, k, concurrentInput -def controlPanel(st, tasks, taskLabel, isAllValid): - index_already_exists, use_aliyun, k = advancedSettings(st) +def controlPanel(st, tasks: list[TaskConfig], taskLabel, isAllValid): + index_already_exists, use_aliyun, k, concurrentInput = advancedSettings(st) def runHandler(): benchmark_runner.set_drop_old(not index_already_exists) + + try: + concurrentInput_list = [int(item.strip()) for item in concurrentInput.split(",")] + except ValueError: + st.write("please input correct number") + return None + for task in tasks: task.case_config.k = k + task.case_config.concurrency_search_config.num_concurrency = concurrentInput_list + benchmark_runner.set_download_address(use_aliyun) benchmark_runner.run(tasks, taskLabel) diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index e004f2ba7..3c9430b2b 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from vectordb_bench.backend.cases import CaseLabel, CaseType from vectordb_bench.backend.clients import DB -from vectordb_bench.backend.clients.api import IndexType, MetricType +from vectordb_bench.backend.clients.api import IndexType, MetricType, SQType from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_configs from vectordb_bench.models import CaseConfig, CaseConfigParamType @@ -164,15 +164,20 @@ class CaseConfigInput(BaseModel): inputConfig={ "options": [ IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, + IndexType.IVF_RABITQ.value, IndexType.DISKANN.value, - IndexType.STREAMING_DISKANN.value, IndexType.Flat.value, IndexType.AUTOINDEX.value, IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_CAGRA.value, + IndexType.GPU_BRUTE_FORCE.value, ], }, ) @@ -345,9 +350,16 @@ class CaseConfigInput(BaseModel): "max": 64, "value": 30, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, + ], ) + CaseConfigParamInput_m = CaseConfigInput( label=CaseConfigParamType.m, inputType=InputType.Number, @@ -368,7 +380,62 @@ class CaseConfigInput(BaseModel): "max": 512, "value": 360, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] + in [ + IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, + ], +) + +CaseConfigParamInput_SQType = CaseConfigInput( + label=CaseConfigParamType.sq_type, + inputType=InputType.Option, + inputHelp="Scalar quantizer type.", + inputConfig={ + "options": [SQType.SQ6.value, SQType.SQ8.value, SQType.BF16.value, SQType.FP16.value, SQType.FP32.value] + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.HNSW_SQ.value], +) + +CaseConfigParamInput_Refine = CaseConfigInput( + label=CaseConfigParamType.refine, + inputType=InputType.Option, + inputHelp="Whether refined data is reserved during index building.", + inputConfig={"options": [True, False]}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value], +) + +CaseConfigParamInput_RefineType = CaseConfigInput( + label=CaseConfigParamType.refine_type, + inputType=InputType.Option, + inputHelp="The data type of the refine index.", + inputConfig={ + "options": [SQType.FP32.value, SQType.FP16.value, SQType.BF16.value, SQType.SQ8.value, SQType.SQ6.value] + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value] + and config.get(CaseConfigParamType.refine, True), +) + +CaseConfigParamInput_RefineK = CaseConfigInput( + label=CaseConfigParamType.refine_k, + inputType=InputType.Float, + inputHelp="The magnification factor of refine compared to k.", + inputConfig={"min": 1.0, "max": 10000.0, "value": 1.0}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value] + and config.get(CaseConfigParamType.refine, True), +) + +CaseConfigParamInput_RBQBitsQuery = CaseConfigInput( + label=CaseConfigParamType.rbq_bits_query, + inputType=InputType.Number, + inputHelp="The magnification factor of refine compared to k.", + inputConfig={"min": 0, "max": 8, "value": 0}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.IVF_RABITQ.value], ) CaseConfigParamInput_EFConstruction_Weaviate = CaseConfigInput( @@ -518,7 +585,13 @@ class CaseConfigInput(BaseModel): "max": MAX_STREAMLIT_INT, "value": 100, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] + in [ + IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, + ], ) CaseConfigParamInput_EF_Weaviate = CaseConfigInput( @@ -559,9 +632,12 @@ class CaseConfigInput(BaseModel): isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [ IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, + IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) @@ -576,9 +652,12 @@ class CaseConfigInput(BaseModel): isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [ IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, + IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) @@ -586,11 +665,12 @@ class CaseConfigInput(BaseModel): label=CaseConfigParamType.m, inputType=InputType.Number, inputConfig={ - "min": 0, + "min": 1, "max": 65536, - "value": 0, + "value": 32, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFPQ.value], ) @@ -602,7 +682,20 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 8, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFPQ.value], +) + +CaseConfigParamInput_NRQ = CaseConfigInput( + label=CaseConfigParamType.nrq, + inputType=InputType.Number, + inputHelp="The number of residual subquantizers.", + inputConfig={ + "min": 1, + "max": 16, + "value": 2, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.HNSW_PRQ.value], ) CaseConfigParamInput_intermediate_graph_degree = CaseConfigInput( @@ -703,6 +796,7 @@ class CaseConfigInput(BaseModel): IndexType.GPU_CAGRA.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_IVF_FLAT.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) @@ -720,6 +814,7 @@ class CaseConfigInput(BaseModel): IndexType.GPU_CAGRA.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_IVF_FLAT.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) @@ -818,6 +913,19 @@ class CaseConfigInput(BaseModel): ], ) +CaseConfigParamInput_TableQuantizationType_PgVector = CaseConfigInput( + label=CaseConfigParamType.tableQuantizationType, + inputType=InputType.Option, + inputConfig={ + "options": ["none", "bit", "halfvec"], + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + ], +) + CaseConfigParamInput_max_parallel_workers_PgVectorRS = CaseConfigInput( label=CaseConfigParamType.max_parallel_workers, displayLabel="Max parallel workers", @@ -1040,6 +1148,122 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_MariaDB = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_StorageEngine_MariaDB = CaseConfigInput( + label=CaseConfigParamType.storage_engine, + inputHelp="Select Storage Engine", + inputType=InputType.Option, + inputConfig={ + "options": ["InnoDB", "MyISAM"], + }, +) + +CaseConfigParamInput_M_MariaDB = CaseConfigInput( + label=CaseConfigParamType.M, + inputHelp="M parameter in MHNSW vector indexing", + inputType=InputType.Number, + inputConfig={ + "min": 3, + "max": 200, + "value": 6, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_EFSearch_MariaDB = CaseConfigInput( + label=CaseConfigParamType.ef_search, + inputHelp="mhnsw_ef_search", + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 10000, + "value": 20, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_CacheSize_MariaDB = CaseConfigInput( + label=CaseConfigParamType.max_cache_size, + inputHelp="mhnsw_max_cache_size", + inputType=InputType.Number, + inputConfig={ + "min": 1048576, + "max": (1 << 53) - 1, + "value": 16 * 1024**3, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_MongoDBQuantizationType = CaseConfigInput( + label=CaseConfigParamType.mongodb_quantization_type, + inputType=InputType.Option, + inputConfig={ + "options": ["none", "scalar", "binary"], + }, +) + + +CaseConfigParamInput_MongoDBNumCandidatesRatio = CaseConfigInput( + label=CaseConfigParamType.mongodb_num_candidates_ratio, + inputType=InputType.Number, + inputConfig={ + "min": 10, + "max": 20, + "value": 10, + }, +) + + +CaseConfigParamInput_M_Vespa = CaseConfigInput( + label=CaseConfigParamType.M, + inputType=InputType.Number, + inputConfig={ + "min": 4, + "max": 64, + "value": 16, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_IndexType_Vespa = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_QuantizationType_Vespa = CaseConfigInput( + label=CaseConfigParamType.quantizationType, + inputType=InputType.Option, + inputConfig={ + "options": ["none", "binary"], + }, +) + +CaseConfigParamInput_EFConstruction_Vespa = CaseConfigInput( + label=CaseConfigParamType.EFConstruction, + inputType=InputType.Number, + inputConfig={ + "min": 8, + "max": 512, + "value": 200, + }, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, +) + MilvusLoadConfig = [ CaseConfigParamInput_IndexType, @@ -1052,6 +1276,10 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_graph_degree, CaseConfigParamInput_build_algo, CaseConfigParamInput_cache_dataset_on_device, + CaseConfigParamInput_SQType, + CaseConfigParamInput_Refine, + CaseConfigParamInput_RefineType, + CaseConfigParamInput_NRQ, ] MilvusPerformanceConfig = [ CaseConfigParamInput_IndexType, @@ -1063,6 +1291,8 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_Nprobe, CaseConfigParamInput_M_PQ, CaseConfigParamInput_Nbits_PQ, + CaseConfigParamInput_RBQBitsQuery, + CaseConfigParamInput_NRQ, CaseConfigParamInput_intermediate_graph_degree, CaseConfigParamInput_graph_degree, CaseConfigParamInput_itopk_size, @@ -1073,6 +1303,10 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_build_algo, CaseConfigParamInput_cache_dataset_on_device, CaseConfigParamInput_refine_ratio, + CaseConfigParamInput_SQType, + CaseConfigParamInput_Refine, + CaseConfigParamInput_RefineType, + CaseConfigParamInput_RefineK, ] WeaviateLoadConfig = [ @@ -1113,6 +1347,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_m, CaseConfigParamInput_EFConstruction_PgVector, CaseConfigParamInput_QuantizationType_PgVector, + CaseConfigParamInput_TableQuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, ] @@ -1124,6 +1359,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_Lists_PgVector, CaseConfigParamInput_Probes_PgVector, CaseConfigParamInput_QuantizationType_PgVector, + CaseConfigParamInput_TableQuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, CaseConfigParamInput_reranking_PgVector, @@ -1224,6 +1460,158 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_NumCandidates_AliES, ] +MongoDBLoadingConfig = [ + CaseConfigParamInput_MongoDBQuantizationType, +] +MongoDBPerformanceConfig = [ + CaseConfigParamInput_MongoDBQuantizationType, + CaseConfigParamInput_MongoDBNumCandidatesRatio, +] + +MariaDBLoadingConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, +] +MariaDBPerformanceConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, + CaseConfigParamInput_EFSearch_MariaDB, +] + +VespaLoadingConfig = [ + CaseConfigParamInput_IndexType_Vespa, + CaseConfigParamInput_QuantizationType_Vespa, + CaseConfigParamInput_M_Vespa, + CaseConfigParamInput_EF_Milvus, + CaseConfigParamInput_EFConstruction_Vespa, +] +VespaPerformanceConfig = VespaLoadingConfig + +CaseConfigParamInput_IndexType_LanceDB = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="AUTOINDEX = IVFPQ with default parameters", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.NONE.value, + IndexType.AUTOINDEX.value, + IndexType.IVFPQ.value, + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_num_partitions_LanceDB = CaseConfigInput( + label=CaseConfigParamType.num_partitions, + displayLabel="Number of Partitions", + inputHelp="Number of partitions (clusters) for IVF_PQ. Default (when 0): sqrt(num_rows)", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 10000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_num_sub_vectors_LanceDB = CaseConfigInput( + label=CaseConfigParamType.num_sub_vectors, + displayLabel="Number of Sub-vectors", + inputHelp="Number of sub-vectors for PQ. Default (when 0): dim/16 or dim/8", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_num_bits_LanceDB = CaseConfigInput( + label=CaseConfigParamType.nbits, + displayLabel="Number of Bits", + inputHelp="Number of bits per sub-vector.", + inputType=InputType.Option, + inputConfig={ + "options": [4, 8], + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_sample_rate_LanceDB = CaseConfigInput( + label=CaseConfigParamType.sample_rate, + displayLabel="Sample Rate", + inputHelp="Sample rate for training. Higher values are more accurate but slower", + inputType=InputType.Number, + inputConfig={ + "min": 16, + "max": 1024, + "value": 256, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_max_iterations_LanceDB = CaseConfigInput( + label=CaseConfigParamType.max_iterations, + displayLabel="Max Iterations", + inputHelp="Maximum iterations for k-means clustering", + inputType=InputType.Number, + inputConfig={ + "min": 10, + "max": 200, + "value": 50, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_m_LanceDB = CaseConfigInput( + label=CaseConfigParamType.m, + displayLabel="m", + inputHelp="m parameter in HNSW", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_ef_construction_LanceDB = CaseConfigInput( + label=CaseConfigParamType.ef_construction, + displayLabel="ef_construction", + inputHelp="ef_construction parameter in HNSW", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +LanceDBLoadConfig = [ + CaseConfigParamInput_IndexType_LanceDB, + CaseConfigParamInput_num_partitions_LanceDB, + CaseConfigParamInput_num_sub_vectors_LanceDB, + CaseConfigParamInput_num_bits_LanceDB, + CaseConfigParamInput_sample_rate_LanceDB, + CaseConfigParamInput_max_iterations_LanceDB, + CaseConfigParamInput_m_LanceDB, + CaseConfigParamInput_ef_construction_LanceDB, +] + +LanceDBPerformanceConfig = LanceDBLoadConfig + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -1272,4 +1660,20 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: AliyunOpensearchLoadingConfig, CaseLabel.Performance: AliyunOpenSearchPerformanceConfig, }, + DB.MongoDB: { + CaseLabel.Load: MongoDBLoadingConfig, + CaseLabel.Performance: MongoDBPerformanceConfig, + }, + DB.MariaDB: { + CaseLabel.Load: MariaDBLoadingConfig, + CaseLabel.Performance: MariaDBPerformanceConfig, + }, + DB.Vespa: { + CaseLabel.Load: VespaLoadingConfig, + CaseLabel.Performance: VespaPerformanceConfig, + }, + DB.LanceDB: { + CaseLabel.Load: LanceDBLoadConfig, + CaseLabel.Performance: LanceDBPerformanceConfig, + }, } diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index 3e0fdb112..96a5eede4 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -47,6 +47,9 @@ def getPatternShape(i): DB.Redis: "https://assets.zilliz.com/Redis_Cloud_74b8bfef39.png", DB.Chroma: "https://assets.zilliz.com/chroma_ceb3f06ed7.png", DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", + DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", + DB.Vespa: "https://vespa.ai/vespa-content/uploads/2025/01/Vespa-symbol-green-rgb.png.webp", + DB.LanceDB: "https://raw.githubusercontent.com/lancedb/lancedb/main/docs/src/assets/logo.png", } # RedisCloud color: #0D6EFD @@ -61,4 +64,6 @@ def getPatternShape(i): DB.PgVector.value: "#4C779A", DB.Redis.value: "#0D6EFD", DB.AWSOpenSearch.value: "#0DCAF0", + DB.TiDB.value: "#0D6EFD", + DB.Vespa.value: "#61d790", } diff --git a/vectordb_bench/log_util.py b/vectordb_bench/log_util.py index d75688137..6ca6ccabf 100644 --- a/vectordb_bench/log_util.py +++ b/vectordb_bench/log_util.py @@ -1,8 +1,13 @@ import logging from logging import config +from pathlib import Path def init(log_level: str): + # Create logs directory if it doesn't exist + log_dir = Path("logs") + log_dir.mkdir(exist_ok=True) + log_config = { "version": 1, "disable_existing_loggers": False, @@ -24,15 +29,23 @@ def init(log_level: str): "class": "logging.StreamHandler", "formatter": "default", }, + "file": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "default", + "filename": "logs/vectordb_bench.log", + "maxBytes": 10485760, # 10MB + "backupCount": 5, + "encoding": "utf8", + }, }, "loggers": { "vectordb_bench": { - "handlers": ["console"], + "handlers": ["console", "file"], "level": log_level, "propagate": False, }, "no_color": { - "handlers": ["no_color_console"], + "handlers": ["no_color_console", "file"], "level": log_level, "propagate": False, }, diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 49bb04ae0..c35c21755 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -12,6 +12,7 @@ DB, DBCaseConfig, DBConfig, + EmptyDBCaseConfig, ) from .base import BaseModel from .metric import Metric @@ -29,6 +30,11 @@ def __init__(self): super().__init__("Performance case optimize timeout") +class ConcurrencySlotTimeoutError(TimeoutError): + def __init__(self): + super().__init__("Timeout while waiting for a concurrency slot to become available") + + class CaseConfigParamType(Enum): """ Value will be the key of CaseConfig.params and displayed in UI @@ -49,11 +55,13 @@ class CaseConfigParamType(Enum): probes = "probes" quantizationType = "quantization_type" quantizationRatio = "quantization_ratio" + tableQuantizationType = "table_quantization_type" reranking = "reranking" rerankingMetric = "reranking_metric" quantizedFetchLimit = "quantized_fetch_limit" m = "m" nbits = "nbits" + nrq = "nrq" intermediate_graph_degree = "intermediate_graph_degree" graph_degree = "graph_degree" itopk_size = "itopk_size" @@ -64,6 +72,11 @@ class CaseConfigParamType(Enum): build_algo = "build_algo" cache_dataset_on_device = "cache_dataset_on_device" refine_ratio = "refine_ratio" + refine = "refine" + refine_type = "refine_type" + refine_k = "refine_k" + rbq_bits_query = "rbq_bits_query" + sq_type = "sq_type" level = "level" maintenance_work_mem = "maintenance_work_mem" max_parallel_workers = "max_parallel_workers" @@ -87,6 +100,15 @@ class CaseConfigParamType(Enum): preReorderingNumNeigbors = "pre_reordering_num_neighbors" numSearchThreads = "num_search_threads" maxNumPrefetchDatasets = "max_num_prefetch_datasets" + storage_engine = "storage_engine" + max_cache_size = "max_cache_size" + num_partitions = "num_partitions" + num_sub_vectors = "num_sub_vectors" + sample_rate = "sample_rate" + + # mongodb params + mongodb_quantization_type = "quantization" + mongodb_num_candidates_ratio = "num_candidates_ratio" class CustomizedCase(BaseModel): @@ -96,6 +118,7 @@ class CustomizedCase(BaseModel): class ConcurrencySearchConfig(BaseModel): num_concurrency: list[int] = config.NUM_CONCURRENCY concurrency_duration: int = config.CONCURRENCY_DURATION + concurrency_timeout: int = config.CONCURRENCY_TIMEOUT class CaseConfig(BaseModel): @@ -234,13 +257,19 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: test_result["task_label"] = test_result["run_id"] for case_result in test_result["results"]: - task_config = case_result.get("task_config") - db = DB(task_config.get("db")) + task_config = case_result["task_config"] + db = DB(task_config["db"]) task_config["db_config"] = db.config_cls(**task_config["db_config"]) - task_config["db_case_config"] = db.case_config_cls( - index_type=task_config["db_case_config"].get("index", None), - )(**task_config["db_case_config"]) + + # Safely instantiate DBCaseConfig (fallback to EmptyDBCaseConfig on None) + raw_case_cfg = task_config.get("db_case_config") or {} + index_value = raw_case_cfg.get("index", None) + try: + task_config["db_case_config"] = db.case_config_cls(index_type=index_value)(**raw_case_cfg) + except Exception: + log.exception(f"Couldn't get class for index '{index_value}' ({full_path})") + task_config["db_case_config"] = EmptyDBCaseConfig(**raw_case_cfg) case_result["task_config"] = task_config @@ -256,7 +285,6 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: ) return TestResult.validate(test_result) - # ruff: noqa def display(self, dbs: list[DB] | None = None): filter_list = dbs if dbs and isinstance(dbs, list) else None sorted_results = sorted( @@ -287,7 +315,7 @@ def append_return(x: any, y: any): max_qps = 10 if max_qps < 10 else max_qps max_recall = 13 if max_recall < 13 else max_recall - LENGTH = ( + LENGTH = ( # noqa: N806 max_db, max_db_labels, max_case, @@ -300,13 +328,13 @@ def append_return(x: any, y: any): 5, ) - DATA_FORMAT = ( + DATA_FORMAT = ( # noqa: N806 f"%-{max_db}s | %-{max_db_labels}s %-{max_case}s %-{len(self.task_label)}s" f" | %-{max_load_dur}s %-{max_qps}s %-15s %-{max_recall}s %-14s" f" | %-5s" ) - TITLE = DATA_FORMAT % ( + TITLE = DATA_FORMAT % ( # noqa: N806 "DB", "db_label", "case", @@ -318,8 +346,8 @@ def append_return(x: any, y: any): "max_load_count", "label", ) - SPLIT = DATA_FORMAT % tuple(map(lambda x: "-" * x, LENGTH)) - SUMMARY_FORMAT = ("Task summary: run_id=%s, task_label=%s") % ( + SPLIT = DATA_FORMAT % tuple(map(lambda x: "-" * x, LENGTH)) # noqa: C417, N806 + SUMMARY_FORMAT = ("Task summary: run_id=%s, task_label=%s") % ( # noqa: N806 self.run_id[:5], self.task_label, )