Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Settings(BaseSettings):
# generation config
allow_intent_classification: bool = Field(default=True)
allow_sql_generation_reasoning: bool = Field(default=True)
max_histories: int = Field(default=10)

# engine config
engine_timeout: float = Field(default=30.0)
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def create_service_container(
},
allow_intent_classification=settings.allow_intent_classification,
allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,
max_histories=settings.max_histories,
**query_cache,
),
chart_service=services.ChartService(
Expand Down
28 changes: 12 additions & 16 deletions wren-ai-service/src/pipelines/generation/data_assistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,11 @@ def prompt(
db_schemas: list[str],
language: str,
prompt_builder: PromptBuilder,
history: Optional[AskHistory] = None,
histories: Optional[list[AskHistory]] = None,
) -> dict:
if history:
previous_query_summaries = [
step.summary for step in history.steps if step.summary
]
else:
previous_query_summaries = []

previous_query_summaries = (
[history.question for history in histories] if histories else []
)
query = "\n".join(previous_query_summaries) + "\n" + query

return prompt_builder.run(
Expand Down Expand Up @@ -106,9 +102,9 @@ def __init__(

def _streaming_callback(self, chunk, query_id):
if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Create a new queue for the user if it doesn't exist
self._user_queues[query_id] = (
asyncio.Queue()
) # Create a new queue for the user if it doesn't exist
# Put the chunk content into the user's queue
asyncio.create_task(self._user_queues[query_id].put(chunk.content))
if chunk.meta.get("finish_reason"):
Expand All @@ -119,9 +115,9 @@ async def _get_streaming_results(query_id):
return await self._user_queues[query_id].get()

if query_id not in self._user_queues:
self._user_queues[
query_id
] = asyncio.Queue() # Ensure the user's queue exists
self._user_queues[query_id] = (
asyncio.Queue()
) # Ensure the user's queue exists
while True:
try:
# Wait for an item from the user's queue
Expand All @@ -146,7 +142,7 @@ async def run(
db_schemas: list[str],
language: str,
query_id: Optional[str] = None,
history: Optional[AskHistory] = None,
histories: Optional[list[AskHistory]] = None,
):
logger.info("Data Assistance pipeline is running...")
return await self._pipe.execute(
Expand All @@ -156,7 +152,7 @@ async def run(
"db_schemas": db_schemas,
"language": language,
"query_id": query_id or "",
"history": history,
"histories": histories,
**self._components,
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@
{% endif %}

### CONTEXT ###
Previous SQL Summary:
{% for summary in previous_query_summaries %}
{{ summary }}
User's query history:
{% for history in histories %}
{{ history.question }}
{{ history.sql }}
{% endfor %}
Previous SQL Query: {{ history.sql }}

### QUESTION ###
User's Follow-up Question: {{ query }}
Expand All @@ -71,20 +71,20 @@ def prompt(
query: str,
documents: List[str],
sql_generation_reasoning: str,
history: AskHistory,
histories: list[AskHistory],
configuration: Configuration,
prompt_builder: PromptBuilder,
sql_samples: List[Dict] | None = None,
has_calculated_field: bool = False,
has_metric: bool = False,
) -> dict:
previous_query_summaries = [step.summary for step in history.steps if step.summary]
previous_query_summaries = [history.question for history in histories]
Comment thread
paopa marked this conversation as resolved.

return prompt_builder.run(
query=query,
documents=documents,
sql_generation_reasoning=sql_generation_reasoning,
history=history,
histories=histories,
previous_query_summaries=previous_query_summaries,
instructions=construct_instructions(
configuration,
Expand Down Expand Up @@ -152,7 +152,7 @@ async def run(
query: str,
contexts: List[str],
sql_generation_reasoning: str,
history: AskHistory,
histories: list[AskHistory],
configuration: Configuration = Configuration(),
sql_samples: List[Dict] | None = None,
project_id: str | None = None,
Expand All @@ -166,7 +166,7 @@ async def run(
"query": query,
"documents": contexts,
"sql_generation_reasoning": sql_generation_reasoning,
"history": history,
"histories": histories,
"project_id": project_id,
"configuration": configuration,
"sql_samples": sql_samples,
Expand Down
18 changes: 11 additions & 7 deletions wren-ai-service/src/pipelines/generation/intent_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@

### INPUT ###
{% if query_history %}
User's previous SQLs: {{ query_history }}
User's query history:
{% for history in query_history %}
{{ history.question }}
{{ history.sql }}
{% endfor %}
{% endif %}
User's question: {{query}}
Current Time: {{ current_time }}
Expand All @@ -116,10 +120,10 @@
## Start of Pipeline
@observe(capture_input=False, capture_output=False)
async def embedding(
query: str, embedder: Any, history: Optional[AskHistory] = None
query: str, embedder: Any, histories: Optional[list[AskHistory]] = None
) -> dict:
previous_query_summaries = (
[step.summary for step in history.steps if step.summary] if history else []
[history.question for history in histories] if histories else []
)

query = "\n".join(previous_query_summaries) + "\n" + query
Expand Down Expand Up @@ -222,14 +226,14 @@ def prompt(
query: str,
construct_db_schemas: list[str],
prompt_builder: PromptBuilder,
history: Optional[AskHistory] = None,
histories: Optional[list[AskHistory]] = None,
configuration: Configuration | None = None,
) -> dict:
return prompt_builder.run(
query=query,
language=configuration.language,
db_schemas=construct_db_schemas,
query_history=history.sql if history else [],
query_history=histories,
current_time=configuration.show_current_time(),
)

Expand Down Expand Up @@ -316,7 +320,7 @@ async def run(
self,
query: str,
id: Optional[str] = None,
history: Optional[AskHistory] = None,
histories: Optional[list[AskHistory]] = None,
configuration: Configuration = Configuration(),
):
logger.info("Intent Classification pipeline is running...")
Expand All @@ -325,7 +329,7 @@ async def run(
inputs={
"query": query,
"id": id or "",
"history": history,
"histories": histories,
"configuration": configuration,
**self._components,
},
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,6 @@ async def run(
"sql_expansion",
query="query",
contexts=[],
history=AskHistory(sql="SELECT * FROM table", summary="Summary", steps=[]),
history=AskHistory(sql="SELECT * FROM table", question="user question"),
configuration=Configuration(),
Comment thread
paopa marked this conversation as resolved.
)
21 changes: 9 additions & 12 deletions wren-ai-service/src/pipelines/retrieval/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def _build_view_ddl(content: dict) -> str:
## Start of Pipeline
@observe(capture_input=False, capture_output=False)
async def embedding(
query: str, embedder: Any, history: Optional[AskHistory] = None
query: str, embedder: Any, histories: Optional[list[AskHistory]] = None
) -> dict:
if query:
if history:
if histories:
previous_query_summaries = [
step.summary for step in history.steps if step.summary
history.question for history in histories
]
else:
previous_query_summaries = []
Expand Down Expand Up @@ -292,7 +292,7 @@ def prompt(
construct_db_schemas: list[dict],
prompt_builder: PromptBuilder,
check_using_db_schemas_without_pruning: dict,
history: Optional[AskHistory] = None,
histories: Optional[list[AskHistory]] = None,
) -> dict:
if not check_using_db_schemas_without_pruning["db_schemas"]:
logger.info(
Expand All @@ -303,12 +303,9 @@ def prompt(
for construct_db_schema in construct_db_schemas
]

if history:
previous_query_summaries = [
step.summary for step in history.steps if step.summary
]
else:
previous_query_summaries = []
previous_query_summaries = (
[history.question for history in histories] if histories else []
)

query = "\n".join(previous_query_summaries) + "\n" + query
return prompt_builder.run(question=query, db_schemas=db_schemas)
Expand Down Expand Up @@ -482,7 +479,7 @@ async def run(
query: str = "",
tables: Optional[list[str]] = None,
id: Optional[str] = None,
history: Optional[AskHistory] = None,
histories: Optional[list[AskHistory]] = None,
):
logger.info("Ask Retrieval pipeline is running...")
return await self._pipe.execute(
Expand All @@ -491,7 +488,7 @@ async def run(
"query": query,
"tables": tables,
"id": id or "",
"history": history,
"histories": histories,
**self._components,
**self._configs,
},
Expand Down
8 changes: 4 additions & 4 deletions wren-ai-service/src/web/v1/routers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ async def ask_feedback(
) -> AskFeedbackResponse:
query_id = str(uuid.uuid4())
ask_feedback_request.query_id = query_id
service_container.ask_service._ask_feedback_results[
query_id
] = AskFeedbackResultResponse(
status="searching",
service_container.ask_service._ask_feedback_results[query_id] = (
AskFeedbackResultResponse(
status="searching",
)
)

background_tasks.add_task(
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/src/web/v1/routers/sql_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@

@router.post("/sql-expansions")
async def sql_expansion(
sql_expansion_request: SqlExpansionRequest,
request: SqlExpansionRequest,
background_tasks: BackgroundTasks,
service_container: ServiceContainer = Depends(get_service_container),
service_metadata: ServiceMetadata = Depends(get_service_metadata),
) -> SqlExpansionResponse:
query_id = str(uuid.uuid4())
sql_expansion_request.query_id = query_id
request.query_id = query_id
service_container.sql_expansion_service._sql_expansion_results[
query_id
] = SqlExpansionResultResponse(
Expand All @@ -103,7 +103,7 @@ async def sql_expansion(

background_tasks.add_task(
service_container.sql_expansion_service.sql_expansion,
sql_expansion_request,
request,
service_metadata=asdict(service_metadata),
)
return SqlExpansionResponse(query_id=query_id)
Expand Down
18 changes: 10 additions & 8 deletions wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from src.core.pipeline import BasicPipeline
from src.utils import trace_metadata
from src.web.v1.services import Configuration, SSEEvent
from src.web.v1.services.ask_details import SQLBreakdown

logger = logging.getLogger("wren-ai-service")


class AskHistory(BaseModel):
sql: str
steps: List[SQLBreakdown]
question: str


# POST /v1/asks
Expand All @@ -29,7 +28,7 @@ class AskRequest(BaseModel):
# so we need to support as a choice, and will remove it in the future
mdl_hash: Optional[str] = Field(validation_alias=AliasChoices("mdl_hash", "id"))
thread_id: Optional[str] = None
history: Optional[AskHistory] = None
histories: Optional[list[AskHistory]] = Field(default_factory=list)
configurations: Optional[Configuration] = Configuration()

@property
Expand Down Expand Up @@ -166,6 +165,7 @@ def __init__(
pipelines: Dict[str, BasicPipeline],
allow_intent_classification: bool = True,
allow_sql_generation_reasoning: bool = True,
max_histories: int = 10,
maxsize: int = 1_000_000,
ttl: int = 120,
):
Expand All @@ -178,6 +178,7 @@ def __init__(
)
self._allow_sql_generation_reasoning = allow_sql_generation_reasoning
self._allow_intent_classification = allow_intent_classification
self._max_histories = max_histories

def _is_stopped(self, query_id: str, container: dict):
if (
Expand Down Expand Up @@ -205,6 +206,7 @@ async def ask(
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use the configured max_histories value instead of hardcoding.

The code currently hardcodes the history limit to 10 instead of using the configured self._max_histories value.

-        histories = ask_request.histories[:10]
+        histories = ask_request.histories[:self._max_histories]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
histories = ask_request.histories[:self._max_histories]

query_id = ask_request.query_id
histories = ask_request.histories[: self._max_histories]
rephrased_question = None
intent_reasoning = None
sql_generation_reasoning = None
Expand Down Expand Up @@ -250,7 +252,7 @@ async def ask(
intent_classification_result = (
await self._pipelines["intent_classification"].run(
query=user_query,
history=ask_request.history,
histories=histories,
id=ask_request.project_id,
configuration=ask_request.configurations,
)
Expand Down Expand Up @@ -278,7 +280,7 @@ async def ask(
asyncio.create_task(
self._pipelines["data_assistance"].run(
query=user_query,
history=ask_request.history,
histories=histories,
db_schemas=intent_classification_result.get(
"db_schemas"
),
Expand Down Expand Up @@ -315,7 +317,7 @@ async def ask(

retrieval_result = await self._pipelines["retrieval"].run(
query=user_query,
history=ask_request.history,
histories=histories,
id=ask_request.project_id,
)
_retrieval_result = retrieval_result.get(
Expand Down Expand Up @@ -403,14 +405,14 @@ async def ask(
)
has_metric = (_retrieval_result.get("has_metric", False),)

if ask_request.history:
if histories:
text_to_sql_generation_results = await self._pipelines[
"followup_sql_generation"
].run(
query=user_query,
contexts=table_ddls,
sql_generation_reasoning=sql_generation_reasoning,
history=ask_request.history,
histories=histories,
project_id=ask_request.project_id,
configuration=ask_request.configurations,
sql_samples=sql_samples,
Expand Down
Loading