diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 7ffd5ba19f..0aafeda92a 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -27,6 +27,7 @@ class ServiceContainer: sql_pairs_service: services.SqlPairsService sql_question_service: services.SqlQuestionService instructions_service: services.InstructionsService + sql_correction_service: services.SqlCorrectionService @dataclass @@ -257,6 +258,15 @@ def create_service_container( }, **query_cache, ), + sql_correction_service=services.SqlCorrectionService( + pipelines={ + "sql_correction": generation.SQLCorrection( + **pipe_components["sql_correction"], + engine_timeout=settings.engine_timeout, + ), + }, + **query_cache, + ), ) diff --git a/wren-ai-service/src/web/v1/routers/__init__.py b/wren-ai-service/src/web/v1/routers/__init__.py index d67687a3a1..2eb8384691 100644 --- a/wren-ai-service/src/web/v1/routers/__init__.py +++ b/wren-ai-service/src/web/v1/routers/__init__.py @@ -11,6 +11,7 @@ semantics_description, semantics_preparation, sql_answers, + sql_corrections, sql_expansions, sql_pairs, sql_question, @@ -30,3 +31,4 @@ router.include_router(sql_pairs.router) router.include_router(sql_question.router) router.include_router(instructions.router) +router.include_router(sql_corrections.router) diff --git a/wren-ai-service/src/web/v1/routers/sql_corrections.py b/wren-ai-service/src/web/v1/routers/sql_corrections.py new file mode 100644 index 0000000000..b097047000 --- /dev/null +++ b/wren-ai-service/src/web/v1/routers/sql_corrections.py @@ -0,0 +1,113 @@ +import uuid +from dataclasses import asdict +from typing import Literal, Optional + +from fastapi import APIRouter, BackgroundTasks, Depends +from pydantic import BaseModel + +from src.globals import ( + ServiceContainer, + ServiceMetadata, + get_service_container, + get_service_metadata, +) +from src.web.v1.services import SqlCorrectionService + +router = APIRouter() + + +""" +SQL Correction Router + +This router handles endpoints related to correcting invalid SQL queries. + +Endpoints: +1. POST /sql-corrections + - Initiates SQL correction process for invalid SQL queries + - Request body: PostRequest + { + "sql": "SELECT * FROM table", # Invalid SQL statement + "error": "Error message" # Error message + "project_id": "project-id" # Optional project ID + } + - Response: PostResponse + { + "event_id": "unique-uuid" # Unique identifier for tracking correction + } + +2. GET /sql-corrections/{event_id} + - Retrieves status and results of SQL correction process + - Path parameter: event_id (str) + - Response: GetResponse + { + "event_id": "unique-uuid", # Unique identifier + "status": "correcting" | "finished" | "failed", + "response": "corrected-sql", # Correction results (when status is "finished") + "error": { # Present only if status is "failed" + "code": "OTHERS", + "message": "Error description" + }, + "trace_id": "trace-id" # Optional trace ID for debugging + } + +The SQL correction is an asynchronous process. The POST endpoint initiates the operation +and returns immediately with an event_id. The GET endpoint can then be used to check the +status and retrieve the results. + +Usage: +1. Send a POST request to start the correction process +2. Use the returned event_id to poll the GET endpoint until status is "finished" or "failed" + +Note: The actual processing is performed in the background using FastAPI's BackgroundTasks. +Results are cached with a TTL defined in the service configuration. +""" + + +class PostRequest(BaseModel): + sql: str + error: str + project_id: Optional[str] = None + + +class PostResponse(BaseModel): + event_id: str + + +@router.post("/sql-corrections") +async def correct( + request: PostRequest, + background_tasks: BackgroundTasks, + service_container: ServiceContainer = Depends(get_service_container), + service_metadata: ServiceMetadata = Depends(get_service_metadata), +) -> PostResponse: + event_id = str(uuid.uuid4()) + service = service_container.sql_correction_service + service[event_id] = SqlCorrectionService.Event(event_id=event_id) + + _request = SqlCorrectionService.CorrectionRequest( + event_id=event_id, **request.model_dump() + ) + + background_tasks.add_task( + service.correct, + _request, + service_metadata=asdict(service_metadata), + ) + return PostResponse(event_id=event_id) + + +class GetResponse(BaseModel): + event_id: str + status: Literal["correcting", "finished", "failed"] + response: Optional[str] = None + error: Optional[dict] = None + trace_id: Optional[str] = None + + +@router.get("/sql-corrections/{event_id}") +async def get( + event_id: str, + container: ServiceContainer = Depends(get_service_container), +) -> GetResponse: + event: SqlCorrectionService.Event = container.sql_correction_service[event_id] + return GetResponse(**event.model_dump()) diff --git a/wren-ai-service/src/web/v1/services/__init__.py b/wren-ai-service/src/web/v1/services/__init__.py index 55c1ed9ec9..81b909e82b 100644 --- a/wren-ai-service/src/web/v1/services/__init__.py +++ b/wren-ai-service/src/web/v1/services/__init__.py @@ -69,6 +69,7 @@ def serialize(self): from .semantics_description import SemanticsDescription # noqa: E402 from .semantics_preparation import SemanticsPreparationService # noqa: E402 from .sql_answer import SqlAnswerService # noqa: E402 +from .sql_corrections import SqlCorrectionService # noqa: E402 from .sql_expansion import SqlExpansionService # noqa: E402 from .sql_pairs import SqlPairsService # noqa: E402 from .sql_question import SqlQuestionService # noqa: E402 @@ -83,6 +84,7 @@ def serialize(self): "SemanticsDescription", "SemanticsPreparationService", "SqlAnswerService", + "SqlCorrectionService", "SqlExpansionService", "SqlPairsService", "SqlQuestionService", diff --git a/wren-ai-service/src/web/v1/services/instructions.py b/wren-ai-service/src/web/v1/services/instructions.py index 27e9e1eb45..393175f03f 100644 --- a/wren-ai-service/src/web/v1/services/instructions.py +++ b/wren-ai-service/src/web/v1/services/instructions.py @@ -51,7 +51,7 @@ def _handle_exception( self._cache[id] = self.Event( event_id=id, status="failed", - error=self.Event.Error(code=code, message=error_message), + error=self.Error(code=code, message=error_message), trace_id=trace_id, ) logger.error(error_message) @@ -151,7 +151,7 @@ def __getitem__(self, event_id: str) -> Event: return self.Event( event_id=event_id, status="failed", - error=self.Event.Error(code="OTHERS", message=message), + error=self.Error(code="OTHERS", message=message), ) return response diff --git a/wren-ai-service/src/web/v1/services/sql_corrections.py b/wren-ai-service/src/web/v1/services/sql_corrections.py new file mode 100644 index 0000000000..a0bcf4576e --- /dev/null +++ b/wren-ai-service/src/web/v1/services/sql_corrections.py @@ -0,0 +1,122 @@ +import logging +from typing import Literal, Optional + +from cachetools import TTLCache +from langfuse.decorators import observe +from pydantic import BaseModel + +from src.core.pipeline import BasicPipeline +from src.utils import trace_metadata +from src.web.v1.services import MetadataTraceable + +logger = logging.getLogger("wren-ai-service") + + +class SqlCorrectionService: + class Error(BaseModel): + code: Literal["OTHERS"] + message: str + + class Event(BaseModel, MetadataTraceable): + event_id: str + status: Literal["correcting", "finished", "failed"] = "correcting" + response: Optional[str] = None + error: Optional["SqlCorrectionService.Error"] = None + trace_id: Optional[str] = None + + def __init__( + self, + pipelines: dict[str, BasicPipeline], + maxsize: int = 1_000_000, + ttl: int = 120, + ): + self._pipelines = pipelines + self._cache: dict[str, self.Event] = TTLCache(maxsize=maxsize, ttl=ttl) + + def _handle_exception( + self, + event_id: str, + error_message: str, + code: str = "OTHERS", + trace_id: Optional[str] = None, + ): + self._cache[event_id] = self.Event( + event_id=event_id, + status="failed", + error=self.Error(code=code, message=error_message), + trace_id=trace_id, + ) + logger.error(error_message) + + class CorrectionRequest(BaseModel): + event_id: str + sql: str + error: str + project_id: Optional[str] = None + + @observe(name="SQL Correction") + @trace_metadata + async def correct( + self, + request: CorrectionRequest, + **kwargs, + ): + logger.info(f"Request {request.event_id}: SQL Correction process is running...") + trace_id = kwargs.get("trace_id") + + try: + _invalid = { + "sql": request.sql, + "error": request.error, + } + + res = await self._pipelines["sql_correction"].run( + contexts=[], + invalid_generation_results=[_invalid], + project_id=request.project_id, + ) + + post_process = res["post_process"] + valid = post_process["valid_generation_results"] + invalid = post_process["invalid_generation_results"] + + if not valid: + error = invalid[0]["error"] + raise Exception( + f"Unable to correct the SQL query. Error: {error}. Please try with a different SQL query or simplify your request." + ) + + corrected = valid[0]["sql"] + + self._cache[request.event_id] = self.Event( + event_id=request.event_id, + status="finished", + trace_id=trace_id, + response=corrected, + ) + + except Exception as e: + self._handle_exception( + request.event_id, + f"An error occurred during SQL correction: {str(e)}", + trace_id=trace_id, + ) + + return self._cache[request.event_id].with_metadata() + + def __getitem__(self, event_id: str) -> Event: + response = self._cache.get(event_id) + + if response is None: + message = f"SQL Correction Event with ID '{event_id}' not found." + logger.exception(message) + return self.Event( + event_id=event_id, + status="failed", + error=self.Error(code="OTHERS", message=message), + ) + + return response + + def __setitem__(self, event_id: str, value: Event): + self._cache[event_id] = value