diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index e812e36c25..b468281c03 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -27,9 +27,11 @@ 3. The reasoning plan should be in the language same as the language user provided in the input. 4. Make sure to consider the current time provided in the input if the user's question is related to the date/time. 5. Don't include SQL in the reasoning plan. -6. Each step in the reasoning plan must start with a number, and a reasoning for the step. +6. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step. 7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan. 8. Do not include ```markdown or ``` in the answer. +9. A table name in the reasoning plan must be in this format: `table: `. +10. A column name in the reasoning plan must be in this format: `column: .`. ### FINAL ANSWER FORMAT ### The final answer must be a reasoning plan in plain Markdown string format @@ -152,8 +154,7 @@ async def _get_streaming_results(query_id): return await self._user_queues[query_id].get() if query_id not in self._user_queues: - yield "" - return + self._user_queues[query_id] = asyncio.Queue() while True: try: diff --git a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py index 7cb35a933f..59c8428981 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation_reasoning.py @@ -26,9 +26,11 @@ 3. The reasoning plan should be in the language same as the language user provided in the input. 4. Make sure to consider the current time provided in the input if the user's question is related to the date/time. 5. Don't include SQL in the reasoning plan. -6. Each step in the reasoning plan must start with a number, and a reasoning for the step. +6. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step. 7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan. 8. Do not include ```markdown or ``` in the answer. +9. A table name in the reasoning plan must be in this format: `table: `. +10. A column name in the reasoning plan must be in this format: `column: .`. ### FINAL ANSWER FORMAT ### The final answer must be a reasoning plan in plain Markdown string format @@ -141,8 +143,7 @@ async def _get_streaming_results(query_id): return await self._user_queues[query_id].get() if query_id not in self._user_queues: - yield "" - return + self._user_queues[query_id] = asyncio.Queue() while True: try: diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 8d339dac2e..6ce93efbb3 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -79,7 +79,7 @@ class AskResultRequest(BaseModel): query_id: str -class AskResultResponse(BaseModel): +class _AskResultResponse(BaseModel): status: Literal[ "understanding", "searching", @@ -99,6 +99,11 @@ class AskResultResponse(BaseModel): invalid_sql: Optional[str] = None error: Optional[AskError] = None trace_id: Optional[str] = None + is_followup: Optional[bool] = False + + +class AskResultResponse(_AskResultResponse): + is_followup: Optional[bool] = Field(False, exclude=True) # POST /v1/ask-feedbacks @@ -227,6 +232,7 @@ async def ask( self._ask_results[query_id] = AskResultResponse( status="understanding", trace_id=trace_id, + is_followup=True if histories else False, ) historical_question = await self._pipelines["historical_question"].run( @@ -304,6 +310,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["type"] = "MISLEADING_QUERY" return results @@ -326,6 +333,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["type"] = "GENERAL" return results @@ -336,6 +344,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) if not self._is_stopped(query_id, self._ask_results) and not api_results: self._ask_results[query_id] = AskResultResponse( @@ -344,6 +353,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) retrieval_result = await self._pipelines["retrieval"].run( @@ -371,6 +381,7 @@ async def ask( rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["error_type"] = "NO_RELEVANT_DATA" results["metadata"]["type"] = "TEXT_TO_SQL" @@ -388,6 +399,7 @@ async def ask( intent_reasoning=intent_reasoning, retrieved_tables=table_names, trace_id=trace_id, + is_followup=True if histories else False, ) if histories: @@ -422,6 +434,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) if not self._is_stopped(query_id, self._ask_results) and not api_results: @@ -433,6 +446,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) sql_functions = await self._pipelines["sql_functions_retrieval"].run( @@ -500,6 +514,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) sql_correction_results = await self._pipelines[ "sql_correction" @@ -543,6 +558,7 @@ async def ask( retrieved_tables=table_names, sql_generation_reasoning=sql_generation_reasoning, trace_id=trace_id, + is_followup=True if histories else False, ) results["ask_result"] = api_results results["metadata"]["type"] = "TEXT_TO_SQL" @@ -562,6 +578,7 @@ async def ask( sql_generation_reasoning=sql_generation_reasoning, invalid_sql=invalid_sql, trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" results["metadata"]["error_message"] = error_message @@ -579,6 +596,7 @@ async def ask( message=str(e), ), trace_id=trace_id, + is_followup=True if histories else False, ) results["metadata"]["error_type"] = "OTHERS" @@ -618,7 +636,7 @@ async def get_ask_streaming_result( query_id: str, ): if self._ask_results.get(query_id): - if self._ask_results.get(query_id).type == "GENERAL": + if self._ask_results[query_id].type == "GENERAL": async for chunk in self._pipelines[ "data_assistance" ].get_streaming_results(query_id): @@ -626,23 +644,23 @@ async def get_ask_streaming_result( data=SSEEvent.SSEEventMessage(message=chunk), ) yield event.serialize() - elif self._ask_results.get(query_id).status == "planning": - # only one of the two pipelines will be used - async for chunk in self._pipelines[ - "sql_generation_reasoning" - ].get_streaming_results(query_id): - event = SSEEvent( - data=SSEEvent.SSEEventMessage(message=chunk), - ) - yield event.serialize() - - async for chunk in self._pipelines[ - "followup_sql_generation_reasoning" - ].get_streaming_results(query_id): - event = SSEEvent( - data=SSEEvent.SSEEventMessage(message=chunk), - ) - yield event.serialize() + elif self._ask_results[query_id].status == "planning": + if self._ask_results[query_id].is_followup: + async for chunk in self._pipelines[ + "followup_sql_generation_reasoning" + ].get_streaming_results(query_id): + event = SSEEvent( + data=SSEEvent.SSEEventMessage(message=chunk), + ) + yield event.serialize() + else: + async for chunk in self._pipelines[ + "sql_generation_reasoning" + ].get_streaming_results(query_id): + event = SSEEvent( + data=SSEEvent.SSEEventMessage(message=chunk), + ) + yield event.serialize() @observe(name="Ask Feedback") @trace_metadata