diff --git a/README.md b/README.md index 80db92ec5..7a52c3242 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ All the database client supported | awsopensearch | `pip install vectordb-bench[opensearch]` | | aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` | | mongodb | `pip install vectordb-bench[mongodb]` | +| tidb | `pip install vectordb-bench[tidb]` | ### Run diff --git a/pyproject.toml b/pyproject.toml index ad80bacfe..019c2973e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ all = [ "alibabacloud_ha3engine_vector", "alibabacloud_searchengine20211025", "mariadb", + "PyMySQL", ] qdrant = [ "qdrant-client" ] @@ -87,7 +88,8 @@ chromadb = [ "chromadb" ] opensearch = [ "opensearch-py" ] aliyun_opensearch = [ "alibabacloud_ha3engine_vector", "alibabacloud_searchengine20211025"] mongodb = [ "pymongo" ] -mariadb = [ "mariadb" ] +mariadb = [ "mariadb" ] +tidb = [ "PyMySQL" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 742da5213..8d667ec23 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -42,6 +42,7 @@ class DB(Enum): Test = "test" AliyunOpenSearch = "AliyunOpenSearch" MongoDB = "MongoDB" + TiDB = "TiDB" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 @@ -141,6 +142,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return MariaDB + if self == DB.TiDB: + from .tidb.tidb import TiDB + + return TiDB + if self == DB.Test: from .test.test import Test @@ -244,8 +250,14 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 if self == DB.MariaDB: from .mariadb.config import MariaDBConfig + return MariaDBConfig + if self == DB.TiDB: + from .tidb.config import TiDBConfig + + return TiDBConfig + if self == DB.Test: from .test.config import TestConfig @@ -333,6 +345,11 @@ def case_config_cls( # noqa: PLR0911 return _mariadb_case_config.get(index_type) + if self == DB.TiDB: + from .tidb.config import TiDBIndexConfig + + return TiDBIndexConfig + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/tidb/cli.py b/vectordb_bench/backend/clients/tidb/cli.py new file mode 100644 index 000000000..cdfcbe432 --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/cli.py @@ -0,0 +1,98 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB + +from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run + + +class TiDBTypedDict(CommonTypedDict): + user_name: Annotated[ + str, + click.option( + "--username", + type=str, + help="Username", + default="root", + show_default=True, + required=True, + ), + ] + password: Annotated[ + str, + click.option( + "--password", + type=str, + default="", + show_default=True, + help="Password", + ), + ] + host: Annotated[ + str, + click.option( + "--host", + type=str, + default="127.0.0.1", + show_default=True, + required=True, + help="Db host", + ), + ] + port: Annotated[ + int, + click.option( + "--port", + type=int, + default=4000, + show_default=True, + required=True, + help="Db Port", + ), + ] + db_name: Annotated[ + str, + click.option( + "--db-name", + type=str, + default="test", + show_default=True, + required=True, + help="Db name", + ), + ] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + default=False, + show_default=True, + is_flag=True, + help="Enable or disable SSL, for TiDB Serverless SSL must be enabled", + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(TiDBTypedDict) +def TiDB( + **parameters: Unpack[TiDBTypedDict], +): + from .config import TiDBConfig, TiDBIndexConfig + + run( + db=DB.TiDB, + db_config=TiDBConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + db_name=parameters["db_name"], + ssl=parameters["ssl"], + ), + db_case_config=TiDBIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/tidb/config.py b/vectordb_bench/backend/clients/tidb/config.py new file mode 100644 index 000000000..213a18bc5 --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/config.py @@ -0,0 +1,49 @@ +from pydantic import SecretStr, BaseModel, validator +from ..api import DBConfig, DBCaseConfig, MetricType + + +class TiDBConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 4000 + db_name: str = "test" + ssl: bool = False + + @validator("*") + def not_empty_field(cls, v: any, field: any): + return v + + def to_dict(self) -> dict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + "database": self.db_name, + "ssl_verify_cert": self.ssl, + "ssl_verify_identity": self.ssl, + } + + +class TiDBIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + + def get_metric_fn(self) -> str: + if self.metric_type == MetricType.L2: + return "vec_l2_distance" + elif self.metric_type == MetricType.COSINE: + return "vec_cosine_distance" + else: + raise ValueError(f"Unsupported metric type: {self.metric_type}") + + def index_param(self) -> dict: + return { + "metric_fn": self.get_metric_fn(), + } + + def search_param(self) -> dict: + return { + "metric_fn": self.get_metric_fn(), + } diff --git a/vectordb_bench/backend/clients/tidb/tidb.py b/vectordb_bench/backend/clients/tidb/tidb.py new file mode 100644 index 000000000..d1f26084e --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/tidb.py @@ -0,0 +1,234 @@ +import concurrent.futures +import io +import logging +import time +from contextlib import contextmanager +from typing import Any, Optional, Tuple + +import pymysql + +from ..api import VectorDB +from .config import TiDBIndexConfig + +log = logging.getLogger(__name__) + + +class TiDB(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: TiDBIndexConfig, + collection_name: str = "vector_bench_test", + drop_old: bool = False, + **kwargs, + ): + self.name = "TiDB" + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + self.conn = None # To be inited by init() + self.cursor = None # To be inited by init() + + self.search_fn = db_case_config.search_param()["metric_fn"] + + if drop_old: + self._drop_table() + self._create_table() + + @contextmanager + def init(self): + with self._get_connection() as (conn, cursor): + self.conn = conn + self.cursor = cursor + try: + yield + finally: + self.conn = None + self.cursor = None + + @contextmanager + def _get_connection(self): + with pymysql.connect(**self.db_config) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + yield conn, cursor + + def _drop_table(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") + conn.commit() + except Exception as e: + log.warning("Failed to drop table: %s error: %s", self.table_name, e) + raise e + + def _create_table(self): + try: + index_param = self.case_config.index_param() + with self._get_connection() as (conn, cursor): + cursor.execute( + f""" + CREATE TABLE {self.table_name} ( + id BIGINT PRIMARY KEY, + embedding VECTOR({self.dim}) NOT NULL, + VECTOR INDEX (({index_param["metric_fn"]}(embedding))) + ); + """ + ) + conn.commit() + except Exception as e: + log.warning("Failed to create table: %s error: %s", self.table_name, e) + raise e + + def ready_to_load(self) -> bool: + pass + + def optimize(self, data_size: int | None = None) -> None: + while True: + progress = self._optimize_check_tiflash_replica_progress() + if progress != 1: + log.info("Data replication not ready, progress: %d", progress) + time.sleep(2) + else: + break + + log.info("Waiting TiFlash to catch up...") + self._optimize_wait_tiflash_catch_up() + + log.info("Start compacting TiFlash replica...") + self._optimize_compact_tiflash() + + log.info("Waiting index build to finish...") + log_reduce_seq = 0 + while True: + pending_rows = self._optimize_get_tiflash_index_pending_rows() + if pending_rows > 0: + if log_reduce_seq % 15 == 0: + log.info("Index not fully built, pending rows: %d", pending_rows) + log_reduce_seq += 1 + time.sleep(2) + else: + break + + log.info("Index build finished successfully.") + + def _optimize_check_tiflash_replica_progress(self): + try: + database = self.db_config["database"] + with self._get_connection() as (_, cursor): + cursor.execute( + f""" + SELECT PROGRESS FROM information_schema.tiflash_replica + WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}" + """ + ) + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to check TiFlash replica progress: %s", e) + raise e + + def _optimize_wait_tiflash_catch_up(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"') + conn.commit() + cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to wait TiFlash to catch up: %s", e) + raise e + + def _optimize_compact_tiflash(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute(f"ALTER TABLE {self.table_name} COMPACT") + conn.commit() + except Exception as e: + log.warning("Failed to compact table: %s", e) + raise e + + def _optimize_get_tiflash_index_pending_rows(self): + try: + database = self.db_config["database"] + with self._get_connection() as (_, cursor): + cursor.execute( + f""" + SELECT SUM(ROWS_STABLE_NOT_INDEXED) + FROM information_schema.tiflash_indexes + WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}" + """ + ) + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to read TiFlash index pending rows: %s", e) + raise e + + def _insert_embeddings_serial( + self, + embeddings: list[list[float]], + metadata: list[int], + offset: int, + size: int, + ) -> Exception: + try: + with self._get_connection() as (conn, cursor): + buf = io.StringIO() + buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") + for i in range(offset, offset + size): + if i > offset: + buf.write(",") + buf.write(f'({metadata[i]}, "{str(embeddings[i])}")') + cursor.execute(buf.getvalue()) + conn.commit() + except Exception as e: + log.warning("Failed to insert data into table: %s", e) + raise e + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + workers = 10 + # Avoid exceeding MAX_ALLOWED_PACKET (default=64MB) + max_batch_size = 64 * 1024 * 1024 // 24 // self.dim + batch_size = len(embeddings) // workers + if batch_size > max_batch_size: + batch_size = max_batch_size + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [] + for i in range(0, len(embeddings), batch_size): + offset = i + size = min(batch_size, len(embeddings) - i) + future = executor.submit(self._insert_embeddings_serial, embeddings, metadata, offset, size) + futures.append(future) + done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + executor.shutdown(wait=False) + for future in done: + future.result() + for future in pending: + future.cancel() + return len(metadata), None + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + self.cursor.execute( + f""" + SELECT id FROM {self.table_name} + ORDER BY {self.search_fn}(embedding, "{str(query)}") LIMIT {k}; + """ + ) + result = self.cursor.fetchall() + return [int(i[0]) for i in result] diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 7934b3871..49428b678 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -11,6 +11,7 @@ from ..backend.clients.test.cli import Test from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex +from ..backend.clients.tidb.cli import TiDB from .cli import cli cli.add_command(PgVectorHNSW) @@ -27,6 +28,7 @@ cli.add_command(PgDiskAnn) cli.add_command(AlloyDBScaNN) cli.add_command(MariaDBHNSW) +cli.add_command(TiDB) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index 3e0fdb112..57456722f 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -47,6 +47,7 @@ def getPatternShape(i): DB.Redis: "https://assets.zilliz.com/Redis_Cloud_74b8bfef39.png", DB.Chroma: "https://assets.zilliz.com/chroma_ceb3f06ed7.png", DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", + DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", } # RedisCloud color: #0D6EFD @@ -61,4 +62,5 @@ def getPatternShape(i): DB.PgVector.value: "#4C779A", DB.Redis.value: "#0D6EFD", DB.AWSOpenSearch.value: "#0DCAF0", + DB.TiDB.value: "#0D6EFD", }