Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 142 additions & 139 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def ask(
query_id = ask_request.query_id
rephrased_question = None
intent_reasoning = None
api_results = []

try:
# ask status can be understanding, searching, generating, finished, failed, stopped
Expand All @@ -151,60 +152,85 @@ async def ask(
status="understanding",
)

intent_classification_result = (
await self._pipelines["intent_classification"].run(
query=ask_request.query,
history=ask_request.history,
id=ask_request.project_id,
)
).get("post_process", {})
intent = intent_classification_result.get("intent")
rephrased_question = intent_classification_result.get(
"rephrased_question"
historical_question = await self._pipelines["historical_question"].run(
query=ask_request.query,
id=ask_request.project_id,
)
intent_reasoning = intent_classification_result.get("reasoning")

user_query = (
ask_request.query if not rephrased_question else rephrased_question
)
# we only return top 1 result
historical_question_result = historical_question.get(
"formatted_output", {}
).get("documents", [])[:1]

if intent == "MISLEADING_QUERY":
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="MISLEADING_QUERY",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "MISLEADING_QUERY"
return results
elif intent == "GENERAL":
asyncio.create_task(
self._pipelines["data_assistance"].run(
query=user_query,
if historical_question_result:
api_results = [
AskResult(
**{
"sql": result.get("statement"),
"type": "view",
"viewId": result.get("viewId"),
}
)
for result in historical_question_result
]
else:
intent_classification_result = (
await self._pipelines["intent_classification"].run(
query=ask_request.query,
history=ask_request.history,
db_schemas=intent_classification_result.get("db_schemas"),
language=ask_request.configurations.language,
query_id=ask_request.query_id,
id=ask_request.project_id,
)
).get("post_process", {})
intent = intent_classification_result.get("intent")
rephrased_question = intent_classification_result.get(
"rephrased_question"
)
intent_reasoning = intent_classification_result.get("reasoning")

self._ask_results[query_id] = AskResultResponse(
status="finished",
type="GENERAL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "GENERAL"
return results
else:
self._ask_results[query_id] = AskResultResponse(
status="understanding",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
user_query = (
ask_request.query
if not rephrased_question
else rephrased_question
)

if not self._is_stopped(query_id):
if intent == "MISLEADING_QUERY":
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="MISLEADING_QUERY",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "MISLEADING_QUERY"
return results
elif intent == "GENERAL":
asyncio.create_task(
self._pipelines["data_assistance"].run(
query=user_query,
history=ask_request.history,
db_schemas=intent_classification_result.get(
"db_schemas"
),
language=ask_request.configurations.language,
query_id=ask_request.query_id,
)
)

self._ask_results[query_id] = AskResultResponse(
status="finished",
type="GENERAL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["type"] = "GENERAL"
return results
else:
self._ask_results[query_id] = AskResultResponse(
status="understanding",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
if not self._is_stopped(query_id) and not api_results:
self._ask_results[query_id] = AskResultResponse(
status="searching",
type="TEXT_TO_SQL",
Expand Down Expand Up @@ -236,125 +262,102 @@ async def ask(
results["metadata"]["type"] = "TEXT_TO_SQL"
return results

if not self._is_stopped(query_id):
if not self._is_stopped(query_id) and not api_results:
self._ask_results[query_id] = AskResultResponse(
status="generating",
type="TEXT_TO_SQL",
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)

historical_question = await self._pipelines["historical_question"].run(
query=ask_request.query,
id=ask_request.project_id,
)

# we only return top 1 result
historical_question_result = historical_question.get(
"formatted_output", {}
).get("documents", [])[:1]
if ask_request.history:
text_to_sql_generation_results = await self._pipelines[
"followup_sql_generation"
].run(
query=user_query,
contexts=documents,
history=ask_request.history,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)
else:
text_to_sql_generation_results = await self._pipelines[
"sql_generation"
].run(
query=user_query,
contexts=documents,
exclude=historical_question_result,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)

api_results = []
if historical_question_result:
if sql_valid_results := text_to_sql_generation_results["post_process"][
"valid_generation_results"
]:
api_results = [
AskResult(
**{
"sql": result.get("statement"),
"type": "view",
"viewId": result.get("viewId"),
"sql": result.get("sql"),
"type": "llm",
}
)
for result in historical_question_result
for result in sql_valid_results
][:1]
elif failed_dry_run_results := self._get_failed_dry_run_results(
text_to_sql_generation_results["post_process"][
"invalid_generation_results"
]
else:
if ask_request.history:
text_to_sql_generation_results = await self._pipelines[
"followup_sql_generation"
].run(
query=user_query,
contexts=documents,
history=ask_request.history,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)
else:
text_to_sql_generation_results = await self._pipelines[
"sql_generation"
].run(
query=user_query,
contexts=documents,
exclude=historical_question_result,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
)
):
self._ask_results[query_id] = AskResultResponse(
status="correcting",
)
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=documents,
invalid_generation_results=failed_dry_run_results,
project_id=ask_request.project_id,
)

if sql_valid_results := text_to_sql_generation_results[
if valid_generation_results := sql_correction_results[
"post_process"
]["valid_generation_results"]:
api_results = [
AskResult(
**{
"sql": result.get("sql"),
"sql": valid_generation_result.get("sql"),
"type": "llm",
}
)
for result in sql_valid_results
for valid_generation_result in valid_generation_results
][:1]
elif failed_dry_run_results := self._get_failed_dry_run_results(
text_to_sql_generation_results["post_process"][
"invalid_generation_results"
]
):
self._ask_results[query_id] = AskResultResponse(
status="correcting",
)
sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=documents,
invalid_generation_results=failed_dry_run_results,
project_id=ask_request.project_id,
)

if valid_generation_results := sql_correction_results[
"post_process"
]["valid_generation_results"]:
api_results = [
AskResult(
**{
"sql": valid_generation_result.get("sql"),
"type": "llm",
}
)
for valid_generation_result in valid_generation_results
][:1]

if api_results:
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="TEXT_TO_SQL",
response=api_results,
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["ask_result"] = api_results
results["metadata"]["type"] = "TEXT_TO_SQL"
else:
logger.exception(f"ask pipeline - NO_RELEVANT_SQL: {user_query}")
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="failed",
type="TEXT_TO_SQL",
error=AskError(
code="NO_RELEVANT_SQL",
message="No relevant SQL",
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["type"] = "TEXT_TO_SQL"
if api_results:
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="finished",
type="TEXT_TO_SQL",
response=api_results,
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["ask_result"] = api_results
results["metadata"]["type"] = "TEXT_TO_SQL"
else:
logger.exception(f"ask pipeline - NO_RELEVANT_SQL: {user_query}")
if not self._is_stopped(query_id):
self._ask_results[query_id] = AskResultResponse(
status="failed",
type="TEXT_TO_SQL",
error=AskError(
code="NO_RELEVANT_SQL",
message="No relevant SQL",
),
rephrased_question=rephrased_question,
intent_reasoning=intent_reasoning,
)
results["metadata"]["error_type"] = "NO_RELEVANT_SQL"
results["metadata"]["type"] = "TEXT_TO_SQL"

return results
except Exception as e:
Expand Down