Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class DB(Enum):
Pinecone = "Pinecone"
ElasticCloud = "ElasticCloud"
QdrantCloud = "QdrantCloud"
QdrantLocal = "QdrantLocal"
WeaviateCloud = "WeaviateCloud"
PgVector = "PgVector"
PgVectoRS = "PgVectoRS"
Expand All @@ -46,6 +47,7 @@ class DB(Enum):
Clickhouse = "Clickhouse"
Vespa = "Vespa"
LanceDB = "LanceDB"


@property
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
Expand Down Expand Up @@ -74,6 +76,11 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
from .qdrant_cloud.qdrant_cloud import QdrantCloud

return QdrantCloud

if self == DB.QdrantLocal:
from .qdrant_local.qdrant_local import QdrantLocal

return QdrantLocal

if self == DB.WeaviateCloud:
from .weaviate_cloud.weaviate_cloud import WeaviateCloud
Expand Down Expand Up @@ -200,6 +207,9 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915
from .qdrant_cloud.config import QdrantConfig

return QdrantConfig

if self == DB.QdrantLocal:
from .qdrant_local.config import QdrantLocalConfig

if self == DB.WeaviateCloud:
from .weaviate_cloud.config import WeaviateConfig
Expand Down Expand Up @@ -322,6 +332,11 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912
from .qdrant_cloud.config import QdrantIndexConfig

return QdrantIndexConfig

if self == DB.QdrantLocal:
from .qdrant_local.config import QdrantLocalIndexConfig

return QdrantLocalIndexConfig

if self == DB.WeaviateCloud:
from .weaviate_cloud.config import WeaviateIndexConfig
Expand Down
65 changes: 65 additions & 0 deletions vectordb_bench/backend/clients/qdrant_local/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Annotated, TypedDict, Unpack

import click
from pydantic import SecretStr

from vectordb_bench.backend.clients import DB
from vectordb_bench.cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)


DBTYPE = DB.QdrantLocal


class QdrantLocalTypedDict(CommonTypedDict):
url: Annotated[
str,
click.option("--url", type=str, help="Qdrant url", required=True),
]
on_disk: Annotated[
bool,
click.option(
"--on-disk", type=bool, default=False, help="Store the vectors and the HNSW index on disk"
),
]
m: Annotated[
int,
click.option(
"--m", type=int, default=16, help="HNSW index parameter m, set 0 to disable the index"
),
]
ef_construct: Annotated[
int,
click.option(
"--ef-construct", type=int, default=200, help="HNSW index parameter ef_construct"
),
]
hnsw_ef: Annotated[
int,
click.option(
"--hnsw-ef", type=int, default=0, help="HNSW index parameter hnsw_ef, set 0 to use ef_construct for search",
),
]

@cli.command()
@click_parameter_decorators_from_typed_dict(QdrantLocalTypedDict)
def QdrantLocal(**parameters: Unpack[QdrantLocalTypedDict]):
from .config import QdrantLocalConfig, QdrantLocalIndexConfig

run(
db=DBTYPE,
db_config=QdrantLocalConfig(
url=SecretStr(parameters["url"])
),
db_case_config=QdrantLocalIndexConfig(
on_disk=parameters["on_disk"],
m=parameters["m"],
ef_construct=parameters["ef_construct"],
hnsw_ef=parameters["hnsw_ef"],
),
**parameters,
)
46 changes: 46 additions & 0 deletions vectordb_bench/backend/clients/qdrant_local/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType

class QdrantLocalConfig(DBConfig):
url: SecretStr

def to_dict(self) -> dict:
return {
"url": self.url.get_secret_value(),
}


class QdrantLocalIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
m: int
ef_construct: int
hnsw_ef: int | None = 0
on_disk: bool | None = False

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "Euclid"

if self.metric_type == MetricType.IP:
return "Dot"

return "Cosine"

def index_param(self) -> dict:
return {
"distance": self.parse_metric(),
"m": self.m,
"ef_construct": self.ef_construct,
"on_disk": self.on_disk,
}

def search_param(self) -> dict:
search_params = {
"exact": False, # Force to use ANNs
}

if self.hnsw_ef != 0:
search_params["hnsw_ef"] = self.hnsw_ef

return search_params
231 changes: 231 additions & 0 deletions vectordb_bench/backend/clients/qdrant_local/qdrant_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Wrapper around the Qdrant over VectorDB"""

import logging
import time
from collections.abc import Iterable
from contextlib import contextmanager

from qdrant_client import QdrantClient
from qdrant_client.http.models import (
Batch,
CollectionStatus,
FieldCondition,
Filter,
HnswConfigDiff,
OptimizersConfigDiff,
PayloadSchemaType,
Range,
SearchParams,
VectorParams,
)

from ..api import VectorDB
from .config import QdrantLocalIndexConfig

log = logging.getLogger(__name__)

SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
QDRANT_BATCH_SIZE = 100


def qdrant_collection_exists(client, collection_name: str) -> bool:
collection_exists = True

try:
client.get_collection(collection_name)
except Exception as e:
collection_exists = False

return collection_exists

class QdrantLocal(VectorDB):
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: dict,
collection_name: str = "QdrantLocalCollection",
drop_old: bool = False,
name: str = "QdrantLocal",
**kwargs,
):
"""Initialize wrapper around the qdrant."""
self.name = name
self.db_config = db_config
self.case_config = db_case_config
self.search_parameter = self.case_config.search_param()
self.collection_name = collection_name
self.client = None

self._primary_field = "pk"
self._vector_field = "vector"

client = QdrantClient(**self.db_config)

# Lets just print the parameters here for double check
log.info(f"Case config: {self.case_config.index_param()}")
log.info(f"Search parameter: {self.search_parameter}")

if drop_old and qdrant_collection_exists(client, self.collection_name):
log.info(f"{self.name} client drop_old collection: {self.collection_name}")
client.delete_collection(self.collection_name)

if not qdrant_collection_exists(client, self.collection_name):
log.info(f"{self.name} create collection: {self.collection_name}")
self._create_collection(dim, client)

client = None

@contextmanager
def init(self):
"""
Examples:
>>> with self.init():
>>> self.insert_embeddings()
>>> self.search_embedding()
"""
# create connection
self.client = QdrantClient(**self.db_config)
yield
self.client = None
del self.client

def _create_collection(self, dim: int, qdrant_client: QdrantClient):
log.info(f"Create collection: {self.collection_name}")
log.info(f"Index parameters: m={self.case_config.index_param()['m']}, ef_construct={self.case_config.index_param()['ef_construct']}, on_disk={self.case_config.index_param()['on_disk']}")

# If the on_disk is true, we enable both on disk index and vectors.
try:
qdrant_client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=dim,
distance=self.case_config.index_param()["distance"],
on_disk=self.case_config.index_param()["on_disk"],
),
hnsw_config=HnswConfigDiff(
m = self.case_config.index_param()["m"],
ef_construct=self.case_config.index_param()["ef_construct"],
on_disk=self.case_config.index_param()["on_disk"],
)
)

qdrant_client.create_payload_index(
collection_name=self.collection_name,
field_name=self._primary_field,
field_schema=PayloadSchemaType.INTEGER,
)

except Exception as e:
if "already exists!" in str(e):
return
log.warning(f"Failed to create collection: {self.collection_name} error: {e}")
raise e from None

def optimize(self, data_size: int | None = None):
assert self.client, "Please call self.init() before"
# wait for vectors to be fully indexed
try:
while True:
info = self.client.get_collection(self.collection_name)
time.sleep(SECONDS_WAITING_FOR_INDEXING_API_CALL)
if info.status != CollectionStatus.GREEN:
continue
if info.status == CollectionStatus.GREEN:
log.info(f"Finishing building index for collection: {self.collection_name}")
msg = (
f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, "
f"Collection status: {info.indexed_vectors_count}"
)
log.info(msg)
return

except Exception as e:
log.warning(f"QdrantCloud ready to search error: {e}")
raise e from None

def insert_embeddings(
self,
embeddings: Iterable[list[float]],
metadata: list[int],
**kwargs,
) -> tuple[int, Exception]:
"""Insert embeddings into the database.

Args:
embeddings(list[list[float]]): list of embeddings
metadata(list[int]): list of metadata
kwargs: other arguments

Returns:
tuple[int, Exception]: number of embeddings inserted and exception if any
"""
assert self.client is not None
assert len(embeddings) == len(metadata)
insert_count = 0

# disable indexing for quick insertion
self.client.update_collection(
collection_name=self.collection_name,
optimizer_config=OptimizersConfigDiff(indexing_threshold=0),
)
try:
for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE):
vectors = embeddings[offset : offset + QDRANT_BATCH_SIZE]
ids = metadata[offset : offset + QDRANT_BATCH_SIZE]
payloads = [{self._primary_field: v} for v in ids]
_ = self.client.upsert(
collection_name=self.collection_name,
wait=True,
points=Batch(ids=ids, payloads=payloads, vectors=vectors),
)
insert_count += QDRANT_BATCH_SIZE
# enable indexing after insertion
self.client.update_collection(
collection_name=self.collection_name,
optimizer_config=OptimizersConfigDiff(indexing_threshold=100),
)

except Exception as e:
log.info(f"Failed to insert data, {e}")
return insert_count, e
else:
return insert_count, None

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
) -> list[int]:
"""Perform a search on a query embedding and return results with score.
Should call self.init() first.
"""
assert self.client is not None

f = None
if filters:
f = Filter(
must=[
FieldCondition(
key=self._primary_field,
range=Range(
gt=filters.get("id"),
),
),
],
)
res = (
self.client.query_points(
collection_name=self.collection_name,
query=query,
limit=k,
query_filter=f,
search_params=SearchParams(**self.search_parameter),

).points
)

return [result.id for result in res]

Loading
Loading