-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat(wren-ai-service): Implement Chat History Management with Max History Limit #1377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7e8f86
ea2a7ad
303d739
8ef0645
6dd8778
03b917b
54c5107
db635e3
39eef72
6fd8104
fb9e7e4
c002739
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,14 +9,13 @@ | |||||
| from src.core.pipeline import BasicPipeline | ||||||
| from src.utils import trace_metadata | ||||||
| from src.web.v1.services import Configuration, SSEEvent | ||||||
| from src.web.v1.services.ask_details import SQLBreakdown | ||||||
|
|
||||||
| logger = logging.getLogger("wren-ai-service") | ||||||
|
|
||||||
|
|
||||||
| class AskHistory(BaseModel): | ||||||
| sql: str | ||||||
| steps: List[SQLBreakdown] | ||||||
| question: str | ||||||
|
|
||||||
|
|
||||||
| # POST /v1/asks | ||||||
|
|
@@ -29,7 +28,7 @@ class AskRequest(BaseModel): | |||||
| # so we need to support as a choice, and will remove it in the future | ||||||
| mdl_hash: Optional[str] = Field(validation_alias=AliasChoices("mdl_hash", "id")) | ||||||
| thread_id: Optional[str] = None | ||||||
| history: Optional[AskHistory] = None | ||||||
| histories: Optional[list[AskHistory]] = Field(default_factory=list) | ||||||
| configurations: Optional[Configuration] = Configuration() | ||||||
|
|
||||||
| @property | ||||||
|
|
@@ -166,6 +165,7 @@ def __init__( | |||||
| pipelines: Dict[str, BasicPipeline], | ||||||
| allow_intent_classification: bool = True, | ||||||
| allow_sql_generation_reasoning: bool = True, | ||||||
| max_histories: int = 10, | ||||||
| maxsize: int = 1_000_000, | ||||||
| ttl: int = 120, | ||||||
| ): | ||||||
|
|
@@ -178,6 +178,7 @@ def __init__( | |||||
| ) | ||||||
| self._allow_sql_generation_reasoning = allow_sql_generation_reasoning | ||||||
| self._allow_intent_classification = allow_intent_classification | ||||||
| self._max_histories = max_histories | ||||||
|
|
||||||
| def _is_stopped(self, query_id: str, container: dict): | ||||||
| if ( | ||||||
|
|
@@ -205,6 +206,7 @@ async def ask( | |||||
| } | ||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use the configured max_histories value instead of hardcoding. The code currently hardcodes the history limit to 10 instead of using the configured - histories = ask_request.histories[:10]
+ histories = ask_request.histories[:self._max_histories]📝 Committable suggestion
Suggested change
|
||||||
| query_id = ask_request.query_id | ||||||
| histories = ask_request.histories[: self._max_histories] | ||||||
| rephrased_question = None | ||||||
| intent_reasoning = None | ||||||
| sql_generation_reasoning = None | ||||||
|
|
@@ -250,7 +252,7 @@ async def ask( | |||||
| intent_classification_result = ( | ||||||
| await self._pipelines["intent_classification"].run( | ||||||
| query=user_query, | ||||||
| history=ask_request.history, | ||||||
| histories=histories, | ||||||
| id=ask_request.project_id, | ||||||
| configuration=ask_request.configurations, | ||||||
| ) | ||||||
|
|
@@ -278,7 +280,7 @@ async def ask( | |||||
| asyncio.create_task( | ||||||
| self._pipelines["data_assistance"].run( | ||||||
| query=user_query, | ||||||
| history=ask_request.history, | ||||||
| histories=histories, | ||||||
| db_schemas=intent_classification_result.get( | ||||||
| "db_schemas" | ||||||
| ), | ||||||
|
|
@@ -315,7 +317,7 @@ async def ask( | |||||
|
|
||||||
| retrieval_result = await self._pipelines["retrieval"].run( | ||||||
| query=user_query, | ||||||
| history=ask_request.history, | ||||||
| histories=histories, | ||||||
| id=ask_request.project_id, | ||||||
| ) | ||||||
| _retrieval_result = retrieval_result.get( | ||||||
|
|
@@ -403,14 +405,14 @@ async def ask( | |||||
| ) | ||||||
| has_metric = (_retrieval_result.get("has_metric", False),) | ||||||
|
|
||||||
| if ask_request.history: | ||||||
| if histories: | ||||||
| text_to_sql_generation_results = await self._pipelines[ | ||||||
| "followup_sql_generation" | ||||||
| ].run( | ||||||
| query=user_query, | ||||||
| contexts=table_ddls, | ||||||
| sql_generation_reasoning=sql_generation_reasoning, | ||||||
| history=ask_request.history, | ||||||
| histories=histories, | ||||||
| project_id=ask_request.project_id, | ||||||
| configuration=ask_request.configurations, | ||||||
| sql_samples=sql_samples, | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.