Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
3. The reasoning plan should be in the language same as the language user provided in the input.
4. Make sure to consider the current time provided in the input if the user's question is related to the date/time.
5. Don't include SQL in the reasoning plan.
6. Each step in the reasoning plan must start with a number, and a reasoning for the step.
6. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step.
7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan.
8. Do not include ```markdown or ``` in the answer.
9. A table name in the reasoning plan must be in this format: `table: <table_name>`.
10. A column name in the reasoning plan must be in this format: `column: <table_name>.<column_name>`.

### FINAL ANSWER FORMAT ###
The final answer must be a reasoning plan in plain Markdown string format
Expand Down Expand Up @@ -152,8 +154,7 @@ async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()

if query_id not in self._user_queues:
yield ""
return
self._user_queues[query_id] = asyncio.Queue()

while True:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
3. The reasoning plan should be in the language same as the language user provided in the input.
4. Make sure to consider the current time provided in the input if the user's question is related to the date/time.
5. Don't include SQL in the reasoning plan.
6. Each step in the reasoning plan must start with a number, and a reasoning for the step.
6. Each step in the reasoning plan must start with a number, a title(in bold format in markdown), and a reasoning for the step.
7. If SQL SAMPLES are provided, make sure to consider them in the reasoning plan.
8. Do not include ```markdown or ``` in the answer.
9. A table name in the reasoning plan must be in this format: `table: <table_name>`.
10. A column name in the reasoning plan must be in this format: `column: <table_name>.<column_name>`.

### FINAL ANSWER FORMAT ###
The final answer must be a reasoning plan in plain Markdown string format
Expand Down Expand Up @@ -141,8 +143,7 @@ async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()

if query_id not in self._user_queues:
yield ""
return
self._user_queues[query_id] = asyncio.Queue()

while True:
try:
Expand Down
56 changes: 37 additions & 19 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class AskResultRequest(BaseModel):
query_id: str


class AskResultResponse(BaseModel):
class _AskResultResponse(BaseModel):
status: Literal[
"understanding",
"searching",
Expand All @@ -99,6 +99,11 @@ class AskResultResponse(BaseModel):
invalid_sql: Optional[str] = None
error: Optional[AskError] = None
trace_id: Optional[str] = None
is_followup: Optional[bool] = False


class AskResultResponse(_AskResultResponse):
is_followup: Optional[bool] = Field(False, exclude=True)


# POST /v1/ask-feedbacks
Expand Down Expand Up @@ -227,6 +232,7 @@ async def ask(
self._ask_results[query_id] = AskResultResponse(
status="understanding",
trace_id=trace_id,
is_followup=True if histories else False,
)

historical_question = await self._pipelines["historical_question"].run(
Expand Down Expand Up @@ -304,6 +310,7 @@ async def ask(
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)
results["metadata"]["type"] = "MISLEADING_QUERY"
return results
Expand All @@ -326,6 +333,7 @@ async def ask(
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)
results["metadata"]["type"] = "GENERAL"
return results
Expand All @@ -336,6 +344,7 @@ async def ask(
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)
if not self._is_stopped(query_id, self._ask_results) and not api_results:
self._ask_results[query_id] = AskResultResponse(
Expand All @@ -344,6 +353,7 @@ async def ask(
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)

retrieval_result = await self._pipelines["retrieval"].run(
Expand Down Expand Up @@ -371,6 +381,7 @@ async def ask(
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)
results["metadata"]["error_type"] = "NO_RELEVANT_DATA"
results["metadata"]["type"] = "TEXT_TO_SQL"
Expand All @@ -388,6 +399,7 @@ async def ask(
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
trace_id=trace_id,
is_followup=True if histories else False,
)

if histories:
Expand Down Expand Up @@ -422,6 +434,7 @@ async def ask(
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)

if not self._is_stopped(query_id, self._ask_results) and not api_results:
Expand All @@ -433,6 +446,7 @@ async def ask(
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)

sql_functions = await self._pipelines["sql_functions_retrieval"].run(
Expand Down Expand Up @@ -500,6 +514,7 @@ async def ask(
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)
sql_correction_results = await self._pipelines[
"sql_correction"
Expand Down Expand Up @@ -543,6 +558,7 @@ async def ask(
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
trace_id=trace_id,
is_followup=True if histories else False,
)
results["ask_result"] = api_results
results["metadata"]["type"] = "TEXT_TO_SQL"
Expand All @@ -562,6 +578,7 @@ async def ask(
sql_generation_reasoning=sql_generation_reasoning,
invalid_sql=invalid_sql,
trace_id=trace_id,
is_followup=True if histories else False,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["error_message"] = error_message
Expand All @@ -579,6 +596,7 @@ async def ask(
message=str(e),
),
trace_id=trace_id,
is_followup=True if histories else False,
)

results["metadata"]["error_type"] = "OTHERS"
Expand Down Expand Up @@ -618,31 +636,31 @@ async def get_ask_streaming_result(
query_id: str,
):
if self._ask_results.get(query_id):
if self._ask_results.get(query_id).type == "GENERAL":
if self._ask_results[query_id].type == "GENERAL":
async for chunk in self._pipelines[
"data_assistance"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()
elif self._ask_results.get(query_id).status == "planning":
# only one of the two pipelines will be used
async for chunk in self._pipelines[
"sql_generation_reasoning"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()

async for chunk in self._pipelines[
"followup_sql_generation_reasoning"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()
elif self._ask_results[query_id].status == "planning":
if self._ask_results[query_id].is_followup:
async for chunk in self._pipelines[
"followup_sql_generation_reasoning"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()
else:
async for chunk in self._pipelines[
"sql_generation_reasoning"
].get_streaming_results(query_id):
event = SSEEvent(
data=SSEEvent.SSEEventMessage(message=chunk),
)
yield event.serialize()

@observe(name="Ask Feedback")
@trace_metadata
Expand Down