diff --git a/wren-ai-service/poetry.lock b/wren-ai-service/poetry.lock index f5b7ff77fc..7779e44b83 100644 --- a/wren-ai-service/poetry.lock +++ b/wren-ai-service/poetry.lock @@ -5115,27 +5115,27 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} [[package]] name = "qdrant-client" -version = "1.13.2" +version = "1.11.0" description = "Client library for the Qdrant vector search engine" optional = false -python-versions = ">=3.9" +python-versions = ">=3.8" files = [ - {file = "qdrant_client-1.13.2-py3-none-any.whl", hash = "sha256:db97e759bd3f8d483a383984ba4c2a158eef56f2188d83df7771591d43de2201"}, - {file = "qdrant_client-1.13.2.tar.gz", hash = "sha256:c8cce87ce67b006f49430a050a35c85b78e3b896c0c756dafc13bdeca543ec13"}, + {file = "qdrant_client-1.11.0-py3-none-any.whl", hash = "sha256:1f574ccebb91c0bc8a620c9a41a5a010084fbc4d8c6f1cd0ab7b2eeb97336fc0"}, + {file = "qdrant_client-1.11.0.tar.gz", hash = "sha256:7c1d4d7a96cfd1ee0cde2a21c607e9df86bcca795ad8d1fd274d295ab64b8458"}, ] [package.dependencies] grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} -numpy = {version = ">=1.26", markers = "python_version >= \"3.12\" and python_version < \"3.13\""} +numpy = {version = ">=1.26", markers = "python_version >= \"3.12\""} portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" urllib3 = ">=1.26.14,<3" [package.extras] -fastembed = ["fastembed (==0.5.1)"] -fastembed-gpu = ["fastembed-gpu (==0.5.1)"] +fastembed = ["fastembed (==0.3.4)"] +fastembed-gpu = ["fastembed-gpu (==0.3.4)"] [[package]] name = "qdrant-haystack" @@ -7110,4 +7110,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.12.*, <3.13" -content-hash = "6029475d60f6588b5fbf8c07dd8ce3179115fcf7b018ea6d1e88adf561d49e73" +content-hash = "4523687d231739fa1265a084bf32788e074219983517c42a2d0fe35441e30ced" diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index 23918202f6..31c274ecc4 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -35,6 +35,7 @@ tiktoken = "^0.8.0" jsonschema = "^4.23.0" litellm = "^1.60.5" boto3 = "^1.35.90" +qdrant-client = "==1.11.0" [tool.poetry.group.dev.dependencies] pre-commit = "^3.7.1" diff --git a/wren-ai-service/src/pipelines/generation/sql_answer.py b/wren-ai-service/src/pipelines/generation/sql_answer.py index c1337bf557..72c0899648 100644 --- a/wren-ai-service/src/pipelines/generation/sql_answer.py +++ b/wren-ai-service/src/pipelines/generation/sql_answer.py @@ -27,10 +27,11 @@ 4. Generate a concise and clear answer in string format to answerthe user's question based on the data and sql. 5. If answer is in list format, only list top few examples, and tell users there are more results omitted. 6. Answer must be in the same language user specified. +7. Do not include ```markdown or ``` in the answer. ### OUTPUT FORMAT -Please provide your response in proper Markdown format. +Please provide your response in proper Markdown stringformat. """ sql_to_answer_user_prompt_template = """ diff --git a/wren-ai-service/src/pipelines/indexing/table_description.py b/wren-ai-service/src/pipelines/indexing/table_description.py index 2255bd66c6..0b1a6c9ff0 100644 --- a/wren-ai-service/src/pipelines/indexing/table_description.py +++ b/wren-ai-service/src/pipelines/indexing/table_description.py @@ -31,6 +31,7 @@ def _additional_meta() -> Dict[str, Any]: "id": str(uuid.uuid4()), "meta": { "type": "TABLE_DESCRIPTION", + "name": chunk["name"], **_additional_meta(), }, "content": str(chunk), @@ -53,6 +54,7 @@ def _structure_data(mdl_type: str, payload: Dict[str, Any]) -> Dict[str, Any]: return { "mdl_type": mdl_type, "name": payload.get("name"), + "columns": [column["name"] for column in payload.get("columns", [])], "properties": payload.get("properties", {}), } @@ -65,8 +67,8 @@ def _structure_data(mdl_type: str, payload: Dict[str, Any]) -> Dict[str, Any]: return [ { "name": resource["name"], - "mdl_type": resource["mdl_type"], "description": resource["properties"].get("description", ""), + "columns": ", ".join(resource["columns"]), } for resource in resources if resource["name"] is not None diff --git a/wren-ai-service/src/pipelines/retrieval/retrieval.py b/wren-ai-service/src/pipelines/retrieval/retrieval.py index 50ee7bdcae..fbfda607bc 100644 --- a/wren-ai-service/src/pipelines/retrieval/retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/retrieval.py @@ -226,19 +226,37 @@ def check_using_db_schemas_without_pruning( for table_schema in construct_db_schemas: if table_schema["type"] == "TABLE": ddl, _has_calculated_field = build_table_ddl(table_schema) - retrieval_results.append(ddl) + retrieval_results.append( + { + "table_name": table_schema["name"], + "table_ddl": ddl, + } + ) has_calculated_field = has_calculated_field or _has_calculated_field for document in dbschema_retrieval: content = ast.literal_eval(document.content) if content["type"] == "METRIC": - retrieval_results.append(_build_metric_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_metric_ddl(content), + } + ) has_metric = True elif content["type"] == "VIEW": - retrieval_results.append(_build_view_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_view_ddl(content), + } + ) - _token_count = len(encoding.encode(" ".join(retrieval_results))) + table_ddls = [ + retrieval_result["table_ddl"] for retrieval_result in retrieval_results + ] + _token_count = len(encoding.encode(" ".join(table_ddls))) if _token_count > 100_000 or not allow_using_db_schemas_without_pruning: return { "db_schemas": [], @@ -328,17 +346,32 @@ def construct_retrieval_results( tables=tables, ) has_calculated_field = has_calculated_field or _has_calculated_field - retrieval_results.append(ddl) + retrieval_results.append( + { + "table_name": table_schema["name"], + "table_ddl": ddl, + } + ) for document in dbschema_retrieval: if document.meta["name"] in columns_and_tables_needed: content = ast.literal_eval(document.content) if content["type"] == "METRIC": - retrieval_results.append(_build_metric_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_metric_ddl(content), + } + ) has_metric = True elif content["type"] == "VIEW": - retrieval_results.append(_build_view_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_view_ddl(content), + } + ) return { "retrieval_results": retrieval_results, diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 13b74e63f2..b3f427ddeb 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -94,6 +94,7 @@ class AskResultResponse(BaseModel): intent_reasoning: Optional[str] = None sql_generation_reasoning: Optional[str] = None type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None + retrieved_tables: Optional[List[str]] = None response: Optional[List[AskResult]] = None error: Optional[AskError] = None @@ -310,6 +311,8 @@ async def ask( "construct_retrieval_results", {} ) documents = _retrieval_result.get("retrieval_results", []) + table_names = [document.get("table_name") for document in documents] + table_ddls = [document.get("table_ddl") for document in documents] if not documents: logger.exception(f"ask pipeline - NO_RELEVANT_DATA: {user_query}") @@ -338,6 +341,7 @@ async def ask( type="TEXT_TO_SQL", rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieved_tables=table_names, ) sql_samples = ( @@ -351,7 +355,7 @@ async def ask( ( await self._pipelines["sql_generation_reasoning"].run( query=user_query, - contexts=documents, + contexts=table_ddls, sql_samples=sql_samples, configuration=ask_request.configurations, ) @@ -365,6 +369,7 @@ async def ask( type="TEXT_TO_SQL", rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, ) @@ -374,6 +379,7 @@ async def ask( type="TEXT_TO_SQL", rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, ) @@ -387,7 +393,7 @@ async def ask( "followup_sql_generation" ].run( query=user_query, - contexts=documents, + contexts=table_ddls, sql_generation_reasoning=sql_generation_reasoning, history=ask_request.history, project_id=ask_request.project_id, @@ -401,7 +407,7 @@ async def ask( "sql_generation" ].run( query=user_query, - contexts=documents, + contexts=table_ddls, sql_generation_reasoning=sql_generation_reasoning, project_id=ask_request.project_id, configuration=ask_request.configurations, @@ -431,12 +437,13 @@ async def ask( type="TEXT_TO_SQL", rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, ) sql_correction_results = await self._pipelines[ "sql_correction" ].run( - contexts=documents, + contexts=table_ddls, invalid_generation_results=failed_dry_run_results, project_id=ask_request.project_id, ) @@ -468,6 +475,7 @@ async def ask( response=api_results, rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, ) results["ask_result"] = api_results @@ -484,6 +492,7 @@ async def ask( ), rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" diff --git a/wren-ai-service/src/web/v1/services/question_recommendation.py b/wren-ai-service/src/web/v1/services/question_recommendation.py index 4b6ef9b2b9..e4524fd501 100644 --- a/wren-ai-service/src/web/v1/services/question_recommendation.py +++ b/wren-ai-service/src/web/v1/services/question_recommendation.py @@ -76,6 +76,7 @@ async def _validate_question( ) _retrieval_result = retrieval_result.get("construct_retrieval_results", {}) documents = _retrieval_result.get("retrieval_results", []) + table_ddls = [document.get("table_ddl") for document in documents] has_calculated_field = _retrieval_result.get("has_calculated_field", False) has_metric = _retrieval_result.get("has_metric", False) @@ -83,7 +84,7 @@ async def _validate_question( ( await self._pipelines["sql_generation_reasoning"].run( query=candidate["question"], - contexts=documents, + contexts=table_ddls, configuration=configuration, ) ) @@ -93,7 +94,7 @@ async def _validate_question( generated_sql = await self._pipelines["sql_generation"].run( query=candidate["question"], - contexts=documents, + contexts=table_ddls, sql_generation_reasoning=sql_generation_reasoning, configuration=configuration, project_id=project_id, diff --git a/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py b/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py index fbe1cdb463..7def966fdb 100644 --- a/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py +++ b/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py @@ -37,12 +37,12 @@ def test_single_table_description(): assert len(actual["documents"]) == 1 document: Document = actual["documents"][0] - assert document.meta == {"type": "TABLE_DESCRIPTION"} + assert document.meta == {"type": "TABLE_DESCRIPTION", "name": "user"} assert document.content == str( { "name": "user", - "mdl_type": "MODEL", "description": "A table containing user information.", + "columns": "", } ) @@ -71,22 +71,23 @@ def test_multiple_table_descriptions(): document_1: Document = actual["documents"][0] assert document_1.meta == { "type": "TABLE_DESCRIPTION", + "name": "user", } assert document_1.content == str( { "name": "user", - "mdl_type": "MODEL", "description": "A table containing user information.", + "columns": "", } ) document_2: Document = actual["documents"][1] - assert document_2.meta == {"type": "TABLE_DESCRIPTION"} + assert document_2.meta == {"type": "TABLE_DESCRIPTION", "name": "order"} assert document_2.content == str( { "name": "order", - "mdl_type": "MODEL", "description": "A table containing order details.", + "columns": "", } ) @@ -121,10 +122,8 @@ def test_table_description_missing_description(): assert len(actual["documents"]) == 1 document: Document = actual["documents"][0] - assert document.meta == {"type": "TABLE_DESCRIPTION"} - assert document.content == str( - {"name": "user", "mdl_type": "MODEL", "description": ""} - ) + assert document.meta == {"type": "TABLE_DESCRIPTION", "name": "user"} + assert document.content == str({"name": "user", "description": "", "columns": ""}) @pytest.mark.asyncio