Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class AskResultResponse(BaseModel):
type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None
retrieved_tables: Optional[List[str]] = None
response: Optional[List[AskResult]] = None
invalid_sql: Optional[str] = None
error: Optional[AskError] = None


Expand Down Expand Up @@ -372,6 +373,9 @@ async def ask(
sql_generation_reasoning=sql_generation_reasoning,
)

invalid_sql = None
error_message = None

if not self._is_stopped(query_id, self._ask_results) and not api_results:
self._ask_results[query_id] = AskResultResponse(
status="generating",
Expand Down Expand Up @@ -462,9 +466,13 @@ async def ask(
elif failed_dry_run_results := sql_correction_results[
"post_process"
]["invalid_generation_results"]:
error_message = failed_dry_run_results[0]["error"]
invalid = failed_dry_run_results[0]
invalid_sql = invalid["sql"]
error_message = invalid["error"]
else:
error_message = failed_dry_run_results[0]["error"]
invalid = failed_dry_run_results[0]
invalid_sql = invalid["sql"]
error_message = invalid["error"]

if api_results:
if not self._is_stopped(query_id, self._ask_results):
Expand Down Expand Up @@ -493,6 +501,7 @@ async def ask(
intent_reasoning=intent_reasoning,
retrieved_tables=table_names,
sql_generation_reasoning=sql_generation_reasoning,
invalid_sql=invalid_sql,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["error_message"] = error_message
Expand Down Expand Up @@ -631,10 +640,10 @@ async def ask_feedback(
"post_process"
]["invalid_generation_results"]:
if failed_dry_run_results[0]["type"] != "TIME_OUT":
self._ask_feedback_results[
query_id
] = AskFeedbackResultResponse(
status="correcting",
self._ask_feedback_results[query_id] = (
AskFeedbackResultResponse(
status="correcting",
)
)
sql_correction_results = await self._pipelines[
"sql_correction"
Expand Down Expand Up @@ -704,10 +713,10 @@ def stop_ask_feedback(
self,
stop_ask_feedback_request: StopAskFeedbackRequest,
):
self._ask_feedback_results[
stop_ask_feedback_request.query_id
] = AskFeedbackResultResponse(
status="stopped",
self._ask_feedback_results[stop_ask_feedback_request.query_id] = (
AskFeedbackResultResponse(
status="stopped",
)
)

def get_ask_feedback_result(
Expand Down