From 719f18861b9c79d459c1f53f8efb269467e15a72 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Fri, 27 Dec 2024 14:55:19 +0800 Subject: [PATCH] update --- wren-ai-service/src/web/v1/services/ask.py | 281 +++++++++++---------- 1 file changed, 142 insertions(+), 139 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index c8e22018b2..76fca58eb2 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -142,6 +142,7 @@ async def ask( query_id = ask_request.query_id rephrased_question = None intent_reasoning = None + api_results = [] try: # ask status can be understanding, searching, generating, finished, failed, stopped @@ -151,60 +152,85 @@ async def ask( status="understanding", ) - intent_classification_result = ( - await self._pipelines["intent_classification"].run( - query=ask_request.query, - history=ask_request.history, - id=ask_request.project_id, - ) - ).get("post_process", {}) - intent = intent_classification_result.get("intent") - rephrased_question = intent_classification_result.get( - "rephrased_question" + historical_question = await self._pipelines["historical_question"].run( + query=ask_request.query, + id=ask_request.project_id, ) - intent_reasoning = intent_classification_result.get("reasoning") - user_query = ( - ask_request.query if not rephrased_question else rephrased_question - ) + # we only return top 1 result + historical_question_result = historical_question.get( + "formatted_output", {} + ).get("documents", [])[:1] - if intent == "MISLEADING_QUERY": - self._ask_results[query_id] = AskResultResponse( - status="finished", - type="MISLEADING_QUERY", - rephrased_question=rephrased_question, - intent_reasoning=intent_reasoning, - ) - results["metadata"]["type"] = "MISLEADING_QUERY" - return results - elif intent == "GENERAL": - asyncio.create_task( - self._pipelines["data_assistance"].run( - query=user_query, + if historical_question_result: + api_results = [ + AskResult( + **{ + "sql": result.get("statement"), + "type": "view", + "viewId": result.get("viewId"), + } + ) + for result in historical_question_result + ] + else: + intent_classification_result = ( + await self._pipelines["intent_classification"].run( + query=ask_request.query, history=ask_request.history, - db_schemas=intent_classification_result.get("db_schemas"), - language=ask_request.configurations.language, - query_id=ask_request.query_id, + id=ask_request.project_id, ) + ).get("post_process", {}) + intent = intent_classification_result.get("intent") + rephrased_question = intent_classification_result.get( + "rephrased_question" ) + intent_reasoning = intent_classification_result.get("reasoning") - self._ask_results[query_id] = AskResultResponse( - status="finished", - type="GENERAL", - rephrased_question=rephrased_question, - intent_reasoning=intent_reasoning, - ) - results["metadata"]["type"] = "GENERAL" - return results - else: - self._ask_results[query_id] = AskResultResponse( - status="understanding", - type="TEXT_TO_SQL", - rephrased_question=rephrased_question, - intent_reasoning=intent_reasoning, + user_query = ( + ask_request.query + if not rephrased_question + else rephrased_question ) - if not self._is_stopped(query_id): + if intent == "MISLEADING_QUERY": + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="MISLEADING_QUERY", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + ) + results["metadata"]["type"] = "MISLEADING_QUERY" + return results + elif intent == "GENERAL": + asyncio.create_task( + self._pipelines["data_assistance"].run( + query=user_query, + history=ask_request.history, + db_schemas=intent_classification_result.get( + "db_schemas" + ), + language=ask_request.configurations.language, + query_id=ask_request.query_id, + ) + ) + + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="GENERAL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + ) + results["metadata"]["type"] = "GENERAL" + return results + else: + self._ask_results[query_id] = AskResultResponse( + status="understanding", + type="TEXT_TO_SQL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + ) + if not self._is_stopped(query_id) and not api_results: self._ask_results[query_id] = AskResultResponse( status="searching", type="TEXT_TO_SQL", @@ -236,7 +262,7 @@ async def ask( results["metadata"]["type"] = "TEXT_TO_SQL" return results - if not self._is_stopped(query_id): + if not self._is_stopped(query_id) and not api_results: self._ask_results[query_id] = AskResultResponse( status="generating", type="TEXT_TO_SQL", @@ -244,117 +270,94 @@ async def ask( intent_reasoning=intent_reasoning, ) - historical_question = await self._pipelines["historical_question"].run( - query=ask_request.query, - id=ask_request.project_id, - ) - - # we only return top 1 result - historical_question_result = historical_question.get( - "formatted_output", {} - ).get("documents", [])[:1] + if ask_request.history: + text_to_sql_generation_results = await self._pipelines[ + "followup_sql_generation" + ].run( + query=user_query, + contexts=documents, + history=ask_request.history, + project_id=ask_request.project_id, + configuration=ask_request.configurations, + ) + else: + text_to_sql_generation_results = await self._pipelines[ + "sql_generation" + ].run( + query=user_query, + contexts=documents, + exclude=historical_question_result, + project_id=ask_request.project_id, + configuration=ask_request.configurations, + ) - api_results = [] - if historical_question_result: + if sql_valid_results := text_to_sql_generation_results["post_process"][ + "valid_generation_results" + ]: api_results = [ AskResult( **{ - "sql": result.get("statement"), - "type": "view", - "viewId": result.get("viewId"), + "sql": result.get("sql"), + "type": "llm", } ) - for result in historical_question_result + for result in sql_valid_results + ][:1] + elif failed_dry_run_results := self._get_failed_dry_run_results( + text_to_sql_generation_results["post_process"][ + "invalid_generation_results" ] - else: - if ask_request.history: - text_to_sql_generation_results = await self._pipelines[ - "followup_sql_generation" - ].run( - query=user_query, - contexts=documents, - history=ask_request.history, - project_id=ask_request.project_id, - configuration=ask_request.configurations, - ) - else: - text_to_sql_generation_results = await self._pipelines[ - "sql_generation" - ].run( - query=user_query, - contexts=documents, - exclude=historical_question_result, - project_id=ask_request.project_id, - configuration=ask_request.configurations, - ) + ): + self._ask_results[query_id] = AskResultResponse( + status="correcting", + ) + sql_correction_results = await self._pipelines[ + "sql_correction" + ].run( + contexts=documents, + invalid_generation_results=failed_dry_run_results, + project_id=ask_request.project_id, + ) - if sql_valid_results := text_to_sql_generation_results[ + if valid_generation_results := sql_correction_results[ "post_process" ]["valid_generation_results"]: api_results = [ AskResult( **{ - "sql": result.get("sql"), + "sql": valid_generation_result.get("sql"), "type": "llm", } ) - for result in sql_valid_results + for valid_generation_result in valid_generation_results ][:1] - elif failed_dry_run_results := self._get_failed_dry_run_results( - text_to_sql_generation_results["post_process"][ - "invalid_generation_results" - ] - ): - self._ask_results[query_id] = AskResultResponse( - status="correcting", - ) - sql_correction_results = await self._pipelines[ - "sql_correction" - ].run( - contexts=documents, - invalid_generation_results=failed_dry_run_results, - project_id=ask_request.project_id, - ) - if valid_generation_results := sql_correction_results[ - "post_process" - ]["valid_generation_results"]: - api_results = [ - AskResult( - **{ - "sql": valid_generation_result.get("sql"), - "type": "llm", - } - ) - for valid_generation_result in valid_generation_results - ][:1] - - if api_results: - if not self._is_stopped(query_id): - self._ask_results[query_id] = AskResultResponse( - status="finished", - type="TEXT_TO_SQL", - response=api_results, - rephrased_question=rephrased_question, - intent_reasoning=intent_reasoning, - ) - results["ask_result"] = api_results - results["metadata"]["type"] = "TEXT_TO_SQL" - else: - logger.exception(f"ask pipeline - NO_RELEVANT_SQL: {user_query}") - if not self._is_stopped(query_id): - self._ask_results[query_id] = AskResultResponse( - status="failed", - type="TEXT_TO_SQL", - error=AskError( - code="NO_RELEVANT_SQL", - message="No relevant SQL", - ), - rephrased_question=rephrased_question, - intent_reasoning=intent_reasoning, - ) - results["metadata"]["error_type"] = "NO_RELEVANT_SQL" - results["metadata"]["type"] = "TEXT_TO_SQL" + if api_results: + if not self._is_stopped(query_id): + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="TEXT_TO_SQL", + response=api_results, + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + ) + results["ask_result"] = api_results + results["metadata"]["type"] = "TEXT_TO_SQL" + else: + logger.exception(f"ask pipeline - NO_RELEVANT_SQL: {user_query}") + if not self._is_stopped(query_id): + self._ask_results[query_id] = AskResultResponse( + status="failed", + type="TEXT_TO_SQL", + error=AskError( + code="NO_RELEVANT_SQL", + message="No relevant SQL", + ), + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + ) + results["metadata"]["error_type"] = "NO_RELEVANT_SQL" + results["metadata"]["type"] = "TEXT_TO_SQL" return results except Exception as e: