From 0575c9ae44ab05cb281517a00ac1ed58ccbb7757 Mon Sep 17 00:00:00 2001 From: Yongqiang YANG Date: Fri, 2 Jan 2026 20:48:25 -0800 Subject: [PATCH 1/3] feat: Add Apache Doris vector store support This commit adds Apache Doris as a new vector database option for Dify's RAG system. Features: - Vector similarity search using cosine distance - Full-text search with BM25 scoring and inverted indexes - Hybrid search combining vector and text search - High-performance bulk data loading via StreamLoad - Connection pooling for efficient resource management - Support for multi-tenant isolation Components added: - DorisVector: Main vector database implementation with cleaned code - DorisConfig: Configuration model with validation - DorisConnectionPool: Thread-safe connection management - DorisVectorFactory: Factory for creating Doris instances - DORIS_SETUP.md: Complete setup guide in English --- DORIS_SETUP.md | 168 +++++ api/configs/middleware/__init__.py | 12 +- api/configs/middleware/vdb/doris_config.py | 73 +++ api/controllers/console/datasets/datasets.py | 1 + api/core/rag/datasource/vdb/doris/__init__.py | 5 + .../rag/datasource/vdb/doris/doris_vector.py | 599 +++++++++++++++++ api/core/rag/datasource/vdb/vector_factory.py | 4 + api/core/rag/datasource/vdb/vector_type.py | 1 + .../integration_tests/vdb/doris/__init__.py | 0 .../integration_tests/vdb/doris/test_doris.py | 47 ++ .../core/rag/datasource/vdb/doris/__init__.py | 0 .../datasource/vdb/doris/test_doris_vector.py | 619 ++++++++++++++++++ docker/docker-compose.yaml | 10 + docker/test_doris.py | 190 ++++++ 14 files changed, 1726 insertions(+), 3 deletions(-) create mode 100644 DORIS_SETUP.md create mode 100644 api/configs/middleware/vdb/doris_config.py create mode 100644 api/core/rag/datasource/vdb/doris/__init__.py create mode 100644 api/core/rag/datasource/vdb/doris/doris_vector.py create mode 100644 api/tests/integration_tests/vdb/doris/__init__.py create mode 100644 api/tests/integration_tests/vdb/doris/test_doris.py create mode 100644 api/tests/unit_tests/core/rag/datasource/vdb/doris/__init__.py create mode 100644 api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py create mode 100755 docker/test_doris.py diff --git a/DORIS_SETUP.md b/DORIS_SETUP.md new file mode 100644 index 00000000000000..5651b4657ebfdc --- /dev/null +++ b/DORIS_SETUP.md @@ -0,0 +1,168 @@ +# Apache Doris Vector Store Configuration Guide for Dify + +## Prerequisites + +1. **Apache Doris Installed and Running** + - Doris FE (Frontend) running on port 8030 (HTTP) and 9030 (MySQL protocol) + - Doris BE (Backend) started and connected to FE + - Ensure Doris version >= 2.0 (supports vector search and text search) + +2. **Create Database** + ```sql + CREATE DATABASE IF NOT EXISTS dify; + ``` + +## Configuration Steps + +### Method 1: Using Docker Compose (Recommended) + +1. **Edit `.env` file** (in the `docker` directory) + + If the file doesn't exist, create it from the example file: + ```bash + cd docker + cp .env.example .env + ``` + +2. **Set Vector Store to Doris** + + Add or modify the following configuration in the `.env` file: + ```bash + # Vector Store configuration + VECTOR_STORE=doris + + # Doris connection configuration + DORIS_HOST=your-doris-fe-host # Doris FE host address, e.g., localhost or 127.0.0.1 + DORIS_PORT=9030 # Doris MySQL protocol port (default 9030) + DORIS_USER=root # Doris username + DORIS_PASSWORD=your-password # Doris password + DORIS_DATABASE=dify # Database name + + # Doris StreamLoad configuration + DORIS_STREAMLOAD_PORT=8030 # Doris HTTP port (default 8030) + DORIS_STREAMLOAD_SCHEME=http # HTTP scheme: http or https (default http) + DORIS_STREAMLOAD_MAX_FILTER_RATIO=0.1 # Maximum ratio of filtered rows (0.0-1.0, default 0.1) + + # Connection pool configuration (optional) + DORIS_MAX_CONNECTION=5 # Maximum connections (default 5) + + # Table configuration (optional) + DORIS_TABLE_REPLICATION_NUM=1 # Table replication number (default 1) + DORIS_TABLE_BUCKETS=10 # Number of table buckets (default 10) + + # Text search configuration (optional) + DORIS_ENABLE_TEXT_SEARCH=true # Enable full-text search (default true) + DORIS_TEXT_SEARCH_ANALYZER=english # Text analyzer: english, chinese, standard, unicode, default (default english) + ``` + +3. **Start Services** + ```bash + cd docker + docker compose up -d + ``` + +### Method 2: Local Development Environment + +1. **Set Environment Variables** + + Before running Dify API, set the following environment variables: + ```bash + export VECTOR_STORE=doris + export DORIS_HOST=localhost + export DORIS_PORT=9030 + export DORIS_USER=root + export DORIS_PASSWORD=your-password + export DORIS_DATABASE=dify + export DORIS_STREAMLOAD_PORT=8030 + ``` + + Or set them in a `.env` file (if using python-dotenv) + +2. **Run API Service** + ```bash + cd api + uv run --project api flask run + ``` + +## Verify Configuration + +### 1. Check Doris Connection + +Connect to Doris using MySQL client: +```bash +mysql -h your-doris-host -P 9030 -u root -p +``` + +### 2. Test Doris HTTP Endpoint + +Check if Doris FE HTTP endpoint is accessible: +```bash +curl http://your-doris-host:8030/api/v1/health +``` + +### 3. Create Dataset in Dify + +1. Login to Dify Web interface +2. Create a new dataset +3. Upload documents for indexing +4. Check if corresponding tables are created in Doris database: + ```sql + USE dify; + SHOW TABLES LIKE 'embedding_%'; + ``` + +## Features + +Doris Vector Store supports the following features: + +- ✅ **Vector Similarity Search**: Semantic search using `cosine_distance` +- ✅ **Full-text Search**: Keyword search using `MATCH_ANY` and BM25 scoring +- ✅ **Hybrid Search**: Supports both vector search and text search simultaneously +- ✅ **StreamLoad Batch Import**: High-performance bulk data loading +- ✅ **Connection Pool Management**: Automatic database connection management + +## Troubleshooting + +### Issue: Connection Failed + +**Check:** +1. Is Doris FE running? +2. Are ports correct (MySQL: 9030, HTTP: 8030)? +3. Are username and password correct? +4. Does firewall allow the connection? + +### Issue: StreamLoad Failed + +**Check:** +1. Is Doris HTTP port (8030) accessible? +2. Does the user have StreamLoad permissions? +3. Check error messages in Doris FE logs + +### Issue: Table Creation Failed + +**Check:** +1. Does the database exist? +2. Does the user have CREATE TABLE permissions? +3. Check error messages in Doris logs + +## Performance Optimization Recommendations + +1. **Adjust Connection Pool Size** + - Adjust `DORIS_MAX_CONNECTION` based on concurrent request volume + - Recommended value: concurrent requests + 2 + +2. **Text Analyzer Selection** + - English content: use `english` + - Chinese content: use `chinese` + - Multilingual: use `standard` + +3. **Batch Insertion** + - StreamLoad automatically processes data in batches + - Recommended: 100-1000 records per insertion + +## Reference Documentation + +- [Apache Doris Official Documentation](https://doris.apache.org/) +- [Doris Vector Search Documentation](https://doris.apache.org/docs/latest/ai/vector-search/overview) +- [Doris Text Search Documentation](https://doris.apache.org/docs/latest/ai/text-search/overview) +- [Doris StreamLoad Documentation](https://doris.apache.org/docs/data-operate/import/stream-load-manual) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 63f75924bfa4ce..cafbb2b34a4790 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -24,6 +24,7 @@ from .vdb.chroma_config import ChromaConfig from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig +from .vdb.doris_config import DorisVectorConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig from .vdb.iris_config import IrisVectorConfig @@ -107,8 +108,8 @@ class KeywordStoreConfig(BaseSettings): class DatabaseConfig(BaseSettings): # Database type selector - DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field( - description="Database type to use. OceanBase is MySQL-compatible.", + DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb", "doris"] = Field( + description="Database type to use. OceanBase and Doris are MySQL-compatible.", default="postgresql", ) @@ -150,7 +151,11 @@ class DatabaseConfig(BaseSettings): @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str: - return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql" + if self.DB_TYPE == "postgresql": + return "postgresql" + else: + # mysql, oceanbase, seekdb, doris all use MySQL protocol + return "mysql+pymysql" @computed_field # type: ignore[prop-decorator] @property @@ -336,6 +341,7 @@ class MiddlewareConfig( AnalyticdbConfig, ChromaConfig, ClickzettaConfig, + DorisVectorConfig, HuaweiCloudConfig, IrisVectorConfig, MilvusConfig, diff --git a/api/configs/middleware/vdb/doris_config.py b/api/configs/middleware/vdb/doris_config.py new file mode 100644 index 00000000000000..3ca3dc0c295429 --- /dev/null +++ b/api/configs/middleware/vdb/doris_config.py @@ -0,0 +1,73 @@ +"""Configuration for Apache Doris vector database.""" + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class DorisVectorConfig(BaseSettings): + """Configuration settings for Apache Doris vector database connection.""" + + DORIS_HOST: str | None = Field( + description="Hostname or IP address of the Apache Doris server.", + default=None, + ) + + DORIS_PORT: PositiveInt = Field( + description="Port number for Apache Doris MySQL protocol connection.", + default=9030, + ) + + DORIS_USER: str | None = Field( + description="Username for Apache Doris authentication.", + default=None, + ) + + DORIS_PASSWORD: str | None = Field( + description="Password for Apache Doris authentication.", + default=None, + ) + + DORIS_DATABASE: str | None = Field( + description="Database name in Apache Doris.", + default=None, + ) + + DORIS_MAX_CONNECTION: PositiveInt = Field( + description="Maximum number of connections in the pool.", + default=5, + ) + + DORIS_ENABLE_TEXT_SEARCH: bool = Field( + description="Enable full-text search with inverted indexes.", + default=True, + ) + + DORIS_TEXT_SEARCH_ANALYZER: str | None = Field( + description="Text search analyzer (e.g., 'english', 'chinese', 'standard').", + default="english", + ) + + DORIS_STREAMLOAD_PORT: PositiveInt = Field( + description="Port number for Apache Doris StreamLoad HTTP endpoint.", + default=8030, + ) + + DORIS_STREAMLOAD_SCHEME: str = Field( + description="HTTP scheme for StreamLoad endpoint ('http' or 'https').", + default="http", + ) + + DORIS_STREAMLOAD_MAX_FILTER_RATIO: float = Field( + description="Maximum ratio of filtered rows allowed in StreamLoad (0.0-1.0).", + default=0.1, + ) + + DORIS_TABLE_REPLICATION_NUM: PositiveInt = Field( + description="Replication number for Doris tables.", + default=1, + ) + + DORIS_TABLE_BUCKETS: PositiveInt = Field( + description="Number of buckets for Doris table distribution.", + default=10, + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8ceb896d4f6023..1881aaed74f32f 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -232,6 +232,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.BAIDU, VectorType.ALIBABACLOUD_MYSQL, VectorType.IRIS, + VectorType.DORIS, } semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} diff --git a/api/core/rag/datasource/vdb/doris/__init__.py b/api/core/rag/datasource/vdb/doris/__init__.py new file mode 100644 index 00000000000000..e718506984db34 --- /dev/null +++ b/api/core/rag/datasource/vdb/doris/__init__.py @@ -0,0 +1,5 @@ +"""Apache Doris vector database implementation for Dify.""" + +from .doris_vector import DorisConfig, DorisVector, DorisVectorFactory + +__all__ = ["DorisConfig", "DorisVector", "DorisVectorFactory"] diff --git a/api/core/rag/datasource/vdb/doris/doris_vector.py b/api/core/rag/datasource/vdb/doris/doris_vector.py new file mode 100644 index 00000000000000..d408712ed707d7 --- /dev/null +++ b/api/core/rag/datasource/vdb/doris/doris_vector.py @@ -0,0 +1,599 @@ +""" +Apache Doris vector database implementation for Dify's RAG system. + +This module provides integration with Apache Doris vector database for storing and retrieving +document embeddings used in retrieval-augmented generation workflows. + +Apache Doris supports both vector search and full-text search with BM25 scoring, +enabling hybrid search capabilities. +""" + +import base64 +import hashlib +import json +import logging +import uuid +from contextlib import contextmanager +from typing import Any +from urllib.parse import quote, urljoin + +import httpx +from mysql.connector import pooling +from pydantic import BaseModel, model_validator + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +VALID_TEXT_ANALYZERS = {"english", "chinese", "standard", "unicode", "default"} + + +class DorisConfig(BaseModel): + """Configuration model for Apache Doris connection settings.""" + + host: str + port: int + user: str + password: str + database: str + max_connection: int + enable_text_search: bool = True + text_search_analyzer: str = "english" + streamload_port: int = 8030 + streamload_scheme: str = "http" + streamload_max_filter_ratio: float = 0.1 + table_replication_num: int = 1 + table_buckets: int = 10 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """Validates that required configuration values are present.""" + if not values.get("host"): + raise ValueError("config DORIS_HOST is required") + if not values.get("user"): + raise ValueError("config DORIS_USER is required") + if not values.get("password"): + raise ValueError("config DORIS_PASSWORD is required") + if not values.get("database"): + raise ValueError("config DORIS_DATABASE is required") + # Validate text search analyzer + analyzer = values.get("text_search_analyzer", "english") + if analyzer and analyzer not in VALID_TEXT_ANALYZERS: + raise ValueError(f"config DORIS_TEXT_SEARCH_ANALYZER must be one of {VALID_TEXT_ANALYZERS}") + # Validate streamload scheme + scheme = values.get("streamload_scheme", "http") + if scheme not in ("http", "https"): + raise ValueError("config DORIS_STREAMLOAD_SCHEME must be 'http' or 'https'") + return values + + +class DorisConnectionPool: + """Thread-safe connection pool for Apache Doris database.""" + + def __init__(self, config: DorisConfig) -> None: + self.config = config + self._pool_config = { + "pool_name": "doris_pool", + "pool_size": config.max_connection, + "pool_reset_session": True, + "host": config.host, + "port": config.port, + "user": config.user, + "password": config.password, + "database": config.database, + "charset": "utf8mb4", + "autocommit": False, + } + self._pool = pooling.MySQLConnectionPool(**self._pool_config) + + def get_connection(self) -> Any: + """Get a connection from pool.""" + return self._pool.get_connection() + + +class DorisVector(BaseVector): + """ + Apache Doris vector database implementation for document storage and retrieval. + + Handles creation, insertion, deletion, and querying of document embeddings + in Apache Doris tables with support for both vector similarity search and + full-text search with BM25 scoring. + """ + + def __init__(self, collection_name: str, config: DorisConfig, attributes: list): + """ + Initializes the Apache Doris vector store. + + Args: + collection_name: Name of the Doris table/collection + config: Doris configuration settings + attributes: List of metadata attributes to store + """ + super().__init__(collection_name) + self._pool = DorisConnectionPool(config) + self._attributes = attributes + self._config = config + # Table name format: embedding_ + collection_name + # collection_name already includes Vector_index_ prefix and _Node suffix from Dataset.gen_collection_name_by_id + self.table_name = f"embedding_{collection_name}" + self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8] + + def get_type(self) -> str: + """Returns the vector database type identifier.""" + return VectorType.DORIS + + @contextmanager + def _get_cursor(self): + """Context manager for database cursor.""" + conn = self._pool.get_connection() + cur = conn.cursor(dictionary=True) + try: + yield cur + conn.commit() + except Exception: + conn.rollback() + raise + finally: + cur.close() + conn.close() + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + """ + Creates a new table and adds initial documents with embeddings. + + Args: + texts: List of Document objects to insert + embeddings: List of embedding vectors + **kwargs: Additional arguments + """ + dimension = len(embeddings[0]) if embeddings else 0 + self._create_collection(dimension) + return self.add_texts(texts, embeddings) + + def _create_collection(self, dimension: int): + """ + Creates the Doris table with required schema if it doesn't exist. + + Uses Redis locking to prevent concurrent creation attempts. + """ + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(cache_key): + return + + try: + with self._get_cursor() as cur: + # Create table with vector column and text column + # Doris uses ARRAY for vector type + # Use backticks for table name quoting (MySQL/Doris standard) + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS `{self.table_name}` ( + id VARCHAR(255) NOT NULL, + text TEXT NOT NULL, + meta JSON NOT NULL, + embedding ARRAY NOT NULL + ) ENGINE=OLAP + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS {self._config.table_buckets} + PROPERTIES ( + "replication_num" = "{self._config.table_replication_num}" + ) + """ + cur.execute(create_table_sql) + + # Create vector index using ANN (Approximate Nearest Neighbor) + # Using HNSW algorithm with L2 distance for efficient similarity search + create_vector_index_sql = f""" + CREATE INDEX IF NOT EXISTS idx_embedding_{self.index_hash} + ON `{self.table_name}`(embedding) + USING ANN + PROPERTIES( + "index_type" = "hnsw", + "metric_type" = "l2_distance", + "dim" = "{dimension}" + ) + """ + cur.execute(create_vector_index_sql) + + # Create inverted index for full-text search if enabled + if self._config.enable_text_search: + try: + analyzer = self._config.text_search_analyzer or "english" + create_text_index_sql = f""" + CREATE INDEX IF NOT EXISTS idx_text_{self.index_hash} + ON `{self.table_name}`(text) + USING INVERTED + PROPERTIES ( + "parser" = "{analyzer}", + "support_phrase" = "true" + ) + """ + cur.execute(create_text_index_sql) + except Exception as e: + logger.warning("Could not create text search index: %s", e) + + redis_client.set(cache_key, 1, ex=3600) + except Exception: + logger.exception("Error creating table %s", self.table_name) + raise + + def _streamload(self, data: list[dict]) -> None: + """ + Load data into Doris using StreamLoad HTTP API. + + Args: + data: List of dictionaries containing row data + + Raises: + Exception: If StreamLoad fails + """ + if not data: + return + + # Format data as JSON array for StreamLoad + # With strip_outer_array=true, Doris will parse each element in the array as a row + json_data = json.dumps(data) + + # StreamLoad endpoint URL with URL encoding for database and table names + encoded_database = quote(self._config.database, safe="") + encoded_table = quote(self.table_name, safe="") + url = f"{self._config.streamload_scheme}://{self._config.host}:{self._config.streamload_port}/api/{encoded_database}/{encoded_table}/_stream_load" + + # StreamLoad parameters + # Format parameters are now in headers, only keep load_mem_limit in params + params = { + "load_mem_limit": "2147483648", # 2GB + } + + # Headers for authentication and StreamLoad configuration + # Doris StreamLoad uses Basic Auth with base64 encoding + auth_string = f"{self._config.user}:{self._config.password}" + auth_bytes = auth_string.encode("utf-8") + auth_b64 = base64.b64encode(auth_bytes).decode("utf-8") + + headers = { + "Authorization": f"Basic {auth_b64}", + "Content-Type": "application/json", + "Expect": "100-continue", + "format": "json", # Specify format in header + "strip_outer_array": "true", # Parse each array element as a row + "strict_mode": "false", # Disable strict mode to allow data type conversion + "max_filter_ratio": str(self._config.streamload_max_filter_ratio), + "fuzzy_parse": "true", # Enable fuzzy parsing for better compatibility + "jsonpaths": '["$.id", "$.text", "$.meta", "$.embedding"]', # Explicit column mapping + "columns": "id,text,meta,embedding", # Column order in table + } + + try: + # Manually handle redirects to ensure Authorization header is preserved + max_redirects = 5 + redirect_count = 0 + current_url = url + response = None + + # Disable auto-follow to manually handle redirects + with httpx.Client(timeout=300.0, follow_redirects=False) as client: + while redirect_count < max_redirects: + # For redirects, check if URL already contains query params + # If redirect URL contains params, don't add them again + if redirect_count == 0 or "?" not in current_url: + request_params = params + else: + request_params = None # Redirect URL already has params + + response = client.put( + current_url, + content=json_data.encode("utf-8"), + params=request_params, + headers=headers, + ) + + # Handle redirect + if response.status_code in (301, 302, 303, 307, 308): + redirect_count += 1 + location = response.headers.get("Location") + if not location: + raise Exception("Redirect response missing Location header") + + # Parse redirect URL + if location.startswith("http://") or location.startswith("https://"): + current_url = location + else: + # Relative redirect + current_url = urljoin(current_url, location) + + logger.info("Following redirect %s to %s", redirect_count, current_url) + continue + + # Not a redirect, break the loop + break + + if response is None: + raise Exception("No response received after redirects") + + response.raise_for_status() + result = response.json() + + # Check StreamLoad status + if result.get("Status") != "Success": + error_msg = result.get("Message", "Unknown error") + error_url = result.get("ErrorURL", "") + # Log full response for debugging + logger.error("StreamLoad failed. Full response: %s", json.dumps(result, indent=2)) + raise Exception(f"StreamLoad failed: {error_msg}. ErrorURL: {error_url}") + + # Log success with details + loaded_rows = result.get("NumberLoadedRows", 0) + filtered_rows = result.get("NumberFilteredRows", 0) + total_rows = result.get("NumberTotalRows", len(data)) + logger.info( + "StreamLoad completed: %s/%s rows loaded, %s rows filtered", + loaded_rows, + total_rows, + filtered_rows, + ) + + # Warn if any rows were filtered + if filtered_rows > 0: + logger.warning( + "StreamLoad filtered %s rows. Check ErrorURL if available: %s", + filtered_rows, + result.get("ErrorURL", "N/A"), + ) + + except httpx.HTTPError as e: + logger.exception("StreamLoad HTTP request failed") + raise Exception(f"StreamLoad request failed: {str(e)}") from e + except Exception: + logger.exception("StreamLoad failed") + raise + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + """ + Adds documents with their embeddings to the table using StreamLoad. + + Args: + documents: List of Document objects + embeddings: List of embedding vectors + **kwargs: Additional arguments + + Returns: + List of inserted document IDs + """ + if not documents or not embeddings: + return [] + + pks = [] + streamload_data = [] + for i, doc in enumerate(documents): + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + + # Format data for StreamLoad JSON format + # Embedding needs to be formatted as array of floats for Doris ARRAY + embedding_array = [float(x) for x in embeddings[i]] + + # Ensure meta is a dict (Doris JSON type accepts dict) + meta_dict = doc.metadata if isinstance(doc.metadata, dict) else {} + + row_data = { + "id": str(doc_id), + "text": str(doc.page_content) if doc.page_content else "", + "meta": meta_dict, # Doris JSON type accepts dict directly + "embedding": embedding_array, + } + streamload_data.append(row_data) + + if streamload_data: + self._streamload(streamload_data) + + return pks + + def text_exists(self, id: str) -> bool: + """Checks if a document with the given doc_id exists in the table.""" + with self._get_cursor() as cur: + cur.execute(f"SELECT id FROM `{self.table_name}` WHERE id = %s", (id,)) + return cur.fetchone() is not None + + def delete_by_ids(self, ids: list[str]) -> None: + """ + Deletes objects by their ID identifiers. + + Args: + ids: List of document IDs to delete + """ + if not ids: + return + + with self._get_cursor() as cur: + try: + placeholders = ",".join(["%s"] * len(ids)) + cur.execute(f"DELETE FROM `{self.table_name}` WHERE id IN ({placeholders})", ids) + except Exception as e: + logger.warning("Error deleting documents: %s", e) + raise + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """ + Deletes all objects matching a specific metadata field value. + + Args: + key: Metadata field key + value: Metadata field value + """ + with self._get_cursor() as cur: + try: + # Use JSON_EXTRACT for JSON field access + cur.execute( + f"DELETE FROM `{self.table_name}` WHERE JSON_EXTRACT(meta, %s) = %s", + (f"$.{key}", value), + ) + except Exception as e: + logger.warning("Error deleting by metadata field: %s", e) + raise + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """ + Performs vector similarity search using the provided query vector. + + Args: + query_vector: Query embedding vector + **kwargs: Additional search parameters (top_k, score_threshold, document_ids_filter) + + Returns: + List of Document objects sorted by relevance score + """ + top_k = int(kwargs.get("top_k", 4)) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + document_ids_filter = kwargs.get("document_ids_filter") or [] + + # Build WHERE clause for document filtering + where_clause = "" + params = [] + if document_ids_filter: + placeholders = ",".join(["%s"] * len(document_ids_filter)) + where_clause = f"WHERE JSON_EXTRACT(meta, '$.document_id') IN ({placeholders})" + params.extend(document_ids_filter) + + # Convert query vector to string format for Doris ARRAY + query_vector_str = "[" + ",".join(str(float(x)) for x in query_vector) + "]" + + with self._get_cursor() as cur: + # Use cosine_distance for similarity search + # Doris supports cosine_distance function for vector similarity + search_sql = f""" + SELECT meta, text, + cosine_distance(embedding, CAST(%s AS ARRAY)) AS distance + FROM `{self.table_name}` + {where_clause} + ORDER BY distance ASC + LIMIT %s + """ + params.insert(0, query_vector_str) + params.append(top_k) + + cur.execute(search_sql, params) + docs = [] + for row in cur.fetchall(): + metadata = json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"] + text = row["text"] + distance = float(row["distance"]) + + # Convert distance to similarity score (1 - distance for cosine) + score = 1.0 - distance + + if score >= score_threshold: + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + + # Sort by score descending + docs.sort(key=lambda d: d.metadata.get("score", 0.0), reverse=True) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """ + Performs BM25 full-text search on document content. + + Args: + query: Search query string + **kwargs: Additional search parameters (top_k, document_ids_filter) + + Returns: + List of Document objects with relevance scores + """ + top_k = int(kwargs.get("top_k", 4)) + document_ids_filter = kwargs.get("document_ids_filter") or [] + + # Build WHERE clause + where_parts = [] + params = [] + + # Text search condition using MATCH_ANY for keyword search + where_parts.append("text MATCH_ANY %s") + params.append(query) + + # Document ID filtering + if document_ids_filter: + placeholders = ",".join(["%s"] * len(document_ids_filter)) + where_parts.append(f"JSON_EXTRACT(meta, '$.document_id') IN ({placeholders})") + params.extend(document_ids_filter) + + where_clause = "WHERE " + " AND ".join(where_parts) + + with self._get_cursor() as cur: + # Use BM25 scoring with score() function + search_sql = f""" + SELECT meta, text, score() AS relevance + FROM `{self.table_name}` + {where_clause} + ORDER BY relevance DESC + LIMIT %s + """ + params.append(top_k) + + cur.execute(search_sql, params) + docs = [] + for row in cur.fetchall(): + metadata = json.loads(row["meta"]) if isinstance(row["meta"], str) else row["meta"] + text = row["text"] + score = float(row["relevance"]) + + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + + return docs + + def delete(self): + """Deletes the entire table from Doris.""" + with self._get_cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS `{self.table_name}`") + + +class DorisVectorFactory(AbstractVectorFactory): + """Factory class for creating DorisVector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> DorisVector: + """ + Initializes a DorisVector instance for the given dataset. + + Uses existing collection name from dataset index structure or generates a new one. + Updates dataset index structure if not already set. + """ + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.DORIS, collection_name)) + + return DorisVector( + collection_name=collection_name, + config=DorisConfig( + host=dify_config.DORIS_HOST or "", + port=dify_config.DORIS_PORT, + user=dify_config.DORIS_USER or "", + password=dify_config.DORIS_PASSWORD or "", + database=dify_config.DORIS_DATABASE or "", + max_connection=dify_config.DORIS_MAX_CONNECTION, + enable_text_search=dify_config.DORIS_ENABLE_TEXT_SEARCH, + text_search_analyzer=dify_config.DORIS_TEXT_SEARCH_ANALYZER or "english", + streamload_port=dify_config.DORIS_STREAMLOAD_PORT, + streamload_scheme=dify_config.DORIS_STREAMLOAD_SCHEME, + streamload_max_filter_ratio=dify_config.DORIS_STREAMLOAD_MAX_FILTER_RATIO, + table_replication_num=dify_config.DORIS_TABLE_REPLICATION_NUM, + table_buckets=dify_config.DORIS_TABLE_BUCKETS, + ), + attributes=attributes, + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b9772b3c084c30..f0e5ac7babc51d 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -143,6 +143,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory return CouchbaseVectorFactory + case VectorType.DORIS: + from core.rag.datasource.vdb.doris.doris_vector import DorisVectorFactory + + return DorisVectorFactory case VectorType.BAIDU: from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index bd99a31446bd9f..e56562902b411a 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -34,3 +34,4 @@ class VectorType(StrEnum): MATRIXONE = "matrixone" CLICKZETTA = "clickzetta" IRIS = "iris" + DORIS = "doris" diff --git a/api/tests/integration_tests/vdb/doris/__init__.py b/api/tests/integration_tests/vdb/doris/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/vdb/doris/test_doris.py b/api/tests/integration_tests/vdb/doris/test_doris.py new file mode 100644 index 00000000000000..85779c3c624396 --- /dev/null +++ b/api/tests/integration_tests/vdb/doris/test_doris.py @@ -0,0 +1,47 @@ +import pytest + +from core.rag.datasource.vdb.doris.doris_vector import ( + DorisConfig, + DorisVector, +) +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +@pytest.fixture +def doris_vector(): + return DorisVector( + "dify_test_collection", + config=DorisConfig( + host="127.0.0.1", + port=9030, + user="root", + password="", + database="dify", + min_connection=1, + max_connection=5, + enable_text_search=True, + text_search_analyzer="english", + streamload_port=8030, + ), + attributes=["doc_id", "dataset_id", "document_id"], + ) + + +class DorisVectorTest(AbstractVectorTest): + def __init__(self, vector: DorisVector): + super().__init__() + self.vector = vector + + def get_ids_by_metadata_field(self): + with pytest.raises(NotImplementedError): + self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + + +def test_doris_vector( + setup_mock_redis, + doris_vector, +): + DorisVectorTest(doris_vector).run_all_tests() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/doris/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/doris/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py new file mode 100644 index 00000000000000..15767a5c7238fe --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py @@ -0,0 +1,619 @@ +""" +Comprehensive unit tests for Apache Doris vector database implementation. + +Tests cover: +- DorisConfig validation +- DorisConnectionPool +- DorisVector CRUD operations +- Vector search and full-text search +- StreamLoad functionality +- Error handling +""" + +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.doris.doris_vector import ( + DorisConfig, + DorisConnectionPool, + DorisVector, + DorisVectorFactory, +) +from core.rag.models.document import Document + + +class TestDorisConfig(unittest.TestCase): + """Tests for DorisConfig validation.""" + + def test_valid_config(self): + """Test that valid config is accepted.""" + config = DorisConfig( + host="localhost", + port=9030, + user="root", + password="password", + database="test_db", + max_connection=5, + ) + assert config.host == "localhost" + assert config.port == 9030 + assert config.enable_text_search is True # default + assert config.text_search_analyzer == "english" # default + assert config.streamload_port == 8030 # default + assert config.streamload_scheme == "http" # default + assert config.streamload_max_filter_ratio == 0.1 # default + assert config.table_replication_num == 1 # default + assert config.table_buckets == 10 # default + + def test_missing_host_raises_error(self): + """Test that missing host raises ValueError.""" + with pytest.raises(ValueError, match="DORIS_HOST is required"): + DorisConfig( + host="", + port=9030, + user="root", + password="password", + database="test_db", + max_connection=5, + ) + + def test_missing_user_raises_error(self): + """Test that missing user raises ValueError.""" + with pytest.raises(ValueError, match="DORIS_USER is required"): + DorisConfig( + host="localhost", + port=9030, + user="", + password="password", + database="test_db", + max_connection=5, + ) + + def test_missing_password_raises_error(self): + """Test that missing password raises ValueError.""" + with pytest.raises(ValueError, match="DORIS_PASSWORD is required"): + DorisConfig( + host="localhost", + port=9030, + user="root", + password="", + database="test_db", + max_connection=5, + ) + + def test_missing_database_raises_error(self): + """Test that missing database raises ValueError.""" + with pytest.raises(ValueError, match="DORIS_DATABASE is required"): + DorisConfig( + host="localhost", + port=9030, + user="root", + password="password", + database="", + max_connection=5, + ) + + def test_custom_text_search_settings(self): + """Test custom text search settings.""" + config = DorisConfig( + host="localhost", + port=9030, + user="root", + password="password", + database="test_db", + max_connection=5, + enable_text_search=False, + text_search_analyzer="chinese", + streamload_port=8030, + ) + assert config.enable_text_search is False + assert config.text_search_analyzer == "chinese" + assert config.streamload_port == 8030 + + def test_invalid_analyzer_raises_error(self): + """Test that invalid text_search_analyzer raises ValueError.""" + with pytest.raises(ValueError, match="must be one of"): + DorisConfig( + host="localhost", + port=9030, + user="root", + password="password", + database="test_db", + max_connection=5, + text_search_analyzer="invalid_analyzer", + ) + + def test_invalid_scheme_raises_error(self): + """Test that invalid streamload_scheme raises ValueError.""" + with pytest.raises(ValueError, match="must be 'http' or 'https'"): + DorisConfig( + host="localhost", + port=9030, + user="root", + password="password", + database="test_db", + max_connection=5, + streamload_scheme="ftp", + ) + + def test_custom_table_properties(self): + """Test custom table replication and bucket settings.""" + config = DorisConfig( + host="localhost", + port=9030, + user="root", + password="password", + database="test_db", + max_connection=5, + table_replication_num=3, + table_buckets=20, + ) + assert config.table_replication_num == 3 + assert config.table_buckets == 20 + + def test_https_scheme(self): + """Test that https scheme is accepted.""" + config = DorisConfig( + host="localhost", + port=9030, + user="root", + password="password", + database="test_db", + max_connection=5, + streamload_scheme="https", + ) + assert config.streamload_scheme == "https" + + +class TestDorisConnectionPool(unittest.TestCase): + """Tests for DorisConnectionPool.""" + + def setUp(self): + self.config = DorisConfig( + host="localhost", + port=9030, + user="root", + password="test_password", + database="test_db", + max_connection=5, + ) + + @patch("core.rag.datasource.vdb.doris.doris_vector.pooling.MySQLConnectionPool") + def test_pool_initialization(self, mock_pool_class): + """Test connection pool initialization.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + pool = DorisConnectionPool(self.config) + + mock_pool_class.assert_called_once() + call_kwargs = mock_pool_class.call_args[1] + assert call_kwargs["host"] == "localhost" + assert call_kwargs["port"] == 9030 + assert call_kwargs["user"] == "root" + assert call_kwargs["database"] == "test_db" + assert call_kwargs["pool_size"] == 5 + + @patch("core.rag.datasource.vdb.doris.doris_vector.pooling.MySQLConnectionPool") + def test_get_connection(self, mock_pool_class): + """Test getting a connection from pool.""" + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_pool_class.return_value = mock_pool + + pool = DorisConnectionPool(self.config) + conn = pool.get_connection() + + assert conn == mock_conn + mock_pool.get_connection.assert_called_once() + + +class TestDorisVector(unittest.TestCase): + """Tests for DorisVector operations.""" + + def setUp(self): + self.config = DorisConfig( + host="localhost", + port=9030, + user="root", + password="test_password", + database="test_db", + max_connection=5, + enable_text_search=True, + text_search_analyzer="english", + streamload_port=8030, + ) + self.collection_name = "test_collection" + self.attributes = [] + + # Sample documents for testing + self.sample_documents = [ + Document( + page_content="This is a test document about AI.", + metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"}, + ), + Document( + page_content="Another document about machine learning.", + metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"}, + ), + ] + + # Sample embeddings (4-dimensional for testing) + self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]] + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + def test_init(self, mock_pool_class): + """Test DorisVector initialization.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + + assert doris_vector.collection_name == self.collection_name + assert doris_vector.table_name == f"embedding_{self.collection_name}" + assert doris_vector.get_type() == "doris" + assert doris_vector._pool is not None + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + def test_get_type(self, mock_pool_class): + """Test get_type returns 'doris'.""" + mock_pool_class.return_value = MagicMock() + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + + assert doris_vector.get_type() == "doris" + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_create_collection_with_vector_index(self, mock_redis, mock_pool_class): + """Test that collection creation includes proper ANN vector index.""" + # Mock Redis operations + mock_redis.lock.return_value.__enter__ = MagicMock() + mock_redis.lock.return_value.__exit__ = MagicMock() + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock() + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + + # Create collection with dimension 4 + dimension = 4 + doris_vector._create_collection(dimension) + + # Verify execute was called with CREATE TABLE and CREATE INDEX + execute_calls = mock_cursor.execute.call_args_list + + # Check that we have at least 2 execute calls (table + vector index) + assert len(execute_calls) >= 2, f"Expected at least 2 execute calls, got {len(execute_calls)}" + + # Extract SQL statements from calls + create_table_sql = execute_calls[0][0][0] + create_index_sql = execute_calls[1][0][0] + + # Verify CREATE TABLE + assert "CREATE TABLE IF NOT EXISTS" in create_table_sql + assert "embedding ARRAY NOT NULL" in create_table_sql + assert "ENGINE=OLAP" in create_table_sql + + # Verify CREATE INDEX with ANN + assert "CREATE INDEX IF NOT EXISTS" in create_index_sql + assert "USING ANN" in create_index_sql + assert '"index_type" = "hnsw"' in create_index_sql + assert '"metric_type" = "l2_distance"' in create_index_sql + assert f'"dim" = "{dimension}"' in create_index_sql + + # Verify Redis cache was set + mock_redis.set.assert_called_once() + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_create_collection_with_text_index(self, mock_redis, mock_pool_class): + """Test that collection creation includes inverted index for text search.""" + # Mock Redis operations + mock_redis.lock.return_value.__enter__ = MagicMock() + mock_redis.lock.return_value.__exit__ = MagicMock() + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock() + + # Create DorisVector with text search enabled + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + + # Create collection + dimension = 4 + doris_vector._create_collection(dimension) + + # Verify execute was called 3 times (table + vector index + text index) + execute_calls = mock_cursor.execute.call_args_list + assert len(execute_calls) >= 3, f"Expected at least 3 execute calls for text search, got {len(execute_calls)}" + + # Extract SQL statements + text_index_sql = execute_calls[2][0][0] + + # Verify CREATE INDEX for text search + assert "CREATE INDEX IF NOT EXISTS" in text_index_sql + assert "USING INVERTED" in text_index_sql + assert '"parser" = "english"' in text_index_sql + assert '"support_phrase" = "true"' in text_index_sql + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_create_collection_uses_cache(self, mock_redis, mock_pool_class): + """Test that collection creation is skipped if already cached.""" + # Mock Redis to return cached value + mock_redis.lock.return_value.__enter__ = MagicMock() + mock_redis.lock.return_value.__exit__ = MagicMock() + mock_redis.get.return_value = 1 # Already cached + + # Mock the connection pool + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + + # Create collection + dimension = 4 + doris_vector._create_collection(dimension) + + # Verify no connection was attempted (cache hit) + mock_pool.get_connection.assert_not_called() + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_text_exists(self, mock_redis, mock_pool_class): + """Test text_exists method.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = {"id": "doc1"} + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + result = doris_vector.text_exists("doc1") + + assert result is True + mock_cursor.execute.assert_called_once() + call_sql = mock_cursor.execute.call_args[0][0] + assert "SELECT id FROM" in call_sql + assert "WHERE id = %s" in call_sql + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_text_not_exists(self, mock_redis, mock_pool_class): + """Test text_exists returns False when document doesn't exist.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = None + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + result = doris_vector.text_exists("nonexistent_doc") + + assert result is False + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_delete_by_ids(self, mock_redis, mock_pool_class): + """Test delete_by_ids method.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + doris_vector.delete_by_ids(["doc1", "doc2"]) + + mock_cursor.execute.assert_called_once() + call_sql = mock_cursor.execute.call_args[0][0] + assert "DELETE FROM" in call_sql + assert "id IN" in call_sql + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_delete_by_metadata_field(self, mock_redis, mock_pool_class): + """Test delete_by_metadata_field method.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + doris_vector.delete_by_metadata_field("document_id", "dataset1") + + mock_cursor.execute.assert_called_once() + call_sql = mock_cursor.execute.call_args[0][0] + assert "DELETE FROM" in call_sql + assert "JSON_EXTRACT(meta, %s)" in call_sql + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_delete_collection(self, mock_redis, mock_pool_class): + """Test delete method drops the table.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + doris_vector.delete() + + mock_cursor.execute.assert_called_once() + call_sql = mock_cursor.execute.call_args[0][0] + assert "DROP TABLE IF EXISTS" in call_sql + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_search_by_vector(self, mock_redis, mock_pool_class): + """Test search_by_vector method.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + # Mock search results (dictionary format since cursor is created with dictionary=True) + mock_cursor.fetchall.return_value = [ + {"meta": '{"doc_id": "doc1", "source": "test"}', "text": "Test content", "distance": 0.05}, + ] + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + results = doris_vector.search_by_vector( + query_vector=[0.1, 0.2, 0.3, 0.4], + top_k=5, + ) + + assert len(results) == 1 + assert results[0].metadata["doc_id"] == "doc1" + mock_cursor.execute.assert_called_once() + call_sql = mock_cursor.execute.call_args[0][0] + assert "cosine_distance" in call_sql + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + def test_search_by_full_text(self, mock_redis, mock_pool_class): + """Test search_by_full_text method.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_pool.get_connection.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + # Mock search results (dictionary format since cursor is created with dictionary=True) + mock_cursor.fetchall.return_value = [ + {"meta": '{"doc_id": "doc1", "source": "test"}', "text": "Test content about AI", "relevance": 2.5}, + ] + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + results = doris_vector.search_by_full_text( + query="AI test", + top_k=5, + ) + + assert len(results) == 1 + assert results[0].metadata["doc_id"] == "doc1" + mock_cursor.execute.assert_called_once() + call_sql = mock_cursor.execute.call_args[0][0] + assert "MATCH_ANY" in call_sql + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.redis_client") + @patch("core.rag.datasource.vdb.doris.doris_vector.httpx.Client") + def test_streamload(self, mock_httpx_client, mock_redis, mock_pool_class): + """Test _streamload method for data insertion.""" + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Mock HTTP client for StreamLoad + mock_client_instance = MagicMock() + mock_httpx_client.return_value.__enter__ = MagicMock(return_value=mock_client_instance) + mock_httpx_client.return_value.__exit__ = MagicMock() + + mock_response = MagicMock() + mock_response.json.return_value = {"Status": "Success", "NumberLoadedRows": 2} + mock_client_instance.put.return_value = mock_response + + doris_vector = DorisVector(self.collection_name, self.config, self.attributes) + + # Prepare test data + data = [ + {"doc_id": "doc1", "text": "content1", "embedding": [0.1, 0.2]}, + {"doc_id": "doc2", "text": "content2", "embedding": [0.3, 0.4]}, + ] + + doris_vector._streamload(data) + + mock_client_instance.put.assert_called_once() + call_kwargs = mock_client_instance.put.call_args + assert "Content-Type" in str(call_kwargs) or call_kwargs[1].get("headers", {}).get("Content-Type") + + +class TestDorisVectorFactory(unittest.TestCase): + """Tests for DorisVectorFactory.""" + + @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") + @patch("core.rag.datasource.vdb.doris.doris_vector.dify_config") + def test_factory_creates_vector_from_dataset(self, mock_dify_config, mock_pool_class): + """Test factory creates DorisVector from dataset with dify_config.""" + mock_pool_class.return_value = MagicMock() + mock_dify_config.DORIS_HOST = "localhost" + mock_dify_config.DORIS_PORT = 9030 + mock_dify_config.DORIS_USER = "root" + mock_dify_config.DORIS_PASSWORD = "password" + mock_dify_config.DORIS_DATABASE = "test_db" + mock_dify_config.DORIS_MAX_CONNECTION = 5 + mock_dify_config.DORIS_ENABLE_TEXT_SEARCH = True + mock_dify_config.DORIS_TEXT_SEARCH_ANALYZER = "english" + mock_dify_config.DORIS_STREAMLOAD_PORT = 8030 + mock_dify_config.DORIS_STREAMLOAD_SCHEME = "http" + mock_dify_config.DORIS_STREAMLOAD_MAX_FILTER_RATIO = 0.1 + mock_dify_config.DORIS_TABLE_REPLICATION_NUM = 1 + mock_dify_config.DORIS_TABLE_BUCKETS = 10 + + # Create mock dataset + mock_dataset = MagicMock() + mock_dataset.id = "test-dataset-id" + mock_dataset.index_struct_dict = None + + # Create mock embeddings + mock_embeddings = MagicMock() + + factory = DorisVectorFactory() + vector = factory.init_vector(mock_dataset, [], mock_embeddings) + + assert isinstance(vector, DorisVector) + assert vector._config.host == "localhost" + assert vector._config.port == 9030 + assert vector._config.database == "test_db" + + +if __name__ == "__main__": + unittest.main() diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index fcb07dda366e3e..6f2efaadeaa9bf 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -174,6 +174,16 @@ x-shared-env: &shared-api-worker-env SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index} + DORIS_HOST: ${DORIS_HOST:-} + DORIS_PORT: ${DORIS_PORT:-9030} + DORIS_USER: ${DORIS_USER:-root} + DORIS_PASSWORD: ${DORIS_PASSWORD:-} + DORIS_DATABASE: ${DORIS_DATABASE:-dify} + DORIS_STREAMLOAD_PORT: ${DORIS_STREAMLOAD_PORT:-8080} + DORIS_MIN_CONNECTION: ${DORIS_MIN_CONNECTION:-1} + DORIS_MAX_CONNECTION: ${DORIS_MAX_CONNECTION:-5} + DORIS_ENABLE_TEXT_SEARCH: ${DORIS_ENABLE_TEXT_SEARCH:-false} + DORIS_TEXT_SEARCH_ANALYZER: ${DORIS_TEXT_SEARCH_ANALYZER:-english} WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} WEAVIATE_GRPC_ENDPOINT: ${WEAVIATE_GRPC_ENDPOINT:-grpc://weaviate:50051} diff --git a/docker/test_doris.py b/docker/test_doris.py new file mode 100755 index 00000000000000..f48150c2f1a6bb --- /dev/null +++ b/docker/test_doris.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +""" +Test script for Apache Doris vector store integration. +Run this after starting Doris with: docker compose -f docker-compose.doris.yaml up -d +""" + +import sys +import time +from pathlib import Path + +# Add API directory to path +api_path = Path(__file__).parent.parent / "api" +sys.path.insert(0, str(api_path)) + +from core.rag.datasource.vdb.doris.doris_vector import DorisConfig, DorisVector +from core.rag.models.document import Document + + +def test_doris_connection(): + """Test basic Doris connection.""" + print("=" * 60) + print("Testing Doris Vector Store Connection") + print("=" * 60) + + config = DorisConfig( + host="localhost", + port=9030, + user="root", + password="", + database="dify", + min_connection=1, + max_connection=5, + enable_text_search=True, + text_search_analyzer="english", + streamload_port=8030, + ) + + try: + # Test 1: Create vector store instance + print("\n1. Creating DorisVector instance...") + vector_store = DorisVector( + collection_name="test_collection", + config=config, + attributes=[] + ) + print(" ✓ DorisVector instance created") + + # Test 2: Create collection with sample data + print("\n2. Creating collection with sample documents...") + sample_docs = [ + Document( + page_content="Apache Doris is a high-performance analytical database.", + metadata={"doc_id": "doc1", "document_id": "test1", "source": "test"} + ), + Document( + page_content="Vector search enables semantic similarity matching.", + metadata={"doc_id": "doc2", "document_id": "test1", "source": "test"} + ), + Document( + page_content="HNSW is an efficient algorithm for approximate nearest neighbor search.", + metadata={"doc_id": "doc3", "document_id": "test1", "source": "test"} + ), + ] + + # Sample embeddings (4-dimensional for testing) + sample_embeddings = [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.2, 0.3, 0.4, 0.5], + ] + + ids = vector_store.create(sample_docs, sample_embeddings) + print(f" ✓ Created collection with {len(ids)} documents") + print(f" Document IDs: {ids}") + + # Test 3: Vector search + print("\n3. Testing vector search...") + query_vector = [0.15, 0.25, 0.35, 0.45] + results = vector_store.search_by_vector(query_vector, top_k=2) + print(f" ✓ Found {len(results)} results") + for i, doc in enumerate(results, 1): + score = doc.metadata.get("score", 0) + print(f" Result {i}: score={score:.4f}, text={doc.page_content[:50]}...") + + # Test 4: Full-text search + print("\n4. Testing full-text search...") + text_results = vector_store.search_by_full_text("vector search", top_k=2) + print(f" ✓ Found {len(text_results)} results") + for i, doc in enumerate(text_results, 1): + score = doc.metadata.get("score", 0) + print(f" Result {i}: score={score:.4f}, text={doc.page_content[:50]}...") + + # Test 5: Check if document exists + print("\n5. Testing document existence check...") + exists = vector_store.text_exists("doc1") + print(f" ✓ Document 'doc1' exists: {exists}") + + # Test 6: Delete by IDs + print("\n6. Testing delete by IDs...") + vector_store.delete_by_ids(["doc1"]) + exists_after = vector_store.text_exists("doc1") + print(f" ✓ Document deleted, exists now: {exists_after}") + + # Test 7: Cleanup - delete collection + print("\n7. Cleaning up - deleting collection...") + vector_store.delete() + print(" ✓ Collection deleted") + + print("\n" + "=" * 60) + print("All tests passed! ✓") + print("=" * 60) + return True + + except Exception as e: + print(f"\n✗ Test failed with error: {e}") + import traceback + traceback.print_exc() + return False + + +def wait_for_doris(): + """Wait for Doris to be ready.""" + import mysql.connector + + print("Waiting for Doris to be ready...") + max_retries = 30 + retry_count = 0 + + while retry_count < max_retries: + try: + conn = mysql.connector.connect( + host="localhost", + port=9030, + user="root", + password="", + connect_timeout=5 + ) + cursor = conn.cursor() + cursor.execute("SHOW DATABASES") + cursor.close() + conn.close() + print("✓ Doris is ready!") + return True + except Exception: + retry_count += 1 + if retry_count < max_retries: + print(f" Waiting... ({retry_count}/{max_retries})") + time.sleep(2) + else: + print("✗ Timeout waiting for Doris") + return False + + +def create_database(): + """Create the dify database if it doesn't exist.""" + import mysql.connector + + try: + conn = mysql.connector.connect( + host="localhost", + port=9030, + user="root", + password="" + ) + cursor = conn.cursor() + cursor.execute("CREATE DATABASE IF NOT EXISTS dify") + cursor.close() + conn.close() + print("✓ Database 'dify' ready") + return True + except Exception as e: + print(f"✗ Failed to create database: {e}") + return False + + +if __name__ == "__main__": + print("\nDoris Vector Store Integration Test") + print("=" * 60) + + # Wait for Doris + if not wait_for_doris(): + sys.exit(1) + + # Create database + if not create_database(): + sys.exit(1) + + # Run tests + success = test_doris_connection() + sys.exit(0 if success else 1) From e111491f0dd65c2ef49e38a6b7a407d76491a29b Mon Sep 17 00:00:00 2001 From: Yongqiang YANG Date: Mon, 12 Jan 2026 11:58:34 -0800 Subject: [PATCH 2/3] fix: resolve type errors and improve Doris vector store reliability - Add type annotations to fix mypy errors (pool config dict, params lists) - Add USE database statement in _get_cursor to ensure database context - Add _wait_for_table_normal_state method to wait for schema changes before creating text index (fixes index creation race condition) - Extend Redis lock timeout to accommodate schema change waiting - Update unit tests to account for new USE database statement --- .../rag/datasource/vdb/doris/doris_vector.py | 47 ++++++++++++-- .../datasource/vdb/doris/test_doris_vector.py | 63 +++++++++++-------- 2 files changed, 80 insertions(+), 30 deletions(-) diff --git a/api/core/rag/datasource/vdb/doris/doris_vector.py b/api/core/rag/datasource/vdb/doris/doris_vector.py index d408712ed707d7..f9df2f09091ee2 100644 --- a/api/core/rag/datasource/vdb/doris/doris_vector.py +++ b/api/core/rag/datasource/vdb/doris/doris_vector.py @@ -12,6 +12,7 @@ import hashlib import json import logging +import time import uuid from contextlib import contextmanager from typing import Any @@ -81,7 +82,7 @@ class DorisConnectionPool: def __init__(self, config: DorisConfig) -> None: self.config = config - self._pool_config = { + self._pool_config: dict[str, Any] = { "pool_name": "doris_pool", "pool_size": config.max_connection, "pool_reset_session": True, @@ -137,6 +138,8 @@ def _get_cursor(self): conn = self._pool.get_connection() cur = conn.cursor(dictionary=True) try: + # Ensure database is selected (pool connections may lose context) + cur.execute(f"USE `{self._config.database}`") yield cur conn.commit() except Exception: @@ -159,6 +162,38 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self._create_collection(dimension) return self.add_texts(texts, embeddings) + def _wait_for_table_normal_state(self, cursor, max_wait_seconds: int = 60) -> bool: + """ + Wait for the table to return to NORMAL state after schema changes. + + Args: + cursor: Database cursor + max_wait_seconds: Maximum time to wait in seconds + + Returns: + True if table is in NORMAL state, False if timeout + """ + start_time = time.time() + while time.time() - start_time < max_wait_seconds: + try: + cursor.execute( + f"SHOW ALTER TABLE COLUMN FROM `{self._config.database}` " + f"WHERE TableName = '{self.table_name}' ORDER BY CreateTime DESC LIMIT 1" + ) + result = cursor.fetchone() + if result is None: + # No schema change in progress + return True + # Check if state is FINISHED or CANCELLED + state = result.get("State", "") if isinstance(result, dict) else "" + if state in ("FINISHED", "CANCELLED", ""): + return True + except Exception: + # If query fails, assume table is ready + return True + time.sleep(1) + return False + def _create_collection(self, dimension: int): """ Creates the Doris table with required schema if it doesn't exist. @@ -166,7 +201,7 @@ def _create_collection(self, dimension: int): Uses Redis locking to prevent concurrent creation attempts. """ lock_name = f"vector_indexing_lock_{self._collection_name}" - with redis_client.lock(lock_name, timeout=20): + with redis_client.lock(lock_name, timeout=120): cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(cache_key): return @@ -207,6 +242,8 @@ def _create_collection(self, dimension: int): # Create inverted index for full-text search if enabled if self._config.enable_text_search: + # Wait for vector index creation to complete before creating text index + self._wait_for_table_normal_state(cur, max_wait_seconds=60) try: analyzer = self._config.text_search_analyzer or "english" create_text_index_sql = f""" @@ -460,7 +497,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc # Build WHERE clause for document filtering where_clause = "" - params = [] + params: list[Any] = [] if document_ids_filter: placeholders = ",".join(["%s"] * len(document_ids_filter)) where_clause = f"WHERE JSON_EXTRACT(meta, '$.document_id') IN ({placeholders})" @@ -516,8 +553,8 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: document_ids_filter = kwargs.get("document_ids_filter") or [] # Build WHERE clause - where_parts = [] - params = [] + where_parts: list[str] = [] + params: list[Any] = [] # Text search condition using MATCH_ANY for keyword search where_parts.append("text MATCH_ANY %s") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py index 15767a5c7238fe..3f5d37e83a99d5 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/doris/test_doris_vector.py @@ -296,15 +296,15 @@ def test_create_collection_with_vector_index(self, mock_redis, mock_pool_class): dimension = 4 doris_vector._create_collection(dimension) - # Verify execute was called with CREATE TABLE and CREATE INDEX + # Verify execute was called with USE, CREATE TABLE and CREATE INDEX execute_calls = mock_cursor.execute.call_args_list - # Check that we have at least 2 execute calls (table + vector index) - assert len(execute_calls) >= 2, f"Expected at least 2 execute calls, got {len(execute_calls)}" + # Check that we have at least 3 execute calls (USE + table + vector index) + assert len(execute_calls) >= 3, f"Expected at least 3 execute calls, got {len(execute_calls)}" - # Extract SQL statements from calls - create_table_sql = execute_calls[0][0][0] - create_index_sql = execute_calls[1][0][0] + # Extract SQL statements from calls (first call is USE database) + create_table_sql = execute_calls[1][0][0] + create_index_sql = execute_calls[2][0][0] # Verify CREATE TABLE assert "CREATE TABLE IF NOT EXISTS" in create_table_sql @@ -352,14 +352,21 @@ def test_create_collection_with_text_index(self, mock_redis, mock_pool_class): dimension = 4 doris_vector._create_collection(dimension) - # Verify execute was called 3 times (table + vector index + text index) + # Verify execute was called at least 5 times: + # USE + table + vector index + SHOW ALTER (wait) + text index execute_calls = mock_cursor.execute.call_args_list - assert len(execute_calls) >= 3, f"Expected at least 3 execute calls for text search, got {len(execute_calls)}" - - # Extract SQL statements - text_index_sql = execute_calls[2][0][0] - - # Verify CREATE INDEX for text search + assert len(execute_calls) >= 5, f"Expected at least 5 execute calls for text search, got {len(execute_calls)}" + + # Find the text index SQL (should be the last CREATE INDEX with INVERTED) + text_index_sql = None + for call in execute_calls: + sql = call[0][0] + if "USING INVERTED" in sql: + text_index_sql = sql + break + + # Verify CREATE INDEX for text search was found + assert text_index_sql is not None, "Text index SQL not found" assert "CREATE INDEX IF NOT EXISTS" in text_index_sql assert "USING INVERTED" in text_index_sql assert '"parser" = "english"' in text_index_sql @@ -404,8 +411,9 @@ def test_text_exists(self, mock_redis, mock_pool_class): result = doris_vector.text_exists("doc1") assert result is True - mock_cursor.execute.assert_called_once() - call_sql = mock_cursor.execute.call_args[0][0] + # Execute is called twice: USE database + SELECT query + assert mock_cursor.execute.call_count == 2 + call_sql = mock_cursor.execute.call_args_list[1][0][0] assert "SELECT id FROM" in call_sql assert "WHERE id = %s" in call_sql @@ -442,8 +450,9 @@ def test_delete_by_ids(self, mock_redis, mock_pool_class): doris_vector = DorisVector(self.collection_name, self.config, self.attributes) doris_vector.delete_by_ids(["doc1", "doc2"]) - mock_cursor.execute.assert_called_once() - call_sql = mock_cursor.execute.call_args[0][0] + # Execute is called twice: USE database + DELETE query + assert mock_cursor.execute.call_count == 2 + call_sql = mock_cursor.execute.call_args_list[1][0][0] assert "DELETE FROM" in call_sql assert "id IN" in call_sql @@ -462,8 +471,9 @@ def test_delete_by_metadata_field(self, mock_redis, mock_pool_class): doris_vector = DorisVector(self.collection_name, self.config, self.attributes) doris_vector.delete_by_metadata_field("document_id", "dataset1") - mock_cursor.execute.assert_called_once() - call_sql = mock_cursor.execute.call_args[0][0] + # Execute is called twice: USE database + DELETE query + assert mock_cursor.execute.call_count == 2 + call_sql = mock_cursor.execute.call_args_list[1][0][0] assert "DELETE FROM" in call_sql assert "JSON_EXTRACT(meta, %s)" in call_sql @@ -482,8 +492,9 @@ def test_delete_collection(self, mock_redis, mock_pool_class): doris_vector = DorisVector(self.collection_name, self.config, self.attributes) doris_vector.delete() - mock_cursor.execute.assert_called_once() - call_sql = mock_cursor.execute.call_args[0][0] + # Execute is called twice: USE database + DROP TABLE query + assert mock_cursor.execute.call_count == 2 + call_sql = mock_cursor.execute.call_args_list[1][0][0] assert "DROP TABLE IF EXISTS" in call_sql @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") @@ -511,8 +522,9 @@ def test_search_by_vector(self, mock_redis, mock_pool_class): assert len(results) == 1 assert results[0].metadata["doc_id"] == "doc1" - mock_cursor.execute.assert_called_once() - call_sql = mock_cursor.execute.call_args[0][0] + # Execute is called twice: USE database + SELECT query + assert mock_cursor.execute.call_count == 2 + call_sql = mock_cursor.execute.call_args_list[1][0][0] assert "cosine_distance" in call_sql @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") @@ -540,8 +552,9 @@ def test_search_by_full_text(self, mock_redis, mock_pool_class): assert len(results) == 1 assert results[0].metadata["doc_id"] == "doc1" - mock_cursor.execute.assert_called_once() - call_sql = mock_cursor.execute.call_args[0][0] + # Execute is called twice: USE database + SELECT query + assert mock_cursor.execute.call_count == 2 + call_sql = mock_cursor.execute.call_args_list[1][0][0] assert "MATCH_ANY" in call_sql @patch("core.rag.datasource.vdb.doris.doris_vector.DorisConnectionPool") From 413534baaa9eb143cb9f0062763126594c019ede Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 02:02:15 +0000 Subject: [PATCH 3/3] [autofix.ci] apply automated fixes --- DORIS_SETUP.md | 47 +++++++++++++++++++++++++------------- docker/docker-compose.yaml | 10 -------- docker/test_doris.py | 28 ++++++----------------- 3 files changed, 38 insertions(+), 47 deletions(-) diff --git a/DORIS_SETUP.md b/DORIS_SETUP.md index 5651b4657ebfdc..ba5e78e61a98a6 100644 --- a/DORIS_SETUP.md +++ b/DORIS_SETUP.md @@ -3,11 +3,13 @@ ## Prerequisites 1. **Apache Doris Installed and Running** + - Doris FE (Frontend) running on port 8030 (HTTP) and 9030 (MySQL protocol) - Doris BE (Backend) started and connected to FE - Ensure Doris version >= 2.0 (supports vector search and text search) -2. **Create Database** +1. **Create Database** + ```sql CREATE DATABASE IF NOT EXISTS dify; ``` @@ -19,14 +21,16 @@ 1. **Edit `.env` file** (in the `docker` directory) If the file doesn't exist, create it from the example file: + ```bash cd docker cp .env.example .env ``` -2. **Set Vector Store to Doris** +1. **Set Vector Store to Doris** Add or modify the following configuration in the `.env` file: + ```bash # Vector Store configuration VECTOR_STORE=doris @@ -55,7 +59,8 @@ DORIS_TEXT_SEARCH_ANALYZER=english # Text analyzer: english, chinese, standard, unicode, default (default english) ``` -3. **Start Services** +1. **Start Services** + ```bash cd docker docker compose up -d @@ -66,6 +71,7 @@ 1. **Set Environment Variables** Before running Dify API, set the following environment variables: + ```bash export VECTOR_STORE=doris export DORIS_HOST=localhost @@ -78,7 +84,8 @@ Or set them in a `.env` file (if using python-dotenv) -2. **Run API Service** +1. **Run API Service** + ```bash cd api uv run --project api flask run @@ -89,6 +96,7 @@ ### 1. Check Doris Connection Connect to Doris using MySQL client: + ```bash mysql -h your-doris-host -P 9030 -u root -p ``` @@ -96,6 +104,7 @@ mysql -h your-doris-host -P 9030 -u root -p ### 2. Test Doris HTTP Endpoint Check if Doris FE HTTP endpoint is accessible: + ```bash curl http://your-doris-host:8030/api/v1/health ``` @@ -103,9 +112,9 @@ curl http://your-doris-host:8030/api/v1/health ### 3. Create Dataset in Dify 1. Login to Dify Web interface -2. Create a new dataset -3. Upload documents for indexing -4. Check if corresponding tables are created in Doris database: +1. Create a new dataset +1. Upload documents for indexing +1. Check if corresponding tables are created in Doris database: ```sql USE dify; SHOW TABLES LIKE 'embedding_%'; @@ -126,37 +135,43 @@ Doris Vector Store supports the following features: ### Issue: Connection Failed **Check:** + 1. Is Doris FE running? -2. Are ports correct (MySQL: 9030, HTTP: 8030)? -3. Are username and password correct? -4. Does firewall allow the connection? +1. Are ports correct (MySQL: 9030, HTTP: 8030)? +1. Are username and password correct? +1. Does firewall allow the connection? ### Issue: StreamLoad Failed **Check:** + 1. Is Doris HTTP port (8030) accessible? -2. Does the user have StreamLoad permissions? -3. Check error messages in Doris FE logs +1. Does the user have StreamLoad permissions? +1. Check error messages in Doris FE logs ### Issue: Table Creation Failed **Check:** + 1. Does the database exist? -2. Does the user have CREATE TABLE permissions? -3. Check error messages in Doris logs +1. Does the user have CREATE TABLE permissions? +1. Check error messages in Doris logs ## Performance Optimization Recommendations 1. **Adjust Connection Pool Size** + - Adjust `DORIS_MAX_CONNECTION` based on concurrent request volume - Recommended value: concurrent requests + 2 -2. **Text Analyzer Selection** +1. **Text Analyzer Selection** + - English content: use `english` - Chinese content: use `chinese` - Multilingual: use `standard` -3. **Batch Insertion** +1. **Batch Insertion** + - StreamLoad automatically processes data in batches - Recommended: 100-1000 records per insertion diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 6f2efaadeaa9bf..fcb07dda366e3e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -174,16 +174,6 @@ x-shared-env: &shared-api-worker-env SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index} - DORIS_HOST: ${DORIS_HOST:-} - DORIS_PORT: ${DORIS_PORT:-9030} - DORIS_USER: ${DORIS_USER:-root} - DORIS_PASSWORD: ${DORIS_PASSWORD:-} - DORIS_DATABASE: ${DORIS_DATABASE:-dify} - DORIS_STREAMLOAD_PORT: ${DORIS_STREAMLOAD_PORT:-8080} - DORIS_MIN_CONNECTION: ${DORIS_MIN_CONNECTION:-1} - DORIS_MAX_CONNECTION: ${DORIS_MAX_CONNECTION:-5} - DORIS_ENABLE_TEXT_SEARCH: ${DORIS_ENABLE_TEXT_SEARCH:-false} - DORIS_TEXT_SEARCH_ANALYZER: ${DORIS_TEXT_SEARCH_ANALYZER:-english} WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} WEAVIATE_GRPC_ENDPOINT: ${WEAVIATE_GRPC_ENDPOINT:-grpc://weaviate:50051} diff --git a/docker/test_doris.py b/docker/test_doris.py index f48150c2f1a6bb..d1fa2b29f243bb 100755 --- a/docker/test_doris.py +++ b/docker/test_doris.py @@ -38,11 +38,7 @@ def test_doris_connection(): try: # Test 1: Create vector store instance print("\n1. Creating DorisVector instance...") - vector_store = DorisVector( - collection_name="test_collection", - config=config, - attributes=[] - ) + vector_store = DorisVector(collection_name="test_collection", config=config, attributes=[]) print(" ✓ DorisVector instance created") # Test 2: Create collection with sample data @@ -50,15 +46,15 @@ def test_doris_connection(): sample_docs = [ Document( page_content="Apache Doris is a high-performance analytical database.", - metadata={"doc_id": "doc1", "document_id": "test1", "source": "test"} + metadata={"doc_id": "doc1", "document_id": "test1", "source": "test"}, ), Document( page_content="Vector search enables semantic similarity matching.", - metadata={"doc_id": "doc2", "document_id": "test1", "source": "test"} + metadata={"doc_id": "doc2", "document_id": "test1", "source": "test"}, ), Document( page_content="HNSW is an efficient algorithm for approximate nearest neighbor search.", - metadata={"doc_id": "doc3", "document_id": "test1", "source": "test"} + metadata={"doc_id": "doc3", "document_id": "test1", "source": "test"}, ), ] @@ -114,6 +110,7 @@ def test_doris_connection(): except Exception as e: print(f"\n✗ Test failed with error: {e}") import traceback + traceback.print_exc() return False @@ -128,13 +125,7 @@ def wait_for_doris(): while retry_count < max_retries: try: - conn = mysql.connector.connect( - host="localhost", - port=9030, - user="root", - password="", - connect_timeout=5 - ) + conn = mysql.connector.connect(host="localhost", port=9030, user="root", password="", connect_timeout=5) cursor = conn.cursor() cursor.execute("SHOW DATABASES") cursor.close() @@ -156,12 +147,7 @@ def create_database(): import mysql.connector try: - conn = mysql.connector.connect( - host="localhost", - port=9030, - user="root", - password="" - ) + conn = mysql.connector.connect(host="localhost", port=9030, user="root", password="") cursor = conn.cursor() cursor.execute("CREATE DATABASE IF NOT EXISTS dify") cursor.close()