Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ServiceContainer:
sql_pairs_service: services.SqlPairsService
sql_question_service: services.SqlQuestionService
instructions_service: services.InstructionsService
sql_correction_service: services.SqlCorrectionService


@dataclass
Expand Down Expand Up @@ -257,6 +258,15 @@ def create_service_container(
},
**query_cache,
),
sql_correction_service=services.SqlCorrectionService(
pipelines={
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
engine_timeout=settings.engine_timeout,
),
},
**query_cache,
),
)


Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/src/web/v1/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
semantics_description,
semantics_preparation,
sql_answers,
sql_corrections,
sql_expansions,
sql_pairs,
sql_question,
Expand All @@ -30,3 +31,4 @@
router.include_router(sql_pairs.router)
router.include_router(sql_question.router)
router.include_router(instructions.router)
router.include_router(sql_corrections.router)
113 changes: 113 additions & 0 deletions wren-ai-service/src/web/v1/routers/sql_corrections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import uuid
from dataclasses import asdict
from typing import Literal, Optional

from fastapi import APIRouter, BackgroundTasks, Depends
from pydantic import BaseModel

from src.globals import (
ServiceContainer,
ServiceMetadata,
get_service_container,
get_service_metadata,
)
from src.web.v1.services import SqlCorrectionService

router = APIRouter()


"""
SQL Correction Router

This router handles endpoints related to correcting invalid SQL queries.

Endpoints:
1. POST /sql-corrections
- Initiates SQL correction process for invalid SQL queries
- Request body: PostRequest
{
"sql": "SELECT * FROM table", # Invalid SQL statement
"error": "Error message" # Error message
"project_id": "project-id" # Optional project ID
}
- Response: PostResponse
{
"event_id": "unique-uuid" # Unique identifier for tracking correction
}

2. GET /sql-corrections/{event_id}
- Retrieves status and results of SQL correction process
- Path parameter: event_id (str)
- Response: GetResponse
{
"event_id": "unique-uuid", # Unique identifier
"status": "correcting" | "finished" | "failed",
"response": "corrected-sql", # Correction results (when status is "finished")
"error": { # Present only if status is "failed"
"code": "OTHERS",
"message": "Error description"
},
"trace_id": "trace-id" # Optional trace ID for debugging
}

The SQL correction is an asynchronous process. The POST endpoint initiates the operation
and returns immediately with an event_id. The GET endpoint can then be used to check the
status and retrieve the results.

Usage:
1. Send a POST request to start the correction process
2. Use the returned event_id to poll the GET endpoint until status is "finished" or "failed"

Note: The actual processing is performed in the background using FastAPI's BackgroundTasks.
Results are cached with a TTL defined in the service configuration.
"""


class PostRequest(BaseModel):
sql: str
error: str
project_id: Optional[str] = None


class PostResponse(BaseModel):
event_id: str


@router.post("/sql-corrections")
async def correct(
request: PostRequest,
background_tasks: BackgroundTasks,
service_container: ServiceContainer = Depends(get_service_container),
service_metadata: ServiceMetadata = Depends(get_service_metadata),
) -> PostResponse:
event_id = str(uuid.uuid4())
service = service_container.sql_correction_service
service[event_id] = SqlCorrectionService.Event(event_id=event_id)

_request = SqlCorrectionService.CorrectionRequest(
event_id=event_id, **request.model_dump()
)

background_tasks.add_task(
service.correct,
_request,
service_metadata=asdict(service_metadata),
)
return PostResponse(event_id=event_id)


class GetResponse(BaseModel):
event_id: str
status: Literal["correcting", "finished", "failed"]
response: Optional[str] = None
error: Optional[dict] = None
trace_id: Optional[str] = None


@router.get("/sql-corrections/{event_id}")
async def get(
event_id: str,
container: ServiceContainer = Depends(get_service_container),
) -> GetResponse:
event: SqlCorrectionService.Event = container.sql_correction_service[event_id]
return GetResponse(**event.model_dump())
2 changes: 2 additions & 0 deletions wren-ai-service/src/web/v1/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def serialize(self):
from .semantics_description import SemanticsDescription # noqa: E402
from .semantics_preparation import SemanticsPreparationService # noqa: E402
from .sql_answer import SqlAnswerService # noqa: E402
from .sql_corrections import SqlCorrectionService # noqa: E402
from .sql_expansion import SqlExpansionService # noqa: E402
from .sql_pairs import SqlPairsService # noqa: E402
from .sql_question import SqlQuestionService # noqa: E402
Expand All @@ -83,6 +84,7 @@ def serialize(self):
"SemanticsDescription",
"SemanticsPreparationService",
"SqlAnswerService",
"SqlCorrectionService",
"SqlExpansionService",
"SqlPairsService",
"SqlQuestionService",
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/web/v1/services/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _handle_exception(
self._cache[id] = self.Event(
event_id=id,
status="failed",
error=self.Event.Error(code=code, message=error_message),
error=self.Error(code=code, message=error_message),
trace_id=trace_id,
)
logger.error(error_message)
Expand Down Expand Up @@ -151,7 +151,7 @@ def __getitem__(self, event_id: str) -> Event:
return self.Event(
event_id=event_id,
status="failed",
error=self.Event.Error(code="OTHERS", message=message),
error=self.Error(code="OTHERS", message=message),
)

return response
Expand Down
122 changes: 122 additions & 0 deletions wren-ai-service/src/web/v1/services/sql_corrections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import logging
from typing import Literal, Optional

from cachetools import TTLCache
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.utils import trace_metadata
from src.web.v1.services import MetadataTraceable

logger = logging.getLogger("wren-ai-service")


class SqlCorrectionService:
class Error(BaseModel):
code: Literal["OTHERS"]
message: str

class Event(BaseModel, MetadataTraceable):
event_id: str
status: Literal["correcting", "finished", "failed"] = "correcting"
response: Optional[str] = None
error: Optional["SqlCorrectionService.Error"] = None
trace_id: Optional[str] = None

def __init__(
self,
pipelines: dict[str, BasicPipeline],
maxsize: int = 1_000_000,
ttl: int = 120,
):
self._pipelines = pipelines
self._cache: dict[str, self.Event] = TTLCache(maxsize=maxsize, ttl=ttl)

def _handle_exception(
self,
event_id: str,
error_message: str,
code: str = "OTHERS",
trace_id: Optional[str] = None,
):
self._cache[event_id] = self.Event(
event_id=event_id,
status="failed",
error=self.Error(code=code, message=error_message),
trace_id=trace_id,
)
logger.error(error_message)

class CorrectionRequest(BaseModel):
event_id: str
sql: str
error: str
project_id: Optional[str] = None

@observe(name="SQL Correction")
@trace_metadata
async def correct(
self,
request: CorrectionRequest,
**kwargs,
):
logger.info(f"Request {request.event_id}: SQL Correction process is running...")
trace_id = kwargs.get("trace_id")

try:
_invalid = {
"sql": request.sql,
"error": request.error,
}

res = await self._pipelines["sql_correction"].run(
contexts=[],
invalid_generation_results=[_invalid],
project_id=request.project_id,
)

post_process = res["post_process"]
valid = post_process["valid_generation_results"]
invalid = post_process["invalid_generation_results"]

if not valid:
error = invalid[0]["error"]
raise Exception(
f"Unable to correct the SQL query. Error: {error}. Please try with a different SQL query or simplify your request."
)

corrected = valid[0]["sql"]

self._cache[request.event_id] = self.Event(
event_id=request.event_id,
status="finished",
trace_id=trace_id,
response=corrected,
)

except Exception as e:
self._handle_exception(
request.event_id,
f"An error occurred during SQL correction: {str(e)}",
trace_id=trace_id,
)

return self._cache[request.event_id].with_metadata()

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def __getitem__(self, event_id: str) -> Event:
response = self._cache.get(event_id)

if response is None:
message = f"SQL Correction Event with ID '{event_id}' not found."
logger.exception(message)
return self.Event(
event_id=event_id,
status="failed",
error=self.Error(code="OTHERS", message=message),
)

return response

def __setitem__(self, event_id: str, value: Event):
self._cache[event_id] = value