From 91d3e96efe2ec228a6086e99ce9c305183939468 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 1 Apr 2025 10:24:25 +0800 Subject: [PATCH 1/4] refine --- .../followup_sql_generation_reasoning.py | 7 +-- .../generation/sql_generation_reasoning.py | 7 +-- wren-ai-service/src/web/v1/services/ask.py | 44 ++++++++++++------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index e812e36c25..b468281c03 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -27,9 +27,11 @@ 3. The reasoning plan should be in the language same as the language user provided in the input. 4. Make sure to consider the current time provided in the input if the user's question is related to the date/time. 5. Don't include SQL in the reasoning plan. -6. Each step in the reasoning plan must start with a number, and a reasoning for the step. +6. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step. 7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan. 8. Do not include ```markdown or ``` in the answer. +9. A table name in the reasoning plan must be in this format: `table: `. +10. A column name in the reasoning plan must be in this format: `column: .`. ### FINAL ANSWER FORMAT ### The final answer must be a reasoning plan in plain Markdown string format @@ -152,8 +154,7 @@ async def _get_streaming_results(query_id): return await self._user_queues[query_id].get() if query_id not in self._user_queues: - yield "" - return + self._user_queues[query_id] = asyncio.Queue() while True: try: diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 7cb35a933f..59c8428981 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -26,9 +26,11 @@ 3. The reasoning plan should be in the language same as the language user provided in the input. 4. Make sure to consider the current time provided in the input if the user's question is related to the date/time. 5. Don't include SQL in the reasoning plan. -6. Each step in the reasoning plan must start with a number, and a reasoning for the step. +6. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step. 7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan. 8. Do not include ```markdown or ``` in the answer. +9. A table name in the reasoning plan must be in this format: `table: `. +10. A column name in the reasoning plan must be in this format: `column: .`. ### FINAL ANSWER FORMAT ### The final answer must be a reasoning plan in plain Markdown string format @@ -141,8 +143,7 @@ async def _get_streaming_results(query_id): return await self._user_queues[query_id].get() if query_id not in self._user_queues: - yield "" - return + self._user_queues[query_id] = asyncio.Queue() while True: try: diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 8d339dac2e..9a5ba753ae 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -79,7 +79,7 @@ class AskResultRequest(BaseModel): query_id: str -class AskResultResponse(BaseModel): +class _AskResultResponse(BaseModel): status: Literal[ "understanding", "searching", @@ -99,6 +99,13 @@ class AskResultResponse(BaseModel): invalid_sql: Optional[str] = None error: Optional[AskError] = None trace_id: Optional[str] = None + is_followup: Optional[bool] = False + is_user_guide: Optional[bool] = False + + +class AskResultResponse(_AskResultResponse): + is_followup: Optional[bool] = Field(False, exclude=True) + is_user_guide: Optional[bool] = Field(False, exclude=True) # POST /v1/ask-feedbacks @@ -229,6 +236,9 @@ async def ask( trace_id=trace_id, ) + if histories: + self._ask_results[query_id].is_followup = True + historical_question = await self._pipelines["historical_question"].run( query=user_query, project_id=ask_request.project_id, @@ -627,22 +637,22 @@ async def get_ask_streaming_result( ) yield event.serialize() elif self._ask_results.get(query_id).status == "planning": - # only one of the two pipelines will be used - async for chunk in self._pipelines[ - "sql_generation_reasoning" - ].get_streaming_results(query_id): - event = SSEEvent( - data=SSEEvent.SSEEventMessage(message=chunk), - ) - yield event.serialize() - - async for chunk in self._pipelines[ - "followup_sql_generation_reasoning" - ].get_streaming_results(query_id): - event = SSEEvent( - data=SSEEvent.SSEEventMessage(message=chunk), - ) - yield event.serialize() + if self._ask_results.get(query_id).is_followup: + async for chunk in self._pipelines[ + "followup_sql_generation_reasoning" + ].get_streaming_results(query_id): + event = SSEEvent( + data=SSEEvent.SSEEventMessage(message=chunk), + ) + yield event.serialize() + else: + async for chunk in self._pipelines[ + "sql_generation_reasoning" + ].get_streaming_results(query_id): + event = SSEEvent( + data=SSEEvent.SSEEventMessage(message=chunk), + ) + yield event.serialize() @observe(name="Ask Feedback") @trace_metadata From 06c5af88e66385741b58886c862189070ea95d3f Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 1 Apr 2025 10:41:26 +0800 Subject: [PATCH 2/4] update --- wren-ai-service/src/web/v1/services/ask.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 9a5ba753ae..4a9520372c 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -238,6 +238,7 @@ async def ask( if histories: self._ask_results[query_id].is_followup = True + print("is_followup: ", self._ask_results[query_id].is_followup) historical_question = await self._pipelines["historical_question"].run( query=user_query, @@ -628,7 +629,7 @@ async def get_ask_streaming_result( query_id: str, ): if self._ask_results.get(query_id): - if self._ask_results.get(query_id).type == "GENERAL": + if self._ask_results["query_id"].type == "GENERAL": async for chunk in self._pipelines[ "data_assistance" ].get_streaming_results(query_id): @@ -636,8 +637,8 @@ async def get_ask_streaming_result( data=SSEEvent.SSEEventMessage(message=chunk), ) yield event.serialize() - elif self._ask_results.get(query_id).status == "planning": - if self._ask_results.get(query_id).is_followup: + elif self._ask_results["query_id"].status == "planning": + if self._ask_results["query_id"].is_followup: async for chunk in self._pipelines[ "followup_sql_generation_reasoning" ].get_streaming_results(query_id): From 669bd6d25999734da476a92b889c056232091e4c Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 1 Apr 2025 11:08:27 +0800 Subject: [PATCH 3/4] fix --- wren-ai-service/src/web/v1/services/ask.py | 23 +++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 4a9520372c..7d8f5913b8 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -234,12 +234,9 @@ async def ask( self._ask_results[query_id] = AskResultResponse( status="understanding", trace_id=trace_id, + is_followup=True if histories else False, ) - if histories: - self._ask_results[query_id].is_followup = True - print("is_followup: ", self._ask_results[query_id].is_followup) - historical_question = await self._pipelines["historical_question"].run( query=user_query, project_id=ask_request.project_id, @@ -315,6 +312,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["type"] = "MISLEADING_QUERY" return results @@ -337,6 +335,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["type"] = "GENERAL" return results @@ -347,6 +346,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) if not self._is_stopped(query_id, self._ask_results) and not api_results: self._ask_results[query_id] = AskResultResponse( @@ -355,6 +355,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) retrieval_result = await self._pipelines["retrieval"].run( @@ -382,6 +383,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["error_type"] = "NO_RELEVANT_DATA" results["metadata"]["type"] = "TEXT_TO_SQL" @@ -399,6 +401,7 @@ async def ask( intent_reasoning=intent_reasoning, retrieved_tables=table_names, trace_id=trace_id, + is_followup=True if histories else False, ) if histories: @@ -433,6 +436,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) if not self._is_stopped(query_id, self._ask_results) and not api_results: @@ -444,6 +448,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) sql_functions = await self._pipelines["sql_functions_retrieval"].run( @@ -511,6 +516,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) sql_correction_results = await self._pipelines[ "sql_correction" @@ -554,6 +560,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["ask_result"] = api_results results["metadata"]["type"] = "TEXT_TO_SQL" @@ -573,6 +580,7 @@ async def ask( sql_generation_reasoning=sql_generation_reasoning, invalid_sql=invalid_sql, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" results["metadata"]["error_message"] = error_message @@ -590,6 +598,7 @@ async def ask( message=str(e), ), trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["error_type"] = "OTHERS" @@ -629,7 +638,7 @@ async def get_ask_streaming_result( query_id: str, ): if self._ask_results.get(query_id): - if self._ask_results["query_id"].type == "GENERAL": + if self._ask_results[query_id].type == "GENERAL": async for chunk in self._pipelines[ "data_assistance" ].get_streaming_results(query_id): @@ -637,8 +646,8 @@ async def get_ask_streaming_result( data=SSEEvent.SSEEventMessage(message=chunk), ) yield event.serialize() - elif self._ask_results["query_id"].status == "planning": - if self._ask_results["query_id"].is_followup: + elif self._ask_results[query_id].status == "planning": + if self._ask_results[query_id].is_followup: async for chunk in self._pipelines[ "followup_sql_generation_reasoning" ].get_streaming_results(query_id): From ebfaf18aa2c5a141eb47bf5799c93537e4b1a1da Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 1 Apr 2025 12:02:43 +0800 Subject: [PATCH 4/4] remove user_guide --- wren-ai-service/src/web/v1/services/ask.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 7d8f5913b8..6ce93efbb3 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -100,12 +100,10 @@ class _AskResultResponse(BaseModel): error: Optional[AskError] = None trace_id: Optional[str] = None is_followup: Optional[bool] = False - is_user_guide: Optional[bool] = False class AskResultResponse(_AskResultResponse): is_followup: Optional[bool] = Field(False, exclude=True) - is_user_guide: Optional[bool] = Field(False, exclude=True) # POST /v1/ask-feedbacks