From dd5b162b3095fa76717a02bf8dfdd0ea37e5101e Mon Sep 17 00:00:00 2001 From: yangxuan Date: Fri, 10 Jan 2025 14:21:24 +0800 Subject: [PATCH 01/36] fix: Unable to run vebbench and cli fix: remove comma of logging str fix cli unable to run #444 Signed-off-by: yangxuan --- pyproject.toml | 1 + vectordb_bench/backend/clients/__init__.py | 14 +++++- .../backend/clients/memorydb/cli.py | 4 +- .../backend/clients/pgvecto_rs/pgvecto_rs.py | 9 +--- .../backend/clients/pgvector/pgvector.py | 4 +- .../clients/pgvectorscale/pgvectorscale.py | 4 +- .../clients/qdrant_cloud/qdrant_cloud.py | 4 +- vectordb_bench/backend/clients/test/cli.py | 2 +- vectordb_bench/backend/data_source.py | 16 ++---- vectordb_bench/backend/runner/mp_runner.py | 50 ++++++------------- vectordb_bench/backend/runner/rate_runner.py | 8 +-- .../backend/runner/read_write_runner.py | 24 ++++----- .../backend/runner/serial_runner.py | 46 ++++++++--------- vectordb_bench/backend/task_runner.py | 26 ++-------- vectordb_bench/interface.py | 27 ++++------ 15 files changed, 88 insertions(+), 151 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 312940634..6259bcea6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,7 @@ lint.ignore = [ "RUF017", "C416", "PLW0603", + "COM812", ] # Allow autofix for all enabled rules (when `--fix`) is provided. diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 773cd4948..e796aa069 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -42,7 +42,7 @@ class DB(Enum): AliyunOpenSearch = "AliyunOpenSearch" @property - def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912 + def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 """Import while in use""" if self == DB.Milvus: from .milvus.milvus import Milvus @@ -129,11 +129,16 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912 return AliyunOpenSearch + if self == DB.Test: + from .test.test import Test + + return Test + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @property - def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912 + def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 """Import while in use""" if self == DB.Milvus: from .milvus.config import MilvusConfig @@ -220,6 +225,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912 return AliyunOpenSearchConfig + if self == DB.Test: + from .test.config import TestConfig + + return TestConfig + msg = f"Unknown DB: {self.name}" raise ValueError(msg) diff --git a/vectordb_bench/backend/clients/memorydb/cli.py b/vectordb_bench/backend/clients/memorydb/cli.py index ae00bfd17..568eec2a3 100644 --- a/vectordb_bench/backend/clients/memorydb/cli.py +++ b/vectordb_bench/backend/clients/memorydb/cli.py @@ -43,8 +43,8 @@ class MemoryDBTypedDict(TypedDict): show_default=True, default=False, help=( - "Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance.", - " In production, MemoryDB only supports cluster mode (CME)", + "Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance." + " In production, MemoryDB only supports cluster mode (CME)" ), ), ] diff --git a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py index fc4f17807..64e95a1be 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py @@ -200,10 +200,7 @@ def _create_index(self): self.cursor.execute(index_create_sql) self.conn.commit() except Exception as e: - log.warning( - f"Failed to create pgvecto.rs index {self._index_name} \ - at table {self.table_name} error: {e}", - ) + log.warning(f"Failed to create pgvecto.rs index {self._index_name} at table {self.table_name} error: {e}") raise e from None def _create_table(self, dim: int): @@ -258,9 +255,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: - log.warning( - f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}", - ) + log.warning(f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}") return 0, e def search_embedding( diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 62a7971bb..bd024175c 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -415,9 +415,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: - log.warning( - f"Failed to insert data into pgvector table ({self.table_name}), error: {e}", - ) + log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}") return 0, e def search_embedding( diff --git a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py index 981accc2e..ca7d809b4 100644 --- a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +++ b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py @@ -255,9 +255,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: - log.warning( - f"Failed to insert data into pgvector table ({self.table_name}), error: {e}", - ) + log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}") return 0, e def search_embedding( diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index 0861e8938..a0d146a73 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -76,8 +76,8 @@ def optimize(self): continue if info.status == CollectionStatus.GREEN: msg = ( - f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, ", - f"Collection status: {info.indexed_vectors_count}", + f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, " + f"Collection status: {info.indexed_vectors_count}" ) log.info(msg) return diff --git a/vectordb_bench/backend/clients/test/cli.py b/vectordb_bench/backend/clients/test/cli.py index e5cd4c78b..2dcc4c407 100644 --- a/vectordb_bench/backend/clients/test/cli.py +++ b/vectordb_bench/backend/clients/test/cli.py @@ -17,7 +17,7 @@ class TestTypedDict(CommonTypedDict): ... @click_parameter_decorators_from_typed_dict(TestTypedDict) def Test(**parameters: Unpack[TestTypedDict]): run( - db=DB.NewClient, + db=DB.Test, db_config=TestConfig(db_label=parameters["db_label"]), db_case_config=TestIndexConfig(), **parameters, diff --git a/vectordb_bench/backend/data_source.py b/vectordb_bench/backend/data_source.py index b98dc7d7a..139d2e308 100644 --- a/vectordb_bench/backend/data_source.py +++ b/vectordb_bench/backend/data_source.py @@ -63,9 +63,7 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: # check size equal remote_size, local_size = info.content_length, local.stat().st_size if remote_size != local_size: - log.info( - f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]", - ) + log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") return False return True @@ -89,9 +87,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): local_file = local_ds_root.joinpath(file) if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)): - log.info( - f"local file: {local_file} not match with remote: {remote_file}; add to downloading list", - ) + log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") downloads.append((remote_file, local_file)) if len(downloads) == 0: @@ -135,9 +131,7 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): local_file = local_ds_root.joinpath(file) if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)): - log.info( - f"local file: {local_file} not match with remote: {remote_file}; add to downloading list", - ) + log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") downloads.append(remote_file) if len(downloads) == 0: @@ -157,9 +151,7 @@ def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: # check size equal remote_size, local_size = info.get("size"), local.stat().st_size if remote_size != local_size: - log.info( - f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]", - ) + log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") return False return True diff --git a/vectordb_bench/backend/runner/mp_runner.py b/vectordb_bench/backend/runner/mp_runner.py index 5b69b5481..687a0ecd7 100644 --- a/vectordb_bench/backend/runner/mp_runner.py +++ b/vectordb_bench/backend/runner/mp_runner.py @@ -79,14 +79,14 @@ def search( if count % 500 == 0: log.debug( - f"({mp.current_process().name:16}) ", - f"search_count: {count}, latest_latency={time.perf_counter()-s}", + f"({mp.current_process().name:16}) " + f"search_count: {count}, latest_latency={time.perf_counter()-s}" ) total_dur = round(time.perf_counter() - start_time, 4) log.info( f"{mp.current_process().name:16} search {self.duration}s: " - f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}", + f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}" ) return (count, total_dur, latencies) @@ -94,9 +94,7 @@ def search( @staticmethod def get_mp_context(): mp_start_method = "spawn" - log.debug( - f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}", - ) + log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}") return mp.get_context(mp_start_method) def _run_all_concurrencies_mem_efficient(self): @@ -113,9 +111,7 @@ def _run_all_concurrencies_mem_efficient(self): mp_context=self.get_mp_context(), max_workers=conc, ) as executor: - log.info( - f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}", - ) + log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}") future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)] # Sync all processes while q.qsize() < conc: @@ -124,9 +120,7 @@ def _run_all_concurrencies_mem_efficient(self): with cond: cond.notify_all() - log.info( - f"Syncing all process and start concurrency search, concurrency={conc}", - ) + log.info(f"Syncing all process and start concurrency search, concurrency={conc}") start = time.perf_counter() all_count = sum([r.result()[0] for r in future_iter]) @@ -140,18 +134,14 @@ def _run_all_concurrencies_mem_efficient(self): conc_qps_list.append(qps) conc_latency_p99_list.append(latency_p99) conc_latency_avg_list.append(latency_avg) - log.info( - f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}", - ) + log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}") if qps > max_qps: max_qps = qps - log.info( - f"Update largest qps with concurrency {conc}: current max_qps={max_qps}", - ) + log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}") except Exception as e: log.warning( - f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}", + f"Fail to search, concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}" ) traceback.print_exc() @@ -193,9 +183,7 @@ def _run_by_dur(self, duration: int) -> float: mp_context=self.get_mp_context(), max_workers=conc, ) as executor: - log.info( - f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}", - ) + log.info(f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}") future_iter = [ executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc) ] @@ -206,24 +194,18 @@ def _run_by_dur(self, duration: int) -> float: with cond: cond.notify_all() - log.info( - f"Syncing all process and start concurrency search, concurrency={conc}", - ) + log.info(f"Syncing all process and start concurrency search, concurrency={conc}") start = time.perf_counter() all_count = sum([r.result() for r in future_iter]) cost = time.perf_counter() - start qps = round(all_count / cost, 4) - log.info( - f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}", - ) + log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}") if qps > max_qps: max_qps = qps - log.info( - f"Update largest qps with concurrency {conc}: current max_qps={max_qps}", - ) + log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}") except Exception as e: log.warning( f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}", @@ -275,14 +257,14 @@ def search_by_dur( if count % 500 == 0: log.debug( - f"({mp.current_process().name:16}) search_count: {count}, ", - f"latest_latency={time.perf_counter()-s}", + f"({mp.current_process().name:16}) search_count: {count}, " + f"latest_latency={time.perf_counter()-s}" ) total_dur = round(time.perf_counter() - start_time, 4) log.debug( f"{mp.current_process().name:16} search {self.duration}s: " - f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}", + f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}" ) return count diff --git a/vectordb_bench/backend/runner/rate_runner.py b/vectordb_bench/backend/runner/rate_runner.py index 0145af4ce..4b32bcd9f 100644 --- a/vectordb_bench/backend/runner/rate_runner.py +++ b/vectordb_bench/backend/runner/rate_runner.py @@ -73,14 +73,14 @@ def submit_by_rate() -> bool: if len(not_done) > 0: log.warning( - f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] ", - f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round", + f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] " + f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round" ) executing_futures = list(not_done) else: log.debug( - f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} ", - f"task in 1s, wait_interval={wait_interval:.2f}", + f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} " + f"task in 1s, wait_interval={wait_interval:.2f}" ) executing_futures = [] except Exception as e: diff --git a/vectordb_bench/backend/runner/read_write_runner.py b/vectordb_bench/backend/runner/read_write_runner.py index e916f45d6..d7584459a 100644 --- a/vectordb_bench/backend/runner/read_write_runner.py +++ b/vectordb_bench/backend/runner/read_write_runner.py @@ -45,8 +45,8 @@ def __init__( self.read_dur_after_write = read_dur_after_write log.info( - f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, ", - f"stage_search_dur={read_dur_after_write}", + f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, " + f"stage_search_dur={read_dur_after_write}" ) test_emb = np.stack(dataset.test_data["emb"]) @@ -88,12 +88,10 @@ def run_search(self): res, ssearch_dur = self.serial_search_runner.run() recall, ndcg, p99_latency = res log.info( - f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, ", + f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, " f"dur={ssearch_dur:.4f}", ) - log.info( - f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}", - ) + log.info(f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}") max_qps = self.run_by_dur(self.read_dur_after_write) log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}") @@ -157,9 +155,7 @@ def wait_next_target(start: int, target_batch: int) -> bool: got = wait_next_target(start_batch, target_batch) if got is False: - log.warning( - f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}", - ) + log.warning(f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}") return None log.info(f"Insert {perc}% done, total batch={total_batch}") @@ -167,8 +163,8 @@ def wait_next_target(start: int, target_batch: int) -> bool: res, ssearch_dur = self.serial_search_runner.run() recall, ndcg, p99_latency = res log.info( - f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ", - f"ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}", + f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, " + f"ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}" ) # Search duration for non-last search stage is carefully calculated. @@ -183,8 +179,8 @@ def wait_next_target(start: int, target_batch: int) -> bool: each_conc_search_dur = csearch_dur / len(self.concurrencies) if each_conc_search_dur < 30: warning_msg = ( - f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, ", - f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}.", + f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, " + f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}." ) log.warning(warning_msg) @@ -193,7 +189,7 @@ def wait_next_target(start: int, target_batch: int) -> bool: each_conc_search_dur = 60 log.info( - f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}", + f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}" ) max_qps = self.run_by_dur(each_conc_search_dur) result.append((perc, max_qps, recall, ndcg, p99_latency)) diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 7eb59432b..08d42e14c 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -40,9 +40,7 @@ def __init__( def task(self) -> int: count = 0 with self.db.init(): - log.info( - f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}", - ) + log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}") start = time.perf_counter() for data_df in self.dataset: all_metadata = data_df["id"].tolist() @@ -66,13 +64,11 @@ def task(self) -> int: assert insert_count == len(all_metadata) count += insert_count if count % 100_000 == 0: - log.info( - f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB", - ) + log.info(f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB") log.info( - f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, ", - f"dur={time.perf_counter()-start}", + f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, " + f"dur={time.perf_counter()-start}" ) return count @@ -83,8 +79,8 @@ def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: num_batches = math.ceil(len(all_embeddings) / NUM_PER_BATCH) log.info( - f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} ", - f"embeddings in batch {NUM_PER_BATCH}", + f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} " + f"embeddings in batch {NUM_PER_BATCH}" ) count = 0 for batch_id in range(num_batches): @@ -94,8 +90,8 @@ def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: embeddings = all_embeddings[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH] log.debug( - f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ", - f"Start inserting {len(metadata)} embeddings", + f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], " + f"Start inserting {len(metadata)} embeddings" ) while retry_count < LOAD_MAX_TRY_COUNT: insert_count, error = self.db.insert_embeddings( @@ -113,15 +109,15 @@ def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: else: break log.debug( - f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ", - f"Finish inserting {len(metadata)} embeddings", + f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], " + f"Finish inserting {len(metadata)} embeddings" ) assert already_insert_count == len(metadata) count += already_insert_count log.info( - f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in ", - f"batch {NUM_PER_BATCH}", + f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in " + f"batch {NUM_PER_BATCH}" ) return count @@ -171,13 +167,13 @@ def run_endlessness(self) -> int: max_load_count += count times += 1 log.info( - f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, ", - f"{max_load_count}", + f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, " + f"{max_load_count}" ) except Exception as e: log.info( - f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, ", - f"{max_load_count}, err={e}", + f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, " + f"{max_load_count}, err={e}" ) traceback.print_exc() return max_load_count @@ -209,9 +205,7 @@ def __init__( self.ground_truth = ground_truth def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: - log.info( - f"{mp.current_process().name:14} start search the entire test_data to get recall and latency", - ) + log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency") with self.db.init(): test_data, ground_truth = args ideal_dcg = get_ideal_dcg(self.k) @@ -242,8 +236,8 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: if len(latencies) % 100 == 0: log.debug( - f"({mp.current_process().name:14}) search_count={len(latencies):3}, ", - f"latest_latency={latencies[-1]}, latest recall={recalls[-1]}", + f"({mp.current_process().name:14}) search_count={len(latencies):3}, " + f"latest_latency={latencies[-1]}, latest recall={recalls[-1]}" ) avg_latency = round(np.mean(latencies), 4) @@ -258,7 +252,7 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: f"avg_recall={avg_recall}, " f"avg_ndcg={avg_ndcg}," f"avg_latency={avg_latency}, " - f"p99={p99}", + f"p99={p99}" ) return (avg_recall, avg_ndcg, p99) diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index e24d74f03..e8be9f07d 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -98,9 +98,7 @@ def _pre_run(self, drop_old: bool = True): self.init_db(drop_old) self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate) except ModuleNotFoundError as e: - log.warning( - f"pre run case error: please install client for db: {self.config.db}, error={e}", - ) + log.warning(f"pre run case error: please install client for db: {self.config.db}, error={e}") raise e from None def run(self, drop_old: bool = True) -> Metric: @@ -136,9 +134,7 @@ def _run_capacity_case(self) -> Metric: log.warning(f"Failed to run capacity case, reason = {e}") raise e from None else: - log.info( - f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}", - ) + log.info(f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}") return Metric(max_load_count=count) def _run_perf_case(self, drop_old: bool = True) -> Metric: @@ -147,22 +143,6 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: Returns: Metric: load_duration, recall, serial_latency_p99, and, qps """ - """ - if drop_old: - _, load_dur = self._load_train_data() - build_dur = self._optimize() - m.load_duration = round(load_dur+build_dur, 4) - log.info( - f"Finish loading the entire dataset into VectorDB," - f" insert_duration={load_dur}, optimize_duration={build_dur}" - f" load_duration(insert + optimize) = {m.load_duration}" - ) - - self._init_search_runner() - - m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = self._conc_search() - m.recall, m.serial_latency_p99 = self._serial_search() - """ log.info("Start performance case") try: @@ -175,7 +155,7 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: log.info( f"Finish loading the entire dataset into VectorDB," f" insert_duration={load_dur}, optimize_duration={build_dur}" - f" load_duration(insert + optimize) = {m.load_duration}", + f" load_duration(insert + optimize) = {m.load_duration}" ) else: log.info("Data loading skipped") diff --git a/vectordb_bench/interface.py b/vectordb_bench/interface.py index ebe12d2e6..2e573fdc7 100644 --- a/vectordb_bench/interface.py +++ b/vectordb_bench/interface.py @@ -65,9 +65,7 @@ def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool: log.warning("Empty tasks submitted") return False - log.debug( - f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}", - ) + log.debug(f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}") # Generate run_id run_id = uuid.uuid4().hex @@ -169,14 +167,13 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non drop_old = TaskStage.DROP_OLD in runner.config.stages if (latest_runner and runner == latest_runner) or not self.drop_old: drop_old = False + num_cases = running_task.num_cases() try: - log.info( - f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}", - ) + log.info(f"[{idx+1}/{num_cases}] start case: {runner.display()}, drop_old={drop_old}") case_res.metrics = runner.run(drop_old) log.info( - f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, " - f"result={case_res.metrics}, label={case_res.label}", + f"[{idx+1}/{num_cases}] finish case: {runner.display()}, " + f"result={case_res.metrics}, label={case_res.label}" ) # cache the latest succeeded runner @@ -189,16 +186,12 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non if not drop_old: case_res.metrics.load_duration = cached_load_duration if cached_load_duration else 0.0 except (LoadTimeoutError, PerformanceTimeoutError) as e: - log.warning( - f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}", - ) + log.warning(f"[{idx+1}/{num_cases}] case {runner.display()} failed to run, reason={e}") case_res.label = ResultLabel.OUTOFRANGE continue except Exception as e: - log.warning( - f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}", - ) + log.warning(f"[{idx+1}/{num_cases}] case {runner.display()} failed to run, reason={e}") traceback.print_exc() case_res.label = ResultLabel.FAILED continue @@ -217,9 +210,7 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non send_conn.send((SIGNAL.SUCCESS, None)) send_conn.close() - log.info( - f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}", - ) + log.info(f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}") except Exception as e: err_msg = ( @@ -250,7 +241,7 @@ def _clear_running_task(self): def _run_async(self, conn: Connection) -> bool: log.info( f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, " - f"case number: {len(self.running_task.case_runners)}", + f"case number: {len(self.running_task.case_runners)}" ) global global_result_future executor = concurrent.futures.ProcessPoolExecutor( From 0095bd781bcc9f4f804e3575f5e2d771b186e944 Mon Sep 17 00:00:00 2001 From: yangxuan Date: Mon, 13 Jan 2025 19:45:51 +0800 Subject: [PATCH 02/36] enhance: Unify optimize and remove ready_to_load PyMilvus used to be the only client that uses ready_to_load. Not it'll load the collection when creating it, so this PR removes `ready_to_load` from the client.API Also this PR enhance optimize and remove the optimize_with_size Signed-off-by: yangxuan --- .../aliyun_opensearch/aliyun_opensearch.py | 8 +------ .../backend/clients/alloydb/alloydb.py | 5 +--- vectordb_bench/backend/clients/api.py | 23 +++++++------------ .../clients/aws_opensearch/aws_opensearch.py | 11 ++++----- .../backend/clients/chroma/chroma.py | 5 +--- .../clients/elastic_cloud/elastic_cloud.py | 5 +--- .../backend/clients/memorydb/memorydb.py | 7 ++---- .../backend/clients/milvus/milvus.py | 21 +---------------- .../backend/clients/pgdiskann/pgdiskann.py | 5 +--- .../backend/clients/pgvecto_rs/pgvecto_rs.py | 5 +--- .../backend/clients/pgvector/pgvector.py | 5 +--- .../clients/pgvectorscale/pgvectorscale.py | 5 +--- .../backend/clients/pinecone/pinecone.py | 5 +--- .../clients/qdrant_cloud/qdrant_cloud.py | 5 +--- vectordb_bench/backend/clients/redis/redis.py | 5 +--- vectordb_bench/backend/clients/test/test.py | 5 +--- .../clients/weaviate_cloud/weaviate_cloud.py | 5 +--- .../backend/runner/read_write_runner.py | 2 +- .../backend/runner/serial_runner.py | 4 +--- vectordb_bench/backend/task_runner.py | 6 ++--- 20 files changed, 33 insertions(+), 109 deletions(-) diff --git a/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py b/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py index 00227cfff..324871934 100644 --- a/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +++ b/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py @@ -325,10 +325,7 @@ def need_normalize_cosine(self) -> bool: return False - def optimize(self): - pass - - def optimize_with_size(self, data_size: int): + def optimize(self, data_size: int): log.info(f"optimize count: {data_size}") retry_times = 0 while True: @@ -340,6 +337,3 @@ def optimize_with_size(self, data_size: int): if total_count == data_size: log.info("optimize table finish.") return - - def ready_to_load(self): - """ready_to_load will be called before load in load cases.""" diff --git a/vectordb_bench/backend/clients/alloydb/alloydb.py b/vectordb_bench/backend/clients/alloydb/alloydb.py index c81f77675..b9808ce54 100644 --- a/vectordb_bench/backend/clients/alloydb/alloydb.py +++ b/vectordb_bench/backend/clients/alloydb/alloydb.py @@ -149,10 +149,7 @@ def _drop_table(self): ) self.conn.commit() - def ready_to_load(self): - pass - - def optimize(self): + def optimize(self, data_size: int | None = None): self._post_insert() def _post_insert(self): diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index aa93abc12..a86849e96 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -137,6 +137,13 @@ def __init__( @contextmanager def init(self) -> None: """create and destory connections to database. + Why contextmanager: + + In multiprocessing search tasks, vectordbbench might init + totally hundreds of thousands of connections with DB server. + + Too many connections may drain local FDs or server connection resources. + If the DB client doesn't have `close()` method, just set the object to None. Examples: >>> with self.init(): @@ -187,9 +194,8 @@ def search_embedding( """ raise NotImplementedError - # TODO: remove @abstractmethod - def optimize(self): + def optimize(self, data_size: int | None = None): """optimize will be called between insertion and search in performance cases. Should be blocked until the vectorDB is ready to be tested on @@ -199,16 +205,3 @@ def optimize(self): Optimize's execution time is limited, the limited time is based on cases. """ raise NotImplementedError - - def optimize_with_size(self, data_size: int): - self.optimize() - - # TODO: remove - @abstractmethod - def ready_to_load(self): - """ready_to_load will be called before load in load cases. - - Should be blocked until the vectorDB is ready to be tested on - heavy load cases. - """ - raise NotImplementedError diff --git a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py index 487ec67cc..234014f19 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +++ b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py @@ -145,15 +145,15 @@ def search_embedding( docvalue_fields=[self.id_col_name], stored_fields="_none_", ) - log.info(f'Search took: {resp["took"]}') - log.info(f'Search shards: {resp["_shards"]}') - log.info(f'Search hits total: {resp["hits"]["total"]}') + log.info(f"Search took: {resp['took']}") + log.info(f"Search shards: {resp['_shards']}") + log.info(f"Search hits total: {resp['hits']['total']}") return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]] except Exception as e: log.warning(f"Failed to search: {self.index_name} error: {e!s}") raise e from None - def optimize(self): + def optimize(self, data_size: int | None = None): """optimize will be called between insertion and search in performance cases.""" # Call refresh first to ensure that all segments are created self._refresh_index() @@ -194,6 +194,3 @@ def _load_graphs_to_memory(self): log.info("Calling warmup API to load graphs into memory") warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}" self.client.transport.perform_request("GET", warmup_endpoint) - - def ready_to_load(self): - """ready_to_load will be called before load in load cases.""" diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index a148fa141..76c810263 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -57,10 +57,7 @@ def init(self) -> None: def ready_to_search(self) -> bool: pass - def ready_to_load(self) -> bool: - pass - - def optimize(self) -> None: + def optimize(self, data_size: int | None = None): pass def insert_embeddings( diff --git a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py index a3183bcb7..ea038c587 100644 --- a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +++ b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py @@ -143,7 +143,7 @@ def search_embedding( log.warning(f"Failed to search: {self.indice} error: {e!s}") raise e from None - def optimize(self): + def optimize(self, data_size: int | None = None): """optimize will be called between insertion and search in performance cases.""" assert self.client is not None, "should self.init() first" self.client.indices.refresh(index=self.indice) @@ -158,6 +158,3 @@ def optimize(self): task_status = self.client.tasks.get(task_id=force_merge_task_id) if task_status["completed"]: return - - def ready_to_load(self): - """ready_to_load will be called before load in load cases.""" diff --git a/vectordb_bench/backend/clients/memorydb/memorydb.py b/vectordb_bench/backend/clients/memorydb/memorydb.py index d05e30be1..9d077f5df 100644 --- a/vectordb_bench/backend/clients/memorydb/memorydb.py +++ b/vectordb_bench/backend/clients/memorydb/memorydb.py @@ -157,17 +157,14 @@ def init(self) -> Generator[None, None, None]: self.conn = self.get_client() search_param = self.case_config.search_param() if search_param["ef_runtime"]: - self.ef_runtime_str = f'EF_RUNTIME {search_param["ef_runtime"]}' + self.ef_runtime_str = f"EF_RUNTIME {search_param['ef_runtime']}" else: self.ef_runtime_str = "" yield self.conn.close() self.conn = None - def ready_to_load(self) -> bool: - pass - - def optimize(self) -> None: + def optimize(self, data_size: int | None = None): self._post_insert() def insert_embeddings( diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 45fe7269b..4015eb1f3 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -138,26 +138,7 @@ def wait_index(): log.warning(f"{self.name} optimize error: {e}") raise e from None - def ready_to_load(self): - assert self.col, "Please call self.init() before" - self._pre_load(self.col) - - def _pre_load(self, coll: Collection): - try: - if not coll.has_index(index_name=self._index_name): - log.info(f"{self.name} create index") - coll.create_index( - self._vector_field, - self.case_config.index_param(), - index_name=self._index_name, - ) - coll.load() - log.info(f"{self.name} load") - except Exception as e: - log.warning(f"{self.name} pre load error: {e}") - raise e from None - - def optimize(self): + def optimize(self, data_size: int | None = None): assert self.col, "Please call self.init() before" self._optimize() diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index c21972902..8bede0f01 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -143,10 +143,7 @@ def _drop_table(self): ) self.conn.commit() - def ready_to_load(self): - pass - - def optimize(self): + def optimize(self, data_size: int | None = None): self._post_insert() def _post_insert(self): diff --git a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py index 64e95a1be..3006b861a 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py @@ -153,10 +153,7 @@ def _drop_table(self): ) self.conn.commit() - def ready_to_load(self): - pass - - def optimize(self): + def optimize(self, data_size: int | None = None): self._post_insert() def _post_insert(self): diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index bd024175c..4164461fb 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -228,10 +228,7 @@ def _drop_table(self): ) self.conn.commit() - def ready_to_load(self): - pass - - def optimize(self): + def optimize(self, data_size: int | None = None): self._post_insert() def _post_insert(self): diff --git a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py index ca7d809b4..3985c0716 100644 --- a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +++ b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py @@ -143,10 +143,7 @@ def _drop_table(self): ) self.conn.commit() - def ready_to_load(self): - pass - - def optimize(self): + def optimize(self, data_size: int | None = None): self._post_insert() def _post_insert(self): diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index c59ee8760..1a681b33f 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -59,10 +59,7 @@ def init(self): self.index = pc.Index(self.index_name) yield - def ready_to_load(self): - pass - - def optimize(self): + def optimize(self, data_size: int | None = None): pass def insert_embeddings( diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index a0d146a73..5de72798b 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -62,10 +62,7 @@ def init(self) -> None: self.qdrant_client = None del self.qdrant_client - def ready_to_load(self): - pass - - def optimize(self): + def optimize(self, data_size: int | None = None): assert self.qdrant_client, "Please call self.init() before" # wait for vectors to be fully indexed try: diff --git a/vectordb_bench/backend/clients/redis/redis.py b/vectordb_bench/backend/clients/redis/redis.py index 139850d2f..ef0aad9aa 100644 --- a/vectordb_bench/backend/clients/redis/redis.py +++ b/vectordb_bench/backend/clients/redis/redis.py @@ -95,10 +95,7 @@ def init(self) -> None: def ready_to_search(self) -> bool: """Check if the database is ready to search.""" - def ready_to_load(self) -> bool: - pass - - def optimize(self) -> None: + def optimize(self, data_size: int | None = None): pass def insert_embeddings( diff --git a/vectordb_bench/backend/clients/test/test.py b/vectordb_bench/backend/clients/test/test.py index ee5a523f3..d2bcb74b5 100644 --- a/vectordb_bench/backend/clients/test/test.py +++ b/vectordb_bench/backend/clients/test/test.py @@ -33,10 +33,7 @@ def init(self) -> Generator[None, None, None]: yield - def ready_to_load(self) -> bool: - return True - - def optimize(self) -> None: + def optimize(self, data_size: int | None = None): pass def insert_embeddings( diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py index b42f70af1..aa4368bb7 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py @@ -67,10 +67,7 @@ def init(self) -> None: self.client = None del self.client - def ready_to_load(self): - """Should call insert first, do nothing""" - - def optimize(self): + def optimize(self, data_size: int | None = None): assert self.client.schema.exists(self.collection_name) self.client.schema.update_config( self.collection_name, diff --git a/vectordb_bench/backend/runner/read_write_runner.py b/vectordb_bench/backend/runner/read_write_runner.py index d7584459a..eaba51f5f 100644 --- a/vectordb_bench/backend/runner/read_write_runner.py +++ b/vectordb_bench/backend/runner/read_write_runner.py @@ -80,7 +80,7 @@ def run_optimize(self): """Optimize needs to run in differenct process for pymilvus schema recursion problem""" with self.db.init(): log.info("Search after write - Optimize start") - self.db.optimize() + self.db.optimize(data_size=self.data_volume) log.info("Search after write - Optimize finished") def run_search(self): diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 08d42e14c..365641132 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -68,7 +68,7 @@ def task(self) -> int: log.info( f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, " - f"dur={time.perf_counter()-start}" + f"dur={time.perf_counter() - start}" ) return count @@ -156,8 +156,6 @@ def run_endlessness(self) -> int: start_time = time.perf_counter() max_load_count, times = 0, 0 try: - with self.db.init(): - self.db.ready_to_load() while time.perf_counter() - start_time < self.timeout: count = self.endless_insert_data( all_embeddings, diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index e8be9f07d..2a583b4f5 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -234,13 +234,13 @@ def _conc_search(self): self.stop() @utils.time_it - def _task(self) -> None: + def _optimize_task(self) -> None: with self.db.init(): - self.db.optimize_with_size(data_size=self.ca.dataset.data.size) + self.db.optimize(data_size=self.ca.dataset.data.size) def _optimize(self) -> float: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(self._task) + future = executor.submit(self._optimize_task) try: return future.result(timeout=self.ca.optimize_timeout)[1] except TimeoutError as e: From d9fc5e1243599b2554273fe0a2b712518abc6f5c Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Tue, 14 Jan 2025 10:34:43 +0800 Subject: [PATCH 03/36] add mongodb client Signed-off-by: zhuwenxing --- pyproject.toml | 2 +- vectordb_bench/backend/clients/__init__.py | 16 ++ .../backend/clients/mongodb/config.py | 44 ++++ .../backend/clients/mongodb/mongodb.py | 201 ++++++++++++++++++ 4 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 vectordb_bench/backend/clients/mongodb/config.py create mode 100644 vectordb_bench/backend/clients/mongodb/mongodb.py diff --git a/pyproject.toml b/pyproject.toml index 6259bcea6..aafe70750 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ memorydb = [ "memorydb" ] chromadb = [ "chromadb" ] opensearch = [ "opensearch-py" ] aliyun_opensearch = [ "alibabacloud_ha3engine_vector", "alibabacloud_searchengine20211025"] +mongodb = [ "pymongo" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" @@ -207,4 +208,3 @@ builtins-ignorelist = [ # "dict", # TODO # "filter", ] - diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index e796aa069..f2f480dcf 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -40,6 +40,7 @@ class DB(Enum): AliyunElasticsearch = "AliyunElasticsearch" Test = "test" AliyunOpenSearch = "AliyunOpenSearch" + MongoDB = "MongoDB" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 @@ -129,6 +130,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return AliyunOpenSearch + if self == DB.MongoDB: + from .mongodb.mongodb import MongoDB + + return MongoDB + if self == DB.Test: from .test.test import Test @@ -225,6 +231,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 return AliyunOpenSearchConfig + if self == DB.MongoDB: + from .mongodb.config import MongoDBConfig + + return MongoDBConfig + if self == DB.Test: from .test.config import TestConfig @@ -302,6 +313,11 @@ def case_config_cls( # noqa: PLR0911 return AliyunOpenSearchIndexConfig + if self == DB.MongoDB: + from .mongodb.config import MongoDBIndexConfig + + return MongoDBIndexConfig + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/mongodb/config.py b/vectordb_bench/backend/clients/mongodb/config.py new file mode 100644 index 000000000..cc09471a4 --- /dev/null +++ b/vectordb_bench/backend/clients/mongodb/config.py @@ -0,0 +1,44 @@ +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class MongoDBConfig(DBConfig, BaseModel): + connection_string: SecretStr = "mongodb+srv://:@.heatl.mongodb.net" + database: str = "vdb_bench" + + def to_dict(self) -> dict: + return { + "connection_string": self.connection_string.get_secret_value(), + "database": self.database, + } + + +class MongoDBIndexConfig(BaseModel, DBCaseConfig): + index: IndexType = IndexType.HNSW # MongoDB uses HNSW for vector search + metric_type: MetricType | None = None + num_candidates: int | None = 1500 # Default numCandidates for vector search + exact_search: bool = False # Whether to use exact (ENN) search + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + if self.metric_type == MetricType.IP: + return "dotProduct" + return "cosine" # Default to cosine similarity + + def index_param(self) -> dict: + return { + "type": "vectorSearch", + "fields": [ + { + "type": "vector", + "similarity": self.parse_metric(), + "numDimensions": None, # Will be set in MongoDB class + "path": "vector", # Vector field name + } + ], + } + + def search_param(self) -> dict: + return {"numCandidates": self.num_candidates if not self.exact_search else None, "exact": self.exact_search} diff --git a/vectordb_bench/backend/clients/mongodb/mongodb.py b/vectordb_bench/backend/clients/mongodb/mongodb.py new file mode 100644 index 000000000..dddcc9a4c --- /dev/null +++ b/vectordb_bench/backend/clients/mongodb/mongodb.py @@ -0,0 +1,201 @@ +import logging +import time +from contextlib import contextmanager + +from pymongo import MongoClient +from pymongo.operations import SearchIndexModel + +from ..api import VectorDB +from .config import MongoDBIndexConfig + +log = logging.getLogger(__name__) + + +class MongoDBError(Exception): + """Custom exception class for MongoDB client errors.""" + + +class MongoDB(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: MongoDBIndexConfig, + collection_name: str = "vdb_bench_collection", + id_field: str = "id", + vector_field: str = "vector", + drop_old: bool = False, + **kwargs, + ): + self.dim = dim + self.db_config = db_config + self.case_config = db_case_config + self.collection_name = collection_name + self.id_field = id_field + self.vector_field = vector_field + self.drop_old = drop_old + + # Update index dimensions + index_params = self.case_config.index_param() + log.info(f"index params: {index_params}") + index_params["fields"][0]["numDimensions"] = dim + self.index_params = index_params + + # Initialize - they'll also be set in init() + uri = self.db_config["connection_string"] + self.client = MongoClient(uri) + self.db = self.client[self.db_config["database"]] + self.collection = self.db[self.collection_name] + if self.drop_old and self.collection_name in self.db.list_collection_names(): + log.info(f"MongoDB client dropping old collection: {self.collection_name}") + self.db.drop_collection(self.collection_name) + self.client = None + self.db = None + self.collection = None + + @contextmanager + def init(self): + """Initialize MongoDB client and cleanup when done""" + try: + uri = self.db_config["connection_string"] + self.client = MongoClient(uri) + self.db = self.client[self.db_config["database"]] + self.collection = self.db[self.collection_name] + + yield + finally: + if self.client is not None: + self.client.close() + self.client = None + self.db = None + self.collection = None + + def _create_index(self) -> None: + """Create vector search index""" + index_name = "vector_index" + index_params = self.index_params + log.info(f"index params {index_params}") + # drop index if already exists + if self.collection.list_indexes(): + all_indexes = self.collection.list_search_indexes() + if any(idx.get("name") == index_name for idx in all_indexes): + log.info(f"Drop index: {index_name}") + try: + self.collection.drop_search_index(index_name) + while True: + indices = list(self.collection.list_search_indexes()) + indices = [idx for idx in indices if idx["name"] == index_name] + log.debug(f"index status {indices}") + if len(indices) == 0: + break + log.info(f"index deleting {indices}") + except Exception: + log.exception("Error dropping index") + try: + # Create vector search index + search_index = SearchIndexModel(definition=index_params, name=index_name, type="vectorSearch") + + self.collection.create_search_index(search_index) + log.info(f"Created vector search index: {index_name}") + self._wait_for_index_ready(index_name) + + # Create regular index on id field for faster lookups + self.collection.create_index(self.id_field) + log.info(f"Created index on {self.id_field} field") + + except Exception: + log.exception("Error creating index") + raise + + def _wait_for_index_ready(self, index_name: str, check_interval: int = 5) -> None: + """Wait for index to be ready""" + while True: + indices = list(self.collection.list_search_indexes()) + log.debug(f"index status {indices}") + if indices and any(idx.get("name") == index_name and idx.get("queryable") for idx in indices): + break + for idx in indices: + if idx.get("name") == index_name and idx.get("status") == "FAILED": + error_msg = f"Index {index_name} failed to build" + raise MongoDBError(error_msg) + + time.sleep(check_interval) + log.info(f"Index {index_name} is ready") + + def need_normalize_cosine(self) -> bool: + return False + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> (int, Exception | None): + """Insert embeddings into MongoDB""" + + # Prepare documents in bulk + documents = [ + { + self.id_field: id_, + self.vector_field: embedding, + } + for id_, embedding in zip(metadata, embeddings, strict=False) + ] + + # Use ordered=False for better insert performance + try: + self.collection.insert_many(documents, ordered=False) + except Exception as e: + return 0, e + return len(documents), None + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + **kwargs, + ) -> list[int]: + """Search for similar vectors""" + search_params = self.case_config.search_param() + + vector_search = {"queryVector": query, "index": "vector_index", "path": self.vector_field, "limit": k} + + # Add exact search parameter if specified + if search_params["exact"]: + vector_search["exact"] = True + else: + # Set numCandidates based on k value and data size + # For 50K dataset, use higher multiplier for better recall + num_candidates = min(10000, max(k * 20, search_params["numCandidates"] or 0)) + vector_search["numCandidates"] = num_candidates + + # Add filter if specified + if filters: + log.info(f"Applying filter: {filters}") + vector_search["filter"] = { + "id": {"gt": filters["id"]}, + } + + pipeline = [ + {"$vectorSearch": vector_search}, + { + "$project": { + "_id": 0, + self.id_field: 1, + "score": {"$meta": "vectorSearchScore"}, # Include similarity score + } + }, + ] + + results = list(self.collection.aggregate(pipeline)) + return [doc[self.id_field] for doc in results] + + def optimize(self, data_size: int | None = None) -> None: + """MongoDB vector search indexes are self-optimizing""" + log.info("optimize for search") + self._create_index() + self._wait_for_index_ready("vector_index") + + def ready_to_load(self) -> None: + """MongoDB is always ready to load""" From 811564a1d1a829429e6bfd62c8796a002872f4e4 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Tue, 14 Jan 2025 11:54:31 +0800 Subject: [PATCH 04/36] add mongodb client in readme Signed-off-by: zhuwenxing --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 737fc6064..1bf95dc39 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ All the database client supported | chromadb | `pip install vectordb-bench[chromadb]` | | awsopensearch | `pip install vectordb-bench[opensearch]` | | aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` | +| mongodb | `pip install vectordb-bench[mongodb]` | ### Run From 4f21fcf3d61b6b334e99c83fb2b548ce69bc8db6 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Sun, 19 Jan 2025 17:05:50 +0800 Subject: [PATCH 05/36] add some risk warnings for custom dataset - limit the number of test query vectors. Signed-off-by: min.tian --- README.md | 7 +++++++ .../frontend/components/custom/displaypPrams.py | 13 ++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1bf95dc39..56ae88753 100644 --- a/README.md +++ b/README.md @@ -319,6 +319,13 @@ We have strict requirements for the data set format, please follow them. - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format. - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. + - We recommend limiting the number of test query vectors, like 1,000. + When conducting concurrent query tests, Vdbbench creates a large number of processes. + To minimize additional communication overhead during testing, + we prepare a complete set of test queries for each process, allowing them to run independently. + However, this means that as the number of concurrent processes increases, + the number of copied query vectors also increases significantly, + which can place substantial pressure on memory resources. - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files. diff --git a/vectordb_bench/frontend/components/custom/displaypPrams.py b/vectordb_bench/frontend/components/custom/displaypPrams.py index b677e5909..712b57b00 100644 --- a/vectordb_bench/frontend/components/custom/displaypPrams.py +++ b/vectordb_bench/frontend/components/custom/displaypPrams.py @@ -3,7 +3,7 @@ def displayParams(st): """ - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format. - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. - - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. + - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files. @@ -11,3 +11,14 @@ def displayParams(st): - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order. """ ) + st.caption( + """We recommend limiting the number of test query vectors, like 1,000.""", + help=""" +When conducting concurrent query tests, Vdbbench creates a large number of processes. +To minimize additional communication overhead during testing, +we prepare a complete set of test queries for each process, allowing them to run independently.\n +However, this means that as the number of concurrent processes increases, +the number of copied query vectors also increases significantly, +which can place substantial pressure on memory resources. +""", + ) From 491ef6b109a0cc4e31c8573e115efc9edadd952e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 06:28:13 +0000 Subject: [PATCH 06/36] Bump grpcio from 1.53.0 to 1.53.2 in /install Bumps [grpcio](https://github.com/grpc/grpc) from 1.53.0 to 1.53.2. - [Release notes](https://github.com/grpc/grpc/releases) - [Changelog](https://github.com/grpc/grpc/blob/master/doc/grpc_release_schedule.md) - [Commits](https://github.com/grpc/grpc/compare/v1.53.0...v1.53.2) --- updated-dependencies: - dependency-name: grpcio dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- install/requirements_py3.11.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index c3a3bbbda..18138382f 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -1,4 +1,4 @@ -grpcio==1.53.0 +grpcio==1.53.2 grpcio-tools==1.53.0 qdrant-client pinecone-client From 5eeab7e8135f241306fbaeee0b1e1b75a23cdee7 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Tue, 14 Jan 2025 17:23:35 +0800 Subject: [PATCH 07/36] add mongodb config Signed-off-by: zhuwenxing --- .gitignore | 4 ++- install.py | 3 +- .../backend/clients/mongodb/config.py | 17 +++++++--- .../backend/clients/mongodb/mongodb.py | 9 +++--- .../frontend/config/dbCaseConfigs.py | 32 +++++++++++++++++++ vectordb_bench/log_util.py | 17 ++++++++-- vectordb_bench/models.py | 4 +++ 7 files changed, 73 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index f8602c11a..aa3a72f44 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,7 @@ __MACOSX .DS_Store build/ venv/ +.venv/ .idea/ -results/ \ No newline at end of file +results/ +logs/ \ No newline at end of file diff --git a/install.py b/install.py index f683a37b2..5807485fd 100644 --- a/install.py +++ b/install.py @@ -1,7 +1,8 @@ -import os import argparse +import os import subprocess + def docker_tag_base(): return 'vdbbench' diff --git a/vectordb_bench/backend/clients/mongodb/config.py b/vectordb_bench/backend/clients/mongodb/config.py index cc09471a4..a2d8ca57a 100644 --- a/vectordb_bench/backend/clients/mongodb/config.py +++ b/vectordb_bench/backend/clients/mongodb/config.py @@ -1,8 +1,16 @@ +from enum import Enum + from pydantic import BaseModel, SecretStr from ..api import DBCaseConfig, DBConfig, IndexType, MetricType +class QuantizationType(Enum): + NONE = "none" + BINARY = "binary" + SCALAR = "scalar" + + class MongoDBConfig(DBConfig, BaseModel): connection_string: SecretStr = "mongodb+srv://:@.heatl.mongodb.net" database: str = "vdb_bench" @@ -16,9 +24,9 @@ def to_dict(self) -> dict: class MongoDBIndexConfig(BaseModel, DBCaseConfig): index: IndexType = IndexType.HNSW # MongoDB uses HNSW for vector search - metric_type: MetricType | None = None - num_candidates: int | None = 1500 # Default numCandidates for vector search - exact_search: bool = False # Whether to use exact (ENN) search + metric_type: MetricType = MetricType.COSINE + num_candidates_ratio: int = 10 # Default numCandidates ratio for vector search + quantization: QuantizationType = QuantizationType.NONE # Quantization type if applicable def parse_metric(self) -> str: if self.metric_type == MetricType.L2: @@ -36,9 +44,10 @@ def index_param(self) -> dict: "similarity": self.parse_metric(), "numDimensions": None, # Will be set in MongoDB class "path": "vector", # Vector field name + "quantization": self.quantization.value, } ], } def search_param(self) -> dict: - return {"numCandidates": self.num_candidates if not self.exact_search else None, "exact": self.exact_search} + return {"num_candidates_ratio": self.num_candidates_ratio} diff --git a/vectordb_bench/backend/clients/mongodb/mongodb.py b/vectordb_bench/backend/clients/mongodb/mongodb.py index dddcc9a4c..0bbfd5d9c 100644 --- a/vectordb_bench/backend/clients/mongodb/mongodb.py +++ b/vectordb_bench/backend/clients/mongodb/mongodb.py @@ -90,7 +90,7 @@ def _create_index(self) -> None: break log.info(f"index deleting {indices}") except Exception: - log.exception("Error dropping index") + log.exception(f"Error dropping index {index_name}") try: # Create vector search index search_index = SearchIndexModel(definition=index_params, name=index_name, type="vectorSearch") @@ -104,7 +104,7 @@ def _create_index(self) -> None: log.info(f"Created index on {self.id_field} field") except Exception: - log.exception("Error creating index") + log.exception(f"Error creating index {index_name}") raise def _wait_for_index_ready(self, index_name: str, check_interval: int = 5) -> None: @@ -167,16 +167,15 @@ def search_embedding( else: # Set numCandidates based on k value and data size # For 50K dataset, use higher multiplier for better recall - num_candidates = min(10000, max(k * 20, search_params["numCandidates"] or 0)) + num_candidates = min(10000, k * search_params["num_candidates_ratio"]) vector_search["numCandidates"] = num_candidates # Add filter if specified if filters: log.info(f"Applying filter: {filters}") vector_search["filter"] = { - "id": {"gt": filters["id"]}, + "id": {"gte": filters["id"]}, } - pipeline = [ {"$vectorSearch": vector_search}, { diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index e004f2ba7..13858f879 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1041,6 +1041,26 @@ class CaseConfigInput(BaseModel): ) +CaseConfigParamInput_MongoDBQuantizationType = CaseConfigInput( + label=CaseConfigParamType.mongodb_quantization_type, + inputType=InputType.Option, + inputConfig={ + "options": ["none", "scalar", "binary"], + }, +) + + +CaseConfigParamInput_MongoDBNumCandidatesRatio = CaseConfigInput( + label=CaseConfigParamType.mongodb_num_candidates_ratio, + inputType=InputType.Number, + inputConfig={ + "min": 10, + "max": 20, + "value": 10, + }, +) + + MilvusLoadConfig = [ CaseConfigParamInput_IndexType, CaseConfigParamInput_M, @@ -1224,6 +1244,14 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_NumCandidates_AliES, ] +MongoDBLoadingConfig = [ + CaseConfigParamInput_MongoDBQuantizationType, +] +MongoDBPerformanceConfig = [ + CaseConfigParamInput_MongoDBQuantizationType, + CaseConfigParamInput_MongoDBNumCandidatesRatio, +] + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -1272,4 +1300,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: AliyunOpensearchLoadingConfig, CaseLabel.Performance: AliyunOpenSearchPerformanceConfig, }, + DB.MongoDB: { + CaseLabel.Load: MongoDBLoadingConfig, + CaseLabel.Performance: MongoDBPerformanceConfig, + }, } diff --git a/vectordb_bench/log_util.py b/vectordb_bench/log_util.py index d75688137..6ca6ccabf 100644 --- a/vectordb_bench/log_util.py +++ b/vectordb_bench/log_util.py @@ -1,8 +1,13 @@ import logging from logging import config +from pathlib import Path def init(log_level: str): + # Create logs directory if it doesn't exist + log_dir = Path("logs") + log_dir.mkdir(exist_ok=True) + log_config = { "version": 1, "disable_existing_loggers": False, @@ -24,15 +29,23 @@ def init(log_level: str): "class": "logging.StreamHandler", "formatter": "default", }, + "file": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "default", + "filename": "logs/vectordb_bench.log", + "maxBytes": 10485760, # 10MB + "backupCount": 5, + "encoding": "utf8", + }, }, "loggers": { "vectordb_bench": { - "handlers": ["console"], + "handlers": ["console", "file"], "level": log_level, "propagate": False, }, "no_color": { - "handlers": ["no_color_console"], + "handlers": ["no_color_console", "file"], "level": log_level, "propagate": False, }, diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 49bb04ae0..bf71ebb89 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -88,6 +88,10 @@ class CaseConfigParamType(Enum): numSearchThreads = "num_search_threads" maxNumPrefetchDatasets = "max_num_prefetch_datasets" + # mongodb params + mongodb_quantization_type = "quantization" + mongodb_num_candidates_ratio = "num_candidates_ratio" + class CustomizedCase(BaseModel): pass From 111048d924f5a5d39b452f19e4b77091acf2323a Mon Sep 17 00:00:00 2001 From: Xavierantony1982 Date: Thu, 30 Jan 2025 19:54:56 -0800 Subject: [PATCH 08/36] Opensearch interal configuration parameters (#463) * Added the configuration parameters to create Opensearch dynamically with right replicas, shards and other opensearch related configurations. Added the feature to create OS index with 0 replica and once the data is loaded update the replicas according to the parameter. * Updated the readme for config parameters --------- Co-authored-by: xavrathi --- README.md | 41 +++++++++ .../clients/aws_opensearch/aws_opensearch.py | 57 +++++++++++- .../backend/clients/aws_opensearch/cli.py | 86 ++++++++++++++++++- .../backend/clients/aws_opensearch/config.py | 10 +++ 4 files changed, 189 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 56ae88753..e8be99cec 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,47 @@ Options: with-gt] --help Show this message and exit. ``` + +### Run awsopensearch from command line + +```shell +vectordbbench awsopensearch --db-label awsopensearch \ +--m 16 --ef-construction 256 \ +--host search-vector-db-prod-h4f6m4of6x7yp2rz7gdmots7w4.us-west-2.es.amazonaws.com --port 443 \ +--user vector --password '' \ +--case-type Performance1536D5M --num-insert-workers 10 \ +--skip-load --num-concurrency 75 +``` + +To list the options for awsopensearch, execute `vectordbbench awsopensearch --help` + +```text +$ vectordbbench awsopensearch --help +Usage: vectordbbench awsopensearch [OPTIONS] + +Options: + # Sharding and Replication + --number-of-shards INTEGER Number of primary shards for the index + --number-of-replicas INTEGER Number of replica copies for each primary + shard + # Indexing Performance + --index-thread-qty INTEGER Thread count for native engine indexing + --index-thread-qty-during-force-merge INTEGER + Thread count during force merge operations + --number-of-indexing-clients INTEGER + Number of concurrent indexing clients + # Index Management + --number-of-segments INTEGER Target number of segments after merging + --refresh-interval TEXT How often to make new data available for + search + --force-merge-enabled BOOLEAN Whether to perform force merge operation + --flush-threshold-size TEXT Size threshold for flushing the transaction + log + # Memory Management + --cb-threshold TEXT k-NN Memory circuit breaker threshold + + --help Show this message and exit.``` + #### Using a configuration file. The vectordbbench command can optionally read some or all the options from a yaml formatted configuration file. diff --git a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py index 234014f19..adb766300 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +++ b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py @@ -12,6 +12,7 @@ WAITING_FOR_REFRESH_SEC = 30 WAITING_FOR_FORCE_MERGE_SEC = 30 +SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC = 30 class AWSOpenSearch(VectorDB): @@ -52,10 +53,27 @@ def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIn return AWSOpenSearchIndexConfig def _create_index(self, client: OpenSearch): + cluster_settings_body = { + "persistent": { + "knn.algo_param.index_thread_qty": self.case_config.index_thread_qty, + "knn.memory.circuit_breaker.limit": self.case_config.cb_threshold, + } + } + client.cluster.put_settings(cluster_settings_body) settings = { "index": { "knn": True, + "number_of_shards": self.case_config.number_of_shards, + "number_of_replicas": 0, + "translog.flush_threshold_size": self.case_config.flush_threshold_size, + # Setting trans log threshold to 5GB + **( + {"knn.algo_param.ef_search": self.case_config.ef_search} + if self.case_config.engine == AWSOS_Engine.nmslib + else {} + ), }, + "refresh_interval": self.case_config.refresh_interval, } mappings = { "properties": { @@ -145,9 +163,9 @@ def search_embedding( docvalue_fields=[self.id_col_name], stored_fields="_none_", ) - log.info(f"Search took: {resp['took']}") - log.info(f"Search shards: {resp['_shards']}") - log.info(f"Search hits total: {resp['hits']['total']}") + log.debug(f"Search took: {resp['took']}") + log.debug(f"Search shards: {resp['_shards']}") + log.debug(f"Search hits total: {resp['hits']['total']}") return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]] except Exception as e: log.warning(f"Failed to search: {self.index_name} error: {e!s}") @@ -157,12 +175,37 @@ def optimize(self, data_size: int | None = None): """optimize will be called between insertion and search in performance cases.""" # Call refresh first to ensure that all segments are created self._refresh_index() - self._do_force_merge() + if self.case_config.force_merge_enabled: + self._do_force_merge() + self._refresh_index() + self._update_replicas() # Call refresh again to ensure that the index is ready after force merge. self._refresh_index() # ensure that all graphs are loaded in memory and ready for search self._load_graphs_to_memory() + def _update_replicas(self): + index_settings = self.client.indices.get_settings(index=self.index_name) + current_number_of_replicas = int(index_settings[self.index_name]["settings"]["index"]["number_of_replicas"]) + log.info( + f"Current Number of replicas are {current_number_of_replicas}" + f" and changing the replicas to {self.case_config.number_of_replicas}" + ) + settings_body = {"index": {"number_of_replicas": self.case_config.number_of_replicas}} + self.client.indices.put_settings(index=self.index_name, body=settings_body) + self._wait_till_green() + + def _wait_till_green(self): + log.info("Wait for index to become green..") + while True: + res = self.client.cat.indices(index=self.index_name, h="health", format="json") + health = res[0]["health"] + if health != "green": + break + log.info(f"The index {self.index_name} has health : {health} and is not green. Retrying") + time.sleep(SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC) + log.info(f"Index {self.index_name} is green..") + def _refresh_index(self): log.debug(f"Starting refresh for index {self.index_name}") while True: @@ -179,6 +222,12 @@ def _refresh_index(self): log.debug(f"Completed refresh for index {self.index_name}") def _do_force_merge(self): + log.info(f"Updating the Index thread qty to {self.case_config.index_thread_qty_during_force_merge}.") + + cluster_settings_body = { + "persistent": {"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty_during_force_merge} + } + self.client.cluster.put_settings(cluster_settings_body) log.debug(f"Starting force merge for index {self.index_name}") force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false" force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"] diff --git a/vectordb_bench/backend/clients/aws_opensearch/cli.py b/vectordb_bench/backend/clients/aws_opensearch/cli.py index bb0c2450d..fa457154d 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/cli.py +++ b/vectordb_bench/backend/clients/aws_opensearch/cli.py @@ -18,6 +18,79 @@ class AWSOpenSearchTypedDict(TypedDict): port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")] user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")] password: Annotated[str, click.option("--password", type=str, help="Db password")] + number_of_shards: Annotated[ + int, + click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1), + ] + number_of_replicas: Annotated[ + int, + click.option( + "--number-of-replicas", type=int, help="Number of replica copies for each primary shard", default=1 + ), + ] + index_thread_qty: Annotated[ + int, + click.option( + "--index-thread-qty", + type=int, + help="Thread count for native engine indexing", + default=4, + ), + ] + + index_thread_qty_during_force_merge: Annotated[ + int, + click.option( + "--index-thread-qty-during-force-merge", + type=int, + help="Thread count during force merge operations", + default=4, + ), + ] + + number_of_indexing_clients: Annotated[ + int, + click.option( + "--number-of-indexing-clients", + type=int, + help="Number of concurrent indexing clients", + default=1, + ), + ] + + number_of_segments: Annotated[ + int, + click.option("--number-of-segments", type=int, help="Target number of segments after merging", default=1), + ] + + refresh_interval: Annotated[ + int, + click.option( + "--refresh-interval", type=str, help="How often to make new data available for search", default="60s" + ), + ] + + force_merge_enabled: Annotated[ + int, + click.option("--force-merge-enabled", type=bool, help="Whether to perform force merge operation", default=True), + ] + + flush_threshold_size: Annotated[ + int, + click.option( + "--flush-threshold-size", type=str, help="Size threshold for flushing the transaction log", default="5120mb" + ), + ] + + cb_threshold: Annotated[ + int, + click.option( + "--cb-threshold", + type=str, + help="k-NN Memory circuit breaker threshold", + default="50%", + ), + ] class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ... @@ -36,6 +109,17 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]): user=parameters["user"], password=SecretStr(parameters["password"]), ), - db_case_config=AWSOpenSearchIndexConfig(), + db_case_config=AWSOpenSearchIndexConfig( + number_of_shards=parameters["number_of_shards"], + number_of_replicas=parameters["number_of_replicas"], + index_thread_qty=parameters["index_thread_qty"], + number_of_segments=parameters["number_of_segments"], + refresh_interval=parameters["refresh_interval"], + force_merge_enabled=parameters["force_merge_enabled"], + flush_threshold_size=parameters["flush_threshold_size"], + number_of_indexing_clients=parameters["number_of_indexing_clients"], + index_thread_qty_during_force_merge=parameters["index_thread_qty_during_force_merge"], + cb_threshold=parameters["cb_threshold"], + ), **parameters, ) diff --git a/vectordb_bench/backend/clients/aws_opensearch/config.py b/vectordb_bench/backend/clients/aws_opensearch/config.py index e9ccc7277..dd51b266d 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/config.py +++ b/vectordb_bench/backend/clients/aws_opensearch/config.py @@ -39,6 +39,16 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig): efConstruction: int = 256 efSearch: int = 256 M: int = 16 + index_thread_qty: int | None = 4 + number_of_shards: int | None = 1 + number_of_replicas: int | None = 0 + number_of_segments: int | None = 1 + refresh_interval: str | None = "60s" + force_merge_enabled: bool | None = True + flush_threshold_size: str | None = "5120mb" + number_of_indexing_clients: int | None = 1 + index_thread_qty_during_force_merge: int + cb_threshold: str | None = "50%" def parse_metric(self) -> str: if self.metric_type == MetricType.IP: From 075651619a2c1c2037a9f9a0f6a66fa34018fce8 Mon Sep 17 00:00:00 2001 From: "siqi.an" Date: Mon, 10 Feb 2025 17:30:35 +0800 Subject: [PATCH 09/36] ui control num of concurrencies Signed-off-by: siqi.an --- .../components/run_test/submitTask.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py index 426095397..5827efeb4 100644 --- a/vectordb_bench/frontend/components/run_test/submitTask.py +++ b/vectordb_bench/frontend/components/run_test/submitTask.py @@ -1,6 +1,8 @@ from datetime import datetime +from vectordb_bench import config from vectordb_bench.frontend.config import styles from vectordb_bench.interface import benchmark_runner +from vectordb_bench.models import TaskConfig def submitTask(st, tasks, isAllValid): @@ -47,16 +49,31 @@ def advancedSettings(st): k = container[0].number_input("k", min_value=1, value=100, label_visibility="collapsed") container[1].caption("K value for number of nearest neighbors to search") - return index_already_exists, use_aliyun, k + container = st.columns([1, 2]) + defaultconcurrentInput = ",".join(map(str, config.NUM_CONCURRENCY)) + concurrentInput = container[0].text_input( + "Concurrent Input", value=defaultconcurrentInput, label_visibility="collapsed" + ) + container[1].caption("num of concurrencies for search tests to get max-qps") + return index_already_exists, use_aliyun, k, concurrentInput -def controlPanel(st, tasks, taskLabel, isAllValid): - index_already_exists, use_aliyun, k = advancedSettings(st) +def controlPanel(st, tasks: list[TaskConfig], taskLabel, isAllValid): + index_already_exists, use_aliyun, k, concurrentInput = advancedSettings(st) def runHandler(): benchmark_runner.set_drop_old(not index_already_exists) + + try: + concurrentInput_list = [int(item.strip()) for item in concurrentInput.split(",")] + except ValueError: + st.write("please input correct number") + return None + for task in tasks: task.case_config.k = k + task.case_config.concurrency_search_config.num_concurrency = concurrentInput_list + benchmark_runner.set_download_address(use_aliyun) benchmark_runner.run(tasks, taskLabel) From 62454b35515cfcc0c34583e232bc87a623cee60e Mon Sep 17 00:00:00 2001 From: Xiaofan <83447078+xiaofan-luan@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:36:48 +0800 Subject: [PATCH 10/36] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index e8be99cec..86c6ab529 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ Closely mimicking real-world production environments, we've set up diverse testi Prepare to delve into the world of VectorDBBench, and let it guide you in uncovering your perfect vector database match. +VectorDBBench is sponsered by Zilliz,the leading opensource vectorDB company behind Milvus. Choose smarter with VectorDBBench- start your free test on [zilliz cloud](https://zilliz.com/) today! + **Leaderboard:** https://zilliz.com/benchmark ## Quick Start ### Prerequirement From 6832120456b3bc4e52aca23bfaef5b44d8daac2a Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Thu, 13 Feb 2025 10:26:39 +0800 Subject: [PATCH 11/36] environs version should <14.1.0 Signed-off-by: min.tian --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index aafe70750..8bc83f29b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "psutil", "polars", "plotly", - "environs", + "environs<14.1.0", "pydantic Date: Mon, 24 Feb 2025 14:14:59 +0530 Subject: [PATCH 12/36] Support GPU_BRUTE_FORCE index for Milvus (#476) Signed-off-by: Rachit Chaudhary Co-authored-by: Signed-off-by: Rachit Chaudhary - r0c0axe --- vectordb_bench/backend/clients/api.py | 1 + vectordb_bench/backend/clients/milvus/cli.py | 20 +++++++++++ .../backend/clients/milvus/config.py | 33 +++++++++++++++++++ .../frontend/config/dbCaseConfigs.py | 5 +++ 4 files changed, 59 insertions(+) diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index a86849e96..e498ab077 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -25,6 +25,7 @@ class IndexType(str, Enum): ES_HNSW = "hnsw" ES_IVFFlat = "ivfflat" GPU_IVF_FLAT = "GPU_IVF_FLAT" + GPU_BRUTE_FORCE = "GPU_BRUTE_FORCE" GPU_IVF_PQ = "GPU_IVF_PQ" GPU_CAGRA = "GPU_CAGRA" SCANN = "scann" diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 51ea82eff..303eec5f9 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -194,6 +194,26 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): **parameters, ) +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict) +def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): + from .config import GPUBruteForceConfig, MilvusConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), + ), + db_case_config=GPUBruteForceConfig( + metric_type=parameters["metric_type"], + limit=parameters["limit"], # top-k for search + ), + **parameters, + ) + class MilvusGPUIVFPQTypedDict( CommonTypedDict, diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 7d0df803a..1ff3bea5f 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -40,6 +40,7 @@ def is_gpu_index(self) -> bool: IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ, + IndexType.GPU_BRUTE_FORCE, ] def parse_metric(self) -> str: @@ -184,6 +185,37 @@ def search_param(self) -> dict: } +class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig): + limit: int = 10 # Default top-k for search + metric_type: str # Metric type (e.g., 'L2', 'IP', etc.) + index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE + + def index_param(self) -> dict: + """ + Returns the parameters for creating the GPU_BRUTE_FORCE index. + No additional parameters required for index building. + """ + return { + "metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.) + "index_type": self.index.value, # GPU_BRUTE_FORCE index type + "params": {}, # No additional parameters for GPU_BRUTE_FORCE + } + + def search_param(self) -> dict: + """ + Returns the parameters for performing a search on the GPU_BRUTE_FORCE index. + Only metric_type and top-k (limit) are needed for search. + """ + return { + "metric_type": self.parse_metric(), # Metric type for search + "params": { + "nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search) + "limit": self.limit, # Top-k for search + }, + } + + + class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig): nlist: int = 1024 m: int = 0 @@ -261,4 +293,5 @@ def search_param(self) -> dict: IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig, IndexType.GPU_IVF_PQ: GPUIVFPQConfig, IndexType.GPU_CAGRA: GPUCAGRAConfig, + IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig, } diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 13858f879..a1c67e89a 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -173,6 +173,7 @@ class CaseConfigInput(BaseModel): IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_CAGRA.value, + IndexType.GPU_BRUTE_FORCE.value, ], }, ) @@ -562,6 +563,7 @@ class CaseConfigInput(BaseModel): IndexType.IVFSQ8.value, IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) @@ -579,6 +581,7 @@ class CaseConfigInput(BaseModel): IndexType.IVFSQ8.value, IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) @@ -703,6 +706,7 @@ class CaseConfigInput(BaseModel): IndexType.GPU_CAGRA.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_IVF_FLAT.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) @@ -720,6 +724,7 @@ class CaseConfigInput(BaseModel): IndexType.GPU_CAGRA.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_IVF_FLAT.value, + IndexType.GPU_BRUTE_FORCE.value, ], ) From 7bda989feb1f433350b9196e2e81d14c0b86abe5 Mon Sep 17 00:00:00 2001 From: Luca Giacchino Date: Tue, 5 Nov 2024 15:44:32 -0800 Subject: [PATCH 13/36] Add table quantization type --- README.md | 6 +- .../backend/clients/pgvector/cli.py | 14 +++- .../backend/clients/pgvector/config.py | 27 +++++-- .../backend/clients/pgvector/pgvector.py | 81 ++++++++++++++----- .../frontend/config/dbCaseConfigs.py | 15 ++++ vectordb_bench/models.py | 1 + 6 files changed, 118 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 86c6ab529..80db92ec5 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,11 @@ Options: --ef-construction INTEGER hnsw ef-construction --ef-search INTEGER hnsw ef-search --quantization-type [none|bit|halfvec] - quantization type for vectors + quantization type for vectors (in index) + --table-quantization-type [none|bit|halfvec] + quantization type for vectors (in table). If + equal to bit, the parameter + quantization_type will be set to bit too. --custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K --custom-case-description TEXT Custom name description --custom-case-load-timeout INTEGER diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index 55a462055..1780af991 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -82,7 +82,17 @@ class PgVectorTypedDict(CommonTypedDict): click.option( "--quantization-type", type=click.Choice(["none", "bit", "halfvec"]), - help="quantization type for vectors", + help="quantization type for vectors (in index)", + required=False, + ), + ] + table_quantization_type: Annotated[ + str | None, + click.option( + "--table-quantization-type", + type=click.Choice(["none", "bit", "halfvec"]), + help="quantization type for vectors (in table). " + "If equal to bit, the parameter quantization_type will be set to bit too.", required=False, ), ] @@ -146,6 +156,7 @@ def PgVectorIVFFlat( lists=parameters["lists"], probes=parameters["probes"], quantization_type=parameters["quantization_type"], + table_quantization_type=parameters["table_quantization_type"], reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], @@ -182,6 +193,7 @@ def PgVectorHNSW( maintenance_work_mem=parameters["maintenance_work_mem"], max_parallel_workers=parameters["max_parallel_workers"], quantization_type=parameters["quantization_type"], + table_quantization_type=parameters["table_quantization_type"], reranking=parameters["reranking"], reranking_metric=parameters["reranking_metric"], quantized_fetch_limit=parameters["quantized_fetch_limit"], diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index c386d75ef..abfddc0cf 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -80,7 +80,12 @@ def parse_metric(self) -> str: if d.get(self.quantization_type) is None: return d.get("_fallback").get(self.metric_type) - return d.get(self.quantization_type).get(self.metric_type) + metric = d.get(self.quantization_type).get(self.metric_type) + # If using binary quantization for the index, use a bit metric + # no matter what metric was selected for vector or halfvec data + if self.quantization_type == "bit" and metric is None: + return "bit_hamming_ops" + return metric def parse_metric_fun_op(self) -> LiteralString: if self.quantization_type == "bit": @@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig): maintenance_work_mem: str | None = None max_parallel_workers: int | None = None quantization_type: str | None = None + table_quantization_type: str | None reranking: bool | None = None quantized_fetch_limit: int | None = None reranking_metric: str | None = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"lists": self.lists} - if self.quantization_type == "none": - self.quantization_type = None + if self.quantization_type == "none" or self.quantization_type is None: + self.quantization_type = "vector" + if self.table_quantization_type == "none" or self.table_quantization_type is None: + self.table_quantization_type = "vector" + if self.table_quantization_type == "bit": + self.quantization_type = "bit" return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -183,6 +193,7 @@ def index_param(self) -> PgVectorIndexParam: "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, + "table_quantization_type": self.table_quantization_type, } def search_param(self) -> PgVectorSearchParam: @@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): maintenance_work_mem: str | None = None max_parallel_workers: int | None = None quantization_type: str | None = None + table_quantization_type: str | None reranking: bool | None = None quantized_fetch_limit: int | None = None reranking_metric: str | None = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} - if self.quantization_type == "none": - self.quantization_type = None + if self.quantization_type == "none" or self.quantization_type is None: + self.quantization_type = "vector" + if self.table_quantization_type == "none" or self.table_quantization_type is None: + self.table_quantization_type = "vector" + if self.table_quantization_type == "bit": + self.quantization_type = "bit" return { "metric": self.parse_metric(), "index_type": self.index.value, @@ -227,6 +243,7 @@ def index_param(self) -> PgVectorIndexParam: "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, + "table_quantization_type": self.table_quantization_type, } def search_param(self) -> PgVectorSearchParam: diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 4164461fb..7d06a2ba5 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -94,7 +94,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: reranking = self.case_config.search_param()["reranking"] column_name = ( sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding")) - if index_param["quantization_type"] == "bit" + if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit" else sql.SQL("embedding") ) search_vector = ( @@ -104,7 +104,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: ) # The following sections assume that the quantization_type value matches the quantization function name - if index_param["quantization_type"] is not None: + if index_param["quantization_type"] != index_param["table_quantization_type"]: + # Reranking makes sense only if table quantization is not "bit" if index_param["quantization_type"] == "bit" and reranking: # Embeddings needs to be passed to binary_quantize function if quantization_type is bit search_query = sql.Composed( @@ -113,7 +114,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: """ SELECT i.id FROM ( - SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance + SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) """, @@ -123,6 +124,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: reranking_metric_fun_op=sql.SQL( self.case_config.search_param()["reranking_metric_fun_op"], ), + search_vector=search_vector, + table_quantization_type=sql.SQL(index_param["table_quantization_type"]), quantization_type=sql.SQL(index_param["quantization_type"]), dim=sql.Literal(self.dim), where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), @@ -130,7 +133,7 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL( """ - {search_vector} + {search_vector}::{quantization_type}({dim}) LIMIT {quantized_fetch_limit} ) i ORDER BY i.distance @@ -138,6 +141,8 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: """, ).format( search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), quantized_fetch_limit=sql.Literal( self.case_config.search_param()["quantized_fetch_limit"], ), @@ -160,10 +165,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" {search_vector} LIMIT %s::int").format( + sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format( search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), ), - ], + ] ) else: search_query = sql.Composed( @@ -175,8 +182,12 @@ def _generate_search_query(self, filtered: bool = False) -> sql.Composed: where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int"), - ], + sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format( + search_vector=search_vector, + quantization_type=sql.SQL(index_param["quantization_type"]), + dim=sql.Literal(self.dim), + ), + ] ) return search_query @@ -323,7 +334,7 @@ def _create_index(self): ) with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) - if index_param["quantization_type"] is not None: + if index_param["quantization_type"] != index_param["table_quantization_type"]: index_create_sql = sql.SQL( """ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} @@ -365,14 +376,23 @@ def _create_table(self, dim: int): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + index_param = self.case_config.index_param() + try: log.info(f"{self.name} client create table : {self.table_name}") # create table self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", - ).format(table_name=sql.Identifier(self.table_name), dim=dim), + """ + CREATE TABLE IF NOT EXISTS public.{table_name} + (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim})); + """ + ).format( + table_name=sql.Identifier(self.table_name), + table_quantization_type=sql.SQL(index_param["table_quantization_type"]), + dim=dim, + ) ) self.cursor.execute( sql.SQL( @@ -393,18 +413,41 @@ def insert_embeddings( assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + index_param = self.case_config.index_param() + try: metadata_arr = np.array(metadata) embeddings_arr = np.array(embeddings) - with self.cursor.copy( - sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name), - ), - ) as copy: - copy.set_types(["bigint", "vector"]) - for i, row in enumerate(metadata_arr): - copy.write_row((row, embeddings_arr[i])) + if index_param["table_quantization_type"] == "bit": + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + # Same logic as pgvector binary_quantize + for i, row in enumerate(metadata_arr): + embeddings_bit = "" + for embedding in embeddings_arr[i]: + if embedding > 0: + embeddings_bit += "1" + else: + embeddings_bit += "0" + copy.write_row((str(row), embeddings_bit)) + else: + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + if index_param["table_quantization_type"] == "halfvec": + copy.set_types(["bigint", "halfvec"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, np.float16(embeddings_arr[i]))) + else: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) self.conn.commit() if kwargs.get("last_batch"): diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index a1c67e89a..c3918e7d6 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -823,6 +823,19 @@ class CaseConfigInput(BaseModel): ], ) +CaseConfigParamInput_TableQuantizationType_PgVector = CaseConfigInput( + label=CaseConfigParamType.tableQuantizationType, + inputType=InputType.Option, + inputConfig={ + "options": ["none", "bit", "halfvec"], + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + ], +) + CaseConfigParamInput_max_parallel_workers_PgVectorRS = CaseConfigInput( label=CaseConfigParamType.max_parallel_workers, displayLabel="Max parallel workers", @@ -1138,6 +1151,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_m, CaseConfigParamInput_EFConstruction_PgVector, CaseConfigParamInput_QuantizationType_PgVector, + CaseConfigParamInput_TableQuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, ] @@ -1149,6 +1163,7 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_Lists_PgVector, CaseConfigParamInput_Probes_PgVector, CaseConfigParamInput_QuantizationType_PgVector, + CaseConfigParamInput_TableQuantizationType_PgVector, CaseConfigParamInput_maintenance_work_mem_PgVector, CaseConfigParamInput_max_parallel_workers_PgVector, CaseConfigParamInput_reranking_PgVector, diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index bf71ebb89..fa80da2d5 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -49,6 +49,7 @@ class CaseConfigParamType(Enum): probes = "probes" quantizationType = "quantization_type" quantizationRatio = "quantization_ratio" + tableQuantizationType = "table_quantization_type" reranking = "reranking" rerankingMetric = "reranking_metric" quantizedFetchLimit = "quantized_fetch_limit" From 7f501043ab9190b659fa968ebe639d3a5c0dc588 Mon Sep 17 00:00:00 2001 From: Hugo Wen <46255328+HugoWenTD@users.noreply.github.com> Date: Mon, 10 Mar 2025 18:51:47 -0700 Subject: [PATCH 14/36] Support MariaDB database (#375) MariaDB introduced vector support in version 11.7, enabling MariaDB Server to function as a relational vector database. https://mariadb.com/kb/en/vectors/ Now add support for MariaDB server, verified against MariaDB server of version 11.7.1: - Support MariaDB vector search with HNSW algorithm, support filter search. - Support index and search parameters: - storage_engine: InnoDB or MyISAM - M: M parameter in MHNSW vector indexing - ef_search: minimal number of result candidates to look for in the vector index for ORDER BY ... LIMIT N queries. - max_cache_size: Upper limit for one MHNSW vector index cache - Support CLI of `vectordbbench mariadbhnsw`. --- pyproject.toml | 2 + vectordb_bench/backend/clients/__init__.py | 15 ++ vectordb_bench/backend/clients/mariadb/cli.py | 107 +++++++++ .../backend/clients/mariadb/config.py | 71 ++++++ .../backend/clients/mariadb/mariadb.py | 214 ++++++++++++++++++ vectordb_bench/cli/vectordbbench.py | 2 + .../frontend/config/dbCaseConfigs.py | 76 +++++++ vectordb_bench/models.py | 2 + 8 files changed, 489 insertions(+) create mode 100644 vectordb_bench/backend/clients/mariadb/cli.py create mode 100644 vectordb_bench/backend/clients/mariadb/config.py create mode 100644 vectordb_bench/backend/clients/mariadb/mariadb.py diff --git a/pyproject.toml b/pyproject.toml index 8bc83f29b..ad80bacfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ all = [ "memorydb", "alibabacloud_ha3engine_vector", "alibabacloud_searchengine20211025", + "mariadb", ] qdrant = [ "qdrant-client" ] @@ -86,6 +87,7 @@ chromadb = [ "chromadb" ] opensearch = [ "opensearch-py" ] aliyun_opensearch = [ "alibabacloud_ha3engine_vector", "alibabacloud_searchengine20211025"] mongodb = [ "pymongo" ] +mariadb = [ "mariadb" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index f2f480dcf..742da5213 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -38,6 +38,7 @@ class DB(Enum): Chroma = "Chroma" AWSOpenSearch = "OpenSearch" AliyunElasticsearch = "AliyunElasticsearch" + MariaDB = "MariaDB" Test = "test" AliyunOpenSearch = "AliyunOpenSearch" MongoDB = "MongoDB" @@ -135,6 +136,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return MongoDB + if self == DB.MariaDB: + from .mariadb.mariadb import MariaDB + + return MariaDB + if self == DB.Test: from .test.test import Test @@ -236,6 +242,10 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 return MongoDBConfig + if self == DB.MariaDB: + from .mariadb.config import MariaDBConfig + return MariaDBConfig + if self == DB.Test: from .test.config import TestConfig @@ -318,6 +328,11 @@ def case_config_cls( # noqa: PLR0911 return MongoDBIndexConfig + if self == DB.MariaDB: + from .mariadb.config import _mariadb_case_config + + return _mariadb_case_config.get(index_type) + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/mariadb/cli.py b/vectordb_bench/backend/clients/mariadb/cli.py new file mode 100644 index 000000000..c5439f37d --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/cli.py @@ -0,0 +1,107 @@ +from typing import Annotated, Optional, Unpack + +import click +import os +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor1, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from vectordb_bench.backend.clients import DB + + +class MariaDBTypedDict(CommonTypedDict): + user_name: Annotated[ + str, click.option("--username", + type=str, + help="Username", + required=True, + ), + ] + password: Annotated[ + str, click.option("--password", + type=str, + help="Password", + required=True, + ), + ] + + host: Annotated[ + str, click.option("--host", + type=str, + help="Db host", + default="127.0.0.1", + ), + ] + + port: Annotated[ + int, click.option("--port", + type=int, + default=3306, + help="Db Port", + ), + ] + + storage_engine: Annotated[ + int, click.option("--storage-engine", + type=click.Choice(["InnoDB", "MyISAM"]), + help="DB storage engine", + required=True, + ), + ] + +class MariaDBHNSWTypedDict(MariaDBTypedDict): + ... + m: Annotated[ + Optional[int], click.option("--m", + type=int, + help="M parameter in MHNSW vector indexing", + required=False, + ), + ] + + ef_search: Annotated[ + Optional[int], click.option("--ef-search", + type=int, + help="MariaDB system variable mhnsw_min_limit", + required=False, + ), + ] + + max_cache_size: Annotated[ + Optional[int], click.option("--max-cache-size", + type=int, + help="MariaDB system variable mhnsw_max_cache_size", + required=False, + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict) +def MariaDBHNSW( + **parameters: Unpack[MariaDBHNSWTypedDict], +): + from .config import MariaDBConfig, MariaDBHNSWConfig + + run( + db=DB.MariaDB, + db_config=MariaDBConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + ), + db_case_config=MariaDBHNSWConfig( + M=parameters["m"], + ef_search=parameters["ef_search"], + storage_engine=parameters["storage_engine"], + max_cache_size=parameters["max_cache_size"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/mariadb/config.py b/vectordb_bench/backend/clients/mariadb/config.py new file mode 100644 index 000000000..c7b2cd5fe --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/config.py @@ -0,0 +1,71 @@ +from pydantic import SecretStr, BaseModel +from typing import TypedDict +from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +class MariaDBConfigDict(TypedDict): + """These keys will be directly used as kwargs in mariadb connection string, + so the names must match exactly mariadb API""" + + user: str + password: str + host: str + port: int + + +class MariaDBConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 3306 + + def to_dict(self) -> MariaDBConfigDict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + } + + +class MariaDBIndexConfig(BaseModel): + """Base config for MariaDB""" + + metric_type: MetricType | None = None + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "euclidean" + elif self.metric_type == MetricType.COSINE: + return "cosine" + else: + raise ValueError(f"Metric type {self.metric_type} is not supported!") + +class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): + M: int | None + ef_search: int | None + index: IndexType = IndexType.HNSW + storage_engine: str = "InnoDB" + max_cache_size: int | None + + def index_param(self) -> dict: + return { + "storage_engine": self.storage_engine, + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "M": self.M, + "max_cache_size": self.max_cache_size, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "ef_search": self.ef_search, + } + + +_mariadb_case_config = { + IndexType.HNSW: MariaDBHNSWConfig, +} + + diff --git a/vectordb_bench/backend/clients/mariadb/mariadb.py b/vectordb_bench/backend/clients/mariadb/mariadb.py new file mode 100644 index 000000000..42b621d9c --- /dev/null +++ b/vectordb_bench/backend/clients/mariadb/mariadb.py @@ -0,0 +1,214 @@ +from ..api import VectorDB + +import logging +from contextlib import contextmanager +from typing import Any, Optional, Tuple +from ..api import VectorDB +from .config import MariaDBConfigDict, MariaDBIndexConfig +import numpy as np + +import mariadb + +log = logging.getLogger(__name__) + +class MariaDB(VectorDB): + def __init__( + self, + dim: int, + db_config: MariaDBConfigDict, + db_case_config: MariaDBIndexConfig, + collection_name: str = "vec_collection", + drop_old: bool = False, + **kwargs, + ): + + self.name = "MariaDB" + self.db_config = db_config + self.case_config = db_case_config + self.db_name = "vectordbbench" + self.table_name = collection_name + self.dim = dim + + # construct basic units + self.conn, self.cursor = self._create_connection(**self.db_config) + + if drop_old: + self._drop_db() + self._create_db_table(dim) + + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + @staticmethod + def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]: + conn = mariadb.connect(**kwargs) + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + + + def _drop_db(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop db : {self.db_name}") + + # flush tables before dropping database to avoid some locking issue + self.cursor.execute("FLUSH TABLES") + self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}") + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + def _create_db_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + log.info(f"{self.name} client create database : {self.db_name}") + self.cursor.execute(f"CREATE DATABASE {self.db_name}") + + log.info(f"{self.name} client create table : {self.table_name}") + self.cursor.execute(f"USE {self.db_name}") + + self.cursor.execute(f""" + CREATE TABLE {self.table_name} ( + id INT PRIMARY KEY, + v VECTOR({self.dim}) NOT NULL + ) ENGINE={index_param["storage_engine"]} + """) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning( + f"Failed to create table: {self.table_name} error: {e}" + ) + raise e from None + + + @contextmanager + def init(self) -> None: + """ create and destory connections to database. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.conn, self.cursor = self._create_connection(**self.db_config) + + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() + + # maximize allowed package size + self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824") + + if index_param["index_type"] == "HNSW": + if index_param["max_cache_size"] != None: + self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}") + if search_param["ef_search"] != None: + self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}") + self.cursor.execute("COMMIT") + + self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" + self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d" + self.select_sql_with_filter = f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d" + + try: + yield + finally: + self.cursor.close() + self.conn.close() + self.cursor = None + self.conn = None + + + def ready_to_load(self) -> bool: + pass + + def optimize(self) -> None: + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + try: + index_options = f"DISTANCE={index_param['metric_type']}" + if index_param["index_type"] == "HNSW" and index_param["M"] != None: + index_options += f" M={index_param['M']}" + + self.cursor.execute(f""" + ALTER TABLE {self.db_name}.{self.table_name} + ADD VECTOR KEY v(v) {index_options} + """) + self.cursor.execute("COMMIT") + + except Exception as e: + log.warning( + f"Failed to create index: {self.table_name} error: {e}" + ) + raise e from None + + pass + + @staticmethod + def vector_to_hex(v): + return np.array(v, 'float32').tobytes() + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + """Insert embeddings into the database. + Should call self.init() first. + """ + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + try: + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + batch_data = [] + for i, row in enumerate(metadata_arr): + batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i]))); + + self.cursor.executemany(self.insert_sql, batch_data) + self.cursor.execute("COMMIT") + self.cursor.execute("FLUSH TABLES") + + return len(metadata), None + except Exception as e: + log.warning( + f"Failed to insert data into Vector table ({self.table_name}), error: {e}" + ) + return 0, e + + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> (list[int]): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + search_param = self.case_config.search_param() + + if filters: + self.cursor.execute(self.select_sql_with_filter, (filters.get('id'), self.vector_to_hex(query), k)) + else: + self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k)) + + return [id for id, in self.cursor.fetchall()] + diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 5e3798691..7934b3871 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,5 +1,6 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.mariadb.cli import MariaDBHNSW from ..backend.clients.memorydb.cli import MemoryDB from ..backend.clients.milvus.cli import MilvusAutoIndex from ..backend.clients.pgdiskann.cli import PgDiskAnn @@ -25,6 +26,7 @@ cli.add_command(PgVectorScaleDiskAnn) cli.add_command(PgDiskAnn) cli.add_command(AlloyDBScaNN) +cli.add_command(MariaDBHNSW) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index c3918e7d6..0ab3a932b 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1058,6 +1058,64 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_MariaDB = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_StorageEngine_MariaDB = CaseConfigInput( + label=CaseConfigParamType.storage_engine, + inputHelp="Select Storage Engine", + inputType=InputType.Option, + inputConfig={ + "options": ["InnoDB", "MyISAM"], + }, +) + +CaseConfigParamInput_M_MariaDB = CaseConfigInput( + label=CaseConfigParamType.M, + inputHelp="M parameter in MHNSW vector indexing", + inputType=InputType.Number, + inputConfig={ + "min": 3, + "max": 200, + "value": 6, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + +CaseConfigParamInput_EFSearch_MariaDB = CaseConfigInput( + label=CaseConfigParamType.ef_search, + inputHelp="mhnsw_ef_search", + inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 10000, + "value": 20, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + +CaseConfigParamInput_CacheSize_MariaDB = CaseConfigInput( + label=CaseConfigParamType.max_cache_size, + inputHelp="mhnsw_max_cache_size", + inputType=InputType.Number, + inputConfig={ + "min": 1048576, + "max": (1 << 53) - 1, + "value": 16 * 1024 ** 3, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) CaseConfigParamInput_MongoDBQuantizationType = CaseConfigInput( label=CaseConfigParamType.mongodb_quantization_type, @@ -1272,6 +1330,20 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_MongoDBNumCandidatesRatio, ] +MariaDBLoadingConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, +] +MariaDBPerformanceConfig = [ + CaseConfigParamInput_IndexType_MariaDB, + CaseConfigParamInput_StorageEngine_MariaDB, + CaseConfigParamInput_M_MariaDB, + CaseConfigParamInput_CacheSize_MariaDB, + CaseConfigParamInput_EFSearch_MariaDB, +] + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -1324,4 +1396,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: MongoDBLoadingConfig, CaseLabel.Performance: MongoDBPerformanceConfig, }, + DB.MariaDB: { + CaseLabel.Load: MariaDBLoadingConfig, + CaseLabel.Performance: MariaDBPerformanceConfig, + }, } diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index fa80da2d5..e206919ac 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -88,6 +88,8 @@ class CaseConfigParamType(Enum): preReorderingNumNeigbors = "pre_reordering_num_neighbors" numSearchThreads = "num_search_threads" maxNumPrefetchDatasets = "max_num_prefetch_datasets" + storage_engine = "storage_engine" + max_cache_size = "max_cache_size" # mongodb params mongodb_quantization_type = "quantization" From b8221d1d25b1b0507e8238725d838f10f2b76f41 Mon Sep 17 00:00:00 2001 From: Wenxuan Date: Thu, 13 Mar 2025 10:21:05 +0800 Subject: [PATCH 15/36] Add TiDB backend (#484) * Add TiDB backend Signed-off-by: Wish * Fix Signed-off-by: Wish * Fix Signed-off-by: Wish * Improve error handling Signed-off-by: Wish --------- Signed-off-by: Wish --- README.md | 1 + pyproject.toml | 4 +- vectordb_bench/backend/clients/__init__.py | 17 ++ vectordb_bench/backend/clients/tidb/cli.py | 98 ++++++++ vectordb_bench/backend/clients/tidb/config.py | 49 ++++ vectordb_bench/backend/clients/tidb/tidb.py | 234 ++++++++++++++++++ vectordb_bench/cli/vectordbbench.py | 2 + vectordb_bench/frontend/config/styles.py | 2 + 8 files changed, 406 insertions(+), 1 deletion(-) create mode 100644 vectordb_bench/backend/clients/tidb/cli.py create mode 100644 vectordb_bench/backend/clients/tidb/config.py create mode 100644 vectordb_bench/backend/clients/tidb/tidb.py diff --git a/README.md b/README.md index 80db92ec5..7a52c3242 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ All the database client supported | awsopensearch | `pip install vectordb-bench[opensearch]` | | aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` | | mongodb | `pip install vectordb-bench[mongodb]` | +| tidb | `pip install vectordb-bench[tidb]` | ### Run diff --git a/pyproject.toml b/pyproject.toml index ad80bacfe..019c2973e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ all = [ "alibabacloud_ha3engine_vector", "alibabacloud_searchengine20211025", "mariadb", + "PyMySQL", ] qdrant = [ "qdrant-client" ] @@ -87,7 +88,8 @@ chromadb = [ "chromadb" ] opensearch = [ "opensearch-py" ] aliyun_opensearch = [ "alibabacloud_ha3engine_vector", "alibabacloud_searchengine20211025"] mongodb = [ "pymongo" ] -mariadb = [ "mariadb" ] +mariadb = [ "mariadb" ] +tidb = [ "PyMySQL" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 742da5213..8d667ec23 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -42,6 +42,7 @@ class DB(Enum): Test = "test" AliyunOpenSearch = "AliyunOpenSearch" MongoDB = "MongoDB" + TiDB = "TiDB" @property def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 @@ -141,6 +142,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return MariaDB + if self == DB.TiDB: + from .tidb.tidb import TiDB + + return TiDB + if self == DB.Test: from .test.test import Test @@ -244,8 +250,14 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 if self == DB.MariaDB: from .mariadb.config import MariaDBConfig + return MariaDBConfig + if self == DB.TiDB: + from .tidb.config import TiDBConfig + + return TiDBConfig + if self == DB.Test: from .test.config import TestConfig @@ -333,6 +345,11 @@ def case_config_cls( # noqa: PLR0911 return _mariadb_case_config.get(index_type) + if self == DB.TiDB: + from .tidb.config import TiDBIndexConfig + + return TiDBIndexConfig + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/tidb/cli.py b/vectordb_bench/backend/clients/tidb/cli.py new file mode 100644 index 000000000..cdfcbe432 --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/cli.py @@ -0,0 +1,98 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB + +from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run + + +class TiDBTypedDict(CommonTypedDict): + user_name: Annotated[ + str, + click.option( + "--username", + type=str, + help="Username", + default="root", + show_default=True, + required=True, + ), + ] + password: Annotated[ + str, + click.option( + "--password", + type=str, + default="", + show_default=True, + help="Password", + ), + ] + host: Annotated[ + str, + click.option( + "--host", + type=str, + default="127.0.0.1", + show_default=True, + required=True, + help="Db host", + ), + ] + port: Annotated[ + int, + click.option( + "--port", + type=int, + default=4000, + show_default=True, + required=True, + help="Db Port", + ), + ] + db_name: Annotated[ + str, + click.option( + "--db-name", + type=str, + default="test", + show_default=True, + required=True, + help="Db name", + ), + ] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + default=False, + show_default=True, + is_flag=True, + help="Enable or disable SSL, for TiDB Serverless SSL must be enabled", + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(TiDBTypedDict) +def TiDB( + **parameters: Unpack[TiDBTypedDict], +): + from .config import TiDBConfig, TiDBIndexConfig + + run( + db=DB.TiDB, + db_config=TiDBConfig( + db_label=parameters["db_label"], + user_name=parameters["username"], + password=SecretStr(parameters["password"]), + host=parameters["host"], + port=parameters["port"], + db_name=parameters["db_name"], + ssl=parameters["ssl"], + ), + db_case_config=TiDBIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/tidb/config.py b/vectordb_bench/backend/clients/tidb/config.py new file mode 100644 index 000000000..213a18bc5 --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/config.py @@ -0,0 +1,49 @@ +from pydantic import SecretStr, BaseModel, validator +from ..api import DBConfig, DBCaseConfig, MetricType + + +class TiDBConfig(DBConfig): + user_name: str = "root" + password: SecretStr + host: str = "127.0.0.1" + port: int = 4000 + db_name: str = "test" + ssl: bool = False + + @validator("*") + def not_empty_field(cls, v: any, field: any): + return v + + def to_dict(self) -> dict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "user": self.user_name, + "password": pwd_str, + "database": self.db_name, + "ssl_verify_cert": self.ssl, + "ssl_verify_identity": self.ssl, + } + + +class TiDBIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + + def get_metric_fn(self) -> str: + if self.metric_type == MetricType.L2: + return "vec_l2_distance" + elif self.metric_type == MetricType.COSINE: + return "vec_cosine_distance" + else: + raise ValueError(f"Unsupported metric type: {self.metric_type}") + + def index_param(self) -> dict: + return { + "metric_fn": self.get_metric_fn(), + } + + def search_param(self) -> dict: + return { + "metric_fn": self.get_metric_fn(), + } diff --git a/vectordb_bench/backend/clients/tidb/tidb.py b/vectordb_bench/backend/clients/tidb/tidb.py new file mode 100644 index 000000000..d1f26084e --- /dev/null +++ b/vectordb_bench/backend/clients/tidb/tidb.py @@ -0,0 +1,234 @@ +import concurrent.futures +import io +import logging +import time +from contextlib import contextmanager +from typing import Any, Optional, Tuple + +import pymysql + +from ..api import VectorDB +from .config import TiDBIndexConfig + +log = logging.getLogger(__name__) + + +class TiDB(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: TiDBIndexConfig, + collection_name: str = "vector_bench_test", + drop_old: bool = False, + **kwargs, + ): + self.name = "TiDB" + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + self.conn = None # To be inited by init() + self.cursor = None # To be inited by init() + + self.search_fn = db_case_config.search_param()["metric_fn"] + + if drop_old: + self._drop_table() + self._create_table() + + @contextmanager + def init(self): + with self._get_connection() as (conn, cursor): + self.conn = conn + self.cursor = cursor + try: + yield + finally: + self.conn = None + self.cursor = None + + @contextmanager + def _get_connection(self): + with pymysql.connect(**self.db_config) as conn: + conn.autocommit = False + with conn.cursor() as cursor: + yield conn, cursor + + def _drop_table(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") + conn.commit() + except Exception as e: + log.warning("Failed to drop table: %s error: %s", self.table_name, e) + raise e + + def _create_table(self): + try: + index_param = self.case_config.index_param() + with self._get_connection() as (conn, cursor): + cursor.execute( + f""" + CREATE TABLE {self.table_name} ( + id BIGINT PRIMARY KEY, + embedding VECTOR({self.dim}) NOT NULL, + VECTOR INDEX (({index_param["metric_fn"]}(embedding))) + ); + """ + ) + conn.commit() + except Exception as e: + log.warning("Failed to create table: %s error: %s", self.table_name, e) + raise e + + def ready_to_load(self) -> bool: + pass + + def optimize(self, data_size: int | None = None) -> None: + while True: + progress = self._optimize_check_tiflash_replica_progress() + if progress != 1: + log.info("Data replication not ready, progress: %d", progress) + time.sleep(2) + else: + break + + log.info("Waiting TiFlash to catch up...") + self._optimize_wait_tiflash_catch_up() + + log.info("Start compacting TiFlash replica...") + self._optimize_compact_tiflash() + + log.info("Waiting index build to finish...") + log_reduce_seq = 0 + while True: + pending_rows = self._optimize_get_tiflash_index_pending_rows() + if pending_rows > 0: + if log_reduce_seq % 15 == 0: + log.info("Index not fully built, pending rows: %d", pending_rows) + log_reduce_seq += 1 + time.sleep(2) + else: + break + + log.info("Index build finished successfully.") + + def _optimize_check_tiflash_replica_progress(self): + try: + database = self.db_config["database"] + with self._get_connection() as (_, cursor): + cursor.execute( + f""" + SELECT PROGRESS FROM information_schema.tiflash_replica + WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}" + """ + ) + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to check TiFlash replica progress: %s", e) + raise e + + def _optimize_wait_tiflash_catch_up(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"') + conn.commit() + cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to wait TiFlash to catch up: %s", e) + raise e + + def _optimize_compact_tiflash(self): + try: + with self._get_connection() as (conn, cursor): + cursor.execute(f"ALTER TABLE {self.table_name} COMPACT") + conn.commit() + except Exception as e: + log.warning("Failed to compact table: %s", e) + raise e + + def _optimize_get_tiflash_index_pending_rows(self): + try: + database = self.db_config["database"] + with self._get_connection() as (_, cursor): + cursor.execute( + f""" + SELECT SUM(ROWS_STABLE_NOT_INDEXED) + FROM information_schema.tiflash_indexes + WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}" + """ + ) + result = cursor.fetchone() + return result[0] + except Exception as e: + log.warning("Failed to read TiFlash index pending rows: %s", e) + raise e + + def _insert_embeddings_serial( + self, + embeddings: list[list[float]], + metadata: list[int], + offset: int, + size: int, + ) -> Exception: + try: + with self._get_connection() as (conn, cursor): + buf = io.StringIO() + buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") + for i in range(offset, offset + size): + if i > offset: + buf.write(",") + buf.write(f'({metadata[i]}, "{str(embeddings[i])}")') + cursor.execute(buf.getvalue()) + conn.commit() + except Exception as e: + log.warning("Failed to insert data into table: %s", e) + raise e + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + workers = 10 + # Avoid exceeding MAX_ALLOWED_PACKET (default=64MB) + max_batch_size = 64 * 1024 * 1024 // 24 // self.dim + batch_size = len(embeddings) // workers + if batch_size > max_batch_size: + batch_size = max_batch_size + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [] + for i in range(0, len(embeddings), batch_size): + offset = i + size = min(batch_size, len(embeddings) - i) + future = executor.submit(self._insert_embeddings_serial, embeddings, metadata, offset, size) + futures.append(future) + done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + executor.shutdown(wait=False) + for future in done: + future.result() + for future in pending: + future.cancel() + return len(metadata), None + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + self.cursor.execute( + f""" + SELECT id FROM {self.table_name} + ORDER BY {self.search_fn}(embedding, "{str(query)}") LIMIT {k}; + """ + ) + result = self.cursor.fetchall() + return [int(i[0]) for i in result] diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 7934b3871..49428b678 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -11,6 +11,7 @@ from ..backend.clients.test.cli import Test from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex +from ..backend.clients.tidb.cli import TiDB from .cli import cli cli.add_command(PgVectorHNSW) @@ -27,6 +28,7 @@ cli.add_command(PgDiskAnn) cli.add_command(AlloyDBScaNN) cli.add_command(MariaDBHNSW) +cli.add_command(TiDB) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index 3e0fdb112..57456722f 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -47,6 +47,7 @@ def getPatternShape(i): DB.Redis: "https://assets.zilliz.com/Redis_Cloud_74b8bfef39.png", DB.Chroma: "https://assets.zilliz.com/chroma_ceb3f06ed7.png", DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", + DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", } # RedisCloud color: #0D6EFD @@ -61,4 +62,5 @@ def getPatternShape(i): DB.PgVector.value: "#4C779A", DB.Redis.value: "#0D6EFD", DB.AWSOpenSearch.value: "#0DCAF0", + DB.TiDB.value: "#0D6EFD", } From dba738b88bb6bdd290708d4b6590e2ec67773505 Mon Sep 17 00:00:00 2001 From: Rachit Chaudhary <65501028+Rachit-Chaudhary11@users.noreply.github.com> Date: Fri, 14 Mar 2025 07:14:36 +0530 Subject: [PATCH 16/36] CLI fix for GPU index (#485) * Support GPU_BRUTE_FORCE index for Milvus Signed-off-by: Rachit Chaudhary * MilvusGPUBruteForceTypedDict addition Signed-off-by: Rachit Chaudhary --------- Signed-off-by: Rachit Chaudhary Co-authored-by: Signed-off-by: Rachit Chaudhary - r0c0axe --- vectordb_bench/backend/clients/milvus/cli.py | 30 ++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 303eec5f9..1bec8ebe5 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -214,6 +214,36 @@ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): **parameters, ) +class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict): + metric_type: Annotated[ + str, + click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"), + ] + limit: Annotated[ + int, + click.option("--limit", type=int, required=True, help="Top-k limit for search"), + ] + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict) +def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): + from .config import GPUBruteForceConfig, MilvusConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), + ), + db_case_config=GPUBruteForceConfig( + metric_type=parameters["metric_type"], + limit=parameters["limit"], # top-k for search + ), + **parameters, + ) + class MilvusGPUIVFPQTypedDict( CommonTypedDict, From 4cbfef78c25438d90923ab06099316d48a61ffa3 Mon Sep 17 00:00:00 2001 From: yuyuankang Date: Tue, 25 Mar 2025 20:28:00 +0000 Subject: [PATCH 17/36] remove duplicated code --- vectordb_bench/backend/clients/milvus/cli.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 1bec8ebe5..52524e785 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -194,26 +194,6 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): **parameters, ) -@cli.command() -@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict) -def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): - from .config import GPUBruteForceConfig, MilvusConfig - - run( - db=DBTYPE, - db_config=MilvusConfig( - db_label=parameters["db_label"], - uri=SecretStr(parameters["uri"]), - user=parameters["user_name"], - password=SecretStr(parameters["password"]), - ), - db_case_config=GPUBruteForceConfig( - metric_type=parameters["metric_type"], - limit=parameters["limit"], # top-k for search - ), - **parameters, - ) - class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict): metric_type: Annotated[ str, From a39fe83973d63d5806c3ef49e489732987c51d09 Mon Sep 17 00:00:00 2001 From: Arseniy Ahtaryanov Date: Tue, 8 Apr 2025 17:56:23 +0300 Subject: [PATCH 18/36] feat: initial commit --- install/requirements_py3.11.txt | 1 + pyproject.toml | 2 + vectordb_bench/backend/clients/__init__.py | 16 ++ .../backend/clients/clickhouse/cli.py | 66 ++++++++ .../backend/clients/clickhouse/clickhouse.py | 149 ++++++++++++++++++ .../backend/clients/clickhouse/config.py | 56 +++++++ vectordb_bench/cli/vectordbbench.py | 2 + 7 files changed, 292 insertions(+) create mode 100644 vectordb_bench/backend/clients/clickhouse/cli.py create mode 100644 vectordb_bench/backend/clients/clickhouse/clickhouse.py create mode 100644 vectordb_bench/backend/clients/clickhouse/config.py diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index 18138382f..6969bf9dd 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -22,3 +22,4 @@ environs pydantic type[VectorDB]: # noqa: PLR0911, PLR0912, C901 @@ -117,6 +118,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return AWSOpenSearch + if self == DB.Clickhouse: + from .clickhouse.clickhouse import Clickhouse + + return Clickhouse + if self == DB.AlloyDB: from .alloydb.alloydb import AlloyDB @@ -228,6 +234,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 return AWSOpenSearchConfig + if self == DB.Clickhouse: + from .clickhouse.config import ClickhouseConfig + + return ClickhouseConfig + if self == DB.AlloyDB: from .alloydb.config import AlloyDBConfig @@ -310,6 +321,11 @@ def case_config_cls( # noqa: PLR0911 return AWSOpenSearchIndexConfig + if self == DB.Clickhouse: + from .clickhouse.config import ClickhouseHNSWConfig + + return ClickhouseHNSWConfig + if self == DB.PgVectorScale: from .pgvectorscale.config import _pgvectorscale_case_config diff --git a/vectordb_bench/backend/clients/clickhouse/cli.py b/vectordb_bench/backend/clients/clickhouse/cli.py new file mode 100644 index 000000000..e454f5a75 --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/cli.py @@ -0,0 +1,66 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor2, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from .config import ClickhouseHNSWConfig + + +class ClickhouseTypedDict(TypedDict): + password: Annotated[str, click.option("--password", type=str, help="DB password")] + host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)] + port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")] + user: Annotated[int, click.option("--user", type=str, default='clickhouse', help="DB user")] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + is_flag=True, + show_default=True, + default=True, + help="Enable or disable SSL for Clickhouse", + ), + ] + ssl_ca_certs: Annotated[ + str, + click.option( + "--ssl-ca-certs", + show_default=True, + help="Path to certificate authority file to use for SSL", + ), + ] + + +class ClickhouseHNSWTypedDict(CommonTypedDict, ClickhouseTypedDict, HNSWFlavor2): ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(ClickhouseHNSWTypedDict) +def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]): + from .config import ClickhouseConfig + + run( + db=DB.Clickhouse, + db_config=ClickhouseConfig( + db_label=parameters["db_label"], + password=SecretStr(parameters["password"]) if parameters["password"] else None, + host=parameters["host"], + port=parameters["port"], + ssl=parameters["ssl"], + ssl_ca_certs=parameters["ssl_ca_certs"], + ), + db_case_config=ClickhouseHNSWConfig( + M=parameters["m"], + efConstruction=parameters["ef_construction"], + ef=parameters["ef_runtime"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/clickhouse/clickhouse.py b/vectordb_bench/backend/clients/clickhouse/clickhouse.py new file mode 100644 index 000000000..e5339170b --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/clickhouse.py @@ -0,0 +1,149 @@ +"""Wrapper around the Clickhouse vector database over VectorDB""" + +import io +import logging +from contextlib import contextmanager +from typing import Any +import clickhouse_connect +import numpy as np + +from ..api import VectorDB, DBCaseConfig + +log = logging.getLogger(__name__) + +class Clickhouse(VectorDB): + """Use SQLAlchemy instructions""" + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: DBCaseConfig, + collection_name: str = "CHVectorCollection", + drop_old: bool = False, + **kwargs, + ): + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + + self._index_name = "clickhouse_index" + self._primary_field = "id" + self._vector_field = "embedding" + + # construct basic units + self.conn = clickhouse_connect.get_client( + host=self.db_config["host"], + port=self.db_config["port"], + username=self.db_config["user"], + password=self.db_config["password"], + database=self.db_config["dbname"]) + + if drop_old: + log.info(f"Clickhouse client drop table : {self.table_name}") + self._drop_table() + self._create_table(dim) + + self.conn.close() + self.conn = None + + @contextmanager + def init(self) -> None: + """ + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + >>> self.search_embedding() + """ + + self.conn = clickhouse_connect.get_client( + host=self.db_config["host"], + port=self.db_config["port"], + username=self.db_config["user"], + password=self.db_config["password"], + database=self.db_config["dbname"]) + + try: + yield + finally: + self.conn.close() + self.conn = None + + def _drop_table(self): + assert self.conn is not None, "Connection is not initialized" + + self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}') + + def _create_table(self, dim: int): + assert self.conn is not None, "Connection is not initialized" + + try: + # create table + self.conn.command( + f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \ + (id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;' + ) + + except Exception as e: + log.warning( + f"Failed to create Clickhouse table: {self.table_name} error: {e}" + ) + raise e from None + + def ready_to_load(self): + pass + + def optimize(self, data_size: int | None = None): + pass + + def ready_to_search(self): + pass + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> (int, Exception): + assert self.conn is not None, "Connection is not initialized" + + try: + # do not iterate for bulk insert + items = [metadata, embeddings] + + self.conn.insert(table=self.table_name, data=items, + column_names=['id', 'embedding'], column_type_names=['UInt32', 'Array(Float64)'], + column_oriented=True) + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + ) -> list[int]: + assert self.conn is not None, "Connection is not initialized" + + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() + + if filters: + gt = filters.get("id") + filterSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' + f'FROM {self.db_config["dbname"]}.{self.table_name} ' + f'WHERE id > {gt} ' + f'ORDER BY score LIMIT {k};' + ) + result = self.conn.query(filterSql).result_rows + return [int(row[0]) for row in result] + else: + selectSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' + f'FROM {self.db_config["dbname"]}.{self.table_name} ' + f'ORDER BY score LIMIT {k};' + ) + result = self.conn.query(selectSql).result_rows + return [int(row[0]) for row in result] diff --git a/vectordb_bench/backend/clients/clickhouse/config.py b/vectordb_bench/backend/clients/clickhouse/config.py new file mode 100644 index 000000000..7ce0919ea --- /dev/null +++ b/vectordb_bench/backend/clients/clickhouse/config.py @@ -0,0 +1,56 @@ +from typing import TypedDict +from pydantic import BaseModel, SecretStr +from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +class ClickhouseConfig(DBConfig): + user_name: str = "clickhouse" + password: SecretStr + host: str = "localhost" + port: int = 8123 + db_name: str = "default" + + def to_dict(self) -> dict: + pwd_str = self.password.get_secret_value() + return { + "host": self.host, + "port": self.port, + "dbname": self.db_name, + "user": self.user_name, + "password": pwd_str + } + + +class ClickhouseIndexConfig(BaseModel): + + metric_type: MetricType | None = None + + def parse_metric(self) -> str: + if not self.metric_type: + return "" + return self.metric_type.value + + def parse_metric_str(self) -> str: + if self.metric_type == MetricType.L2: + return "L2Distance" + elif self.metric_type == MetricType.COSINE: + return "cosineDistance" + + +class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig): + M: int | None + efConstruction: int | None + ef: int | None = None + index: IndexType = IndexType.HNSW + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric_str(), + "index_type": self.index.value, + "params": {"M": self.M, "efConstruction": self.efConstruction}, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric_str(), + "params": {"ef": self.ef}, + } \ No newline at end of file diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 49428b678..7f5e24fc2 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -12,6 +12,7 @@ from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex from ..backend.clients.tidb.cli import TiDB +from ..backend.clients.clickhouse.cli import Clickhouse from .cli import cli cli.add_command(PgVectorHNSW) @@ -29,6 +30,7 @@ cli.add_command(AlloyDBScaNN) cli.add_command(MariaDBHNSW) cli.add_command(TiDB) +cli.add_command(Clickhouse) if __name__ == "__main__": From 1446c6e4f2fae6bc21990694ac921bbcb3ea2b01 Mon Sep 17 00:00:00 2001 From: nuvotex-tk <161840620+nuvotex-tk@users.noreply.github.com> Date: Tue, 8 Apr 2025 11:28:41 +0200 Subject: [PATCH 19/36] Add vespa integration --- README.md | 1 + install/requirements_py3.11.txt | 1 + pyproject.toml | 4 +- vectordb_bench/backend/clients/__init__.py | 17 ++ vectordb_bench/backend/clients/vespa/cli.py | 47 ++++ .../backend/clients/vespa/config.py | 51 ++++ vectordb_bench/backend/clients/vespa/util.py | 16 ++ vectordb_bench/backend/clients/vespa/vespa.py | 254 ++++++++++++++++++ vectordb_bench/cli/vectordbbench.py | 4 +- .../frontend/config/dbCaseConfigs.py | 57 ++++ vectordb_bench/frontend/config/styles.py | 2 + 11 files changed, 452 insertions(+), 2 deletions(-) create mode 100644 vectordb_bench/backend/clients/vespa/cli.py create mode 100644 vectordb_bench/backend/clients/vespa/config.py create mode 100644 vectordb_bench/backend/clients/vespa/util.py create mode 100644 vectordb_bench/backend/clients/vespa/vespa.py diff --git a/README.md b/README.md index 7a52c3242..6d83671da 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ All the database client supported | aliyun_opensearch | `pip install vectordb-bench[aliyun_opensearch]` | | mongodb | `pip install vectordb-bench[mongodb]` | | tidb | `pip install vectordb-bench[tidb]` | +| vespa | `pip install vectordb-bench[vespa]` | ### Run diff --git a/install/requirements_py3.11.txt b/install/requirements_py3.11.txt index 6969bf9dd..86958ada2 100644 --- a/install/requirements_py3.11.txt +++ b/install/requirements_py3.11.txt @@ -23,3 +23,4 @@ pydantic type[VectorDB]: # noqa: PLR0911, PLR0912, C901 @@ -157,6 +158,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 from .test.test import Test return Test + + if self == DB.Vespa: + from .vespa.vespa import Vespa + + return Vespa msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -273,6 +279,12 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 from .test.config import TestConfig return TestConfig + + if self == DB.Vespa: + from .vespa.config import VespaConfig + + return VespaConfig + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -365,6 +377,11 @@ def case_config_cls( # noqa: PLR0911 from .tidb.config import TiDBIndexConfig return TiDBIndexConfig + + if self == DB.Vespa: + from .vespa.config import VespaHNSWConfig + + return VespaHNSWConfig # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/vespa/cli.py b/vectordb_bench/backend/clients/vespa/cli.py new file mode 100644 index 000000000..616d46067 --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/cli.py @@ -0,0 +1,47 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.backend.clients import DB +from vectordb_bench.cli.cli import ( + CommonTypedDict, + HNSWFlavor1, + cli, + click_parameter_decorators_from_typed_dict, + run, +) + + +class VespaTypedDict(CommonTypedDict, HNSWFlavor1): + uri: Annotated[ + str, + click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"), + ] + port: Annotated[ + int, + click.option("--port", "-p", type=int, help="connection port", default=8080), + ] + quantization: Annotated[ + str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none") + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(VespaTypedDict) +def Vespa(**params: Unpack[VespaTypedDict]): + from .config import VespaConfig, VespaHNSWConfig + + case_params = { + "quantization_type": params["quantization"], + "M": params["m"], + "efConstruction": params["ef_construction"], + "ef": params["ef_search"], + } + + run( + db=DB.Vespa, + db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]), + db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}), + **params, + ) diff --git a/vectordb_bench/backend/clients/vespa/config.py b/vectordb_bench/backend/clients/vespa/config.py new file mode 100644 index 000000000..3d4a1deaf --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/config.py @@ -0,0 +1,51 @@ +from typing import Literal, TypeAlias + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, MetricType + +VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"] + +VespaQuantizationType: TypeAlias = Literal["none", "binary"] + + +class VespaConfig(DBConfig): + url: SecretStr = "http://127.0.0.1" + port: int = 8080 + + def to_dict(self): + return { + "url": self.url.get_secret_value(), + "port": self.port, + } + + +class VespaHNSWConfig(BaseModel, DBCaseConfig): + metric_type: MetricType = MetricType.COSINE + quantization_type: VespaQuantizationType = "none" + M: int = 16 + efConstruction: int = 200 + ef: int = 100 + + def index_param(self) -> dict: + return { + "distance_metric": self.parse_metric(self.metric_type), + "max_links_per_node": self.M, + "neighbors_to_explore_at_insert": self.efConstruction, + } + + def search_param(self) -> dict: + return {} + + def parse_metric(self, metric_type: MetricType) -> VespaMetric: + match metric_type: + case MetricType.COSINE: + return "angular" + case MetricType.L2: + return "euclidean" + case MetricType.DP | MetricType.IP: + return "dotproduct" + case MetricType.HAMMING: + return "hamming" + case _: + raise NotImplementedError diff --git a/vectordb_bench/backend/clients/vespa/util.py b/vectordb_bench/backend/clients/vespa/util.py new file mode 100644 index 000000000..97d03d644 --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/util.py @@ -0,0 +1,16 @@ +"""Utility functions for supporting binary quantization + +From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8 +""" +import numpy as np + + +def binarize_tensor(tensor: list[float]) -> list[int]: + """ + Binarize a floating-point list by thresholding at zero + and packing the bits into bytes. + """ + tensor = np.array(tensor) + return ( + np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist() + ) diff --git a/vectordb_bench/backend/clients/vespa/vespa.py b/vectordb_bench/backend/clients/vespa/vespa.py new file mode 100644 index 000000000..19bb3f5a4 --- /dev/null +++ b/vectordb_bench/backend/clients/vespa/vespa.py @@ -0,0 +1,254 @@ +import datetime +import logging +import math +from collections.abc import Generator +from contextlib import contextmanager + +from vespa import application + +from ..api import VectorDB +from .config import VespaHNSWConfig +from . import util + +log = logging.getLogger(__name__) + + +class Vespa(VectorDB): + def __init__( + self, + dim: int, + db_config: dict[str, str], + db_case_config: VespaHNSWConfig | None = None, + collection_name: str = "VectorDBBenchCollection", + drop_old: bool = False, + **kwargs, + ) -> None: + self.dim = dim + self.db_config = db_config + self.case_config = db_case_config or VespaHNSWConfig() + self.schema_name = collection_name + + client = self.deploy_http() + client.wait_for_application_up() + + if drop_old: + try: + client.delete_all_docs("vectordbbench_content", self.schema_name) + except Exception: + drop_old = False + log.exception(f"Vespa client drop_old schema: {self.schema_name}") + + @contextmanager + def init(self) -> Generator[None, None, None]: + """create and destory connections to database. + Why contextmanager: + + In multiprocessing search tasks, vectordbbench might init + totally hundreds of thousands of connections with DB server. + + Too many connections may drain local FDs or server connection resources. + If the DB client doesn't have `close()` method, just set the object to None. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.client = application.Vespa(self.db_config["url"], port=self.db_config["port"]) + yield + self.client = None + + def need_normalize_cosine(self) -> bool: + """Wheather this database need to normalize dataset to support COSINE""" + return False + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, + ) -> tuple[int, Exception | None]: + """Insert the embeddings to the vector database. The default number of embeddings for + each insert_embeddings is 5000. + + Args: + embeddings(list[list[float]]): list of embedding to add to the vector database. + metadatas(list[int]): metadata associated with the embeddings, for filtering. + **kwargs(Any): vector database specific parameters. + + Returns: + int: inserted data count + """ + assert self.client is not None + + data = ({"id": str(i), "fields": {"id": i, "embedding": e}} for i, e in zip(metadata, embeddings, strict=True)) + self.client.feed_iterable(data, self.schema_name) + return len(embeddings), None + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + ) -> list[int]: + """Get k most similar embeddings to query vector. + + Args: + query(list[float]): query embedding to look up documents similar to. + k(int): Number of most similar embeddings to return. Defaults to 100. + filters(dict, optional): filtering expression to filter the data while searching. + + Returns: + list[int]: list of k most similar embeddings IDs to the query embedding. + """ + assert self.client is not None + + ef = self.case_config.ef + extra_ef = max(0, ef - k) + embedding_field = "embedding" if self.case_config.quantization_type == "none" else "embedding_binary" + + yql = ( + f"select id from {self.schema_name} where " # noqa: S608 + f"{{targetHits: {k}, hnsw.exploreAdditionalHits: {extra_ef}}}" + f"nearestNeighbor({embedding_field}, query_embedding)" + ) + + if filters: + id_filter = filters.get("id") + yql += f" and id >= {id_filter}" + + query_embedding = ( + query + if self.case_config.quantization_type == "none" + else util.binarize_tensor(query) + ) + + ranking = self.case_config.quantization_type + + result = self.client.query({"yql": yql, "input.query(query_embedding)": query_embedding, "ranking": ranking}) + result_ids = [child["fields"]["id"] for child in result.get_json()["root"]["children"]] + return result_ids + + def optimize(self, data_size: int | None = None): + """optimize will be called between insertion and search in performance cases. + + Should be blocked until the vectorDB is ready to be tested on + heavy performance cases. + + Time(insert the dataset) + Time(optimize) will be recorded as "load_duration" metric + Optimize's execution time is limited, the limited time is based on cases. + """ + + @property + def application_package(self): + if getattr(self, "_application_package", None) is None: + self._application_package = self._create_application_package() + return self._application_package + + def _create_application_package(self): + from vespa.package import ( + HNSW, + ApplicationPackage, + Document, + Field, + RankProfile, + Schema, + Validation, + ValidationID, + ) + + fields = [ + Field( + "id", + "int", + indexing=["summary", "attribute"], + ), + Field( + "embedding", + f"tensor(x[{self.dim}])", + indexing=["summary", "attribute", "index"], + ann=HNSW(**self.case_config.index_param()), + ), + ] + + if self.case_config.quantization_type == "binary": + fields.append( + Field( + "embedding_binary", + f"tensor(x[{math.ceil(self.dim / 8)}])", + indexing=[ + "input embedding", + # convert 32 bit float to 1 bit + "binarize", + # pack 8 bits into one int8 + "pack_bits", + "summary", + "attribute", + "index", + ], + ann=HNSW(**{**self.case_config.index_param(), "distance_metric": "hamming"}), + is_document_field=False, + ) + ) + + tomorrow = datetime.date.today() + datetime.timedelta(days=1) + + return ApplicationPackage( + "vectordbbench", + [ + Schema( + self.schema_name, + Document( + fields, + ), + rank_profiles=[ + RankProfile( + name="none", + first_phase="", + inherits="default", + inputs=[("query(query_embedding)", f"tensor(x[{self.dim}])")], + ), + RankProfile( + name="binary", + first_phase="", + inherits="default", + inputs=[("query(query_embedding)", f"tensor(x[{math.ceil(self.dim / 8)}])")], + ), + ], + ) + ], + validations=[ + Validation(ValidationID.tensorTypeChange, until=tomorrow), + Validation(ValidationID.fieldTypeChange, until=tomorrow), + ], + ) + + def deploy_http(self) -> application.Vespa: + """ + Deploy a Vespa application package via HTTP REST API. + + Returns: + application.Vespa: The deployed Vespa application instance + """ + import requests + + url = self.db_config["url"] + ":19071/application/v2/tenant/default/prepareandactivate" + package_data = self.application_package.to_zip() + headers = {"Content-Type": "application/zip"} + + try: + response = requests.post(url=url, data=package_data, headers=headers, timeout=10) + + response.raise_for_status() + result = response.json() + return application.Vespa( + url=self.db_config["url"], + port=self.db_config["port"], + deployment_message=result.get("message"), + application_package=self.application_package, + ) + + except requests.exceptions.RequestException as e: + error_msg = f"Failed to deploy Vespa application: {e!s}" + if hasattr(e, "response") and e.response is not None: + error_msg += f" - Response: {e.response.text}" + raise RuntimeError(error_msg) from e diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 7f5e24fc2..7469ec9d5 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -9,9 +9,10 @@ from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn from ..backend.clients.redis.cli import Redis from ..backend.clients.test.cli import Test +from ..backend.clients.tidb.cli import TiDB +from ..backend.clients.vespa.cli import Vespa from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex -from ..backend.clients.tidb.cli import TiDB from ..backend.clients.clickhouse.cli import Clickhouse from .cli import cli @@ -31,6 +32,7 @@ cli.add_command(MariaDBHNSW) cli.add_command(TiDB) cli.add_command(Clickhouse) +cli.add_command(Vespa) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 0ab3a932b..e85b42ff3 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1137,6 +1137,50 @@ class CaseConfigInput(BaseModel): ) +CaseConfigParamInput_M_Vespa = CaseConfigInput( + label=CaseConfigParamType.M, + inputType=InputType.Number, + inputConfig={ + "min": 4, + "max": 64, + "value": 16, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_IndexType_Vespa = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_QuantizationType_Vespa = CaseConfigInput( + label=CaseConfigParamType.quantizationType, + inputType=InputType.Option, + inputConfig={ + "options": [ + "none", + "binary" + ], + }, +) + +CaseConfigParamInput_EFConstruction_Vespa = CaseConfigInput( + label=CaseConfigParamType.EFConstruction, + inputType=InputType.Number, + inputConfig={ + "min": 8, + "max": 512, + "value": 200, + }, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, +) + + MilvusLoadConfig = [ CaseConfigParamInput_IndexType, CaseConfigParamInput_M, @@ -1344,6 +1388,15 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_EFSearch_MariaDB, ] +VespaLoadingConfig = [ + CaseConfigParamInput_IndexType_Vespa, + CaseConfigParamInput_QuantizationType_Vespa, + CaseConfigParamInput_M_Vespa, + CaseConfigParamInput_EF_Milvus, + CaseConfigParamInput_EFConstruction_Vespa, +] +VespaPerformanceConfig = VespaLoadingConfig + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -1400,4 +1453,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: MariaDBLoadingConfig, CaseLabel.Performance: MariaDBPerformanceConfig, }, + DB.Vespa: { + CaseLabel.Load: VespaLoadingConfig, + CaseLabel.Performance: VespaPerformanceConfig, + }, } diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index 57456722f..4418c19da 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -48,6 +48,7 @@ def getPatternShape(i): DB.Chroma: "https://assets.zilliz.com/chroma_ceb3f06ed7.png", DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", + DB.Vespa: "https://vespa.ai/vespa-content/uploads/2025/01/Vespa-symbol-green-rgb.png.webp", } # RedisCloud color: #0D6EFD @@ -63,4 +64,5 @@ def getPatternShape(i): DB.Redis.value: "#0D6EFD", DB.AWSOpenSearch.value: "#0DCAF0", DB.TiDB.value: "#0D6EFD", + DB.Vespa.value: "#61d790" } From 1ab262785696368c8770c8454fd26488e8ee5cdc Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Mon, 14 Apr 2025 02:49:50 +0000 Subject: [PATCH 20/36] remove redundant empty_field config check for qdrant and tidb Signed-off-by: min.tian --- .../backend/clients/clickhouse/config.py | 18 +++++++++++------- vectordb_bench/backend/clients/mariadb/cli.py | 5 ++--- .../backend/clients/mariadb/config.py | 13 ++++++++----- .../backend/clients/qdrant_cloud/config.py | 10 +--------- vectordb_bench/backend/clients/tidb/config.py | 15 ++++++--------- 5 files changed, 28 insertions(+), 33 deletions(-) diff --git a/vectordb_bench/backend/clients/clickhouse/config.py b/vectordb_bench/backend/clients/clickhouse/config.py index 7ce0919ea..fad446049 100644 --- a/vectordb_bench/backend/clients/clickhouse/config.py +++ b/vectordb_bench/backend/clients/clickhouse/config.py @@ -1,9 +1,10 @@ -from typing import TypedDict from pydantic import BaseModel, SecretStr -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + class ClickhouseConfig(DBConfig): - user_name: str = "clickhouse" + user_name: str = "clickhouse" password: SecretStr host: str = "localhost" port: int = 8123 @@ -16,7 +17,7 @@ def to_dict(self) -> dict: "port": self.port, "dbname": self.db_name, "user": self.user_name, - "password": pwd_str + "password": pwd_str, } @@ -32,8 +33,11 @@ def parse_metric(self) -> str: def parse_metric_str(self) -> str: if self.metric_type == MetricType.L2: return "L2Distance" - elif self.metric_type == MetricType.COSINE: + if self.metric_type == MetricType.COSINE: return "cosineDistance" + msg = f"Not Support for {self.metric_type}" + raise RuntimeError(msg) + return None class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig): @@ -51,6 +55,6 @@ def index_param(self) -> dict: def search_param(self) -> dict: return { - "metric_type": self.parse_metric_str(), + "met˝ric_type": self.parse_metric_str(), "params": {"ef": self.ef}, - } \ No newline at end of file + } diff --git a/vectordb_bench/backend/clients/mariadb/cli.py b/vectordb_bench/backend/clients/mariadb/cli.py index c5439f37d..17717c38d 100644 --- a/vectordb_bench/backend/clients/mariadb/cli.py +++ b/vectordb_bench/backend/clients/mariadb/cli.py @@ -1,17 +1,16 @@ from typing import Annotated, Optional, Unpack import click -import os from pydantic import SecretStr +from vectordb_bench.backend.clients import DB + from ....cli.cli import ( CommonTypedDict, - HNSWFlavor1, cli, click_parameter_decorators_from_typed_dict, run, ) -from vectordb_bench.backend.clients import DB class MariaDBTypedDict(CommonTypedDict): diff --git a/vectordb_bench/backend/clients/mariadb/config.py b/vectordb_bench/backend/clients/mariadb/config.py index c7b2cd5fe..50d0b55c5 100644 --- a/vectordb_bench/backend/clients/mariadb/config.py +++ b/vectordb_bench/backend/clients/mariadb/config.py @@ -1,6 +1,9 @@ -from pydantic import SecretStr, BaseModel from typing import TypedDict -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + class MariaDBConfigDict(TypedDict): """These keys will be directly used as kwargs in mariadb connection string, @@ -36,10 +39,10 @@ class MariaDBIndexConfig(BaseModel): def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "euclidean" - elif self.metric_type == MetricType.COSINE: + if self.metric_type == MetricType.COSINE: return "cosine" - else: - raise ValueError(f"Metric type {self.metric_type} is not supported!") + msg = f"Metric type {self.metric_type} is not supported!" + raise ValueError(msg) class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): M: int | None diff --git a/vectordb_bench/backend/clients/qdrant_cloud/config.py b/vectordb_bench/backend/clients/qdrant_cloud/config.py index c1d6882c0..d4e27cb3c 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/config.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/config.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, SecretStr, validator +from pydantic import BaseModel, SecretStr from ..api import DBCaseConfig, DBConfig, MetricType @@ -20,14 +20,6 @@ def to_dict(self) -> dict: "url": self.url.get_secret_value(), } - @validator("*") - def not_empty_field(cls, v: any, field: any): - if field.name in ["api_key", "db_label"]: - return v - if isinstance(v, str | SecretStr) and len(v) == 0: - raise ValueError("Empty string!") - return v - class QdrantIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None diff --git a/vectordb_bench/backend/clients/tidb/config.py b/vectordb_bench/backend/clients/tidb/config.py index 213a18bc5..693551045 100644 --- a/vectordb_bench/backend/clients/tidb/config.py +++ b/vectordb_bench/backend/clients/tidb/config.py @@ -1,5 +1,6 @@ -from pydantic import SecretStr, BaseModel, validator -from ..api import DBConfig, DBCaseConfig, MetricType +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, MetricType class TiDBConfig(DBConfig): @@ -10,10 +11,6 @@ class TiDBConfig(DBConfig): db_name: str = "test" ssl: bool = False - @validator("*") - def not_empty_field(cls, v: any, field: any): - return v - def to_dict(self) -> dict: pwd_str = self.password.get_secret_value() return { @@ -33,10 +30,10 @@ class TiDBIndexConfig(BaseModel, DBCaseConfig): def get_metric_fn(self) -> str: if self.metric_type == MetricType.L2: return "vec_l2_distance" - elif self.metric_type == MetricType.COSINE: + if self.metric_type == MetricType.COSINE: return "vec_cosine_distance" - else: - raise ValueError(f"Unsupported metric type: {self.metric_type}") + msg = f"Unsupported metric type: {self.metric_type}" + raise ValueError(msg) def index_param(self) -> dict: return { From 05203c012fed5cc092b0e3151b0ec66a447171c6 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Mon, 14 Apr 2025 03:39:01 +0000 Subject: [PATCH 21/36] reformat all Signed-off-by: min.tian --- vectordb_bench/backend/clients/__init__.py | 9 +- vectordb_bench/backend/clients/api.py | 2 +- .../backend/clients/chroma/chroma.py | 4 +- .../backend/clients/clickhouse/cli.py | 2 +- .../backend/clients/clickhouse/clickhouse.py | 79 +++++++------ .../clients/elastic_cloud/elastic_cloud.py | 2 +- vectordb_bench/backend/clients/mariadb/cli.py | 100 +++++++++------- .../backend/clients/mariadb/config.py | 7 +- .../backend/clients/mariadb/mariadb.py | 110 +++++++++--------- vectordb_bench/backend/clients/milvus/cli.py | 2 + .../backend/clients/milvus/config.py | 1 - .../backend/clients/milvus/milvus.py | 2 +- .../backend/clients/pgvector/cli.py | 3 +- .../backend/clients/pinecone/pinecone.py | 2 +- .../clients/qdrant_cloud/qdrant_cloud.py | 2 +- vectordb_bench/backend/clients/tidb/tidb.py | 35 +++--- vectordb_bench/backend/clients/vespa/util.py | 5 +- vectordb_bench/backend/clients/vespa/vespa.py | 11 +- .../clients/weaviate_cloud/weaviate_cloud.py | 2 +- vectordb_bench/cli/cli.py | 37 +++--- vectordb_bench/cli/vectordbbench.py | 2 +- .../frontend/config/dbCaseConfigs.py | 16 +-- vectordb_bench/frontend/config/styles.py | 2 +- vectordb_bench/models.py | 11 +- 24 files changed, 226 insertions(+), 222 deletions(-) diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 732f427a7..cfef0283f 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -158,7 +158,7 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 from .test.test import Test return Test - + if self == DB.Vespa: from .vespa.vespa import Vespa @@ -279,17 +279,16 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 from .test.config import TestConfig return TestConfig - + if self == DB.Vespa: from .vespa.config import VespaConfig return VespaConfig - msg = f"Unknown DB: {self.name}" raise ValueError(msg) - def case_config_cls( # noqa: PLR0911 + def case_config_cls( # noqa: C901, PLR0911, PLR0912 self, index_type: IndexType | None = None, ) -> type[DBCaseConfig]: @@ -377,7 +376,7 @@ def case_config_cls( # noqa: PLR0911 from .tidb.config import TiDBIndexConfig return TiDBIndexConfig - + if self == DB.Vespa: from .vespa.config import VespaHNSWConfig diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index e498ab077..ce2b05650 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -162,7 +162,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert the embeddings to the vector database. The default number of embeddings for each insert_embeddings is 5000. diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index 76c810263..26a810065 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -65,7 +65,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into the database. Args: @@ -74,7 +74,7 @@ def insert_embeddings( kwargs: other arguments Returns: - (int, Exception): number of embeddings inserted and exception if any + tuple[int, Exception]: number of embeddings inserted and exception if any """ ids = [str(i) for i in metadata] metadata = [{"id": int(i)} for i in metadata] diff --git a/vectordb_bench/backend/clients/clickhouse/cli.py b/vectordb_bench/backend/clients/clickhouse/cli.py index e454f5a75..6fc1a84d7 100644 --- a/vectordb_bench/backend/clients/clickhouse/cli.py +++ b/vectordb_bench/backend/clients/clickhouse/cli.py @@ -18,7 +18,7 @@ class ClickhouseTypedDict(TypedDict): password: Annotated[str, click.option("--password", type=str, help="DB password")] host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)] port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")] - user: Annotated[int, click.option("--user", type=str, default='clickhouse', help="DB user")] + user: Annotated[int, click.option("--user", type=str, default="clickhouse", help="DB user")] ssl: Annotated[ bool, click.option( diff --git a/vectordb_bench/backend/clients/clickhouse/clickhouse.py b/vectordb_bench/backend/clients/clickhouse/clickhouse.py index e5339170b..e241cdb8d 100644 --- a/vectordb_bench/backend/clients/clickhouse/clickhouse.py +++ b/vectordb_bench/backend/clients/clickhouse/clickhouse.py @@ -1,18 +1,19 @@ """Wrapper around the Clickhouse vector database over VectorDB""" -import io import logging from contextlib import contextmanager from typing import Any + import clickhouse_connect -import numpy as np -from ..api import VectorDB, DBCaseConfig +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) + class Clickhouse(VectorDB): """Use SQLAlchemy instructions""" + def __init__( self, dim: int, @@ -32,12 +33,13 @@ def __init__( self._vector_field = "embedding" # construct basic units - self.conn = clickhouse_connect.get_client( - host=self.db_config["host"], - port=self.db_config["port"], - username=self.db_config["user"], - password=self.db_config["password"], - database=self.db_config["dbname"]) + self.conn = clickhouse_connect.get_client( + host=self.db_config["host"], + port=self.db_config["port"], + username=self.db_config["user"], + password=self.db_config["password"], + database=self.db_config["dbname"], + ) if drop_old: log.info(f"Clickhouse client drop table : {self.table_name}") @@ -48,7 +50,7 @@ def __init__( self.conn = None @contextmanager - def init(self) -> None: + def init(self): """ Examples: >>> with self.init(): @@ -56,12 +58,13 @@ def init(self) -> None: >>> self.search_embedding() """ - self.conn = clickhouse_connect.get_client( - host=self.db_config["host"], - port=self.db_config["port"], - username=self.db_config["user"], - password=self.db_config["password"], - database=self.db_config["dbname"]) + self.conn = clickhouse_connect.get_client( + host=self.db_config["host"], + port=self.db_config["port"], + username=self.db_config["user"], + password=self.db_config["password"], + database=self.db_config["dbname"], + ) try: yield @@ -85,9 +88,7 @@ def _create_table(self, dim: int): ) except Exception as e: - log.warning( - f"Failed to create Clickhouse table: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}") raise e from None def ready_to_load(self): @@ -104,16 +105,20 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> (int, Exception): + ) -> tuple[int, Exception]: assert self.conn is not None, "Connection is not initialized" try: # do not iterate for bulk insert items = [metadata, embeddings] - self.conn.insert(table=self.table_name, data=items, - column_names=['id', 'embedding'], column_type_names=['UInt32', 'Array(Float64)'], - column_oriented=True) + self.conn.insert( + table=self.table_name, + data=items, + column_names=["id", "embedding"], + column_type_names=["UInt32", "Array(Float64)"], + column_oriented=True, + ) return len(metadata), None except Exception as e: log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}") @@ -128,22 +133,24 @@ def search_embedding( ) -> list[int]: assert self.conn is not None, "Connection is not initialized" - index_param = self.case_config.index_param() + index_param = self.case_config.index_param() # noqa: F841 search_param = self.case_config.search_param() if filters: gt = filters.get("id") - filterSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' - f'FROM {self.db_config["dbname"]}.{self.table_name} ' - f'WHERE id > {gt} ' - f'ORDER BY score LIMIT {k};' - ) - result = self.conn.query(filterSql).result_rows + filter_sql = ( + f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608 + f'FROM {self.db_config["dbname"]}.{self.table_name} ' + f"WHERE id > {gt} " + f"ORDER BY score LIMIT {k};" + ) + result = self.conn.query(filter_sql).result_rows return [int(row[0]) for row in result] - else: - selectSql = (f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' - f'FROM {self.db_config["dbname"]}.{self.table_name} ' - f'ORDER BY score LIMIT {k};' - ) - result = self.conn.query(selectSql).result_rows + else: # noqa: RET505 + select_sql = ( + f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608 + f'FROM {self.db_config["dbname"]}.{self.table_name} ' + f"ORDER BY score LIMIT {k};" + ) + result = self.conn.query(select_sql).result_rows return [int(row[0]) for row in result] diff --git a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py index ea038c587..7d201729f 100644 --- a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +++ b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py @@ -81,7 +81,7 @@ def insert_embeddings( embeddings: Iterable[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert the embeddings to the elasticsearch.""" assert self.client is not None, "should self.init() first" diff --git a/vectordb_bench/backend/clients/mariadb/cli.py b/vectordb_bench/backend/clients/mariadb/cli.py index 17717c38d..969247271 100644 --- a/vectordb_bench/backend/clients/mariadb/cli.py +++ b/vectordb_bench/backend/clients/mariadb/cli.py @@ -1,4 +1,4 @@ -from typing import Annotated, Optional, Unpack +from typing import Annotated, Unpack import click from pydantic import SecretStr @@ -15,68 +15,84 @@ class MariaDBTypedDict(CommonTypedDict): user_name: Annotated[ - str, click.option("--username", - type=str, - help="Username", - required=True, - ), + str, + click.option( + "--username", + type=str, + help="Username", + required=True, + ), ] password: Annotated[ - str, click.option("--password", - type=str, - help="Password", - required=True, - ), + str, + click.option( + "--password", + type=str, + help="Password", + required=True, + ), ] host: Annotated[ - str, click.option("--host", - type=str, - help="Db host", - default="127.0.0.1", - ), + str, + click.option( + "--host", + type=str, + help="Db host", + default="127.0.0.1", + ), ] port: Annotated[ - int, click.option("--port", - type=int, - default=3306, - help="Db Port", - ), + int, + click.option( + "--port", + type=int, + default=3306, + help="Db Port", + ), ] storage_engine: Annotated[ - int, click.option("--storage-engine", - type=click.Choice(["InnoDB", "MyISAM"]), - help="DB storage engine", - required=True, - ), + int, + click.option( + "--storage-engine", + type=click.Choice(["InnoDB", "MyISAM"]), + help="DB storage engine", + required=True, + ), ] + class MariaDBHNSWTypedDict(MariaDBTypedDict): - ... m: Annotated[ - Optional[int], click.option("--m", - type=int, - help="M parameter in MHNSW vector indexing", - required=False, - ), + int | None, + click.option( + "--m", + type=int, + help="M parameter in MHNSW vector indexing", + required=False, + ), ] ef_search: Annotated[ - Optional[int], click.option("--ef-search", - type=int, - help="MariaDB system variable mhnsw_min_limit", - required=False, - ), + int | None, + click.option( + "--ef-search", + type=int, + help="MariaDB system variable mhnsw_min_limit", + required=False, + ), ] max_cache_size: Annotated[ - Optional[int], click.option("--max-cache-size", - type=int, - help="MariaDB system variable mhnsw_max_cache_size", - required=False, - ), + int | None, + click.option( + "--max-cache-size", + type=int, + help="MariaDB system variable mhnsw_max_cache_size", + required=False, + ), ] diff --git a/vectordb_bench/backend/clients/mariadb/config.py b/vectordb_bench/backend/clients/mariadb/config.py index 50d0b55c5..d183adc76 100644 --- a/vectordb_bench/backend/clients/mariadb/config.py +++ b/vectordb_bench/backend/clients/mariadb/config.py @@ -7,7 +7,7 @@ class MariaDBConfigDict(TypedDict): """These keys will be directly used as kwargs in mariadb connection string, - so the names must match exactly mariadb API""" + so the names must match exactly mariadb API""" user: str password: str @@ -44,6 +44,7 @@ def parse_metric(self) -> str: msg = f"Metric type {self.metric_type} is not supported!" raise ValueError(msg) + class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): M: int | None ef_search: int | None @@ -68,7 +69,5 @@ def search_param(self) -> dict: _mariadb_case_config = { - IndexType.HNSW: MariaDBHNSWConfig, + IndexType.HNSW: MariaDBHNSWConfig, } - - diff --git a/vectordb_bench/backend/clients/mariadb/mariadb.py b/vectordb_bench/backend/clients/mariadb/mariadb.py index 42b621d9c..5ccddfe7a 100644 --- a/vectordb_bench/backend/clients/mariadb/mariadb.py +++ b/vectordb_bench/backend/clients/mariadb/mariadb.py @@ -1,27 +1,25 @@ -from ..api import VectorDB - import logging from contextlib import contextmanager -from typing import Any, Optional, Tuple -from ..api import VectorDB -from .config import MariaDBConfigDict, MariaDBIndexConfig -import numpy as np import mariadb +import numpy as np + +from ..api import VectorDB +from .config import MariaDBConfigDict, MariaDBIndexConfig log = logging.getLogger(__name__) + class MariaDB(VectorDB): def __init__( - self, - dim: int, - db_config: MariaDBConfigDict, - db_case_config: MariaDBIndexConfig, - collection_name: str = "vec_collection", - drop_old: bool = False, - **kwargs, - ): - + self, + dim: int, + db_config: MariaDBConfigDict, + db_case_config: MariaDBIndexConfig, + collection_name: str = "vec_collection", + drop_old: bool = False, + **kwargs, + ): self.name = "MariaDB" self.db_config = db_config self.case_config = db_case_config @@ -31,7 +29,7 @@ def __init__( # construct basic units self.conn, self.cursor = self._create_connection(**self.db_config) - + if drop_old: self._drop_db() self._create_db_table(dim) @@ -41,9 +39,8 @@ def __init__( self.cursor = None self.conn = None - @staticmethod - def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]: + def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]: conn = mariadb.connect(**kwargs) cursor = conn.cursor() @@ -52,7 +49,6 @@ def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]: return conn, cursor - def _drop_db(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -77,24 +73,23 @@ def _create_db_table(self, dim: int): log.info(f"{self.name} client create table : {self.table_name}") self.cursor.execute(f"USE {self.db_name}") - self.cursor.execute(f""" + self.cursor.execute( + f""" CREATE TABLE {self.table_name} ( id INT PRIMARY KEY, v VECTOR({self.dim}) NOT NULL ) ENGINE={index_param["storage_engine"]} - """) + """ + ) self.cursor.execute("COMMIT") except Exception as e: - log.warning( - f"Failed to create table: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create table: {self.table_name} error: {e}") raise e from None - @contextmanager - def init(self) -> None: - """ create and destory connections to database. + def init(self): + """create and destory connections to database. Examples: >>> with self.init(): @@ -109,15 +104,21 @@ def init(self) -> None: self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824") if index_param["index_type"] == "HNSW": - if index_param["max_cache_size"] != None: - self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}") - if search_param["ef_search"] != None: - self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}") + if index_param["max_cache_size"] is not None: + self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param['max_cache_size']}") + if search_param["ef_search"] is not None: + self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}") self.cursor.execute("COMMIT") - self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" - self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d" - self.select_sql_with_filter = f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d" + self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608 + self.select_sql = ( + f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608 + f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d" + ) + self.select_sql_with_filter = ( + f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608 + f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d" + ) try: yield @@ -126,7 +127,6 @@ def init(self) -> None: self.conn.close() self.cursor = None self.conn = None - def ready_to_load(self) -> bool: pass @@ -139,33 +139,31 @@ def optimize(self) -> None: try: index_options = f"DISTANCE={index_param['metric_type']}" - if index_param["index_type"] == "HNSW" and index_param["M"] != None: + if index_param["index_type"] == "HNSW" and index_param["M"] is not None: index_options += f" M={index_param['M']}" - self.cursor.execute(f""" + self.cursor.execute( + f""" ALTER TABLE {self.db_name}.{self.table_name} ADD VECTOR KEY v(v) {index_options} - """) + """ + ) self.cursor.execute("COMMIT") except Exception as e: - log.warning( - f"Failed to create index: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create index: {self.table_name} error: {e}") raise e from None - pass - @staticmethod - def vector_to_hex(v): - return np.array(v, 'float32').tobytes() + def vector_to_hex(v): # noqa: ANN001 + return np.array(v, "float32").tobytes() def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], - **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + **kwargs, + ) -> tuple[int, Exception]: """Insert embeddings into the database. Should call self.init() first. """ @@ -178,7 +176,7 @@ def insert_embeddings( batch_data = [] for i, row in enumerate(metadata_arr): - batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i]))); + batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i]))) self.cursor.executemany(self.insert_sql, batch_data) self.cursor.execute("COMMIT") @@ -186,11 +184,8 @@ def insert_embeddings( return len(metadata), None except Exception as e: - log.warning( - f"Failed to insert data into Vector table ({self.table_name}), error: {e}" - ) + log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}") return 0, e - def search_embedding( self, @@ -198,17 +193,16 @@ def search_embedding( k: int = 100, filters: dict | None = None, timeout: int | None = None, - **kwargs: Any, - ) -> (list[int]): + **kwargs, + ) -> list[int]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" - search_param = self.case_config.search_param() + search_param = self.case_config.search_param() # noqa: F841 if filters: - self.cursor.execute(self.select_sql_with_filter, (filters.get('id'), self.vector_to_hex(query), k)) + self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k)) else: self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k)) - return [id for id, in self.cursor.fetchall()] - + return [id for (id,) in self.cursor.fetchall()] # noqa: A001 diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 52524e785..24a61566f 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -194,6 +194,7 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): **parameters, ) + class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict): metric_type: Annotated[ str, @@ -204,6 +205,7 @@ class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict): click.option("--limit", type=int, required=True, help="Top-k limit for search"), ] + @cli.command() @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict) def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]): diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 1ff3bea5f..e3a3f9b19 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -215,7 +215,6 @@ def search_param(self) -> dict: } - class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig): nlist: int = 1024 m: int = 0 diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 4015eb1f3..c812698fe 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -155,7 +155,7 @@ def insert_embeddings( embeddings: Iterable[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into Milvus. should call self.init() first""" # use the first insert_embeddings to init collection assert self.col is not None diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index 1780af991..f8c138802 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -18,8 +18,7 @@ ) -# ruff: noqa -def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): +def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): # noqa: ARG001 if ctx.params.get("reranking") and value is None: # ef_search is the default value for quantized_fetch_limit as it's bound by ef_search. # 100 is default value for quantized_fetch_limit for IVFFlat. diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index 1a681b33f..9fd0afc7c 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -67,7 +67,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: assert len(embeddings) == len(metadata) insert_count = 0 try: diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index 5de72798b..f618c3ba4 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -111,7 +111,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into Milvus. should call self.init() first""" assert self.qdrant_client is not None try: diff --git a/vectordb_bench/backend/clients/tidb/tidb.py b/vectordb_bench/backend/clients/tidb/tidb.py index d1f26084e..b75605eda 100644 --- a/vectordb_bench/backend/clients/tidb/tidb.py +++ b/vectordb_bench/backend/clients/tidb/tidb.py @@ -3,7 +3,7 @@ import logging import time from contextlib import contextmanager -from typing import Any, Optional, Tuple +from typing import Any import pymysql @@ -62,7 +62,7 @@ def _drop_table(self): conn.commit() except Exception as e: log.warning("Failed to drop table: %s error: %s", self.table_name, e) - raise e + raise def _create_table(self): try: @@ -80,7 +80,7 @@ def _create_table(self): conn.commit() except Exception as e: log.warning("Failed to create table: %s error: %s", self.table_name, e) - raise e + raise def ready_to_load(self) -> bool: pass @@ -122,25 +122,25 @@ def _optimize_check_tiflash_replica_progress(self): f""" SELECT PROGRESS FROM information_schema.tiflash_replica WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}" - """ + """ # noqa: S608 ) result = cursor.fetchone() return result[0] except Exception as e: log.warning("Failed to check TiFlash replica progress: %s", e) - raise e + raise def _optimize_wait_tiflash_catch_up(self): try: with self._get_connection() as (conn, cursor): cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"') conn.commit() - cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") + cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608 result = cursor.fetchone() return result[0] except Exception as e: log.warning("Failed to wait TiFlash to catch up: %s", e) - raise e + raise def _optimize_compact_tiflash(self): try: @@ -149,7 +149,7 @@ def _optimize_compact_tiflash(self): conn.commit() except Exception as e: log.warning("Failed to compact table: %s", e) - raise e + raise def _optimize_get_tiflash_index_pending_rows(self): try: @@ -160,13 +160,13 @@ def _optimize_get_tiflash_index_pending_rows(self): SELECT SUM(ROWS_STABLE_NOT_INDEXED) FROM information_schema.tiflash_indexes WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}" - """ + """ # noqa: S608 ) result = cursor.fetchone() return result[0] except Exception as e: log.warning("Failed to read TiFlash index pending rows: %s", e) - raise e + raise def _insert_embeddings_serial( self, @@ -178,29 +178,28 @@ def _insert_embeddings_serial( try: with self._get_connection() as (conn, cursor): buf = io.StringIO() - buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") + buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608 for i in range(offset, offset + size): if i > offset: buf.write(",") - buf.write(f'({metadata[i]}, "{str(embeddings[i])}")') + buf.write(f'({metadata[i]}, "{embeddings[i]!s}")') cursor.execute(buf.getvalue()) conn.commit() except Exception as e: log.warning("Failed to insert data into table: %s", e) - raise e + raise def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception]: workers = 10 # Avoid exceeding MAX_ALLOWED_PACKET (default=64MB) max_batch_size = 64 * 1024 * 1024 // 24 // self.dim batch_size = len(embeddings) // workers - if batch_size > max_batch_size: - batch_size = max_batch_size + batch_size = min(batch_size, max_batch_size) with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: futures = [] for i in range(0, len(embeddings), batch_size): @@ -227,8 +226,8 @@ def search_embedding( self.cursor.execute( f""" SELECT id FROM {self.table_name} - ORDER BY {self.search_fn}(embedding, "{str(query)}") LIMIT {k}; - """ + ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k}; + """ # noqa: S608 ) result = self.cursor.fetchall() return [int(i[0]) for i in result] diff --git a/vectordb_bench/backend/clients/vespa/util.py b/vectordb_bench/backend/clients/vespa/util.py index 97d03d644..7a64cc30d 100644 --- a/vectordb_bench/backend/clients/vespa/util.py +++ b/vectordb_bench/backend/clients/vespa/util.py @@ -2,6 +2,7 @@ From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8 """ + import numpy as np @@ -11,6 +12,4 @@ def binarize_tensor(tensor: list[float]) -> list[int]: and packing the bits into bytes. """ tensor = np.array(tensor) - return ( - np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist() - ) + return np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist() diff --git a/vectordb_bench/backend/clients/vespa/vespa.py b/vectordb_bench/backend/clients/vespa/vespa.py index 19bb3f5a4..5288bc04c 100644 --- a/vectordb_bench/backend/clients/vespa/vespa.py +++ b/vectordb_bench/backend/clients/vespa/vespa.py @@ -7,8 +7,8 @@ from vespa import application from ..api import VectorDB -from .config import VespaHNSWConfig from . import util +from .config import VespaHNSWConfig log = logging.getLogger(__name__) @@ -116,17 +116,12 @@ def search_embedding( id_filter = filters.get("id") yql += f" and id >= {id_filter}" - query_embedding = ( - query - if self.case_config.quantization_type == "none" - else util.binarize_tensor(query) - ) + query_embedding = query if self.case_config.quantization_type == "none" else util.binarize_tensor(query) ranking = self.case_config.quantization_type result = self.client.query({"yql": yql, "input.query(query_embedding)": query_embedding, "ranking": ranking}) - result_ids = [child["fields"]["id"] for child in result.get_json()["root"]["children"]] - return result_ids + return [child["fields"]["id"] for child in result.get_json()["root"]["children"]] def optimize(self, data_size: int | None = None): """optimize will be called between insertion and search in performance cases. diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py index aa4368bb7..c31104d8b 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py @@ -99,7 +99,7 @@ def insert_embeddings( embeddings: Iterable[list[float]], metadata: list[int], **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: """Insert embeddings into Weaviate""" assert self.client.schema.exists(self.collection_name) insert_count = 0 diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 3bb7763d8..bdf1a25f4 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -1,9 +1,9 @@ import logging -import os import time from collections.abc import Callable from concurrent.futures import wait from datetime import datetime +from pathlib import Path from pprint import pformat from typing import ( Annotated, @@ -38,18 +38,16 @@ from yaml import Loader -def click_get_defaults_from_file(ctx, param, value): +def click_get_defaults_from_file(ctx, param, value): # noqa: ANN001, ARG001 if value: - if os.path.exists(value): - input_file = value - else: - input_file = os.path.join(config.CONFIG_LOCAL_DIR, value) + input_file = value if Path.exists(value) else Path.join(config.CONFIG_LOCAL_DIR, value) try: - with open(input_file) as f: - _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader) + with Path.open(input_file) as f: + _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader) # noqa: S506 ctx.default_map = _config.get(ctx.command.name, {}) except Exception as e: - raise click.BadParameter(f"Failed to load config file: {e}") + msg = f"Failed to load config file: {e}" + raise click.BadParameter(msg) from e return value @@ -68,12 +66,16 @@ def click_parameter_decorators_from_typed_dict( For clarity, the key names of the TypedDict will be used to determine the type hints for the input parameters. - The actual function parameters are controlled by the click.option definitions. You must manually ensure these are aligned in a sensible way! + The actual function parameters are controlled by the click.option definitions. + You must manually ensure these are aligned in a sensible way! Example: ``` class CommonTypedDict(TypedDict): - z: Annotated[int, click.option("--z/--no-z", is_flag=True, type=bool, help="help z", default=True, show_default=True)] + z: Annotated[ + int, + click.option("--z/--no-z", is_flag=True, type=bool, help="help z", default=True, show_default=True) + ] name: Annotated[str, click.argument("name", required=False, default="Jeff")] class FooTypedDict(CommonTypedDict): @@ -91,14 +93,16 @@ def foo(**parameters: Unpack[FooTypedDict]): for _, t in get_type_hints(typed_dict, include_extras=True).items(): assert get_origin(t) is Annotated if len(t.__metadata__) == 1 and t.__metadata__[0].__module__ == "click.decorators": - # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1) + # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] + # with no additional metadata defined (len=1) decorators.append(t.__metadata__[0]) else: raise RuntimeError( - "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring", + "Click-TypedDict decorator parsing must only contain root type " + "and a click decorator like click.option. See docstring", ) - def deco(f): + def deco(f): # noqa: ANN001 for dec in reversed(decorators): f = dec(f) return f @@ -106,7 +110,7 @@ def deco(f): return deco -def click_arg_split(ctx: click.Context, param: click.core.Option, value): +def click_arg_split(ctx: click.Context, param: click.core.Option, value): # noqa: ANN001, ARG001 """Will split a comma-separated list input into an actual list. Args: @@ -145,8 +149,7 @@ def parse_task_stages( return stages -# ruff: noqa -def check_custom_case_parameters(ctx: any, param: any, value: any): +def check_custom_case_parameters(ctx: any, param: any, value: any): # noqa: ARG001 if ctx.params.get("case_type") == "PerformanceCustomDataset" and value is None: raise click.BadParameter( """ Custom case parameters diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 7469ec9d5..1c6ff1260 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,5 +1,6 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.clickhouse.cli import Clickhouse from ..backend.clients.mariadb.cli import MariaDBHNSW from ..backend.clients.memorydb.cli import MemoryDB from ..backend.clients.milvus.cli import MilvusAutoIndex @@ -13,7 +14,6 @@ from ..backend.clients.vespa.cli import Vespa from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex -from ..backend.clients.clickhouse.cli import Clickhouse from .cli import cli cli.add_command(PgVectorHNSW) diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index e85b42ff3..bb9cfa44b 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1087,8 +1087,7 @@ class CaseConfigInput(BaseModel): "max": 200, "value": 6, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) CaseConfigParamInput_EFSearch_MariaDB = CaseConfigInput( @@ -1100,8 +1099,7 @@ class CaseConfigInput(BaseModel): "max": 10000, "value": 20, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) CaseConfigParamInput_CacheSize_MariaDB = CaseConfigInput( @@ -1111,10 +1109,9 @@ class CaseConfigInput(BaseModel): inputConfig={ "min": 1048576, "max": (1 << 53) - 1, - "value": 16 * 1024 ** 3, + "value": 16 * 1024**3, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) CaseConfigParamInput_MongoDBQuantizationType = CaseConfigInput( @@ -1162,10 +1159,7 @@ class CaseConfigInput(BaseModel): label=CaseConfigParamType.quantizationType, inputType=InputType.Option, inputConfig={ - "options": [ - "none", - "binary" - ], + "options": ["none", "binary"], }, ) diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index 4418c19da..03bda0fec 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -64,5 +64,5 @@ def getPatternShape(i): DB.Redis.value: "#0D6EFD", DB.AWSOpenSearch.value: "#0DCAF0", DB.TiDB.value: "#0D6EFD", - DB.Vespa.value: "#61d790" + DB.Vespa.value: "#61d790", } diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index e206919ac..b28521096 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -263,7 +263,6 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: ) return TestResult.validate(test_result) - # ruff: noqa def display(self, dbs: list[DB] | None = None): filter_list = dbs if dbs and isinstance(dbs, list) else None sorted_results = sorted( @@ -294,7 +293,7 @@ def append_return(x: any, y: any): max_qps = 10 if max_qps < 10 else max_qps max_recall = 13 if max_recall < 13 else max_recall - LENGTH = ( + LENGTH = ( # noqa: N806 max_db, max_db_labels, max_case, @@ -307,13 +306,13 @@ def append_return(x: any, y: any): 5, ) - DATA_FORMAT = ( + DATA_FORMAT = ( # noqa: N806 f"%-{max_db}s | %-{max_db_labels}s %-{max_case}s %-{len(self.task_label)}s" f" | %-{max_load_dur}s %-{max_qps}s %-15s %-{max_recall}s %-14s" f" | %-5s" ) - TITLE = DATA_FORMAT % ( + TITLE = DATA_FORMAT % ( # noqa: N806 "DB", "db_label", "case", @@ -325,8 +324,8 @@ def append_return(x: any, y: any): "max_load_count", "label", ) - SPLIT = DATA_FORMAT % tuple(map(lambda x: "-" * x, LENGTH)) - SUMMARY_FORMAT = ("Task summary: run_id=%s, task_label=%s") % ( + SPLIT = DATA_FORMAT % tuple(map(lambda x: "-" * x, LENGTH)) # noqa: C417, N806 + SUMMARY_FORMAT = ("Task summary: run_id=%s, task_label=%s") % ( # noqa: N806 self.run_id[:5], self.task_label, ) From 1a9aa48267c6fa9cd3ee8629c2dceb07a8a18f34 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Wed, 16 Apr 2025 10:43:36 +0800 Subject: [PATCH 22/36] fix cli crush Signed-off-by: min.tian --- vectordb_bench/cli/cli.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index bdf1a25f4..4b42a912c 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -40,9 +40,10 @@ def click_get_defaults_from_file(ctx, param, value): # noqa: ANN001, ARG001 if value: - input_file = value if Path.exists(value) else Path.join(config.CONFIG_LOCAL_DIR, value) + path = Path(value) + input_file = path if path.exists() else Path(config.CONFIG_LOCAL_DIR, path) try: - with Path.open(input_file) as f: + with input_file.open() as f: _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader) # noqa: S506 ctx.default_map = _config.get(ctx.command.name, {}) except Exception as e: From 90879f77b9c893c353dc645392c40006d0d5eac6 Mon Sep 17 00:00:00 2001 From: Polo Vezia Date: Thu, 17 Apr 2025 08:19:39 +0000 Subject: [PATCH 23/36] downgrade streamlit version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a66cfac1a..09fa66972 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "click", "pytz", "streamlit-autorefresh", - "streamlit!=1.34.0", + "streamlit<1.44,!=1.34.0", # There is a breaking change in 1.44 related to get_page https://discuss.streamlit.io/t/from-streamlit-source-util-import-get-pages-gone-in-v-1-44-0-need-urgent-help/98399 "streamlit_extras", "tqdm", "s3fs", From 1a1ba0d0e470c9df31a2030c5cded767c4a11d51 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Fri, 18 Apr 2025 11:30:30 +0800 Subject: [PATCH 24/36] add more milvus index types: hnsw sq/pq/prq; ivf rabitq Signed-off-by: min.tian --- vectordb_bench/backend/clients/api.py | 12 ++ .../backend/clients/milvus/config.py | 113 +++++++++++++++++- .../backend/clients/milvus/milvus.py | 4 +- .../frontend/config/dbCaseConfigs.py | 111 +++++++++++++++-- vectordb_bench/models.py | 6 + 5 files changed, 237 insertions(+), 9 deletions(-) diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index ce2b05650..8070164dd 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -16,10 +16,14 @@ class MetricType(str, Enum): class IndexType(str, Enum): HNSW = "HNSW" + HNSW_SQ = "HNSW_SQ" + HNSW_PQ = "HNSW_PQ" + HNSW_PRQ = "HNSW_PRQ" DISKANN = "DISKANN" STREAMING_DISKANN = "DISKANN" IVFFlat = "IVF_FLAT" IVFSQ8 = "IVF_SQ8" + IVF_RABITQ = "IVF_RABITQ" Flat = "FLAT" AUTOINDEX = "AUTOINDEX" ES_HNSW = "hnsw" @@ -31,6 +35,14 @@ class IndexType(str, Enum): SCANN = "scann" +class SQType(str, Enum): + SQ6 = "SQ6" + SQ8 = "SQ8" + BF16 = "BF16" + FP16 = "FP16" + FP32 = "FP32" + + class DBConfig(ABC, BaseModel): """DBConfig contains the connection info of vector database diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index e3a3f9b19..07cd9aad8 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, SecretStr, validator -from ..api import DBCaseConfig, DBConfig, IndexType, MetricType +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType class MilvusConfig(DBConfig): @@ -88,6 +88,88 @@ def search_param(self) -> dict: } +class HNSWSQConfig(HNSWConfig, DBCaseConfig): + index: IndexType = IndexType.HNSW_SQ + sq_type: SQType = SQType.SQ8 + refine: bool = True + refine_type: SQType = SQType.FP32 + refine_k: float = 1 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "M": self.M, + "efConstruction": self.efConstruction, + "sq_type": self.sq_type.value, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"ef": self.ef, "refine_k": self.refine_k}, + } + + +class HNSWPQConfig(HNSWConfig): + index: IndexType = IndexType.HNSW_PQ + m: int = 32 + nbits: int = 8 + refine: bool = True + refine_type: SQType = SQType.FP32 + refine_k: float = 1 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "M": self.M, + "efConstruction": self.efConstruction, + "m": self.m, + "nbits": self.nbits, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"ef": self.ef, "refine_k": self.refine_k}, + } + + +class HNSWPRQConfig(HNSWPQConfig): + index: IndexType = IndexType.HNSW_PRQ + nrq: int = 2 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "M": self.M, + "efConstruction": self.efConstruction, + "m": self.m, + "nbits": self.nbits, + "nrq": self.nrq, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"ef": self.ef, "refine_k": self.refine_k}, + } + + class DISKANNConfig(MilvusIndexConfig, DBCaseConfig): search_list: int | None = None index: IndexType = IndexType.DISKANN @@ -144,6 +226,31 @@ def search_param(self) -> dict: } +class IVFRABITQConfig(IVFSQ8Config): + index: IndexType = IndexType.IVF_RABITQ + rbq_bits_query: int = 0 # 0, 1, 2, ..., 8 + refine: bool = True + refine_type: SQType = SQType.FP32 + refine_k: float = 1 + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": { + "nlist": self.nlist, + "refine": self.refine, + "refine_type": self.refine_type.value, + }, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"nprobe": self.nprobe, "rbq_bits_query": self.rbq_bits_query, "refine_k": self.refine_k}, + } + + class FLATConfig(MilvusIndexConfig, DBCaseConfig): index: IndexType = IndexType.Flat @@ -285,9 +392,13 @@ def search_param(self) -> dict: _milvus_case_config = { IndexType.AUTOINDEX: AutoIndexConfig, IndexType.HNSW: HNSWConfig, + IndexType.HNSW_SQ: HNSWSQConfig, + IndexType.HNSW_PQ: HNSWPQConfig, + IndexType.HNSW_PRQ: HNSWPRQConfig, IndexType.DISKANN: DISKANNConfig, IndexType.IVFFlat: IVFFlatConfig, IndexType.IVFSQ8: IVFSQ8Config, + IndexType.IVF_RABITQ: IVFRABITQConfig, IndexType.Flat: FLATConfig, IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig, IndexType.GPU_IVF_PQ: GPUIVFPQConfig, diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index c812698fe..465c51179 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -61,6 +61,7 @@ def __init__( consistency_level="Session", ) + log.info(f"{self.name} create index: index_params: {self.case_config.index_param()}") col.create_index( self._vector_field, self.case_config.index_param(), @@ -71,7 +72,7 @@ def __init__( connections.disconnect("default") @contextmanager - def init(self) -> None: + def init(self): """ Examples: >>> with self.init(): @@ -126,6 +127,7 @@ def wait_index(): try: self.col.compact() self.col.wait_for_compaction_completed() + log.info("compactation completed. waiting for the rest of index buliding.") except Exception as e: log.warning(f"{self.name} compact error: {e}") if hasattr(e, "code"): diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index bb9cfa44b..bd59c3470 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from vectordb_bench.backend.cases import CaseLabel, CaseType from vectordb_bench.backend.clients import DB -from vectordb_bench.backend.clients.api import IndexType, MetricType +from vectordb_bench.backend.clients.api import IndexType, MetricType, SQType from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_configs from vectordb_bench.models import CaseConfig, CaseConfigParamType @@ -164,10 +164,13 @@ class CaseConfigInput(BaseModel): inputConfig={ "options": [ IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, IndexType.IVFFlat.value, IndexType.IVFSQ8.value, + IndexType.IVF_RABITQ.value, IndexType.DISKANN.value, - IndexType.STREAMING_DISKANN.value, IndexType.Flat.value, IndexType.AUTOINDEX.value, IndexType.GPU_IVF_FLAT.value, @@ -346,9 +349,16 @@ class CaseConfigInput(BaseModel): "max": 64, "value": 30, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [ + IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, + ], ) + CaseConfigParamInput_m = CaseConfigInput( label=CaseConfigParamType.m, inputType=InputType.Number, @@ -369,7 +379,62 @@ class CaseConfigInput(BaseModel): "max": 512, "value": 360, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] + in [ + IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, + ], +) + +CaseConfigParamInput_SQType = CaseConfigInput( + label=CaseConfigParamType.sq_type, + inputType=InputType.Option, + inputHelp="Scalar quantizer type.", + inputConfig={ + "options": [SQType.SQ6.value, SQType.SQ8.value, SQType.BF16.value, SQType.FP16.value, SQType.FP32.value] + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.HNSW_SQ.value], +) + +CaseConfigParamInput_Refine = CaseConfigInput( + label=CaseConfigParamType.refine, + inputType=InputType.Option, + inputHelp="Whether refined data is reserved during index building.", + inputConfig={"options": [True, False]}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value], +) + +CaseConfigParamInput_RefineType = CaseConfigInput( + label=CaseConfigParamType.refine_type, + inputType=InputType.Option, + inputHelp="The data type of the refine index.", + inputConfig={ + "options": [SQType.FP32.value, SQType.FP16.value, SQType.BF16.value, SQType.SQ8.value, SQType.SQ6.value] + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value] + and config.get(CaseConfigParamType.refine, True), +) + +CaseConfigParamInput_RefineK = CaseConfigInput( + label=CaseConfigParamType.refine_k, + inputType=InputType.Float, + inputHelp="The magnification factor of refine compared to k.", + inputConfig={"min": 1.0, "max": 10000.0, "value": 1.0}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.HNSW_SQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVF_RABITQ.value] + and config.get(CaseConfigParamType.refine, True), +) + +CaseConfigParamInput_RBQBitsQuery = CaseConfigInput( + label=CaseConfigParamType.rbq_bits_query, + inputType=InputType.Number, + inputHelp="The magnification factor of refine compared to k.", + inputConfig={"min": 0, "max": 8, "value": 0}, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.IVF_RABITQ.value], ) CaseConfigParamInput_EFConstruction_Weaviate = CaseConfigInput( @@ -519,7 +584,13 @@ class CaseConfigInput(BaseModel): "max": MAX_STREAMLIT_INT, "value": 100, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] + in [ + IndexType.HNSW.value, + IndexType.HNSW_SQ.value, + IndexType.HNSW_PQ.value, + IndexType.HNSW_PRQ.value, + ], ) CaseConfigParamInput_EF_Weaviate = CaseConfigInput( @@ -561,6 +632,7 @@ class CaseConfigInput(BaseModel): in [ IndexType.IVFFlat.value, IndexType.IVFSQ8.value, + IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_BRUTE_FORCE.value, @@ -579,6 +651,7 @@ class CaseConfigInput(BaseModel): in [ IndexType.IVFFlat.value, IndexType.IVFSQ8.value, + IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, IndexType.GPU_IVF_PQ.value, IndexType.GPU_BRUTE_FORCE.value, @@ -593,7 +666,8 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 0, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value], ) @@ -605,7 +679,20 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 8, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value], +) + +CaseConfigParamInput_NRQ = CaseConfigInput( + label=CaseConfigParamType.nrq, + inputType=InputType.Number, + inputHelp="The number of residual subquantizers.", + inputConfig={ + "min": 1, + "max": 16, + "value": 2, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.HNSW_PRQ.value], ) CaseConfigParamInput_intermediate_graph_degree = CaseConfigInput( @@ -1186,6 +1273,10 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_graph_degree, CaseConfigParamInput_build_algo, CaseConfigParamInput_cache_dataset_on_device, + CaseConfigParamInput_SQType, + CaseConfigParamInput_Refine, + CaseConfigParamInput_RefineType, + CaseConfigParamInput_NRQ, ] MilvusPerformanceConfig = [ CaseConfigParamInput_IndexType, @@ -1197,6 +1288,8 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_Nprobe, CaseConfigParamInput_M_PQ, CaseConfigParamInput_Nbits_PQ, + CaseConfigParamInput_RBQBitsQuery, + CaseConfigParamInput_NRQ, CaseConfigParamInput_intermediate_graph_degree, CaseConfigParamInput_graph_degree, CaseConfigParamInput_itopk_size, @@ -1207,6 +1300,10 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_build_algo, CaseConfigParamInput_cache_dataset_on_device, CaseConfigParamInput_refine_ratio, + CaseConfigParamInput_SQType, + CaseConfigParamInput_Refine, + CaseConfigParamInput_RefineType, + CaseConfigParamInput_RefineK, ] WeaviateLoadConfig = [ diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index b28521096..ca00c4b55 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -55,6 +55,7 @@ class CaseConfigParamType(Enum): quantizedFetchLimit = "quantized_fetch_limit" m = "m" nbits = "nbits" + nrq = "nrq" intermediate_graph_degree = "intermediate_graph_degree" graph_degree = "graph_degree" itopk_size = "itopk_size" @@ -65,6 +66,11 @@ class CaseConfigParamType(Enum): build_algo = "build_algo" cache_dataset_on_device = "cache_dataset_on_device" refine_ratio = "refine_ratio" + refine = "refine" + refine_type = "refine_type" + refine_k = "refine_k" + rbq_bits_query = "rbq_bits_query" + sq_type = "sq_type" level = "level" maintenance_work_mem = "maintenance_work_mem" max_parallel_workers = "max_parallel_workers" From e42845f9bf7af9840c4d1a93f8056e8bf084db01 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Wed, 23 Apr 2025 10:14:15 +0800 Subject: [PATCH 25/36] add more milvus index types: ivf_pq Signed-off-by: min.tian --- vectordb_bench/backend/clients/api.py | 1 + .../backend/clients/milvus/config.py | 22 +++++++++++++++++++ .../frontend/config/dbCaseConfigs.py | 11 ++++++---- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 8070164dd..790da891b 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -22,6 +22,7 @@ class IndexType(str, Enum): DISKANN = "DISKANN" STREAMING_DISKANN = "DISKANN" IVFFlat = "IVF_FLAT" + IVFPQ = "IVF_PQ" IVFSQ8 = "IVF_SQ8" IVF_RABITQ = "IVF_RABITQ" Flat = "FLAT" diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 07cd9aad8..672becf1b 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -207,6 +207,27 @@ def search_param(self) -> dict: } +class IVFPQConfig(MilvusIndexConfig, DBCaseConfig): + nlist: int + nprobe: int | None = None + m: int = 32 + nbits: int = 8 + index: IndexType = IndexType.IVFPQ + + def index_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "index_type": self.index.value, + "params": {"nlist": self.nlist, "m": self.m, "nbits": self.nbits}, + } + + def search_param(self) -> dict: + return { + "metric_type": self.parse_metric(), + "params": {"nprobe": self.nprobe}, + } + + class IVFSQ8Config(MilvusIndexConfig, DBCaseConfig): nlist: int nprobe: int | None = None @@ -397,6 +418,7 @@ def search_param(self) -> dict: IndexType.HNSW_PRQ: HNSWPRQConfig, IndexType.DISKANN: DISKANNConfig, IndexType.IVFFlat: IVFFlatConfig, + IndexType.IVFPQ: IVFPQConfig, IndexType.IVFSQ8: IVFSQ8Config, IndexType.IVF_RABITQ: IVFRABITQConfig, IndexType.Flat: FLATConfig, diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index bd59c3470..da5e91d91 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -168,6 +168,7 @@ class CaseConfigInput(BaseModel): IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, IndexType.IVF_RABITQ.value, IndexType.DISKANN.value, @@ -631,6 +632,7 @@ class CaseConfigInput(BaseModel): isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [ IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, @@ -650,6 +652,7 @@ class CaseConfigInput(BaseModel): isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [ IndexType.IVFFlat.value, + IndexType.IVFPQ.value, IndexType.IVFSQ8.value, IndexType.IVF_RABITQ.value, IndexType.GPU_IVF_FLAT.value, @@ -662,12 +665,12 @@ class CaseConfigInput(BaseModel): label=CaseConfigParamType.m, inputType=InputType.Number, inputConfig={ - "min": 0, + "min": 1, "max": 65536, - "value": 0, + "value": 32, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value], + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFPQ.value], ) @@ -680,7 +683,7 @@ class CaseConfigInput(BaseModel): "value": 8, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value], + in [IndexType.GPU_IVF_PQ.value, IndexType.HNSW_PQ.value, IndexType.HNSW_PRQ.value, IndexType.IVFPQ.value], ) CaseConfigParamInput_NRQ = CaseConfigInput( From 7f83936cd1229fcdbd07928abda73a4d65164e48 Mon Sep 17 00:00:00 2001 From: MansorY <119126888+MansorY23@users.noreply.github.com> Date: Thu, 24 Apr 2025 04:49:40 +0300 Subject: [PATCH 26/36] Add HNSW support for Clickhouse client (#500) * feat: add hnsw support * refactor: minor fixes * feat: reformat code * fix: remove sql injections, reformat code --- .../backend/clients/clickhouse/clickhouse.py | 170 +++++++++++++----- .../backend/clients/clickhouse/config.py | 49 +++-- 2 files changed, 162 insertions(+), 57 deletions(-) diff --git a/vectordb_bench/backend/clients/clickhouse/clickhouse.py b/vectordb_bench/backend/clients/clickhouse/clickhouse.py index e241cdb8d..498132ca9 100644 --- a/vectordb_bench/backend/clients/clickhouse/clickhouse.py +++ b/vectordb_bench/backend/clients/clickhouse/clickhouse.py @@ -5,8 +5,11 @@ from typing import Any import clickhouse_connect +from clickhouse_connect.driver import Client -from ..api import DBCaseConfig, VectorDB +from .. import IndexType +from ..api import VectorDB +from .config import ClickhouseConfigDict, ClickhouseIndexConfig log = logging.getLogger(__name__) @@ -17,8 +20,8 @@ class Clickhouse(VectorDB): def __init__( self, dim: int, - db_config: dict, - db_case_config: DBCaseConfig, + db_config: ClickhouseConfigDict, + db_case_config: ClickhouseIndexConfig, collection_name: str = "CHVectorCollection", drop_old: bool = False, **kwargs, @@ -28,29 +31,29 @@ def __init__( self.table_name = collection_name self.dim = dim + self.index_param = self.case_config.index_param() + self.search_param = self.case_config.search_param() + self.session_param = self.case_config.session_param() + self._index_name = "clickhouse_index" self._primary_field = "id" self._vector_field = "embedding" # construct basic units - self.conn = clickhouse_connect.get_client( - host=self.db_config["host"], - port=self.db_config["port"], - username=self.db_config["user"], - password=self.db_config["password"], - database=self.db_config["dbname"], - ) + self.conn = self._create_connection(**self.db_config, settings=self.session_param) if drop_old: log.info(f"Clickhouse client drop table : {self.table_name}") self._drop_table() self._create_table(dim) + if self.case_config.create_index_before_load: + self._create_index() self.conn.close() self.conn = None @contextmanager - def init(self): + def init(self) -> None: """ Examples: >>> with self.init(): @@ -58,13 +61,7 @@ def init(self): >>> self.search_embedding() """ - self.conn = clickhouse_connect.get_client( - host=self.db_config["host"], - port=self.db_config["port"], - username=self.db_config["user"], - password=self.db_config["password"], - database=self.db_config["dbname"], - ) + self.conn = self._create_connection(**self.db_config, settings=self.session_param) try: yield @@ -72,10 +69,61 @@ def init(self): self.conn.close() self.conn = None + def _create_connection(self, settings: dict | None, **kwargs) -> Client: + return clickhouse_connect.get_client(**self.db_config, settings=settings) + + def _drop_index(self): + assert self.conn is not None, "Connection is not initialized" + try: + self.conn.command( + f'ALTER TABLE {self.db_config["database"]}.{self.table_name} DROP INDEX {self._index_name}' + ) + except Exception as e: + log.warning(f"Failed to drop index on table {self.db_config['database']}.{self.table_name}: {e}") + raise e from None + def _drop_table(self): assert self.conn is not None, "Connection is not initialized" - self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}') + try: + self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["database"]}.{self.table_name}') + except Exception as e: + log.warning(f"Failed to drop table {self.db_config['database']}.{self.table_name}: {e}") + raise e from None + + def _perfomance_tuning(self): + self.conn.command("SET materialize_skip_indexes_on_insert = 1") + + def _create_index(self): + assert self.conn is not None, "Connection is not initialized" + try: + if self.index_param["index_type"] == IndexType.HNSW.value: + if ( + self.index_param["quantization"] + and self.index_param["params"]["M"] + and self.index_param["params"]["efConstruction"] + ): + query = f""" + ALTER TABLE {self.db_config["database"]}.{self.table_name} + ADD INDEX {self._index_name} {self._vector_field} + TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}', + '{self.index_param["quantization"]}', + {self.index_param["params"]["M"]}, {self.index_param["params"]["efConstruction"]}) + GRANULARITY {self.index_param["granularity"]} + """ + else: + query = f""" + ALTER TABLE {self.db_config["database"]}.{self.table_name} + ADD INDEX {self._index_name} {self._vector_field} + TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}') + GRANULARITY {self.index_param["granularity"]} + """ + self.conn.command(cmd=query) + else: + log.warning("HNSW is only avaliable method in clickhouse now") + except Exception as e: + log.warning(f"Failed to create Clickhouse vector index on table: {self.table_name} error: {e}") + raise e from None def _create_table(self, dim: int): assert self.conn is not None, "Connection is not initialized" @@ -83,21 +131,22 @@ def _create_table(self, dim: int): try: # create table self.conn.command( - f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \ - (id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;' + f'CREATE TABLE IF NOT EXISTS {self.db_config["database"]}.{self.table_name} ' + f"({self._primary_field} UInt32, " + f'{self._vector_field} Array({self.index_param["vector_data_type"]}) CODEC(NONE), ' + f"CONSTRAINT same_length CHECK length(embedding) = {dim}) " + f"ENGINE = MergeTree() " + f"ORDER BY {self._primary_field}" ) except Exception as e: log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}") raise e from None - def ready_to_load(self): - pass - def optimize(self, data_size: int | None = None): pass - def ready_to_search(self): + def _post_insert(self): pass def insert_embeddings( @@ -105,7 +154,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> tuple[int, Exception]: + ) -> (int, Exception): assert self.conn is not None, "Connection is not initialized" try: @@ -116,7 +165,7 @@ def insert_embeddings( table=self.table_name, data=items, column_names=["id", "embedding"], - column_type_names=["UInt32", "Array(Float64)"], + column_type_names=["UInt32", f'Array({self.index_param["vector_data_type"]})'], column_oriented=True, ) return len(metadata), None @@ -132,25 +181,52 @@ def search_embedding( timeout: int | None = None, ) -> list[int]: assert self.conn is not None, "Connection is not initialized" - - index_param = self.case_config.index_param() # noqa: F841 - search_param = self.case_config.search_param() - - if filters: - gt = filters.get("id") - filter_sql = ( - f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608 - f'FROM {self.db_config["dbname"]}.{self.table_name} ' - f"WHERE id > {gt} " - f"ORDER BY score LIMIT {k};" - ) - result = self.conn.query(filter_sql).result_rows + parameters = { + "primary_field": self._primary_field, + "vector_field": self._vector_field, + "schema": self.db_config["database"], + "table": self.table_name, + "gt": filters.get("id"), + "k": k, + "metric_type": self.search_param["metric_type"], + "query": query, + } + if self.case_config.metric_type == "COSINE": + if filters: + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "WHERE {primary_field:Identifier} > {gt:UInt32} " + "ORDER BY cosineDistance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows + return [int(row[0]) for row in result] + + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "ORDER BY cosineDistance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows return [int(row[0]) for row in result] - else: # noqa: RET505 - select_sql = ( - f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608 - f'FROM {self.db_config["dbname"]}.{self.table_name} ' - f"ORDER BY score LIMIT {k};" - ) - result = self.conn.query(select_sql).result_rows + if filters: + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "WHERE {primary_field:Identifier} > {gt:UInt32} " + "ORDER BY L2Distance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows return [int(row[0]) for row in result] + + result = self.conn.query( + "SELECT {primary_field:Identifier}, {vector_field:Identifier} " + "FROM {schema:Identifier}.{table:Identifier} " + "ORDER BY L2Distance(embedding,{query:Array(Float64)}) " + "LIMIT {k:UInt32}", + parameters=parameters, + ).result_rows + return [int(row[0]) for row in result] diff --git a/vectordb_bench/backend/clients/clickhouse/config.py b/vectordb_bench/backend/clients/clickhouse/config.py index fad446049..a4c5fe499 100644 --- a/vectordb_bench/backend/clients/clickhouse/config.py +++ b/vectordb_bench/backend/clients/clickhouse/config.py @@ -1,29 +1,46 @@ +from abc import abstractmethod +from typing import TypedDict + from pydantic import BaseModel, SecretStr from ..api import DBCaseConfig, DBConfig, IndexType, MetricType +class ClickhouseConfigDict(TypedDict): + user: str + password: str + host: str + port: int + database: str + secure: bool + + class ClickhouseConfig(DBConfig): user_name: str = "clickhouse" password: SecretStr host: str = "localhost" port: int = 8123 db_name: str = "default" + secure: bool = False - def to_dict(self) -> dict: + def to_dict(self) -> ClickhouseConfigDict: pwd_str = self.password.get_secret_value() return { "host": self.host, "port": self.port, - "dbname": self.db_name, + "database": self.db_name, "user": self.user_name, "password": pwd_str, + "secure": self.secure, } -class ClickhouseIndexConfig(BaseModel): +class ClickhouseIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None + vector_data_type: str | None = "Float32" # Data type of vectors. Can be Float32 or Float64 or BFloat16 + create_index_before_load: bool = True + create_index_after_load: bool = False def parse_metric(self) -> str: if not self.metric_type: @@ -35,26 +52,38 @@ def parse_metric_str(self) -> str: return "L2Distance" if self.metric_type == MetricType.COSINE: return "cosineDistance" - msg = f"Not Support for {self.metric_type}" - raise RuntimeError(msg) - return None + return "cosineDistance" + + @abstractmethod + def session_param(self): + pass -class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig): - M: int | None - efConstruction: int | None +class ClickhouseHNSWConfig(ClickhouseIndexConfig): + M: int | None # Default in clickhouse in 32 + efConstruction: int | None # Default in clickhouse in 128 ef: int | None = None index: IndexType = IndexType.HNSW + quantization: str | None = "bf16" # Default is bf16. Possible values are f64, f32, f16, bf16, or i8 + granularity: int | None = 10_000_000 # Size of the index granules. By default, in CH it's equal 10.000.000 def index_param(self) -> dict: return { + "vector_data_type": self.vector_data_type, "metric_type": self.parse_metric_str(), "index_type": self.index.value, + "quantization": self.quantization, + "granularity": self.granularity, "params": {"M": self.M, "efConstruction": self.efConstruction}, } def search_param(self) -> dict: return { - "met˝ric_type": self.parse_metric_str(), + "metric_type": self.parse_metric_str(), "params": {"ef": self.ef}, } + + def session_param(self) -> dict: + return { + "allow_experimental_vector_similarity_index": 1, + } From b7bad93f71c82e32ca4443c0bdbac8a02a79ea83 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Wed, 30 Apr 2025 13:01:13 +0800 Subject: [PATCH 27/36] fix bugs when use custom_dataset without groundtruth file Signed-off-by: min.tian --- vectordb_bench/backend/dataset.py | 12 ++++++++---- vectordb_bench/backend/runner/serial_runner.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index 62700b0fa..f90580dc6 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -220,10 +220,12 @@ def prepare( train_files = utils.compose_train_files(file_count, use_shuffled) all_files = train_files - gt_file, test_file = None, None + test_file = "test.parquet" + all_files.extend([test_file]) + gt_file = None if self.data.with_gt: - gt_file, test_file = utils.compose_gt_file(filters), "test.parquet" - all_files.extend([gt_file, test_file]) + gt_file = utils.compose_gt_file(filters) + all_files.extend([gt_file]) if not self.data.is_custom: source.reader().read( @@ -232,8 +234,10 @@ def prepare( local_ds_root=self.data_dir, ) - if gt_file is not None and test_file is not None: + if test_file is not None: self.test_data = self._read_file(test_file) + + if gt_file is not None: self.gt_data = self._read_file(gt_file) prefix = "shuffle_train" if use_shuffled else "train" diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 365641132..5b418c886 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -209,7 +209,8 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: ideal_dcg = get_ideal_dcg(self.k) log.debug(f"test dataset size: {len(test_data)}") - log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}") + if ground_truth is not None: + log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}") latencies, recalls, ndcgs = [], [], [] for idx, emb in enumerate(test_data): @@ -228,9 +229,13 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: latencies.append(time.perf_counter() - s) - gt = ground_truth["neighbors_id"][idx] - recalls.append(calc_recall(self.k, gt[: self.k], results)) - ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg)) + if ground_truth is not None: + gt = ground_truth["neighbors_id"][idx] + recalls.append(calc_recall(self.k, gt[: self.k], results)) + ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg)) + else: + recalls.append(0) + ndcgs.append(0) if len(latencies) % 100 == 0: log.debug( From 024455f67cd741b1f30a54db5fb205a1372d6ade Mon Sep 17 00:00:00 2001 From: Andreas Opferkuch Date: Sat, 3 May 2025 12:46:44 +0200 Subject: [PATCH 28/36] fix: prevent the frontend from crashing on invalid indexes in results --- vectordb_bench/models.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index ca00c4b55..e99b3c789 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -6,12 +6,15 @@ import ujson +from vectordb_bench.backend.clients.api import EmptyDBCaseConfig + from . import config from .backend.cases import CaseType from .backend.clients import ( DB, DBCaseConfig, DBConfig, + EmptyDBCaseConfig, ) from .base import BaseModel from .metric import Metric @@ -247,13 +250,21 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: test_result["task_label"] = test_result["run_id"] for case_result in test_result["results"]: - task_config = case_result.get("task_config") - db = DB(task_config.get("db")) + task_config = case_result["task_config"] + db = DB(task_config["db"]) task_config["db_config"] = db.config_cls(**task_config["db_config"]) - task_config["db_case_config"] = db.case_config_cls( - index_type=task_config["db_case_config"].get("index", None), - )(**task_config["db_case_config"]) + + # Safely instantiate DBCaseConfig (fallback to EmptyDBCaseConfig on None) + raw_case_cfg = task_config.get("db_case_config") or {} + index_value = raw_case_cfg.get("index", None) + try: + task_config["db_case_config"] = db.case_config_cls(index_type=index_value)(**raw_case_cfg) + except: + log.error( + f"Couldn't get class for index '{index_value}' ({full_path})" + ) + task_config["db_case_config"] = EmptyDBCaseConfig(**raw_case_cfg) case_result["task_config"] = task_config From 4ef378b2cdd3121efa745d78335f4735a7391c58 Mon Sep 17 00:00:00 2001 From: Andreas Opferkuch Date: Tue, 6 May 2025 11:45:15 +0200 Subject: [PATCH 29/36] fix ruff warnings --- vectordb_bench/models.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index e99b3c789..001469f3d 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -6,8 +6,6 @@ import ujson -from vectordb_bench.backend.clients.api import EmptyDBCaseConfig - from . import config from .backend.cases import CaseType from .backend.clients import ( @@ -260,8 +258,8 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: index_value = raw_case_cfg.get("index", None) try: task_config["db_case_config"] = db.case_config_cls(index_type=index_value)(**raw_case_cfg) - except: - log.error( + except Exception: + log.exception( f"Couldn't get class for index '{index_value}' ({full_path})" ) task_config["db_case_config"] = EmptyDBCaseConfig(**raw_case_cfg) @@ -368,3 +366,4 @@ def append_return(x: any, y: any): tmp_logger = logging.getLogger("no_color") for f in fmt: tmp_logger.info(f) + From b1e5cb73159bf0ba879f3e554ae68c632e49f880 Mon Sep 17 00:00:00 2001 From: Andreas Opferkuch Date: Tue, 6 May 2025 12:02:21 +0200 Subject: [PATCH 30/36] Fix formatting --- vectordb_bench/models.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 001469f3d..76ceaaddc 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -259,9 +259,7 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: try: task_config["db_case_config"] = db.case_config_cls(index_type=index_value)(**raw_case_cfg) except Exception: - log.exception( - f"Couldn't get class for index '{index_value}' ({full_path})" - ) + log.exception(f"Couldn't get class for index '{index_value}' ({full_path})") task_config["db_case_config"] = EmptyDBCaseConfig(**raw_case_cfg) case_result["task_config"] = task_config @@ -366,4 +364,3 @@ def append_return(x: any, y: any): tmp_logger = logging.getLogger("no_color") for f in fmt: tmp_logger.info(f) - From 617e57e634c0fe12f998a7fa22821bc84175449e Mon Sep 17 00:00:00 2001 From: Andreas Opferkuch Date: Sat, 26 Apr 2025 19:07:26 +0200 Subject: [PATCH 31/36] Add lancedb --- README.md | 8 +- pyproject.toml | 2 + vectordb_bench/backend/clients/__init__.py | 20 ++- vectordb_bench/backend/clients/api.py | 1 + vectordb_bench/backend/clients/lancedb/cli.py | 92 +++++++++++++ .../backend/clients/lancedb/config.py | 103 +++++++++++++++ .../backend/clients/lancedb/lancedb.py | 91 +++++++++++++ vectordb_bench/cli/vectordbbench.py | 2 + .../frontend/config/dbCaseConfigs.py | 125 ++++++++++++++++++ vectordb_bench/frontend/config/styles.py | 1 + vectordb_bench/models.py | 3 + 11 files changed, 442 insertions(+), 6 deletions(-) create mode 100644 vectordb_bench/backend/clients/lancedb/cli.py create mode 100644 vectordb_bench/backend/clients/lancedb/config.py create mode 100644 vectordb_bench/backend/clients/lancedb/lancedb.py diff --git a/README.md b/README.md index 6d83671da..662461a73 100644 --- a/README.md +++ b/README.md @@ -267,13 +267,13 @@ pip install -e '.[pinecone]' ``` ### Run test server ``` -$ python -m vectordb_bench +python -m vectordb_bench ``` OR: ```shell -$ init_bench +init_bench ``` OR: @@ -290,13 +290,13 @@ After reopen the repository in container, run `python -m vectordb_bench` in the ### Check coding styles ```shell -$ make lint +make lint ``` To fix the coding styles automatically ```shell -$ make format +make format ``` ## How does it work? diff --git a/pyproject.toml b/pyproject.toml index 09fa66972..8cab39194 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ all = [ "PyMySQL", "clickhouse-connect", "pyvespa", + "lancedb", ] qdrant = [ "qdrant-client" ] @@ -94,6 +95,7 @@ mariadb = [ "mariadb" ] tidb = [ "PyMySQL" ] clickhouse = [ "clickhouse-connect" ] vespa = [ "pyvespa" ] +lancedb = [ "lancedb" ] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index cfef0283f..f05913a06 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -45,9 +45,10 @@ class DB(Enum): TiDB = "TiDB" Clickhouse = "Clickhouse" Vespa = "Vespa" + LanceDB = "LanceDB" @property - def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 + def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915 """Import while in use""" if self == DB.Milvus: from .milvus.milvus import Milvus @@ -164,11 +165,16 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901 return Vespa + if self == DB.LanceDB: + from .lancedb.lancedb import LanceDB + + return LanceDB + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @property - def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 + def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915 """Import while in use""" if self == DB.Milvus: from .milvus.config import MilvusConfig @@ -285,6 +291,11 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901 return VespaConfig + if self == DB.LanceDB: + from .lancedb.config import LanceDBConfig + + return LanceDBConfig + msg = f"Unknown DB: {self.name}" raise ValueError(msg) @@ -382,6 +393,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912 return VespaHNSWConfig + if self == DB.LanceDB: + from .lancedb.config import _lancedb_case_config + + return _lancedb_case_config.get(index_type) + # DB.Pinecone, DB.Chroma, DB.Redis return EmptyDBCaseConfig diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 790da891b..ff7b378a7 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -34,6 +34,7 @@ class IndexType(str, Enum): GPU_IVF_PQ = "GPU_IVF_PQ" GPU_CAGRA = "GPU_CAGRA" SCANN = "scann" + NONE = "NONE" class SQType(str, Enum): diff --git a/vectordb_bench/backend/clients/lancedb/cli.py b/vectordb_bench/backend/clients/lancedb/cli.py new file mode 100644 index 000000000..573c64d05 --- /dev/null +++ b/vectordb_bench/backend/clients/lancedb/cli.py @@ -0,0 +1,92 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from ..api import IndexType + + +class LanceDBTypedDict(CommonTypedDict): + uri: Annotated[ + str, + click.option("--uri", type=str, help="URI connection string", required=True), + ] + token: Annotated[ + str | None, + click.option("--token", type=str, help="Authentication token", required=False), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDB(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get("NONE")(), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDBAutoIndex(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get(IndexType.AUTOINDEX)(), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get(IndexType.IVFPQ)(), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBTypedDict) +def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]): + from .config import LanceDBConfig, _lancedb_case_config + + run( + db=DB.LanceDB, + db_config=LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + ), + db_case_config=_lancedb_case_config.get(IndexType.HNSW)(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/lancedb/config.py b/vectordb_bench/backend/clients/lancedb/config.py new file mode 100644 index 000000000..0bbdfc4c9 --- /dev/null +++ b/vectordb_bench/backend/clients/lancedb/config.py @@ -0,0 +1,103 @@ +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class LanceDBConfig(DBConfig): + """LanceDB connection configuration.""" + + db_label: str + uri: str + token: SecretStr | None = None + + def to_dict(self) -> dict: + return { + "uri": self.uri, + "token": self.token.get_secret_value() if self.token else None, + } + + +class LanceDBIndexConfig(BaseModel, DBCaseConfig): + index: IndexType = IndexType.IVFPQ + metric_type: MetricType = MetricType.L2 + num_partitions: int = 0 + num_sub_vectors: int = 0 + nbits: int = 8 # Must be 4 or 8 + sample_rate: int = 256 + max_iterations: int = 50 + + def index_param(self) -> dict: + if self.index not in [ + IndexType.IVFPQ, + IndexType.HNSW, + IndexType.AUTOINDEX, + IndexType.NONE, + ]: + msg = f"Index type {self.index} is not supported for LanceDB!" + raise ValueError(msg) + + # See https://lancedb.github.io/lancedb/python/python/#lancedb.table.Table.create_index + params = { + "metric": self.parse_metric(), + "num_bits": self.nbits, + "sample_rate": self.sample_rate, + "max_iterations": self.max_iterations, + } + + if self.num_partitions > 0: + params["num_partitions"] = self.num_partitions + if self.num_sub_vectors > 0: + params["num_sub_vectors"] = self.num_sub_vectors + + return params + + def search_param(self) -> dict: + pass + + def parse_metric(self) -> str: + if self.metric_type in [MetricType.L2, MetricType.COSINE]: + return self.metric_type.value.lower() + if self.metric_type in [MetricType.IP, MetricType.DP]: + return "dot" + msg = f"Metric type {self.metric_type} is not supported for LanceDB!" + raise ValueError(msg) + + +class LanceDBNoIndexConfig(LanceDBIndexConfig): + index: IndexType = IndexType.NONE + + def index_param(self) -> dict: + return {} + + +class LanceDBAutoIndexConfig(LanceDBIndexConfig): + index: IndexType = IndexType.AUTOINDEX + + def index_param(self) -> dict: + return {} + + +class LanceDBHNSWIndexConfig(LanceDBIndexConfig): + index: IndexType = IndexType.HNSW + m: int = 0 + ef_construction: int = 0 + + def index_param(self) -> dict: + params = LanceDBIndexConfig.index_param(self) + + # See https://lancedb.github.io/lancedb/python/python/#lancedb.index.HnswSq + params["index_type"] = "IVF_HNSW_SQ" + if self.m > 0: + params["m"] = self.m + if self.ef_construction > 0: + params["ef_construction"] = self.ef_construction + + return params + + +_lancedb_case_config = { + IndexType.IVFPQ: LanceDBIndexConfig, + IndexType.AUTOINDEX: LanceDBAutoIndexConfig, + IndexType.HNSW: LanceDBHNSWIndexConfig, + IndexType.NONE: LanceDBNoIndexConfig, +} diff --git a/vectordb_bench/backend/clients/lancedb/lancedb.py b/vectordb_bench/backend/clients/lancedb/lancedb.py new file mode 100644 index 000000000..d93871de2 --- /dev/null +++ b/vectordb_bench/backend/clients/lancedb/lancedb.py @@ -0,0 +1,91 @@ +import logging +from contextlib import contextmanager + +import lancedb +import pyarrow as pa +from lancedb.pydantic import LanceModel + +from ..api import IndexType, VectorDB +from .config import LanceDBConfig, LanceDBIndexConfig + +log = logging.getLogger(__name__) + + +class VectorModel(LanceModel): + id: int + vector: list[float] + + +class LanceDB(VectorDB): + def __init__( + self, + dim: int, + db_config: LanceDBConfig, + db_case_config: LanceDBIndexConfig, + collection_name: str = "vector_bench_test", + drop_old: bool = False, + **kwargs, + ): + self.name = "LanceDB" + self.db_config = db_config + self.case_config = db_case_config + self.table_name = collection_name + self.dim = dim + self.uri = db_config["uri"] + + db = lancedb.connect(self.uri) + + if drop_old: + try: + db.drop_table(self.table_name) + except Exception as e: + log.warning(f"Failed to drop table {self.table_name}: {e}") + + try: + db.open_table(self.table_name) + except Exception: + schema = pa.schema( + [pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.float64(), list_size=self.dim))] + ) + db.create_table(self.table_name, schema=schema, mode="overwrite") + + @contextmanager + def init(self): + self.db = lancedb.connect(self.uri) + self.table = self.db.open_table(self.table_name) + yield + self.db = None + self.table = None + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + ) -> tuple[int, Exception | None]: + try: + data = [{"id": meta, "vector": emb} for meta, emb in zip(metadata, embeddings, strict=False)] + self.table.add(data) + return len(metadata), None + except Exception as e: + log.warning(f"Failed to insert data into LanceDB table ({self.table_name}), error: {e}") + return 0, e + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + ) -> list[int]: + if filters: + results = self.table.search(query).where(f"id >= {filters['id']}", prefilter=True).limit(k).to_list() + else: + results = self.table.search(query).limit(k).to_list() + return [int(result["id"]) for result in results] + + def optimize(self, data_size: int | None = None): + if self.table and hasattr(self, "case_config") and self.case_config.index != IndexType.NONE: + log.info(f"Creating index for LanceDB table ({self.table_name})") + self.table.create_index(**self.case_config.index_param()) + # Better recall with IVF_PQ (though still bad) but breaks HNSW: https://github.com/lancedb/lancedb/issues/2369 + if self.case_config.index in (IndexType.IVFPQ, IndexType.AUTOINDEX): + self.table.optimize() diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 1c6ff1260..210a3ccb7 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,6 +1,7 @@ from ..backend.clients.alloydb.cli import AlloyDBScaNN from ..backend.clients.aws_opensearch.cli import AWSOpenSearch from ..backend.clients.clickhouse.cli import Clickhouse +from ..backend.clients.lancedb.cli import LanceDB from ..backend.clients.mariadb.cli import MariaDBHNSW from ..backend.clients.memorydb.cli import MemoryDB from ..backend.clients.milvus.cli import MilvusAutoIndex @@ -33,6 +34,7 @@ cli.add_command(TiDB) cli.add_command(Clickhouse) cli.add_command(Vespa) +cli.add_command(LanceDB) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index da5e91d91..3c9430b2b 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1491,6 +1491,127 @@ class CaseConfigInput(BaseModel): ] VespaPerformanceConfig = VespaLoadingConfig +CaseConfigParamInput_IndexType_LanceDB = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="AUTOINDEX = IVFPQ with default parameters", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.NONE.value, + IndexType.AUTOINDEX.value, + IndexType.IVFPQ.value, + IndexType.HNSW.value, + ], + }, +) + +CaseConfigParamInput_num_partitions_LanceDB = CaseConfigInput( + label=CaseConfigParamType.num_partitions, + displayLabel="Number of Partitions", + inputHelp="Number of partitions (clusters) for IVF_PQ. Default (when 0): sqrt(num_rows)", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 10000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_num_sub_vectors_LanceDB = CaseConfigInput( + label=CaseConfigParamType.num_sub_vectors, + displayLabel="Number of Sub-vectors", + inputHelp="Number of sub-vectors for PQ. Default (when 0): dim/16 or dim/8", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_num_bits_LanceDB = CaseConfigInput( + label=CaseConfigParamType.nbits, + displayLabel="Number of Bits", + inputHelp="Number of bits per sub-vector.", + inputType=InputType.Option, + inputConfig={ + "options": [4, 8], + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_sample_rate_LanceDB = CaseConfigInput( + label=CaseConfigParamType.sample_rate, + displayLabel="Sample Rate", + inputHelp="Sample rate for training. Higher values are more accurate but slower", + inputType=InputType.Number, + inputConfig={ + "min": 16, + "max": 1024, + "value": 256, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_max_iterations_LanceDB = CaseConfigInput( + label=CaseConfigParamType.max_iterations, + displayLabel="Max Iterations", + inputHelp="Maximum iterations for k-means clustering", + inputType=InputType.Number, + inputConfig={ + "min": 10, + "max": 200, + "value": 50, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFPQ.value + or config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_m_LanceDB = CaseConfigInput( + label=CaseConfigParamType.m, + displayLabel="m", + inputHelp="m parameter in HNSW", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +CaseConfigParamInput_ef_construction_LanceDB = CaseConfigInput( + label=CaseConfigParamType.ef_construction, + displayLabel="ef_construction", + inputHelp="ef_construction parameter in HNSW", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1000, + "value": 0, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, +) + +LanceDBLoadConfig = [ + CaseConfigParamInput_IndexType_LanceDB, + CaseConfigParamInput_num_partitions_LanceDB, + CaseConfigParamInput_num_sub_vectors_LanceDB, + CaseConfigParamInput_num_bits_LanceDB, + CaseConfigParamInput_sample_rate_LanceDB, + CaseConfigParamInput_max_iterations_LanceDB, + CaseConfigParamInput_m_LanceDB, + CaseConfigParamInput_ef_construction_LanceDB, +] + +LanceDBPerformanceConfig = LanceDBLoadConfig + CASE_CONFIG_MAP = { DB.Milvus: { CaseLabel.Load: MilvusLoadConfig, @@ -1551,4 +1672,8 @@ class CaseConfigInput(BaseModel): CaseLabel.Load: VespaLoadingConfig, CaseLabel.Performance: VespaPerformanceConfig, }, + DB.LanceDB: { + CaseLabel.Load: LanceDBLoadConfig, + CaseLabel.Performance: LanceDBPerformanceConfig, + }, } diff --git a/vectordb_bench/frontend/config/styles.py b/vectordb_bench/frontend/config/styles.py index 03bda0fec..96a5eede4 100644 --- a/vectordb_bench/frontend/config/styles.py +++ b/vectordb_bench/frontend/config/styles.py @@ -49,6 +49,7 @@ def getPatternShape(i): DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", DB.Vespa: "https://vespa.ai/vespa-content/uploads/2025/01/Vespa-symbol-green-rgb.png.webp", + DB.LanceDB: "https://raw.githubusercontent.com/lancedb/lancedb/main/docs/src/assets/logo.png", } # RedisCloud color: #0D6EFD diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 76ceaaddc..997831a00 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -97,6 +97,9 @@ class CaseConfigParamType(Enum): maxNumPrefetchDatasets = "max_num_prefetch_datasets" storage_engine = "storage_engine" max_cache_size = "max_cache_size" + num_partitions = "num_partitions" + num_sub_vectors = "num_sub_vectors" + sample_rate = "sample_rate" # mongodb params mongodb_quantization_type = "quantization" From 029666da49efd098ab00c54f84026ceb66381181 Mon Sep 17 00:00:00 2001 From: LoveYou3000 <760583490@qq.com> Date: Thu, 8 May 2025 00:09:46 +0800 Subject: [PATCH 32/36] Add --task-label option for cli (#517) * Add --task-label option for cli * Fix lint issues --- vectordb_bench/cli/cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 4b42a912c..3ec3c18cd 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -405,6 +405,7 @@ class CommonTypedDict(TypedDict): show_default=True, ), ] + task_label: Annotated[str, click.option("--task-label", help="Task label")] class HNSWBaseTypedDict(TypedDict): @@ -499,10 +500,11 @@ def run( parameters["search_concurrent"], ), ) + task_label = parameters["task_label"] log.info(f"Task:\n{pformat(task)}\n") if not parameters["dry_run"]: - benchmark_runner.run([task]) + benchmark_runner.run([task], task_label) time.sleep(5) if global_result_future: wait([global_result_future]) From 31b8cbdb4c504221b477806e250008e3745da2d7 Mon Sep 17 00:00:00 2001 From: Andreas Opferkuch Date: Tue, 6 May 2025 20:53:30 +0200 Subject: [PATCH 33/36] Add qdrant cli --- .../backend/clients/qdrant_cloud/cli.py | 43 +++++++++++++++++++ .../backend/clients/qdrant_cloud/config.py | 8 ++-- vectordb_bench/cli/vectordbbench.py | 2 + .../components/run_test/dbConfigSetting.py | 14 ++++-- 4 files changed, 59 insertions(+), 8 deletions(-) create mode 100644 vectordb_bench/backend/clients/qdrant_cloud/cli.py diff --git a/vectordb_bench/backend/clients/qdrant_cloud/cli.py b/vectordb_bench/backend/clients/qdrant_cloud/cli.py new file mode 100644 index 000000000..32353472b --- /dev/null +++ b/vectordb_bench/backend/clients/qdrant_cloud/cli.py @@ -0,0 +1,43 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB + + +class QdrantTypedDict(CommonTypedDict): + url: Annotated[ + str, + click.option("--url", type=str, help="URL connection string", required=True), + ] + api_key: Annotated[ + str | None, + click.option("--api-key", type=str, help="API key for authentication", required=False), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(QdrantTypedDict) +def QdrantCloud(**parameters: Unpack[QdrantTypedDict]): + from .config import QdrantConfig, QdrantIndexConfig + + config_params = { + "db_label": parameters["db_label"], + "url": SecretStr(parameters["url"]), + } + + config_params["api_key"] = SecretStr(parameters["api_key"]) if parameters["api_key"] else None + + run( + db=DB.QdrantCloud, + db_config=QdrantConfig(**config_params), + db_case_config=QdrantIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/qdrant_cloud/config.py b/vectordb_bench/backend/clients/qdrant_cloud/config.py index d4e27cb3c..b60733bc3 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/config.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/config.py @@ -6,14 +6,14 @@ # Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant. class QdrantConfig(DBConfig): url: SecretStr - api_key: SecretStr + api_key: SecretStr | None = None def to_dict(self) -> dict: - api_key = self.api_key.get_secret_value() - if len(api_key) > 0: + api_key_value = self.api_key.get_secret_value() if self.api_key else None + if api_key_value: return { "url": self.url.get_secret_value(), - "api_key": self.api_key.get_secret_value(), + "api_key": api_key_value, "prefer_grpc": True, } return { diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 210a3ccb7..d4153bc1e 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -9,6 +9,7 @@ from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat from ..backend.clients.pgvector.cli import PgVectorHNSW from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn +from ..backend.clients.qdrant_cloud.cli import QdrantCloud from ..backend.clients.redis.cli import Redis from ..backend.clients.test.cli import Test from ..backend.clients.tidb.cli import TiDB @@ -35,6 +36,7 @@ cli.add_command(Clickhouse) cli.add_command(Vespa) cli.add_command(LanceDB) +cli.add_command(QdrantCloud) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py index 800e6dede..a2d2de77f 100644 --- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py +++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py @@ -36,21 +36,27 @@ def dbConfigSettingItem(st, activeDb: DB): columns = st.columns(DB_CONFIG_SETTING_COLUMNS) dbConfigClass = activeDb.config_cls - properties = dbConfigClass.schema().get("properties") + schema = dbConfigClass.schema() + property_items = schema.get("properties").items() + required_fields = set(schema.get("required", [])) dbConfig = {} idx = 0 # db config (unique) - for key, property in properties.items(): + for key, property in property_items: if key not in dbConfigClass.common_short_configs() and key not in dbConfigClass.common_long_configs(): column = columns[idx % DB_CONFIG_SETTING_COLUMNS] idx += 1 - dbConfig[key] = column.text_input( + input_value = column.text_input( key, - key="%s-%s" % (activeDb.name, key), + key=f"{activeDb.name}-{key}", value=property.get("default", ""), type="password" if inputIsPassword(key) else "default", + placeholder="optional" if key not in required_fields else None, ) + if key in required_fields or input_value: + dbConfig[key] = input_value + # db config (common short labels) for key in dbConfigClass.common_short_configs(): column = columns[idx % DB_CONFIG_SETTING_COLUMNS] From 7d8464c95ead32bf24270ae21561f58b59a1c518 Mon Sep 17 00:00:00 2001 From: Yuyuan Kang <36235611+yuyuankang@users.noreply.github.com> Date: Sun, 11 May 2025 22:43:17 -0500 Subject: [PATCH 34/36] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 662461a73..a17220749 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,8 @@ Options: # Memory Management --cb-threshold TEXT k-NN Memory circuit breaker threshold - --help Show this message and exit.``` + --help Show this message and exit. + ``` #### Using a configuration file. From 975ba84a085a4d413c902f59bef800cc40cd1490 Mon Sep 17 00:00:00 2001 From: Yuyuan Kang <36235611+yuyuankang@users.noreply.github.com> Date: Tue, 13 May 2025 01:39:56 -0500 Subject: [PATCH 35/36] Fixing Bugs in Benchmarking ClickHouse with vectordbbench (#523) * Update cli.py * Update clickhouse.py * Update clickhouse.py * Update cli.py * Update config.py * remove space --- vectordb_bench/backend/clients/clickhouse/cli.py | 1 + vectordb_bench/backend/clients/clickhouse/clickhouse.py | 6 +++--- vectordb_bench/backend/clients/clickhouse/config.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vectordb_bench/backend/clients/clickhouse/cli.py b/vectordb_bench/backend/clients/clickhouse/cli.py index 6fc1a84d7..4b50bc55b 100644 --- a/vectordb_bench/backend/clients/clickhouse/cli.py +++ b/vectordb_bench/backend/clients/clickhouse/cli.py @@ -51,6 +51,7 @@ def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]): db=DB.Clickhouse, db_config=ClickhouseConfig( db_label=parameters["db_label"], + user=parameters["user"], password=SecretStr(parameters["password"]) if parameters["password"] else None, host=parameters["host"], port=parameters["port"], diff --git a/vectordb_bench/backend/clients/clickhouse/clickhouse.py b/vectordb_bench/backend/clients/clickhouse/clickhouse.py index 498132ca9..de09895a8 100644 --- a/vectordb_bench/backend/clients/clickhouse/clickhouse.py +++ b/vectordb_bench/backend/clients/clickhouse/clickhouse.py @@ -106,7 +106,7 @@ def _create_index(self): query = f""" ALTER TABLE {self.db_config["database"]}.{self.table_name} ADD INDEX {self._index_name} {self._vector_field} - TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}', + TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}',{self.dim}, '{self.index_param["quantization"]}', {self.index_param["params"]["M"]}, {self.index_param["params"]["efConstruction"]}) GRANULARITY {self.index_param["granularity"]} @@ -115,7 +115,7 @@ def _create_index(self): query = f""" ALTER TABLE {self.db_config["database"]}.{self.table_name} ADD INDEX {self._index_name} {self._vector_field} - TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}') + TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}', {self.dim}) GRANULARITY {self.index_param["granularity"]} """ self.conn.command(cmd=query) @@ -186,7 +186,7 @@ def search_embedding( "vector_field": self._vector_field, "schema": self.db_config["database"], "table": self.table_name, - "gt": filters.get("id"), + "gt": 0 if filters is None else filters.get("id", 0), "k": k, "metric_type": self.search_param["metric_type"], "query": query, diff --git a/vectordb_bench/backend/clients/clickhouse/config.py b/vectordb_bench/backend/clients/clickhouse/config.py index a4c5fe499..f9e09812b 100644 --- a/vectordb_bench/backend/clients/clickhouse/config.py +++ b/vectordb_bench/backend/clients/clickhouse/config.py @@ -16,7 +16,7 @@ class ClickhouseConfigDict(TypedDict): class ClickhouseConfig(DBConfig): - user_name: str = "clickhouse" + user: str = "clickhouse" password: SecretStr host: str = "localhost" port: int = 8123 @@ -29,7 +29,7 @@ def to_dict(self) -> ClickhouseConfigDict: "host": self.host, "port": self.port, "database": self.db_name, - "user": self.user_name, + "user": self.user, "password": pwd_str, "secure": self.secure, } From 556b703455324d6be800d80921b0962c469365d4 Mon Sep 17 00:00:00 2001 From: LoveYou3000 <760583490@qq.com> Date: Wed, 14 May 2025 12:03:09 +0800 Subject: [PATCH 36/36] Add --concurrency-timeout option to avoid long time waiting (#521) * Add --concurrency-timeout option to avoid long time waiting, by default, it's 3600s. * Fix lint error * Update README.md, add --concurrency-timeout option --- README.md | 4 ++++ vectordb_bench/__init__.py | 4 +++- vectordb_bench/backend/runner/mp_runner.py | 21 ++++++++++++++++----- vectordb_bench/backend/task_runner.py | 1 + vectordb_bench/cli/cli.py | 15 +++++++++++++-- vectordb_bench/models.py | 6 ++++++ 6 files changed, 43 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index a17220749..3d38d9444 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,10 @@ Options: --num-concurrency TEXT Comma-separated list of concurrency values to test during concurrent search [default: 1,10,20] + --concurrency-timeout INTEGER Timeout (in seconds) to wait for a + concurrency slot before failing. Set to a + negative value to wait indefinitely. + [default: 3600] --user-name TEXT Db username [required] --password TEXT Db password [required] --host TEXT Db host [required] diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py index c07fc855d..52c2094b4 100644 --- a/vectordb_bench/__init__.py +++ b/vectordb_bench/__init__.py @@ -6,7 +6,7 @@ from . import log_util env = environs.Env() -env.read_env(".env", False) +env.read_env(path=".env", recurse=False) class config: @@ -52,6 +52,8 @@ class config: CONCURRENCY_DURATION = 30 + CONCURRENCY_TIMEOUT = 3600 + RESULTS_LOCAL_DIR = env.path( "RESULTS_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("results"), diff --git a/vectordb_bench/backend/runner/mp_runner.py b/vectordb_bench/backend/runner/mp_runner.py index 687a0ecd7..fd87d7ece 100644 --- a/vectordb_bench/backend/runner/mp_runner.py +++ b/vectordb_bench/backend/runner/mp_runner.py @@ -5,10 +5,12 @@ import time import traceback from collections.abc import Iterable +from multiprocessing.queues import Queue import numpy as np from ... import config +from ...models import ConcurrencySlotTimeoutError from ..clients import api NUM_PER_BATCH = config.NUM_PER_BATCH @@ -28,16 +30,18 @@ def __init__( self, db: api.VectorDB, test_data: list[list[float]], - k: int = 100, + k: int = config.K_DEFAULT, filters: dict | None = None, concurrencies: Iterable[int] = config.NUM_CONCURRENCY, - duration: int = 30, + duration: int = config.CONCURRENCY_DURATION, + concurrency_timeout: int = config.CONCURRENCY_TIMEOUT, ): self.db = db self.k = k self.filters = filters self.concurrencies = concurrencies self.duration = duration + self.concurrency_timeout = concurrency_timeout self.test_data = test_data log.debug(f"test dataset columns: {len(test_data)}") @@ -114,9 +118,7 @@ def _run_all_concurrencies_mem_efficient(self): log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}") future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)] # Sync all processes - while q.qsize() < conc: - sleep_t = conc if conc < 10 else 10 - time.sleep(sleep_t) + self._wait_for_queue_fill(q, size=conc) with cond: cond.notify_all() @@ -160,6 +162,15 @@ def _run_all_concurrencies_mem_efficient(self): conc_latency_avg_list, ) + def _wait_for_queue_fill(self, q: Queue, size: int): + wait_t = 0 + while q.qsize() < size: + sleep_t = size if size < 10 else 10 + wait_t += sleep_t + if wait_t > self.concurrency_timeout > 0: + raise ConcurrencySlotTimeoutError + time.sleep(sleep_t) + def run(self) -> float: """ Returns: diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 2a583b4f5..1da3bb4cd 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -275,6 +275,7 @@ def _init_search_runner(self): filters=self.ca.filters, concurrencies=self.config.case_config.concurrency_search_config.num_concurrency, duration=self.config.case_config.concurrency_search_config.concurrency_duration, + concurrency_timeout=self.config.case_config.concurrency_search_config.concurrency_timeout, k=self.config.case_config.k, ) diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 3ec3c18cd..1b0eb295b 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -17,10 +17,9 @@ import click from yaml import load -from vectordb_bench.backend.clients.api import MetricType - from .. import config from ..backend.clients import DB +from ..backend.clients.api import MetricType from ..interface import benchmark_runner, global_result_future from ..models import ( CaseConfig, @@ -303,6 +302,17 @@ class CommonTypedDict(TypedDict): callback=lambda *args: list(map(int, click_arg_split(*args))), ), ] + concurrency_timeout: Annotated[ + int, + click.option( + "--concurrency-timeout", + type=int, + default=config.CONCURRENCY_TIMEOUT, + show_default=True, + help="Timeout (in seconds) to wait for a concurrency slot before failing. " + "Set to a negative value to wait indefinitely.", + ), + ] custom_case_name: Annotated[ str, click.option( @@ -490,6 +500,7 @@ def run( concurrency_search_config=ConcurrencySearchConfig( concurrency_duration=parameters["concurrency_duration"], num_concurrency=[int(s) for s in parameters["num_concurrency"]], + concurrency_timeout=parameters["concurrency_timeout"], ), custom_case=get_custom_case_config(parameters), ), diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 997831a00..c35c21755 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -30,6 +30,11 @@ def __init__(self): super().__init__("Performance case optimize timeout") +class ConcurrencySlotTimeoutError(TimeoutError): + def __init__(self): + super().__init__("Timeout while waiting for a concurrency slot to become available") + + class CaseConfigParamType(Enum): """ Value will be the key of CaseConfig.params and displayed in UI @@ -113,6 +118,7 @@ class CustomizedCase(BaseModel): class ConcurrencySearchConfig(BaseModel): num_concurrency: list[int] = config.NUM_CONCURRENCY concurrency_duration: int = config.CONCURRENCY_DURATION + concurrency_timeout: int = config.CONCURRENCY_TIMEOUT class CaseConfig(BaseModel):