From ea1503f3caafa33db2570041d170bddb630e0056 Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 18:14:51 +0800 Subject: [PATCH 1/9] feat: impl svc for sql correction --- .../src/web/v1/services/sql_corrections.py | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 wren-ai-service/src/web/v1/services/sql_corrections.py diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py new file mode 100644 index 0000000000..403f74c0c9 --- /dev/null +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -0,0 +1,108 @@ +import logging +from typing import Dict, List, Literal, Optional + +from cachetools import TTLCache +from haystack import Document +from langfuse.decorators import observe +from pydantic import BaseModel + +from src.core.pipeline import BasicPipeline +from src.utils import trace_metadata +from src.web.v1.services import MetadataTraceable + +logger = logging.getLogger("wren-ai-service") + + +class SqlCorrectionService: + class Event(BaseModel, MetadataTraceable): + class Error(BaseModel): + code: Literal["OTHERS"] + message: str + + id: str + status: Literal["correcting", "finished", "failed"] = "correcting" + response: Optional[Dict] = None + error: Optional[Error] = None + trace_id: Optional[str] = None + + def __init__( + self, + pipelines: Dict[str, BasicPipeline], + maxsize: int = 1_000_000, + ttl: int = 120, + ): + self._pipelines = pipelines + self._cache: Dict[str, self.Event] = TTLCache(maxsize=maxsize, ttl=ttl) + + def _handle_exception( + self, + id: str, + error_message: str, + code: str = "OTHERS", + trace_id: Optional[str] = None, + ): + self._cache[id] = self.Event( + id=id, + status="failed", + error=self.Event.Error(code=code, message=error_message), + trace_id=trace_id, + ) + logger.error(error_message) + + class CorrectionRequest(BaseModel): + id: str + contexts: List[Document] + invalid_generation_results: List[Dict[str, str]] + project_id: Optional[str] = None + + @observe(name="SQL Correction") + @trace_metadata + async def correct( + self, + request: CorrectionRequest, + **kwargs, + ): + logger.info(f"Request {request.id}: SQL Correction process is running...") + trace_id = kwargs.get("trace_id") + + try: + # todo: modify the contexts + # todo: check the result format + result = await self._pipelines["sql_correction"].run( + contexts=request.contexts, + invalid_generation_results=request.invalid_generation_results, + project_id=request.project_id, + ) + + self._cache[request.id] = self.Event( + id=request.id, + status="finished", + trace_id=trace_id, + response=result, + ) + + except Exception as e: + self._handle_exception( + request.id, + f"An error occurred during SQL correction: {str(e)}", + trace_id=trace_id, + ) + + return self._cache[request.id].with_metadata() + + def __getitem__(self, id: str) -> Event: + response = self._cache.get(id) + + if response is None: + message = f"SQL Correction Event with ID '{id}' not found." + logger.exception(message) + return self.Event( + id=id, + status="failed", + error=self.Event.Error(code="OTHERS", message=message), + ) + + return response + + def __setitem__(self, id: str, value: Event): + self._cache[id] = value From c3f7b15ccc4c97217c54cc86e8b08e497c5e2955 Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 18:26:13 +0800 Subject: [PATCH 2/9] feat: add service into global container --- wren-ai-service/src/globals.py | 10 ++++++++++ wren-ai-service/src/web/v1/services/__init__.py | 2 ++ 2 files changed, 12 insertions(+) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 7ffd5ba19f..0aafeda92a 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -27,6 +27,7 @@ class ServiceContainer: sql_pairs_service: services.SqlPairsService sql_question_service: services.SqlQuestionService instructions_service: services.InstructionsService + sql_correction_service: services.SqlCorrectionService @dataclass @@ -257,6 +258,15 @@ def create_service_container( }, **query_cache, ), + sql_correction_service=services.SqlCorrectionService( + pipelines={ + "sql_correction": generation.SQLCorrection( + **pipe_components["sql_correction"], + engine_timeout=settings.engine_timeout, + ), + }, + **query_cache, + ), ) diff --git a/wren-ai-service/src/web/v1/services/__init__.py b/wren-ai-service/src/web/v1/services/__init__.py index 55c1ed9ec9..81b909e82b 100644 --- a/wren-ai-service/src/web/v1/services/__init__.py +++ b/wren-ai-service/src/web/v1/services/__init__.py @@ -69,6 +69,7 @@ def serialize(self): from .semantics_description import SemanticsDescription # noqa: E402 from .semantics_preparation import SemanticsPreparationService # noqa: E402 from .sql_answer import SqlAnswerService # noqa: E402 +from .sql_corrections import SqlCorrectionService # noqa: E402 from .sql_expansion import SqlExpansionService # noqa: E402 from .sql_pairs import SqlPairsService # noqa: E402 from .sql_question import SqlQuestionService # noqa: E402 @@ -83,6 +84,7 @@ def serialize(self): "SemanticsDescription", "SemanticsPreparationService", "SqlAnswerService", + "SqlCorrectionService", "SqlExpansionService", "SqlPairsService", "SqlQuestionService", From 46482d5a4a8e4b34c384e782a8def8a7ab506e40 Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 18:28:56 +0800 Subject: [PATCH 3/9] feat: impl sql correction router --- .../src/web/v1/routers/__init__.py | 2 + .../src/web/v1/routers/sql_corrections.py | 126 ++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 wren-ai-service/src/web/v1/routers/sql_corrections.py diff --git a/wren-ai-service/src/web/v1/routers/__init__.py b/wren-ai-service/src/web/v1/routers/__init__.py index d67687a3a1..2eb8384691 100644 --- a/wren-ai-service/src/web/v1/routers/__init__.py +++ b/wren-ai-service/src/web/v1/routers/__init__.py @@ -11,6 +11,7 @@ semantics_description, semantics_preparation, sql_answers, + sql_corrections, sql_expansions, sql_pairs, sql_question, @@ -30,3 +31,4 @@ router.include_router(sql_pairs.router) router.include_router(sql_question.router) router.include_router(instructions.router) +router.include_router(sql_corrections.router) diff --git a/wren-ai-service/src/web/v1/routers/sql_corrections.py b/wren-ai-service/src/web/v1/routers/sql_corrections.py new file mode 100644 index 0000000000..83145d6518 --- /dev/null +++ b/wren-ai-service/src/web/v1/routers/sql_corrections.py @@ -0,0 +1,126 @@ +import uuid +from dataclasses import asdict +from typing import Dict, List, Literal, Optional + +from fastapi import APIRouter, BackgroundTasks, Depends +from haystack import Document +from pydantic import BaseModel + +from src.globals import ( + ServiceContainer, + ServiceMetadata, + get_service_container, + get_service_metadata, +) +from src.web.v1.services import SqlCorrectionService + +router = APIRouter() + + +""" +SQL Correction Router + +This router handles endpoints related to correcting invalid SQL queries. + +Endpoints: +1. POST /sql-correction + - Initiates SQL correction process for invalid SQL queries + - Request body: PostRequest + { + "contexts": [Document], # List of context documents + "invalid_generation_results": [ # List of invalid SQL generation results + { + "sql": "SELECT * FROM table", # Invalid SQL statement + "error": "Error message" # Error message + } + ], + "project_id": "project-id" # Optional project ID + } + - Response: PostResponse + { + "event_id": "unique-uuid" # Unique identifier for tracking correction + } + +2. GET /sql-correction/{event_id} + - Retrieves status and results of SQL correction process + - Path parameter: event_id (str) + - Response: GetResponse + { + "event_id": "unique-uuid", # Unique identifier + "status": "correcting" | "finished" | "failed", + "response": {}, # Correction results (when status is "finished") + "error": { # Present only if status is "failed" + "code": "OTHERS", + "message": "Error description" + }, + "trace_id": "trace-id" # Optional trace ID for debugging + } + +The SQL correction is an asynchronous process. The POST endpoint initiates the operation +and returns immediately with an event_id. The GET endpoint can then be used to check the +status and retrieve the results. + +Usage: +1. Send a POST request to start the correction process +2. Use the returned event_id to poll the GET endpoint until status is "finished" or "failed" + +Note: The actual processing is performed in the background using FastAPI's BackgroundTasks. +Results are cached with a TTL defined in the service configuration. +""" + + +class PostRequest(BaseModel): + # todo: check the contexts + contexts: List[Document] + invalid_generation_results: List[Dict[str, str]] + project_id: Optional[str] = None + + +class PostResponse(BaseModel): + event_id: str + + +@router.post("/sql-correction") +async def correct( + request: PostRequest, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), + service_metadata: ServiceMetadata = Depends(get_service_metadata), +) -> PostResponse: + event_id = str(uuid.uuid4()) + service = service_container.sql_correction_service + service[event_id] = SqlCorrectionService.Event(id=event_id, status="correcting") + + correction_request = SqlCorrectionService.CorrectionRequest( + id=event_id, **request.model_dump() + ) + + background_tasks.add_task( + service.correct, + correction_request, + service_metadata=asdict(service_metadata), + ) + return PostResponse(event_id=event_id) + + +class GetResponse(BaseModel): + event_id: str + status: Literal["correcting", "finished", "failed"] + response: Optional[Dict] = None + error: Optional[dict] = None + trace_id: Optional[str] = None + + +@router.get("/sql-correction/{event_id}") +async def get( + event_id: str, + container: ServiceContainer = Depends(get_service_container), +) -> GetResponse: + event: SqlCorrectionService.Event = container.sql_correction_service[event_id] + return GetResponse( + event_id=event.id, + status=event.status, + response=event.response, + error=event.error and event.error.model_dump(), + trace_id=event.trace_id, + ) From 24e89610205f2e614ccc2cea6ab1fa3e9764ac06 Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 18:40:16 +0800 Subject: [PATCH 4/9] chore: modify the endpoint spec --- .../src/web/v1/routers/sql_corrections.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/wren-ai-service/src/web/v1/routers/sql_corrections.py b/wren-ai-service/src/web/v1/routers/sql_corrections.py index 83145d6518..b4d511f203 100644 --- a/wren-ai-service/src/web/v1/routers/sql_corrections.py +++ b/wren-ai-service/src/web/v1/routers/sql_corrections.py @@ -1,6 +1,6 @@ import uuid from dataclasses import asdict -from typing import Dict, List, Literal, Optional +from typing import Literal, Optional from fastapi import APIRouter, BackgroundTasks, Depends from haystack import Document @@ -23,7 +23,7 @@ This router handles endpoints related to correcting invalid SQL queries. Endpoints: -1. POST /sql-correction +1. POST /sql-corrections - Initiates SQL correction process for invalid SQL queries - Request body: PostRequest { @@ -41,7 +41,7 @@ "event_id": "unique-uuid" # Unique identifier for tracking correction } -2. GET /sql-correction/{event_id} +2. GET /sql-corrections/{event_id} - Retrieves status and results of SQL correction process - Path parameter: event_id (str) - Response: GetResponse @@ -71,8 +71,8 @@ class PostRequest(BaseModel): # todo: check the contexts - contexts: List[Document] - invalid_generation_results: List[Dict[str, str]] + contexts: list[Document] + invalid_generation_results: list[dict[str, str]] project_id: Optional[str] = None @@ -80,7 +80,7 @@ class PostResponse(BaseModel): event_id: str -@router.post("/sql-correction") +@router.post("/sql-corrections") async def correct( request: PostRequest, background_tasks: BackgroundTasks, @@ -89,10 +89,12 @@ async def correct( ) -> PostResponse: event_id = str(uuid.uuid4()) service = service_container.sql_correction_service - service[event_id] = SqlCorrectionService.Event(id=event_id, status="correcting") + service[event_id] = SqlCorrectionService.Event( + event_id=event_id, status="correcting" + ) correction_request = SqlCorrectionService.CorrectionRequest( - id=event_id, **request.model_dump() + event_id=event_id, **request.model_dump() ) background_tasks.add_task( @@ -106,21 +108,15 @@ async def correct( class GetResponse(BaseModel): event_id: str status: Literal["correcting", "finished", "failed"] - response: Optional[Dict] = None + response: Optional[dict] = None error: Optional[dict] = None trace_id: Optional[str] = None -@router.get("/sql-correction/{event_id}") +@router.get("/sql-corrections/{event_id}") async def get( event_id: str, container: ServiceContainer = Depends(get_service_container), ) -> GetResponse: event: SqlCorrectionService.Event = container.sql_correction_service[event_id] - return GetResponse( - event_id=event.id, - status=event.status, - response=event.response, - error=event.error and event.error.model_dump(), - trace_id=event.trace_id, - ) + return GetResponse(**event.model_dump()) From 665be7d48a378fd566725b26eb4a3854a1bbd4cc Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 18:40:45 +0800 Subject: [PATCH 5/9] chore: refactor the code for service --- .../src/web/v1/services/sql_corrections.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py index 403f74c0c9..1abd5dd0d9 100644 --- a/wren-ai-service/src/web/v1/services/sql_corrections.py +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -14,15 +14,15 @@ class SqlCorrectionService: - class Event(BaseModel, MetadataTraceable): - class Error(BaseModel): - code: Literal["OTHERS"] - message: str + class Error(BaseModel): + code: Literal["OTHERS"] + message: str - id: str + class Event(BaseModel, MetadataTraceable): + event_id: str status: Literal["correcting", "finished", "failed"] = "correcting" response: Optional[Dict] = None - error: Optional[Error] = None + error: Optional["SqlCorrectionService.Error"] = None trace_id: Optional[str] = None def __init__( @@ -36,21 +36,21 @@ def __init__( def _handle_exception( self, - id: str, + event_id: str, error_message: str, code: str = "OTHERS", trace_id: Optional[str] = None, ): - self._cache[id] = self.Event( - id=id, + self._cache[event_id] = self.Event( + event_id=event_id, status="failed", - error=self.Event.Error(code=code, message=error_message), + error=self.Error(code=code, message=error_message), trace_id=trace_id, ) logger.error(error_message) class CorrectionRequest(BaseModel): - id: str + event_id: str contexts: List[Document] invalid_generation_results: List[Dict[str, str]] project_id: Optional[str] = None @@ -62,7 +62,7 @@ async def correct( request: CorrectionRequest, **kwargs, ): - logger.info(f"Request {request.id}: SQL Correction process is running...") + logger.info(f"Request {request.event_id}: SQL Correction process is running...") trace_id = kwargs.get("trace_id") try: @@ -74,8 +74,8 @@ async def correct( project_id=request.project_id, ) - self._cache[request.id] = self.Event( - id=request.id, + self._cache[request.event_id] = self.Event( + event_id=request.event_id, status="finished", trace_id=trace_id, response=result, @@ -83,26 +83,26 @@ async def correct( except Exception as e: self._handle_exception( - request.id, + request.event_id, f"An error occurred during SQL correction: {str(e)}", trace_id=trace_id, ) - return self._cache[request.id].with_metadata() + return self._cache[request.event_id].with_metadata() - def __getitem__(self, id: str) -> Event: - response = self._cache.get(id) + def __getitem__(self, event_id: str) -> Event: + response = self._cache.get(event_id) if response is None: message = f"SQL Correction Event with ID '{id}' not found." logger.exception(message) return self.Event( - id=id, + event_id=event_id, status="failed", - error=self.Event.Error(code="OTHERS", message=message), + error=self.Error(code="OTHERS", message=message), ) return response - def __setitem__(self, id: str, value: Event): - self._cache[id] = value + def __setitem__(self, event_id: str, value: Event): + self._cache[event_id] = value From 79a8078a032525e1d01377ac64e6bde673487811 Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 18:41:10 +0800 Subject: [PATCH 6/9] fix: invoking the error type --- wren-ai-service/src/web/v1/services/instructions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/instructions.py b/wren-ai-service/src/web/v1/services/instructions.py index 27e9e1eb45..393175f03f 100644 --- a/wren-ai-service/src/web/v1/services/instructions.py +++ b/wren-ai-service/src/web/v1/services/instructions.py @@ -51,7 +51,7 @@ def _handle_exception( self._cache[id] = self.Event( event_id=id, status="failed", - error=self.Event.Error(code=code, message=error_message), + error=self.Error(code=code, message=error_message), trace_id=trace_id, ) logger.error(error_message) @@ -151,7 +151,7 @@ def __getitem__(self, event_id: str) -> Event: return self.Event( event_id=event_id, status="failed", - error=self.Event.Error(code="OTHERS", message=message), + error=self.Error(code="OTHERS", message=message), ) return response From ea2f419bb9c4d54d8780d958a88807d7751a5458 Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 18:53:57 +0800 Subject: [PATCH 7/9] feat: remove document class to avoid PydanticSchemaGenerationError --- .../src/web/v1/routers/sql_corrections.py | 11 ++++------- .../src/web/v1/services/sql_corrections.py | 13 ++++++------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/wren-ai-service/src/web/v1/routers/sql_corrections.py b/wren-ai-service/src/web/v1/routers/sql_corrections.py index b4d511f203..69af0902cb 100644 --- a/wren-ai-service/src/web/v1/routers/sql_corrections.py +++ b/wren-ai-service/src/web/v1/routers/sql_corrections.py @@ -3,7 +3,6 @@ from typing import Literal, Optional from fastapi import APIRouter, BackgroundTasks, Depends -from haystack import Document from pydantic import BaseModel from src.globals import ( @@ -71,7 +70,7 @@ class PostRequest(BaseModel): # todo: check the contexts - contexts: list[Document] + contexts: list[dict] invalid_generation_results: list[dict[str, str]] project_id: Optional[str] = None @@ -89,17 +88,15 @@ async def correct( ) -> PostResponse: event_id = str(uuid.uuid4()) service = service_container.sql_correction_service - service[event_id] = SqlCorrectionService.Event( - event_id=event_id, status="correcting" - ) + service[event_id] = SqlCorrectionService.Event(event_id=event_id) - correction_request = SqlCorrectionService.CorrectionRequest( + _request = SqlCorrectionService.CorrectionRequest( event_id=event_id, **request.model_dump() ) background_tasks.add_task( service.correct, - correction_request, + _request, service_metadata=asdict(service_metadata), ) return PostResponse(event_id=event_id) diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py index 1abd5dd0d9..e26724e51c 100644 --- a/wren-ai-service/src/web/v1/services/sql_corrections.py +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -1,8 +1,7 @@ import logging -from typing import Dict, List, Literal, Optional +from typing import Literal, Optional from cachetools import TTLCache -from haystack import Document from langfuse.decorators import observe from pydantic import BaseModel @@ -21,18 +20,18 @@ class Error(BaseModel): class Event(BaseModel, MetadataTraceable): event_id: str status: Literal["correcting", "finished", "failed"] = "correcting" - response: Optional[Dict] = None + response: Optional[dict] = None error: Optional["SqlCorrectionService.Error"] = None trace_id: Optional[str] = None def __init__( self, - pipelines: Dict[str, BasicPipeline], + pipelines: dict[str, BasicPipeline], maxsize: int = 1_000_000, ttl: int = 120, ): self._pipelines = pipelines - self._cache: Dict[str, self.Event] = TTLCache(maxsize=maxsize, ttl=ttl) + self._cache: dict[str, self.Event] = TTLCache(maxsize=maxsize, ttl=ttl) def _handle_exception( self, @@ -51,8 +50,8 @@ def _handle_exception( class CorrectionRequest(BaseModel): event_id: str - contexts: List[Document] - invalid_generation_results: List[Dict[str, str]] + contexts: list[dict] + invalid_generation_results: list[dict[str, str]] project_id: Optional[str] = None @observe(name="SQL Correction") From 54ef445a309c1ba865f3d08ba1e803ee317d1759 Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 19:07:33 +0800 Subject: [PATCH 8/9] feat: modify the interface spec for sql correction router and svc --- .../src/web/v1/routers/sql_corrections.py | 14 ++++---------- .../src/web/v1/services/sql_corrections.py | 15 +++++++++------ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/wren-ai-service/src/web/v1/routers/sql_corrections.py b/wren-ai-service/src/web/v1/routers/sql_corrections.py index 69af0902cb..582010bcca 100644 --- a/wren-ai-service/src/web/v1/routers/sql_corrections.py +++ b/wren-ai-service/src/web/v1/routers/sql_corrections.py @@ -26,13 +26,8 @@ - Initiates SQL correction process for invalid SQL queries - Request body: PostRequest { - "contexts": [Document], # List of context documents - "invalid_generation_results": [ # List of invalid SQL generation results - { - "sql": "SELECT * FROM table", # Invalid SQL statement - "error": "Error message" # Error message - } - ], + "sql": "SELECT * FROM table", # Invalid SQL statement + "error": "Error message" # Error message "project_id": "project-id" # Optional project ID } - Response: PostResponse @@ -69,9 +64,8 @@ class PostRequest(BaseModel): - # todo: check the contexts - contexts: list[dict] - invalid_generation_results: list[dict[str, str]] + sql: str + error: str project_id: Optional[str] = None diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py index e26724e51c..69f1895b00 100644 --- a/wren-ai-service/src/web/v1/services/sql_corrections.py +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -50,8 +50,8 @@ def _handle_exception( class CorrectionRequest(BaseModel): event_id: str - contexts: list[dict] - invalid_generation_results: list[dict[str, str]] + sql: str + error: str project_id: Optional[str] = None @observe(name="SQL Correction") @@ -65,11 +65,14 @@ async def correct( trace_id = kwargs.get("trace_id") try: - # todo: modify the contexts - # todo: check the result format + _invalid = { + "sql": request.sql, + "error": request.error, + } + result = await self._pipelines["sql_correction"].run( - contexts=request.contexts, - invalid_generation_results=request.invalid_generation_results, + contexts=[], + invalid_generation_results=[_invalid], project_id=request.project_id, ) From caa3f87a3e25a7b68db735d6eadb988ef4f132ca Mon Sep 17 00:00:00 2001 From: Pao Sheng Date: Mon, 17 Mar 2025 19:20:42 +0800 Subject: [PATCH 9/9] feat: simplify the output spec --- .../src/web/v1/routers/sql_corrections.py | 4 ++-- .../src/web/v1/services/sql_corrections.py | 20 +++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/wren-ai-service/src/web/v1/routers/sql_corrections.py b/wren-ai-service/src/web/v1/routers/sql_corrections.py index 582010bcca..b097047000 100644 --- a/wren-ai-service/src/web/v1/routers/sql_corrections.py +++ b/wren-ai-service/src/web/v1/routers/sql_corrections.py @@ -42,7 +42,7 @@ { "event_id": "unique-uuid", # Unique identifier "status": "correcting" | "finished" | "failed", - "response": {}, # Correction results (when status is "finished") + "response": "corrected-sql", # Correction results (when status is "finished") "error": { # Present only if status is "failed" "code": "OTHERS", "message": "Error description" @@ -99,7 +99,7 @@ async def correct( class GetResponse(BaseModel): event_id: str status: Literal["correcting", "finished", "failed"] - response: Optional[dict] = None + response: Optional[str] = None error: Optional[dict] = None trace_id: Optional[str] = None diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py index 69f1895b00..a0bcf4576e 100644 --- a/wren-ai-service/src/web/v1/services/sql_corrections.py +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -20,7 +20,7 @@ class Error(BaseModel): class Event(BaseModel, MetadataTraceable): event_id: str status: Literal["correcting", "finished", "failed"] = "correcting" - response: Optional[dict] = None + response: Optional[str] = None error: Optional["SqlCorrectionService.Error"] = None trace_id: Optional[str] = None @@ -70,17 +70,29 @@ async def correct( "error": request.error, } - result = await self._pipelines["sql_correction"].run( + res = await self._pipelines["sql_correction"].run( contexts=[], invalid_generation_results=[_invalid], project_id=request.project_id, ) + post_process = res["post_process"] + valid = post_process["valid_generation_results"] + invalid = post_process["invalid_generation_results"] + + if not valid: + error = invalid[0]["error"] + raise Exception( + f"Unable to correct the SQL query. Error: {error}. Please try with a different SQL query or simplify your request." + ) + + corrected = valid[0]["sql"] + self._cache[request.event_id] = self.Event( event_id=request.event_id, status="finished", trace_id=trace_id, - response=result, + response=corrected, ) except Exception as e: @@ -96,7 +108,7 @@ def __getitem__(self, event_id: str) -> Event: response = self._cache.get(event_id) if response is None: - message = f"SQL Correction Event with ID '{id}' not found." + message = f"SQL Correction Event with ID '{event_id}' not found." logger.exception(message) return self.Event( event_id=event_id,