From 4041484e59c69c23ed6f3a0c81e09a66fd7f9377 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Wed, 13 Mar 2024 15:02:25 +0800 Subject: [PATCH 1/3] migrate the ai service to the new repo --- .github/workflows/ai-service-ci.yaml | 41 ++ .gitignore | 32 ++ wren-ai-service/.dockerignore | 4 + wren-ai-service/.env.dev.example | 12 + wren-ai-service/.env.prod.example | 12 + wren-ai-service/.pre-commit-config.yaml | 10 + wren-ai-service/Makefile | 45 ++ wren-ai-service/README.md | 30 ++ wren-ai-service/docker/Dockerfile | 13 + wren-ai-service/docker/docker-compose.yml | 24 + wren-ai-service/pyproject.toml | 33 ++ wren-ai-service/ruff.toml | 77 ++++ wren-ai-service/src/__init__.py | 0 wren-ai-service/src/__main__.py | 57 +++ wren-ai-service/src/core/__init__.py | 0 wren-ai-service/src/core/pipeline.py | 20 + wren-ai-service/src/eval/__init__.py | 0 wren-ai-service/src/eval/ask.py | 198 +++++++++ wren-ai-service/src/eval/eval_pipeline.py | 43 ++ wren-ai-service/src/eval/utils.py | 420 ++++++++++++++++++ .../config.properties.example | 10 + .../vulcansql-core-server/docker-compose.yml | 13 + .../eval/vulcansql-core-server/launch-cli.sh | 2 + wren-ai-service/src/globals.py | 81 ++++ wren-ai-service/src/pipelines/__init__.py | 0 wren-ai-service/src/pipelines/ask/__init__.py | 0 .../src/pipelines/ask/components/__init__.py | 0 .../ask/components/document_store.py | 21 + .../src/pipelines/ask/components/embedder.py | 39 ++ .../src/pipelines/ask/components/generator.py | 74 +++ .../src/pipelines/ask/components/prompts.py | 43 ++ .../src/pipelines/ask/components/retriever.py | 41 ++ .../src/pipelines/ask/generation_pipeline.py | 152 +++++++ .../src/pipelines/ask/indexing_pipeline.py | 125 ++++++ .../src/pipelines/ask/retrieval_pipeline.py | 114 +++++ .../src/pipelines/ask_details/__init__.py | 0 .../ask_details/components/__init__.py | 0 .../ask_details/components/generator.py | 78 ++++ .../ask_details/components/prompts.py | 28 ++ .../ask_details/generation_pipeline.py | 150 +++++++ .../src/pipelines/semantics/__init__.py | 0 .../src/pipelines/semantics/description.py | 273 ++++++++++++ wren-ai-service/src/pipelines/trace.py | 96 ++++ wren-ai-service/src/utils.py | 22 + wren-ai-service/src/web/__init__.py | 0 wren-ai-service/src/web/v1/__init__.py | 0 wren-ai-service/src/web/v1/routers.py | 115 +++++ wren-ai-service/src/web/v1/services/ask.py | 180 ++++++++ .../src/web/v1/services/ask_details.py | 93 ++++ .../src/web/v1/services/semantics.py | 55 +++ wren-ai-service/tests/__init__.py | 0 wren-ai-service/tests/data/book_2_mdl.json | 100 +++++ wren-ai-service/tests/services/__init__.py | 0 wren-ai-service/tests/services/test_ask.py | 111 +++++ .../tests/services/test_ask_details.py | 58 +++ .../tests/services/test_semantics.py | 46 ++ wren-ai-service/tests/test_main.py | 237 ++++++++++ 57 files changed, 3428 insertions(+) create mode 100644 .github/workflows/ai-service-ci.yaml create mode 100644 .gitignore create mode 100644 wren-ai-service/.dockerignore create mode 100644 wren-ai-service/.env.dev.example create mode 100644 wren-ai-service/.env.prod.example create mode 100644 wren-ai-service/.pre-commit-config.yaml create mode 100644 wren-ai-service/Makefile create mode 100644 wren-ai-service/README.md create mode 100644 wren-ai-service/docker/Dockerfile create mode 100644 wren-ai-service/docker/docker-compose.yml create mode 100644 wren-ai-service/pyproject.toml create mode 100644 wren-ai-service/ruff.toml create mode 100644 wren-ai-service/src/__init__.py create mode 100644 wren-ai-service/src/__main__.py create mode 100644 wren-ai-service/src/core/__init__.py create mode 100644 wren-ai-service/src/core/pipeline.py create mode 100644 wren-ai-service/src/eval/__init__.py create mode 100644 wren-ai-service/src/eval/ask.py create mode 100644 wren-ai-service/src/eval/eval_pipeline.py create mode 100644 wren-ai-service/src/eval/utils.py create mode 100644 wren-ai-service/src/eval/vulcansql-core-server/config.properties.example create mode 100644 wren-ai-service/src/eval/vulcansql-core-server/docker-compose.yml create mode 100644 wren-ai-service/src/eval/vulcansql-core-server/launch-cli.sh create mode 100644 wren-ai-service/src/globals.py create mode 100644 wren-ai-service/src/pipelines/__init__.py create mode 100644 wren-ai-service/src/pipelines/ask/__init__.py create mode 100644 wren-ai-service/src/pipelines/ask/components/__init__.py create mode 100644 wren-ai-service/src/pipelines/ask/components/document_store.py create mode 100644 wren-ai-service/src/pipelines/ask/components/embedder.py create mode 100644 wren-ai-service/src/pipelines/ask/components/generator.py create mode 100644 wren-ai-service/src/pipelines/ask/components/prompts.py create mode 100644 wren-ai-service/src/pipelines/ask/components/retriever.py create mode 100644 wren-ai-service/src/pipelines/ask/generation_pipeline.py create mode 100644 wren-ai-service/src/pipelines/ask/indexing_pipeline.py create mode 100644 wren-ai-service/src/pipelines/ask/retrieval_pipeline.py create mode 100644 wren-ai-service/src/pipelines/ask_details/__init__.py create mode 100644 wren-ai-service/src/pipelines/ask_details/components/__init__.py create mode 100644 wren-ai-service/src/pipelines/ask_details/components/generator.py create mode 100644 wren-ai-service/src/pipelines/ask_details/components/prompts.py create mode 100644 wren-ai-service/src/pipelines/ask_details/generation_pipeline.py create mode 100644 wren-ai-service/src/pipelines/semantics/__init__.py create mode 100644 wren-ai-service/src/pipelines/semantics/description.py create mode 100644 wren-ai-service/src/pipelines/trace.py create mode 100644 wren-ai-service/src/utils.py create mode 100644 wren-ai-service/src/web/__init__.py create mode 100644 wren-ai-service/src/web/v1/__init__.py create mode 100644 wren-ai-service/src/web/v1/routers.py create mode 100644 wren-ai-service/src/web/v1/services/ask.py create mode 100644 wren-ai-service/src/web/v1/services/ask_details.py create mode 100644 wren-ai-service/src/web/v1/services/semantics.py create mode 100644 wren-ai-service/tests/__init__.py create mode 100644 wren-ai-service/tests/data/book_2_mdl.json create mode 100644 wren-ai-service/tests/services/__init__.py create mode 100644 wren-ai-service/tests/services/test_ask.py create mode 100644 wren-ai-service/tests/services/test_ask_details.py create mode 100644 wren-ai-service/tests/services/test_semantics.py create mode 100644 wren-ai-service/tests/test_main.py diff --git a/.github/workflows/ai-service-ci.yaml b/.github/workflows/ai-service-ci.yaml new file mode 100644 index 0000000000..a9bfbe04f4 --- /dev/null +++ b/.github/workflows/ai-service-ci.yaml @@ -0,0 +1,41 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: AI Service CI + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + ci: + strategy: + fail-fast: false + matrix: + python-version: [ "3.12" ] + poetry-version: [ "1.7.1" ] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2 + with: + poetry-version: ${{ matrix.poetry-version }} + - name: Install the project dependencies + run: poetry install + - name: Run Qdrant + run: docker run -p 6333:6333 -p 6334:6334 -d --name qdrant qdrant/qdrant:v1.7.4 + - name: Test with pytest + run: poetry run pytest -s + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENV: dev diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..f2b0565ece --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# vectorstore +vectorstore +retrieve_db/ +wren-ai-service/qdrant_storage/ + +# app +.env +.env.* +!.env.*.example +!src/eval/vulcansql-core-server/.env +src/eval/vulcansql-core-server/config.properties +outputs/ +spider/ +data/ +!wren-ai-service/tests/data +!src/eval/data +poetry.lock +assertion.log + +# cache +__pycache__ +local_cache +.ruff_cache +.pytest_cache + +# ide +.idea +.vscode/ + +# os +.DS_Store +__MACOSX/ \ No newline at end of file diff --git a/wren-ai-service/.dockerignore b/wren-ai-service/.dockerignore new file mode 100644 index 0000000000..5c727874b9 --- /dev/null +++ b/wren-ai-service/.dockerignore @@ -0,0 +1,4 @@ +* +src/eval +!src +!pyproject.toml \ No newline at end of file diff --git a/wren-ai-service/.env.dev.example b/wren-ai-service/.env.dev.example new file mode 100644 index 0000000000..a55e51df75 --- /dev/null +++ b/wren-ai-service/.env.dev.example @@ -0,0 +1,12 @@ +# fastapi related +UVICORN_HOST=127.0.0.1 +UVICORN_PORT=8080 + +# app related +OPENAI_API_KEY= +LANGFUSE_PUBLIC_KEY= +LANGFUSE_SECRET_KEY= +ENABLE_TRACE= + +# evaluation related +DATASET_NAME=book_2 \ No newline at end of file diff --git a/wren-ai-service/.env.prod.example b/wren-ai-service/.env.prod.example new file mode 100644 index 0000000000..7ebe7365be --- /dev/null +++ b/wren-ai-service/.env.prod.example @@ -0,0 +1,12 @@ +# docker related +# This will determine the prefix of container name. +# If the service name in docker-compose.yaml is copilot, +# then the container name will be vulcansql-copilot-1. +# If COMPOSE_PROJECT_NAME is not set, the default prefix name is docker. +COMPOSE_PROJECT_NAME=vulcansql + +# fastapi related +UVICORN_PORT=8080 + +# app related +OPENAI_API_KEY= \ No newline at end of file diff --git a/wren-ai-service/.pre-commit-config.yaml b/wren-ai-service/.pre-commit-config.yaml new file mode 100644 index 0000000000..9639e9b491 --- /dev/null +++ b/wren-ai-service/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.2.2 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/wren-ai-service/Makefile b/wren-ai-service/Makefile new file mode 100644 index 0000000000..1245a241ef --- /dev/null +++ b/wren-ai-service/Makefile @@ -0,0 +1,45 @@ +args ?= + +start: + poetry run python -m src.__main__ + +build: + docker compose -f docker/docker-compose.yml build + +up: + docker compose -f docker/docker-compose.yml --env-file .env.prod up -d + +down: + docker compose -f docker/docker-compose.yml --env-file .env.prod down + +run-qdrant: + docker run \ + -p 6333:6333 \ + -p 6334:6334 \ + -d \ + --name qdrant \ + qdrant/qdrant:v1.7.4 + +stop-qdrant: + docker stop qdrant && docker rm qdrant + +# evaluation related +eval: + poetry run python -m src.eval.ask $(args) + +run-vulcansql-core-server: + docker compose -f ./src/eval/vulcansql-core-server/docker-compose.yml --env-file .env.dev --env-file ./src/eval/vulcansql-core-server/.env up -d + +stop-vulcansql-core-server: + docker compose -f ./src/eval/vulcansql-core-server/docker-compose.yml --env-file .env.dev --env-file ./src/eval/vulcansql-core-server/.env down + +psql: + docker exec -it vulcansql-engine-1 bash launch-cli.sh + +run-all: + make run-qdrant && \ + make run-vulcansql-core-server + +stop-all: + make stop-qdrant && \ + make stop-vulcansql-core-server \ No newline at end of file diff --git a/wren-ai-service/README.md b/wren-ai-service/README.md new file mode 100644 index 0000000000..e5349bafed --- /dev/null +++ b/wren-ai-service/README.md @@ -0,0 +1,30 @@ +## Introduction + +## Environment Setup + +- Python 3.12 or later +- follow the instructions at https://pipx.pypa.io/stable/ install `pipx` +- execute `pipx install poetry` to install `poetry` +- execute `poetry install` to install the dependencies +- copy `.env.example` file to `.env`, and `.env.dev.example` file to `.env.dev` and fill in the environment variables +- [for development] execute `poetry run pre-commit install` to install the pre-commit hooks and `poetry run pre-commit run --all-files` to run the pre-commit checks at the first time to check if everything is set up correctly + +## Start the service for development + +- execute `make start` to start the service + - go to `http://UVICORN_HOST:UVICORN_PORT/docs` to see the API documentation and try the API + +## Production Environment Setup + +- copy `.env.prod.example` file to `.env.prod` and fill in the environment variables +- `make build` to build the docker image +- `make up` to run the docker container +- `make down` to stop the docker container + +## Pipeline Evaluation(for development) + +- fill in environment variables: `.env.dev` in the src folder and `config.properties` in the src/eval/vulcansql-core-server folder +- start docker +- run qdrant and vulcansql-core-server docker containers: `make run-all` +- evaluation: `make eval` and check out the outputs folder +- to run individual pipeline: `poetry run python -m src.pipelines.ask.[pipeline_name]` (e.g. `poetry run python -m src.pipelines.ask.retrieval_pipeline`) \ No newline at end of file diff --git a/wren-ai-service/docker/Dockerfile b/wren-ai-service/docker/Dockerfile new file mode 100644 index 0000000000..43e15b7111 --- /dev/null +++ b/wren-ai-service/docker/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.12.2-bookworm + +RUN pip install poetry==1.7.1 + +WORKDIR /app + +COPY pyproject.toml ./ + +RUN poetry install --without dev + +COPY src/ src/ + +ENTRYPOINT [ "poetry", "run", "uvicorn", "src.__main__:app", "--host", "0.0.0.0", "--port", "80" ] \ No newline at end of file diff --git a/wren-ai-service/docker/docker-compose.yml b/wren-ai-service/docker/docker-compose.yml new file mode 100644 index 0000000000..ac71ce236d --- /dev/null +++ b/wren-ai-service/docker/docker-compose.yml @@ -0,0 +1,24 @@ +version: '3.8' + +services: + copilot: + image: vulcansql-copilot:latest + build: + context: .. + dockerfile: docker/Dockerfile + environment: + UVICORN_PORT: ${UVICORN_PORT} + OPENAI_API_KEY: ${OPENAI_API_KEY} + # sometimes the console won't show print messages, + # using PYTHONUNBUFFERED: 1 can fix this + PYTHONUNBUFFERED: 1 + ports: + - ${UVICORN_PORT}:80 + depends_on: + - qdrant + + qdrant: + image: qdrant/qdrant:v1.7.4 + ports: + - 6333:6333 + - 6334:6334 diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml new file mode 100644 index 0000000000..e0bb18230e --- /dev/null +++ b/wren-ai-service/pyproject.toml @@ -0,0 +1,33 @@ +[tool.poetry] +name = "wren-ai-service" +version = "0.1.0" +description = "" +authors = ["Jimmy Yeh ", "Pao Sheng Wang "] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.12" +fastapi = "^0.109.2" +uvicorn = "^0.27.1" +python-dotenv = "^1.0.1" +setuptools = "^69.1.0" +haystack-ai = "^2.0.0" +openai = "^1.12.0" +langfuse = "^2.19.1" +ragas = "^0.1.1" +qdrant-haystack = "^3.0.0" +backoff = "^2.2.1" +ragas-haystack = "^0.1.1" +tqdm = "^4.66.2" + +[tool.poetry.group.dev.dependencies] +pytest = "^8.0.0" +pre-commit = "^3.6.1" +pytest-cov = "^4.1.0" +gdown = "^5.1.0" +sqlglot = "^22.1.0" +httpx = "^0.27.0" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/wren-ai-service/ruff.toml b/wren-ai-service/ruff.toml new file mode 100644 index 0000000000..b365289b98 --- /dev/null +++ b/wren-ai-service/ruff.toml @@ -0,0 +1,77 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py38" + +[lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F", "I001"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = false + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" diff --git a/wren-ai-service/src/__init__.py b/wren-ai-service/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py new file mode 100644 index 0000000000..34cc7bff61 --- /dev/null +++ b/wren-ai-service/src/__main__.py @@ -0,0 +1,57 @@ +import os +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse + +import src.globals as container +from src.utils import load_env_vars +from src.web.v1 import routers + +env = load_env_vars() + +server_host = os.getenv("UVICORN_HOST") or "127.0.0.1" +server_port = ( + int(os.getenv("UVICORN_PORT")) if os.getenv("UVICORN_PORT") is not None else 8000 +) + + +# https://fastapi.tiangolo.com/advanced/events/#lifespan +@asynccontextmanager +async def lifespan(app: FastAPI): + # startup events + container.init_globals() + + yield + + # shutdown events + + +app = FastAPI(lifespan=lifespan, redoc_url=None) + +app.include_router(routers.router, prefix="/v1") +app.add_middleware( + CORSMiddleware, + allow_origins=[ + f"http://{server_host}:{server_port}", + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/") +def root(): + return RedirectResponse(url="/docs") + + +if __name__ == "__main__": + uvicorn.run( + "src.__main__:app", + host=server_host, + port=server_port, + reload=(env == "dev"), + ) diff --git a/wren-ai-service/src/core/__init__.py b/wren-ai-service/src/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/core/pipeline.py b/wren-ai-service/src/core/pipeline.py new file mode 100644 index 0000000000..127f53632d --- /dev/null +++ b/wren-ai-service/src/core/pipeline.py @@ -0,0 +1,20 @@ +import os +from abc import ABCMeta, abstractmethod +from pathlib import Path +from typing import Any, Dict + +from haystack import Pipeline + + +class BasicPipeline(metaclass=ABCMeta): + def __init__(self, pipe: Pipeline): + self._pipe = pipe + + @abstractmethod + def run(self, *args, **kwargs) -> Dict[str, Any]: + ... + + def draw(self, path: Path) -> None: + dir_path, _ = os.path.split(path) + os.makedirs(dir_path, exist_ok=True) + self._pipe.draw(path) diff --git a/wren-ai-service/src/eval/__init__.py b/wren-ai-service/src/eval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/eval/ask.py b/wren-ai-service/src/eval/ask.py new file mode 100644 index 0000000000..5d7583a4a1 --- /dev/null +++ b/wren-ai-service/src/eval/ask.py @@ -0,0 +1,198 @@ +import argparse +import json +import os +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional + +from tqdm import tqdm + +from src.pipelines.ask.components.document_store import init_document_store +from src.pipelines.ask.components.embedder import init_embedder +from src.pipelines.ask.components.generator import init_generator +from src.pipelines.ask.components.prompts import user_prompt_builder +from src.pipelines.ask.components.retriever import init_retriever +from src.pipelines.ask.generation_pipeline import Generation +from src.pipelines.ask.indexing_pipeline import Indexing +from src.pipelines.ask.retrieval_pipeline import Retrieval +from src.utils import clean_generation_result, load_env_vars + +# from .eval_pipeline import Evaluation +from .utils import ( + download_spider_data, + generate_eval_report, + get_latest_prediction_outputs_file, + write_prediction_results, +) + +load_env_vars() + +if with_trace := os.getenv("ENABLE_TRACE", default=False): + from src.pipelines.trace import ( + langfuse, + ) + + +def process_item(query: str, user_id: Optional[str] = None) -> Dict[str, Any]: + retrieval_start = time.perf_counter() + retrieval_result = retrieval_pipeline.run( + query, + user_id=user_id, + ) + retrieval_end = time.perf_counter() + + generation_start = time.perf_counter() + generation_result = generation_pipeline.run( + query, + contexts=retrieval_result["retriever"]["documents"], + user_id=user_id, + ) + generation_end = time.perf_counter() + + metadata = { + "generation": generation_result["generator"]["meta"][0], + "latency": { + "retrieval": retrieval_end - retrieval_start, + "generation": generation_end - generation_start, + }, + } + + return { + "contexts": retrieval_result["retriever"]["documents"], + "prediction": json.loads( + clean_generation_result(generation_result["generator"]["replies"][0]) + )["sql"], + "metadata": metadata, + } + + +def eval(prediction_results_file: Path, dataset_name: str, ground_truths: list[dict]): + download_spider_data() + + with open(prediction_results_file, "r") as f: + predictions = [json.loads(line) for line in f] + + # eval_pipeline = Evaluation() + # eval_pipeline_inputs = prepare_evaluation_pipeline_inputs( + # eval_pipeline.component_names, + # ground_truths, + # predictions, + # ) + # ragas_eval_results = eval_pipeline.run(eval_pipeline_inputs) + + ragas_eval_results = {} + + eval_results = generate_eval_report( + dataset_name, + ground_truths, + predictions, + ragas_eval_results, + ) + + timestamp = prediction_results_file.stem.split("_")[-1] + + with open(f"./outputs/{dataset_name}_eval_results_{timestamp}.json", "w") as f: + json.dump(eval_results, f, indent=2) + + +if __name__ == "__main__": + DATASET_NAME = os.getenv("DATASET_NAME") + + parser = argparse.ArgumentParser( + description=f"Evaluate the ask pipeline using the Spider dataset: {DATASET_NAME}" + ) + parser.add_argument( + "--input_file", + type=str, + default=get_latest_prediction_outputs_file(Path("./outputs"), DATASET_NAME), + help="Path to the prediction results file. If not provided, the latest prediction results file will be used. The file should be located in the outputs folder in the root directory of the project.", + ) + parser.add_argument( + "--eval_after_prediction", + action="store_true", + default=True, + help="Whether to run the evaluation after making predictions. Default is True.", + ) + parser.add_argument( + "--eval_from_scratch", + action="store_true", + default=False, + help="Whether to run the evaluation from scratch. Default is False.", + ) + args = parser.parse_args() + + PREDICTION_RESULTS_FILE = args.input_file + EVAL_AFTER_PREDICTION = args.eval_after_prediction + + with open(f"./src/eval/data/{DATASET_NAME}_data.json", "r") as f: + ground_truths = [json.loads(line) for line in f] + + if ( + PREDICTION_RESULTS_FILE + and Path(PREDICTION_RESULTS_FILE).exists() + and not args.eval_from_scratch + ): + eval(Path(PREDICTION_RESULTS_FILE), DATASET_NAME, ground_truths) + else: + with open(f"./src/eval/data/{DATASET_NAME}_mdl.json", "r") as f: + mdl_str = json.dumps(json.load(f)) + + document_store = init_document_store() + embedder = init_embedder(with_trace=with_trace) + retriever = init_retriever( + document_store=document_store, + with_trace=with_trace, + ) + generator = init_generator(with_trace=with_trace) + + Indexing(document_store=document_store).run(mdl_str) + print( + f"finished indexing documents, document count: {document_store.count_documents()}" + ) + + retrieval_pipeline = Retrieval( + embedder=embedder, + retriever=retriever, + with_trace=with_trace, + ) + + generation_pipeline = Generation( + generator=generator, + with_trace=with_trace, + prompt_builder=user_prompt_builder, + ) + + start = time.time() + max_workers = os.cpu_count() // 2 if with_trace else None + user_id = str(uuid.uuid4()) if with_trace else None + with ThreadPoolExecutor(max_workers=max_workers) as executor: + args_list = [ + (ground_truth["question"], user_id) for ground_truth in ground_truths + ] + outputs = list( + tqdm( + executor.map(lambda p: process_item(*p), args_list), + total=len(args_list), + ) + ) + end = time.time() + print(f"Time taken: {end - start:.2f}s") + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + write_prediction_results( + f"./outputs/{DATASET_NAME}_predictions_{timestamp}.json", + ground_truths, + outputs, + ) + if with_trace: + langfuse.flush() + + if EVAL_AFTER_PREDICTION: + eval( + Path(f"./outputs/{DATASET_NAME}_predictions_{timestamp}.json"), + DATASET_NAME, + ground_truths, + ) diff --git a/wren-ai-service/src/eval/eval_pipeline.py b/wren-ai-service/src/eval/eval_pipeline.py new file mode 100644 index 0000000000..2670fdeb7e --- /dev/null +++ b/wren-ai-service/src/eval/eval_pipeline.py @@ -0,0 +1,43 @@ +from typing import Any, Dict, List, Optional + +from haystack import Pipeline +from haystack_integrations.components.evaluators.ragas import ( + RagasEvaluator, + RagasMetric, +) + +from src.core.pipeline import BasicPipeline + + +class Evaluation(BasicPipeline): + def __init__( + self, + metrics: Optional[List[RagasMetric]] = [ + RagasMetric.ANSWER_CORRECTNESS, + # RagasMetric.FAITHFULNESS, # Not supported at the moment + RagasMetric.ANSWER_SIMILARITY, + RagasMetric.CONTEXT_UTILIZATION, + RagasMetric.CONTEXT_PRECISION, + RagasMetric.CONTEXT_RECALL, + # RagasMetric.ASPECT_CRITIQUE, # Not supported at the moment + RagasMetric.CONTEXT_RELEVANCY, + RagasMetric.ANSWER_RELEVANCY, + ], + metric_params: Optional[Dict[str, Any]] = None, + ): + self._pipeline = Pipeline() + self.component_names = [] + + for metric in metrics: + component_name = f"evaluator_{metric.name}" + self._pipeline.add_component( + component_name, + RagasEvaluator( + metric=metric, + metric_params=metric_params[metric.name] if metric_params else None, + ), + ) + self.component_names.append(component_name) + + def run(self, data) -> Dict[str, Any]: + return self._pipeline.run(data) diff --git a/wren-ai-service/src/eval/utils.py b/wren-ai-service/src/eval/utils.py new file mode 100644 index 0000000000..df622ba3c7 --- /dev/null +++ b/wren-ai-service/src/eval/utils.py @@ -0,0 +1,420 @@ +import json +import os +import sqlite3 +import subprocess +import zipfile +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +import gdown +import pandas as pd +import sqlglot +from tqdm import tqdm +from tqdm.contrib import tzip + + +def semantic_diff(sql_query1: str, sql_query2: str): + try: + diff = sqlglot.diff( + sqlglot.parse_one(sql_query1, read=sqlglot.Dialects.TRINO), + sqlglot.parse_one(sql_query2, read=sqlglot.Dialects.TRINO), + ) + + for d in diff: + if str(d).startswith("Keep"): + continue + + return True + + return False + except Exception as e: + print(f"semantic_diff: {e}") + return True + + +def execute_sql_query(sql_query: str, db_path: str): + conn = sqlite3.connect(db_path) + cur = conn.cursor() + + try: + cur.execute(sql_query) + # make each row a tuple of strings for easier comparison with the results from Vulcan + # also sort each row to make the order of the columns consistent + results = tuple(tuple(sorted(map(str, row))) for row in cur.fetchall()) + except Exception: + results = [] + finally: + cur.close() + conn.close() + + return results + + +def execute_sql_query_through_vulcan(sql_query: str): + command = f'psql -d "postgres://localhost:7432/canner-cml?options=--search_path%3Dspider" -c "{sql_query}"' + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + ) + output, error = process.communicate() + + if error: + return "", error.decode() + + df = get_csv_table_from_response(output.decode()) + # also sort each row to make the order of the columns consistent + sorted_df = df.apply(lambda x: tuple(sorted(x)), axis=1) + return tuple(sorted_df.values.tolist()), "" + + +def get_csv_table_from_response(full_response: str): + lines = [line.strip() for line in full_response.strip().split("\n")[2:-1]] + table_content = [[element.strip() for element in line.split("|")] for line in lines] + + sql_query_result_in_csv = pd.DataFrame(table_content) + return sql_query_result_in_csv + + +def ground_truth_query_results_issubset( + ground_truth_query_results, prediction_query_results +): + rows1 = sorted(frozenset(row) for row in ground_truth_query_results) + rows2 = sorted(frozenset(row) for row in prediction_query_results) + for row1, row2 in zip(rows1, rows2): + if not row1.issubset(row2): + return False + + return True + + +def get_ragas_eval_results( + ragas_eval_results: Dict[str, Any], + index: int, +): + return { + metric: eval_result["results"][index][0]["score"] + for metric, eval_result in ragas_eval_results.items() + } + + +def get_generation_model_pricing( + model_name: str, +): + # https://openai.com/pricing + generation_model_pricing = { + "gpt-3.5-turbo": { + "prompt_tokens": 0.5 / 10**6, + "completion_tokens": 1.5 / 10**6, + }, + "gpt-3.5-turbo-0125": { + "prompt_tokens": 0.5 / 10**6, + "completion_tokens": 1.5 / 10**6, + }, + "gpt-4-turbo": { + "prompt_tokens": 10 / 10**6, + "completion_tokens": 30 / 10**6, + }, + "gpt-4-0125-preview": { + "prompt_tokens": 10 / 10**6, + "completion_tokens": 30 / 10**6, + }, + } + + return generation_model_pricing[model_name] + + +def generate_eval_report( + database_name: str, + groundtruths: List[Dict[str, str]], + predictions: List[Dict[str, Any]], + ragas_eval_results: Dict[str, Any], +): + results = { + "eval_results": { + "average_accuracy": 0, + "average_cost": { + "total": 0, + "input": 0, + "output": 0, + }, + "average_latency": { + "total": 0, + "retrieval": 0, + "generation": 0, + }, + "details": { + "correct": { + "sql_semantic_same": [], + "query_results_same": [], + "ground_truth_query_results_issubset": [], + }, + "wrong": [], + }, + } + } + + total = 0 + correct = 0 + retrieval_total_latency = 0 + generation_total_latency = 0 + input_total_cost = 0 + output_total_cost = 0 + for i, (ground_truth, prediction) in enumerate(tzip(groundtruths, predictions)): + ## dealing with cost part + model_name = prediction["metadata"]["generation"]["model"] + generation_model_pricing = get_generation_model_pricing(model_name) + input_total_cost += ( + generation_model_pricing["prompt_tokens"] + * prediction["metadata"]["generation"]["usage"]["prompt_tokens"] + ) + output_total_cost += ( + generation_model_pricing["completion_tokens"] + * prediction["metadata"]["generation"]["usage"]["completion_tokens"] + ) + + ## dealing with latency part + retrieval_total_latency += prediction["metadata"]["latency"]["retrieval"] + generation_total_latency += prediction["metadata"]["latency"]["generation"] + + ## dealing with accuracy part + assert ground_truth["question"] == prediction["question"] + question = ground_truth["question"] + total += 1 + + # directly compare the sql query using semantic diff + if not semantic_diff(ground_truth["answer"], prediction["answer"]): + correct += 1 + results["eval_results"]["details"]["correct"]["sql_semantic_same"].append( + { + "question": ground_truth["question"], + "ground_truth_answer": ground_truth["answer"], + "prediction_answer": prediction["answer"], + "regas_eval_results": get_ragas_eval_results( + ragas_eval_results, + i, + ), + } + ) + continue + + # since the order of the sql query may be different, we compare the results as sets + ground_truth_query_results = execute_sql_query( + ground_truth["answer"], + f"./spider/database/{database_name}/{database_name}.sqlite", + ) + + ( + prediction_query_results, + prediction_error_details, + ) = execute_sql_query_through_vulcan( + prediction["answer"], + ) + if len(ground_truth_query_results) == len(prediction_query_results): + if set(ground_truth_query_results) == set(prediction_query_results): + correct += 1 + results["eval_results"]["details"]["correct"][ + "query_results_same" + ].append( + { + "question": question, + "ground_truth_answer": ground_truth["answer"], + "prediction_answer": prediction["answer"], + "regas_eval_results": get_ragas_eval_results( + ragas_eval_results, + i, + ), + } + ) + continue + elif ground_truth_query_results_issubset( + ground_truth_query_results, prediction_query_results + ): + correct += 1 + results["eval_results"]["details"]["correct"][ + "ground_truth_query_results_issubset" + ].append( + { + "question": question, + "ground_truth_answer": ground_truth["answer"], + "prediction_answer": prediction["answer"], + "ground_truth_query_results": ground_truth_query_results, + "prediction_query_results": prediction_query_results, + "regas_eval_results": get_ragas_eval_results( + ragas_eval_results, + i, + ), + } + ) + continue + + results["eval_results"]["details"]["wrong"].append( + { + "question": question, + "ground_truth_answer": ground_truth["answer"], + "prediction_answer": prediction["answer"], + "ground_truth_query_results": ground_truth_query_results, + "prediction_query_results": prediction_query_results, + "prediction_error_details": prediction_error_details, + "regas_eval_results": get_ragas_eval_results( + ragas_eval_results, + i, + ), + } + ) + + results["eval_results"]["average_accuracy"] = correct / total + results["eval_results"]["average_cost"]["total"] = ( + input_total_cost + output_total_cost + ) / total + results["eval_results"]["average_cost"]["input"] = input_total_cost / total + results["eval_results"]["average_cost"]["output"] = output_total_cost / total + results["eval_results"]["average_latency"]["total"] = ( + retrieval_total_latency + generation_total_latency + ) / total + results["eval_results"]["average_latency"]["retrieval"] = ( + retrieval_total_latency + ) / total + results["eval_results"]["average_latency"]["generation"] = ( + generation_total_latency + ) / total + + return results + + +def download_spider_data(): + if Path("spider").exists(): + return + + if Path("spider.zip").exists(): + os.remove("spider.zip") + + # 1. uploaded to Jimmy's google drive from the official Spider dataset website(data based at 2024/01/28) + # 2. added `table_counts_in_database.json` + # 3. changed column `salary` to `player_salary` in `baseball_1.player` table, + # since in bigquery, column name should not be the same as table name + url = "https://drive.google.com/u/0/uc?id=1StzD_Yha1W-BJOLimuvdzH-cF6sEc_ak&export=download" + + output = "spider.zip" + gdown.download(url, output, quiet=False) + + with zipfile.ZipFile(output, "r") as zip_ref: + zip_ref.extractall(".") + + os.remove("spider.zip") + + +def write_prediction_results( + file_name: str, ground_truths: List[Dict], outputs: List[Dict[str, Any]] +): + with open(file_name, "w") as f: + for ground_truth, output in tzip(ground_truths, outputs): + json.dump( + { + "question": ground_truth["question"], + "contexts": [ + { + "content": [json.loads(context.content)], + "score": context.score, + } + for context in output["contexts"] + ], + "answer": output["prediction"], + "metadata": output["metadata"], + }, + f, + ) + f.write("\n") + + +def prepare_evaluation_pipeline_inputs( + component_names: List[str], + ground_truths_data: List[Dict[str, Any]], + predictions_data: List[Dict[str, Any]], +): + inputs = {} + + questions = [] + ground_truths = [] + contexts = [] + responses = [] + for ground_truth, prediction in tzip(ground_truths_data, predictions_data): + assert ground_truth["question"] == prediction["question"] + + questions.append(ground_truth["question"]) + ground_truths.append(ground_truth["answer"]) + contexts.append( + [json.dumps(context["content"]) for context in prediction["contexts"]] + ) + responses.append(prediction["answer"]) + + for component_name in tqdm(component_names): + ragas_metric = "_".join(component_name.split("_")[1:]) + + # https://docs.haystack.deepset.ai/v2.0/docs/ragasevaluator#supported-metrics + if ragas_metric == "ANSWER_CORRECTNESS": + inputs[component_name] = { + "questions": questions, + "responses": responses, + "ground_truths": ground_truths, + } + elif ragas_metric == "FAITHFULNESS": + inputs[component_name] = { + "questions": questions, + "contexts": contexts, + "responses": responses, + } + elif ragas_metric == "ANSWER_SIMILARITY": + inputs[component_name] = { + "responses": responses, + "ground_truths": ground_truths, + } + elif ragas_metric == "CONTEXT_PRECISION": + inputs[component_name] = { + "questions": questions, + "contexts": contexts, + "ground_truths": ground_truths, + } + elif ragas_metric == "CONTEXT_UTILIZATION": + inputs[component_name] = { + "questions": questions, + "contexts": contexts, + "responses": responses, + } + elif ragas_metric == "CONTEXT_RECALL": + inputs[component_name] = { + "questions": questions, + "contexts": contexts, + "ground_truths": ground_truths, + } + elif ragas_metric == "ASPECT_CRITIQUE": + inputs[component_name] = { + "questions": questions, + "contexts": contexts, + "responses": responses, + } + elif ragas_metric == "CONTEXT_RELEVANCY": + inputs[component_name] = { + "questions": questions, + "contexts": contexts, + } + elif ragas_metric == "ANSWER_RELEVANCY": + inputs[component_name] = { + "questions": questions, + "contexts": contexts, + "responses": responses, + } + + return inputs + + +def get_latest_prediction_outputs_file(path: Path, dataset_name: str) -> str: + def _extract_datetime(file_name: Path) -> datetime: + file_name, _ = os.path.splitext(file_name) + timestamp_str = file_name.split("_")[-1] + return datetime.strptime(timestamp_str, "%Y%m%d%H%M%S") + + files = list(path.glob(f"{dataset_name}_predictions*.json")) + if not files: + return "" + + return str(sorted(files, key=_extract_datetime, reverse=True)[0]) diff --git a/wren-ai-service/src/eval/vulcansql-core-server/config.properties.example b/wren-ai-service/src/eval/vulcansql-core-server/config.properties.example new file mode 100644 index 0000000000..ff95c0ce9b --- /dev/null +++ b/wren-ai-service/src/eval/vulcansql-core-server/config.properties.example @@ -0,0 +1,10 @@ + +node.environment = test +duckdb.storage.access-key = +duckdb.storage.secret-key = +bigquery.project-id = canner-cml +bigquery.credentials-key = +bigquery.location = asia-east1 +bigquery.bucket-name = test-pre-agg-cml +accio.file = etc/mdl.json +accio.datasource.type = bigquery \ No newline at end of file diff --git a/wren-ai-service/src/eval/vulcansql-core-server/docker-compose.yml b/wren-ai-service/src/eval/vulcansql-core-server/docker-compose.yml new file mode 100644 index 0000000000..1cc549dd3d --- /dev/null +++ b/wren-ai-service/src/eval/vulcansql-core-server/docker-compose.yml @@ -0,0 +1,13 @@ +version: '3.8' + +services: + engine: + image: ghcr.io/canner/accio:latest + platform: ${PLATFORM} + ports: + - 8080:8080 + - 7432:7432 + volumes: + - ../data/${DATASET_NAME}_mdl.json:/usr/src/app/etc/mdl.json + - ${CONFIG_PATH}:/usr/src/app/etc/config.properties + - ${LAUNCH_CLI_PATH}:/usr/src/app/launch-cli.sh diff --git a/wren-ai-service/src/eval/vulcansql-core-server/launch-cli.sh b/wren-ai-service/src/eval/vulcansql-core-server/launch-cli.sh new file mode 100644 index 0000000000..441d6450de --- /dev/null +++ b/wren-ai-service/src/eval/vulcansql-core-server/launch-cli.sh @@ -0,0 +1,2 @@ +#!/bin/bash +psql postgres://localhost:7432/canner-cml?options=--search_path%3Dspider \ No newline at end of file diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py new file mode 100644 index 0000000000..52575d6843 --- /dev/null +++ b/wren-ai-service/src/globals.py @@ -0,0 +1,81 @@ +import os + +from dotenv import load_dotenv + +from src.pipelines.ask import ( + generation_pipeline as ask_generation_pipeline, +) +from src.pipelines.ask import ( + indexing_pipeline as ask_indexing_pipeline, +) +from src.pipelines.ask import ( + retrieval_pipeline as ask_retrieval_pipeline, +) +from src.pipelines.ask.components.document_store import init_document_store +from src.pipelines.ask.components.embedder import init_embedder +from src.pipelines.ask.components.generator import init_generator +from src.pipelines.ask.components.prompts import init_generation_prompt_builder +from src.pipelines.ask.components.retriever import init_retriever +from src.pipelines.ask_details import ( + generation_pipeline as ask_details_generation_pipeline, +) +from src.pipelines.ask_details.components.generator import ( + init_generator as init_ask_details_generator, +) +from src.pipelines.semantics import description +from src.web.v1.services.ask import AskService +from src.web.v1.services.ask_details import AskDetailsService +from src.web.v1.services.semantics import SemanticsService + +load_dotenv() + +SEMANTIC_SERVICE = None +ASK_SERVICE = None +ASK_DETAILS_SERVICE = None + + +def init_globals(): + global SEMANTIC_SERVICE, ASK_SERVICE, ASK_DETAILS_SERVICE + + with_trace = os.getenv("ENABLE_TRACE", default=False) + + document_store = init_document_store() + embedder = init_embedder(with_trace=with_trace) + retriever = init_retriever( + document_store=document_store, + with_trace=with_trace, + ) + ask_generator = init_generator(with_trace=with_trace) + ask_details_generator = init_ask_details_generator(with_trace=with_trace) + generation_prompt_builder = init_generation_prompt_builder() + + SEMANTIC_SERVICE = SemanticsService( + pipelines={ + "generate_description": description.Generation(), + } + ) + ASK_SERVICE = AskService( + pipelines={ + "indexing": ask_indexing_pipeline.Indexing( + document_store=document_store, + ), + "retrieval": ask_retrieval_pipeline.Retrieval( + embedder=embedder, + retriever=retriever, + with_trace=with_trace, + ), + "generation": ask_generation_pipeline.Generation( + generator=ask_generator, + prompt_builder=generation_prompt_builder, + with_trace=with_trace, + ), + } + ) + ASK_DETAILS_SERVICE = AskDetailsService( + pipelines={ + "generation": ask_details_generation_pipeline.Generation( + generator=ask_details_generator, + with_trace=with_trace, + ), + } + ) diff --git a/wren-ai-service/src/pipelines/__init__.py b/wren-ai-service/src/pipelines/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/pipelines/ask/__init__.py b/wren-ai-service/src/pipelines/ask/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/pipelines/ask/components/__init__.py b/wren-ai-service/src/pipelines/ask/components/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/pipelines/ask/components/document_store.py b/wren-ai-service/src/pipelines/ask/components/document_store.py new file mode 100644 index 0000000000..6d1c8a3d17 --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/components/document_store.py @@ -0,0 +1,21 @@ +from typing import Optional + +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore + +from src.utils import load_env_vars + +from .embedder import EMBEDDING_MODEL_DIMENSION + +env = load_env_vars() + + +def init_document_store( + env: str = env, + embedding_dim: int = EMBEDDING_MODEL_DIMENSION, + dataset_name: Optional[str] = None, +): + return QdrantDocumentStore( + url="localhost" if env == "dev" else "qdrant", + embedding_dim=embedding_dim, + index=dataset_name or "Document", + ) diff --git a/wren-ai-service/src/pipelines/ask/components/embedder.py b/wren-ai-service/src/pipelines/ask/components/embedder.py new file mode 100644 index 0000000000..b4be555224 --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/components/embedder.py @@ -0,0 +1,39 @@ +from typing import Any, Dict, List + +from haystack import component +from haystack.components.embedders import OpenAITextEmbedder +from haystack.utils.auth import Secret + +from src.utils import load_env_vars + +from ...trace import TraceSpanInput, trace_span + +load_env_vars() + +EMBEDDING_MODEL_NAME = "text-embedding-3-large" +EMBEDDING_MODEL_DIMENSION = 3072 + + +@component +class TracedOpenAITextEmbedder(OpenAITextEmbedder): + def _run(self, *args, **kwargs): + return super(TracedOpenAITextEmbedder, self).run(*args, **kwargs) + + @component.output_types(embedding=List[float], meta=Dict[str, Any]) + def run(self, text: str, trace_span_input: TraceSpanInput): + return trace_span(self._run)(trace_span_input=trace_span_input, text=text) + + +def init_embedder( + with_trace: bool = False, embedding_model_name: str = EMBEDDING_MODEL_NAME +): + if with_trace: + return TracedOpenAITextEmbedder( + api_key=Secret.from_env_var("OPENAI_API_KEY"), + model=embedding_model_name, + ) + + return OpenAITextEmbedder( + api_key=Secret.from_env_var("OPENAI_API_KEY"), + model=embedding_model_name, + ) diff --git a/wren-ai-service/src/pipelines/ask/components/generator.py b/wren-ai-service/src/pipelines/ask/components/generator.py new file mode 100644 index 0000000000..6c6dbfe25a --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/components/generator.py @@ -0,0 +1,74 @@ +import logging +from typing import Any, Dict, List, Optional + +import backoff +import openai +from haystack import component +from haystack.components.generators import OpenAIGenerator +from haystack.utils.auth import Secret + +from src.utils import load_env_vars + +from ...trace import TraceGenerationInput, trace_generation + +load_env_vars() +logging.getLogger("backoff").addHandler(logging.StreamHandler()) + +MODEL_NAME = "gpt-3.5-turbo" +MAX_TOKENS = { + "gpt-3.5-turbo": 4096, +} +GENERATION_KWARGS = { + "temperature": 0, + "n": 3, + "max_tokens": MAX_TOKENS[MODEL_NAME] if MODEL_NAME in MAX_TOKENS else 4096, + "response_format": {"type": "json_object"}, +} + + +@component +class CustomOpenAIGenerator(OpenAIGenerator): + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + @backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + return super(CustomOpenAIGenerator, self).run( + prompt=prompt, generation_kwargs=generation_kwargs + ) + + +@component +class TracedOpenAIGenerator(CustomOpenAIGenerator): + def _run(self, *args, **kwargs): + return super(TracedOpenAIGenerator, self).run(*args, **kwargs) + + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + def run( + self, + trace_generation_input: TraceGenerationInput, + prompt: str, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + return trace_generation(self._run)( + trace_generation_input=trace_generation_input, + prompt=prompt, + generation_kwargs=generation_kwargs, + ) + + +def init_generator( + with_trace: bool = False, + model_name: str = MODEL_NAME, + generation_kwargs: Optional[Dict[str, Any]] = GENERATION_KWARGS, +): + if with_trace: + return TracedOpenAIGenerator( + api_key=Secret.from_env_var("OPENAI_API_KEY"), + model=model_name, + generation_kwargs=generation_kwargs, + ) + + return CustomOpenAIGenerator( + api_key=Secret.from_env_var("OPENAI_API_KEY"), + model=model_name, + generation_kwargs=generation_kwargs, + ) diff --git a/wren-ai-service/src/pipelines/ask/components/prompts.py b/wren-ai-service/src/pipelines/ask/components/prompts.py new file mode 100644 index 0000000000..17a1f1fd6e --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/components/prompts.py @@ -0,0 +1,43 @@ +from haystack.components.builders.prompt_builder import PromptBuilder + +generation_user_prompt_template = """ +You are a Trino SQL expert with exceptional logical thinking skills. +Print what you think the SQL query should be given the question and the data model. +This is vital to my career, I will become homeless if you make a mistake. + +### INSTRUCTIONS ### +- If the question is complex enough, you can also answer complex SQL query that consists of a combination of JOINs, subqueries, and conditional filtering. +- Try not to use '*' to select all columns, please be specific what columns to choose from the table. +- If you can't construct the Trino SQL query, please answer with empty SQL string. +- If you can construct the Trino SQL query, please answer with the SQL query: ```sql ...```. +- Make sure the chosen "GROUP BY" conditions are correct given the selected columns. +- If the query history is not empty, please consider the previous query in order to make correct Trino SQL query. + +### TASK ### +Given an input question, create a syntactically correct Trino SQL query to run and a short sentence to summary the Trino SQL query +and return them as the answer to the input question. + +### DATA MODELS ### +{% for document in documents %} + {{ document.content }} +{% endfor %} + +### QUERY HISTORY ### +{{ history }} + +### QUESTION ### +{{ query }} + +### FINAL ANSWER FORMAT ### +The final answer must be the JSON format + +For a question that you can return the SQL query +{"sql": , "summary": } + +For a question that you can't return the SQL query +{"sql": "", "summary": ""} +""" + + +def init_generation_prompt_builder() -> PromptBuilder: + return PromptBuilder(template=generation_user_prompt_template) diff --git a/wren-ai-service/src/pipelines/ask/components/retriever.py b/wren-ai-service/src/pipelines/ask/components/retriever.py new file mode 100644 index 0000000000..cb08869557 --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/components/retriever.py @@ -0,0 +1,41 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component +from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever + +from ...trace import TraceSpanInput, trace_span + + +@component +class TracedQdrantEmbeddingRetriever(QdrantEmbeddingRetriever): + def _run(self, *args, **kwargs): + return super(TracedQdrantEmbeddingRetriever, self).run(*args, **kwargs) + + @component.output_types(documents=List[Document]) + def run( + self, + trace_span_input: TraceSpanInput, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + return_embedding: Optional[bool] = None, + ): + return trace_span(self._run)( + trace_span_input=trace_span_input, + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + scale_score=scale_score, + return_embedding=return_embedding, + ) + + +def init_retriever(document_store: Any, with_trace: bool = False): + if with_trace: + return TracedQdrantEmbeddingRetriever( + document_store=document_store, + top_k=3, + ) + + return QdrantEmbeddingRetriever(document_store=document_store, top_k=3) diff --git a/wren-ai-service/src/pipelines/ask/generation_pipeline.py b/wren-ai-service/src/pipelines/ask/generation_pipeline.py new file mode 100644 index 0000000000..c7c16b5d39 --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/generation_pipeline.py @@ -0,0 +1,152 @@ +import json +import os +import uuid +from typing import Any, List, Optional + +from haystack import Document, Pipeline + +from src.core.pipeline import BasicPipeline +from src.pipelines.ask.components.document_store import init_document_store +from src.pipelines.ask.components.embedder import init_embedder +from src.pipelines.ask.components.generator import MODEL_NAME, init_generator +from src.pipelines.ask.components.prompts import init_generation_prompt_builder +from src.pipelines.ask.components.retriever import init_retriever +from src.pipelines.ask.retrieval_pipeline import Retrieval +from src.utils import clean_generation_result, load_env_vars +from src.web.v1.services.ask import AskRequest, AskResultResponse + +load_env_vars() + +if with_trace := os.getenv("ENABLE_TRACE", default=False): + from src.pipelines.trace import ( + TraceGenerationInput, + TraceInput, + langfuse, + ) + + +class Generation(BasicPipeline): + def __init__( + self, + generator: Any, + prompt_builder: Any, + with_trace: bool = False, + ): + self._pipeline = Pipeline() + self._pipeline.add_component("prompt_builder", prompt_builder) + self._pipeline.add_component("generator", generator) + + self._pipeline.connect("prompt_builder.prompt", "generator.prompt") + + self.with_trace = with_trace + self.prompt_builder = self._pipeline.get_component("prompt_builder") + + super().__init__(self._pipeline) + + def run( + self, + query: str, + contexts: List[Document], + history: Optional[AskRequest.AskResponseDetails] = None, + user_id: Optional[str] = None, + ): + if self.with_trace: + trace = langfuse.trace( + **TraceInput( + name="generation", + user_id=user_id, + ).__dict__, + public=True, + ) + + result = self._pipeline.run( + { + "prompt_builder": { + "query": query, + "documents": contexts, + "history": history, + }, + "generator": { + "trace_generation_input": TraceGenerationInput( + trace_id=trace.id, + name="generator", + input=self.prompt_builder.run( + query=query, + documents=contexts, + history=history, + )["prompt"], + model=MODEL_NAME, + ) + }, + } + ) + + trace.update(input=query, output=result["generator"]) + return result + else: + return self._pipeline.run( + { + "prompt_builder": { + "query": query, + "documents": contexts, + "history": history, + }, + } + ) + + +# this is for quick testing only, please ignore this +if __name__ == "__main__": + DATASET_NAME = os.getenv("DATASET_NAME") + + document_store = init_document_store() + embedder = init_embedder(with_trace=with_trace) + retriever = init_retriever(document_store=document_store, with_trace=with_trace) + generator = init_generator(with_trace=with_trace) + generation_prompt_builder = init_generation_prompt_builder() + + retrieval_pipeline = Retrieval( + embedder=embedder, + retriever=retriever, + with_trace=with_trace, + ) + + generation_pipeline = Generation( + generator=generator, + with_trace=with_trace, + prompt_builder=generation_prompt_builder, + ) + + if DATASET_NAME == "book_2": + query = "How many books are there?" + elif DATASET_NAME == "baseball_1": + query = "what is the full name and id of the college with the largest number of baseball players?" + else: + query = "random query here..." + + retrieval_result = retrieval_pipeline.run( + query, + user_id=str(uuid.uuid4()) if with_trace else None, + ) + + generation_result = generation_pipeline.run( + query, + contexts=retrieval_result["retriever"]["documents"], + user_id=str(uuid.uuid4()) if with_trace else None, + ) + + assert len(generation_result["generator"]["replies"]) == 3 + + cleaned_generation_result = json.loads( + clean_generation_result(generation_result["generator"]["replies"][0]) + ) + print(f"cleaned_generation_result: {cleaned_generation_result}") + assert AskResultResponse.AskResult(**cleaned_generation_result) + + if with_trace: + generation_pipeline.draw( + "./outputs/pipelines/ask/generation_pipeline_with_trace.jpg" + ) + langfuse.flush() + else: + generation_pipeline.draw("./outputs/pipelines/ask/generation_pipeline.jpg") diff --git a/wren-ai-service/src/pipelines/ask/indexing_pipeline.py b/wren-ai-service/src/pipelines/ask/indexing_pipeline.py new file mode 100644 index 0000000000..88c63ebc3c --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/indexing_pipeline.py @@ -0,0 +1,125 @@ +import json +import os +from typing import Any, Dict, List + +import openai +from haystack import Document, Pipeline +from haystack.components.writers import DocumentWriter +from haystack.document_stores.types import DocumentStore, DuplicatePolicy +from tqdm import tqdm + +from src.core.pipeline import BasicPipeline +from src.pipelines.ask.components.document_store import init_document_store +from src.pipelines.ask.components.embedder import ( + EMBEDDING_MODEL_DIMENSION, + EMBEDDING_MODEL_NAME, +) +from src.utils import load_env_vars + +load_env_vars() + +DATASET_NAME = os.getenv("DATASET_NAME") + + +class Indexing(BasicPipeline): + def __init__( + self, + document_store: DocumentStore, + embedding_model_name: str = EMBEDDING_MODEL_NAME, + embedding_model_dim: int = EMBEDDING_MODEL_DIMENSION, + ) -> None: + self._pipeline = Pipeline() + self._pipeline.add_component( + "writer", + DocumentWriter( + document_store=document_store, + policy=DuplicatePolicy.OVERWRITE, + ), + ) + self._openai_client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) + + self.embedding_model_name = embedding_model_name + self.embedding_model_dim = embedding_model_dim + + super().__init__(self._pipeline) + + def run(self, mdl_str: str) -> Dict[str, Any]: + return self._pipeline.run( + {"writer": {"documents": self._get_documents(mdl_str)}} + ) + + def _get_documents(self, mdl_str: str) -> List[Document]: + mdl_json = json.loads(mdl_str) + + semantics = {"models": [], "relationships": mdl_json["relationships"]} + + for model in mdl_json["models"]: + columns = [] + for column in model["columns"]: + if "relationship" in column: + columns.append( + { + "name": column["name"], + "properties": column["properties"], + "type": column["type"], + "relationship": column["relationship"], + } + ) + else: + columns.append( + { + "name": column["name"], + "properties": column["properties"], + "type": column["type"], + } + ) + + semantics["models"].append( + { + "type": "model", + "name": model["name"], + "properties": model["properties"], + "columns": columns, + "primaryKey": model["primaryKey"], + } + ) + + embeddings = self._openai_client.embeddings.create( + input=[ + json.dumps(data) + for data in semantics["models"] + semantics["relationships"] + ], + model=self.embedding_model_name, + dimensions=self.embedding_model_dim, + ) + + return [ + Document( + id=str(i), + meta={"id": str(i)}, + content=json.dumps(data), + embedding=embeddings.data[i].embedding, + ) + for i, data in enumerate( + tqdm(semantics["models"] + semantics["relationships"]) + ) + ] + + +# this is for quick testing only, please ignore this +if __name__ == "__main__": + document_store = init_document_store() + + indexing_pipeline = Indexing( + document_store=document_store, + ) + + with open("src/eval/data/book_2_mdl.json", "r") as f: + mdl_str = json.dumps(json.load(f)) + + indexing_pipeline.run(mdl_str) + indexing_pipeline.draw("./outputs/pipelines/ask/indexing_pipeline.jpg") + + print( + f"finished indexing documents, document count: {document_store.count_documents()}" + ) diff --git a/wren-ai-service/src/pipelines/ask/retrieval_pipeline.py b/wren-ai-service/src/pipelines/ask/retrieval_pipeline.py new file mode 100644 index 0000000000..5ac3f98bad --- /dev/null +++ b/wren-ai-service/src/pipelines/ask/retrieval_pipeline.py @@ -0,0 +1,114 @@ +import os +import uuid +from typing import Any, Optional + +from haystack import Pipeline + +from src.core.pipeline import BasicPipeline +from src.pipelines.ask.components.document_store import init_document_store +from src.pipelines.ask.components.embedder import init_embedder +from src.pipelines.ask.components.retriever import init_retriever +from src.utils import load_env_vars + +load_env_vars() + +if with_trace := os.getenv("ENABLE_TRACE", default=False): + from src.pipelines.trace import ( + TraceInput, + TraceSpanInput, + langfuse, + ) + + +class Retrieval(BasicPipeline): + def __init__( + self, + embedder: Any, + retriever: Any, + with_trace: bool = False, + ): + self._pipeline = Pipeline() + self._pipeline.add_component("embedder", embedder) + self._pipeline.add_component("retriever", retriever) + + self._pipeline.connect("embedder.embedding", "retriever.query_embedding") + + self.with_trace = with_trace + + super().__init__(self._pipeline) + + def run(self, query: str, user_id: Optional[str] = None): + if self.with_trace: + trace = langfuse.trace( + **TraceInput( + name="retrieval", + user_id=user_id, + ).__dict__, + public=True, + ) + + result = self._pipeline.run( + { + "embedder": { + "trace_span_input": TraceSpanInput( + trace_id=trace.id, + name="text_embedder", + input=query, + ), + "text": query, + }, + "retriever": { + "trace_span_input": TraceSpanInput( + trace_id=trace.id, + name="retriever", + input="text_embedder.embedding", + ), + }, + } + ) + + trace.update(input=query, output=result["retriever"]) + else: + result = self._pipeline.run( + { + "embedder": { + "text": query, + }, + } + ) + + return result + + +# this is for quick testing only, please ignore this +if __name__ == "__main__": + DATASET_NAME = os.getenv("DATASET_NAME") + document_store = init_document_store() + + retrieval_pipeline = Retrieval( + embedder=init_embedder(with_trace=with_trace), + retriever=init_retriever(with_trace=with_trace, document_store=document_store), + with_trace=with_trace, + ) + + if DATASET_NAME == "book_2": + query = "How many books are there?" + elif DATASET_NAME == "baseball_1": + query = "what is the full name and id of the college with the largest number of baseball players?" + else: + query = "random query here..." + + retrieval_result = retrieval_pipeline.run( + query, + user_id=str(uuid.uuid4()) if with_trace else None, + ) + + print(retrieval_result) + + if with_trace: + retrieval_pipeline.draw( + "./outputs/pipelines/ask/retrieval_pipeline_with_trace.jpg" + ) + langfuse.flush() + else: + retrieval_pipeline.draw("./outputs/pipelines/ask/retrieval_pipeline.jpg") diff --git a/wren-ai-service/src/pipelines/ask_details/__init__.py b/wren-ai-service/src/pipelines/ask_details/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/pipelines/ask_details/components/__init__.py b/wren-ai-service/src/pipelines/ask_details/components/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/pipelines/ask_details/components/generator.py b/wren-ai-service/src/pipelines/ask_details/components/generator.py new file mode 100644 index 0000000000..94f354e598 --- /dev/null +++ b/wren-ai-service/src/pipelines/ask_details/components/generator.py @@ -0,0 +1,78 @@ +import logging +from typing import Any, Dict, List, Optional + +import backoff +import openai +from haystack import component +from haystack.components.generators import OpenAIGenerator +from haystack.utils.auth import Secret + +from src.utils import load_env_vars + +from ...trace import TraceGenerationInput, trace_generation +from .prompts import system_prompt_builder + +load_env_vars() + +logging.getLogger("backoff").addHandler(logging.StreamHandler()) + +MODEL_NAME = "gpt-3.5-turbo" +MAX_TOKENS = { + "gpt-3.5-turbo": 4096, +} +GENERATION_KWARGS = { + "temperature": 0, + "n": 3, + "max_tokens": MAX_TOKENS[MODEL_NAME] if MODEL_NAME in MAX_TOKENS else 4096, + "response_format": {"type": "json_object"}, +} + + +@component +class CustomOpenAIGenerator(OpenAIGenerator): + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + @backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + return super(CustomOpenAIGenerator, self).run( + prompt=prompt, generation_kwargs=generation_kwargs + ) + + +@component +class TracedOpenAIGenerator(CustomOpenAIGenerator): + def _run(self, *args, **kwargs): + return super(TracedOpenAIGenerator, self).run(*args, **kwargs) + + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + def run( + self, + trace_generation_input: TraceGenerationInput, + prompt: str, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + return trace_generation(self._run)( + trace_generation_input=trace_generation_input, + prompt=prompt, + generation_kwargs=generation_kwargs, + ) + + +def init_generator( + with_trace: bool = False, + model_name: str = MODEL_NAME, + generation_kwargs: Optional[Dict[str, Any]] = GENERATION_KWARGS, +) -> Any: + if with_trace: + return TracedOpenAIGenerator( + api_key=Secret.from_env_var("OPENAI_API_KEY"), + model=model_name, + generation_kwargs=generation_kwargs, + system_prompt=system_prompt_builder.run()["prompt"], + ) + + return CustomOpenAIGenerator( + api_key=Secret.from_env_var("OPENAI_API_KEY"), + model=model_name, + generation_kwargs=generation_kwargs, + system_prompt=system_prompt_builder.run()["prompt"], + ) diff --git a/wren-ai-service/src/pipelines/ask_details/components/prompts.py b/wren-ai-service/src/pipelines/ask_details/components/prompts.py new file mode 100644 index 0000000000..dd92307238 --- /dev/null +++ b/wren-ai-service/src/pipelines/ask_details/components/prompts.py @@ -0,0 +1,28 @@ +from haystack.components.builders.prompt_builder import PromptBuilder + +system_prompt_template = """ +You are a Trino SQL expert with exceptional logical thinking skills. +Print what you think the SQL query means by giving 1 to 5 explainable steps to the user according to the complexity of SQL query. +If the SQL query is simple a select statement, you can just give one step to explain the SQL query; and vice versa. +This is vital to my career, I will become homeless if you make a mistake. + +### TASK ### +Given an input SQL query, create two things: +1. a list of steps composed of syntactically and semantically correct Trino SQL query to run, a short sentence to summary the Trino SQL query and a cte_name to represent the Trino SQL query. +2. a short description describing the SQL query in a human-readable format. +3. only the cte_name of the last step is empty. + +### FINAL ANSWER FORMAT ### +The final answer must be a valid JSON format as follows: + +{ + "description": , + "steps: [{ + "sql": , + "summary": , + "cte_name": + }] # list of steps +} +""" + +system_prompt_builder = PromptBuilder(template=system_prompt_template) diff --git a/wren-ai-service/src/pipelines/ask_details/generation_pipeline.py b/wren-ai-service/src/pipelines/ask_details/generation_pipeline.py new file mode 100644 index 0000000000..06abcb5f62 --- /dev/null +++ b/wren-ai-service/src/pipelines/ask_details/generation_pipeline.py @@ -0,0 +1,150 @@ +import json +import os +import uuid +from typing import Any, Optional + +from haystack import Pipeline + +from src.core.pipeline import BasicPipeline +from src.pipelines.ask.components.document_store import init_document_store +from src.pipelines.ask.components.embedder import init_embedder +from src.pipelines.ask.components.generator import MODEL_NAME, init_generator +from src.pipelines.ask.components.retriever import init_retriever +from src.pipelines.ask.generation_pipeline import Generation as AskGeneration +from src.pipelines.ask.retrieval_pipeline import Retrieval as AskRetrieval +from src.pipelines.ask_details.components.generator import ( + init_generator as init_ask_details_generator, +) +from src.utils import clean_generation_result, load_env_vars +from src.web.v1.services.ask import AskResultResponse +from src.web.v1.services.ask_details import AskDetailsResultResponse + +load_env_vars() + +if with_trace := os.getenv("ENABLE_TRACE", default=False): + from src.pipelines.trace import ( + TraceGenerationInput, + TraceInput, + langfuse, + ) + + +class Generation(BasicPipeline): + def __init__( + self, + generator: Any, + with_trace: bool = False, + ): + self._pipeline = Pipeline() + self._pipeline.add_component("generator", generator) + + self.with_trace = with_trace + + super().__init__(self._pipeline) + + def run(self, sql: str, user_id: Optional[str] = None): + if self.with_trace: + trace = langfuse.trace( + **TraceInput( + name="generation", + user_id=user_id, + ).__dict__, + public=True, + ) + + result = self._pipeline.run( + { + "generator": { + "trace_generation_input": TraceGenerationInput( + trace_id=trace.id, + name="generator", + input=sql, + model=MODEL_NAME, + ) + }, + } + ) + + trace.update(input=sql, output=result["generator"]) + return result + else: + return self._pipeline.run( + { + "generator": { + "prompt": sql, + }, + } + ) + + +# this is for quick testing only, please ignore this +if __name__ == "__main__": + DATASET_NAME = os.getenv("DATASET_NAME") + + document_store = init_document_store() + embedder = init_embedder(with_trace=with_trace) + retriever = init_retriever(with_trace=with_trace, document_store=document_store) + ask_generator = init_generator(with_trace=with_trace) + ask_details_generator = init_ask_details_generator(with_trace=with_trace) + + ask_retrieval_pipeline = AskRetrieval( + embedder=embedder, + retriever=retriever, + with_trace=with_trace, + ) + ask_generation_pipeline = AskGeneration( + generator=ask_generator, + with_trace=with_trace, + ) + + if DATASET_NAME == "book_2": + query = "Show the title and publication dates of books." + elif DATASET_NAME == "baseball_1": + query = "what is the full name and id of the college with the largest number of baseball players?" + else: + query = "random query here..." + + retrieval_result = ask_retrieval_pipeline.run( + query, + user_id=str(uuid.uuid4()) if with_trace else None, + ) + + ask_generation_result = ask_generation_pipeline.run( + query, + contexts=retrieval_result["retriever"]["documents"], + user_id=str(uuid.uuid4()) if with_trace else None, + ) + + cleaned_ask_generation_result = json.loads( + clean_generation_result(ask_generation_result["generator"]["replies"][0]) + ) + print(f"cleaned_ask_generation_result: {cleaned_ask_generation_result}") + assert AskResultResponse.AskResult(**cleaned_ask_generation_result) + + generation_pipeline = Generation( + generator=ask_details_generator, + with_trace=with_trace, + ) + + generation_result = generation_pipeline.run( + cleaned_ask_generation_result["sql"], + user_id=str(uuid.uuid4()) if with_trace else None, + ) + + cleaned_generation_result = json.loads( + clean_generation_result(generation_result["generator"]["replies"][0]) + ) + print(f"cleaned_generation_result: {cleaned_generation_result}") + assert AskDetailsResultResponse.AskDetailsResponseDetails( + **cleaned_generation_result + ) + + if with_trace: + generation_pipeline.draw( + "./outputs/pipelines/ask_details/generation_pipeline_with_trace.jpg" + ) + langfuse.flush() + else: + generation_pipeline.draw( + "./outputs/pipelines/ask_details/generation_pipeline.jpg" + ) diff --git a/wren-ai-service/src/pipelines/semantics/__init__.py b/wren-ai-service/src/pipelines/semantics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/pipelines/semantics/description.py b/wren-ai-service/src/pipelines/semantics/description.py new file mode 100644 index 0000000000..56ebfc3c02 --- /dev/null +++ b/wren-ai-service/src/pipelines/semantics/description.py @@ -0,0 +1,273 @@ +import json +from typing import Any, AnyStr, Dict, Optional + +from haystack import Pipeline +from haystack.components.builders import PromptBuilder +from haystack.components.embedders import ( + OpenAIDocumentEmbedder, + OpenAITextEmbedder, +) +from haystack.components.generators import OpenAIGenerator +from haystack.utils import Secret +from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore + +from src.core.pipeline import BasicPipeline +from src.utils import load_env_vars + +_TEMPLATE = """ +There are numerous experts dedicated to generating semantic descriptions and names for various types of +data. They are working together to provide a comprehensive and accurate description of the data. + +### EXTRA INFORMATION ### +Given the following information to improve the description generation. + +Context: +{% for document in documents %} + {{ document.content }} +{% endfor %} + +### INSTRUCTIONS ### +- Provide a brief summary of the specified identifier as the description. +- Name the display_name based on the description, using natural language. + +### TASK ### +Given the input model, provide a description of the specified identifier. + +### MODEL STRUCTURE ### +Model Structure: {{ mdl }} + +### MODEL NAME ### +Model Name: {{ model }} + +### IDENTIFIER ### +the types for the identifier include: model, column@column_name +Identifier: {{ identifier }} + +### OUTPUT FORMAT ### +The output format must be in JSON format: +{ + "identifier": "", + "display_name": "", + "description": "" +} + +The output format doesn't need a markdown JSON code block. +""" + + +class Generation(BasicPipeline): + def __init__(self): + self._document_store = create_qdrant_document_store() + self._text_embedder = create_openai_text_embedder() + self._retriever = create_qdrant_embedding_retriever( + document_store=self._document_store + ) + self._prompt_builder = PromptBuilder(template=_TEMPLATE) + self._llm = OpenAIGenerator() + + self._pipe = Pipeline() + self._pipe.add_component("text_embedder", self._text_embedder) + self._pipe.add_component("retriever", self._retriever) + self._pipe.add_component("prompt_builder", self._prompt_builder) + self._pipe.add_component("llm", self._llm) + + self._pipe.connect("text_embedder.embedding", "retriever.query_embedding") + self._pipe.connect("retriever", "prompt_builder.documents") + self._pipe.connect("prompt_builder", "llm") + + super().__init__(self._pipe) + + def run( + self, *, mdl: Dict[AnyStr, Any], model: str, identifier: Optional[str] = None + ): + return self._pipe.run( + { + "prompt_builder": { + "mdl": mdl, + "model": model, + "identifier": identifier, + }, + "text_embedder": { + "text": f"model: {model}, identifier: {identifier}", + }, + } + ) + + +_EMBEDDING_MODEL_DIMENSION = 3072 +_EMBEDDING_MODEL_NAME = "text-embedding-3-large" +_DATASET_NAME = "example" + +""" +This is a simple example of how to use the QdrantDocumentStore and OpenAITextEmbedder to index and query documents. + +``` +document_store = create_qdrant_document_store() +documents = [Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="Elephants have been observed to behave in a way that indicates a high level of self-awareness, such as recognizing themselves in mirrors."), + Document( + content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves.")] + +document_embedder = create_openai_document_embedder() +documents_with_embeddings = document_embedder.run(documents) +document_store.write_documents(documents_with_embeddings.get("documents"), policy=DuplicatePolicy.OVERWRITE) +``` +""" + + +def create_qdrant_document_store( + url: str = "localhost" if load_env_vars() == "dev" else "qdrant", + embedding_dim: int = _EMBEDDING_MODEL_DIMENSION, + index: str = _DATASET_NAME, +) -> QdrantDocumentStore: + return QdrantDocumentStore( + url=url, + embedding_dim=embedding_dim, + index=index, + ) + + +# this method is not being used in the pipeline, but it will be kept for testing and evaluation purposes +def create_openai_document_embedder( + api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), + model: str = _EMBEDDING_MODEL_NAME, +) -> OpenAIDocumentEmbedder: + return OpenAIDocumentEmbedder( + api_key=api_key, + model=model, + ) + + +""" +This is a simple example of how to use the QdrantEmbeddingRetriever to query documents from the document store. + +``` +document_store = create_qdrant_document_store() +text_embedder = create_openai_text_embedder() +retriever = create_qdrant_embedding_retriever(document_store) + +query_pipeline = Pipeline() +query_pipeline.add_component("text_embedder", text_embedder) +query_pipeline.add_component("retriever", retriever) + +query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + +query = "How many languages are there?" + +result = query_pipeline.run({"text_embedder": {"text": query}}) + +print(result['retriever']['documents']) +``` +""" + + +def create_openai_text_embedder( + api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), + model: str = _EMBEDDING_MODEL_NAME, +) -> OpenAITextEmbedder: + return OpenAITextEmbedder( + api_key=api_key, + model=model, + ) + + +def create_qdrant_embedding_retriever( + document_store: QdrantDocumentStore, top_k: int = 3 +) -> QdrantEmbeddingRetriever: + return QdrantEmbeddingRetriever( + document_store=document_store, + top_k=top_k, + ) + + +if __name__ == "__main__": + env = load_env_vars() + pipe = Generation() + + res = pipe.run( + **{ + "mdl": { + "name": "all_star", + "properties": {}, + "refsql": 'select * from "canner-cml".spider."baseball_1-all_star"', + "columns": [ + { + "name": "player_id", + "type": "varchar", + "notnull": False, + "iscalculated": False, + "expression": "player_id", + "properties": {}, + }, + { + "name": "year", + "type": "integer", + "notnull": False, + "iscalculated": False, + "expression": "year", + "properties": {}, + }, + { + "name": "game_num", + "type": "integer", + "notnull": False, + "iscalculated": False, + "expression": "game_num", + "properties": {}, + }, + { + "name": "game_id", + "type": "varchar", + "notnull": False, + "iscalculated": False, + "expression": "game_id", + "properties": {}, + }, + { + "name": "team_id", + "type": "varchar", + "notnull": False, + "iscalculated": False, + "expression": "team_id", + "properties": {}, + }, + { + "name": "league_id", + "type": "varchar", + "notnull": False, + "iscalculated": False, + "expression": "league_id", + "properties": {}, + }, + { + "name": "gp", + "type": "real", + "notnull": False, + "iscalculated": False, + "expression": "gp", + "properties": {}, + }, + { + "name": "starting_pos", + "type": "real", + "notnull": False, + "iscalculated": False, + "expression": "starting_pos", + "properties": {}, + }, + ], + "primarykey": "", + }, + "model": "all_star", + "identifier": "model", + } + ) + print(res) + print(res["llm"]["replies"][0]) + content = json.loads(res["llm"]["replies"][0]) + print(content) + + pipe.draw("./outputs/pipelines/semantics/description.jpg") + pass diff --git a/wren-ai-service/src/pipelines/trace.py b/wren-ai-service/src/pipelines/trace.py new file mode 100644 index 0000000000..c04e7c844d --- /dev/null +++ b/wren-ai-service/src/pipelines/trace.py @@ -0,0 +1,96 @@ +import os +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Callable, Dict, List, Literal, Optional, Union + +import pydantic +from langfuse import Langfuse +from langfuse.model import MapValue, ModelUsage, PromptClient + +from src.utils import load_env_vars + +load_env_vars() + +if with_trace := os.getenv("ENABLE_TRACE", default=False): + langfuse = Langfuse( + public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), + secret_key=os.getenv("LANGFUSE_SECRET_KEY"), + host="https://cloud.langfuse.com", + threads=os.cpu_count() // 2, + ) + langfuse.auth_check() + + +@dataclass +class TraceInput: + id: Optional[str] = None + name: Optional[str] = None + user_id: Optional[str] = None + version: Optional[str] = None + input: Optional[Any] = None + output: Optional[Any] = None + metadata: Optional[Any] = None + tags: Optional[List[str]] = None + timestamp: Optional[datetime] = None + + +@dataclass +class TraceSpanInput: + id: Optional[str] = None + trace_id: Optional[str] = None + name: Optional[str] = None + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + metadata: Optional[Any] = None + input: Optional[Any] = None + output: Optional[Any] = None + level: Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]] = None + status_message: Optional[str] = None + parent_observation_id: Optional[str] = None + version: Optional[str] = None + + +@dataclass +class TraceGenerationInput: + id: Optional[str] = None + trace_id: Optional[str] = None + name: Optional[str] = None + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + metadata: Optional[Any] = None + level: Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]] = None + status_message: Optional[str] = None + parent_observation_id: Optional[str] = None + version: Optional[str] = None + completion_start_time: Optional[datetime] = None + completion_end_time: Optional[datetime] = None + model: Optional[str] = None + model_parameters: Optional[Dict[str, MapValue]] = None + input: Optional[Any] = None + output: Optional[Any] = None + usage: Optional[Union[pydantic.BaseModel, ModelUsage]] = None + prompt: Optional[PromptClient] = None + + +def trace_span(func: Callable): + def wrapper(*args, **kwargs): + span = langfuse.span(**kwargs["trace_span_input"].__dict__) + del kwargs["trace_span_input"] + results = func(*args, **kwargs) + span.end(output=results) + return results + + return wrapper + + +def trace_generation(func: Callable): + def wrapper(*args, **kwargs): + generation = langfuse.generation(**kwargs["trace_generation_input"].__dict__) + del kwargs["trace_generation_input"] + results = func(*args, **kwargs) + generation.end( + output=results, + ) + return results + + return wrapper diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py new file mode 100644 index 0000000000..e5ab441436 --- /dev/null +++ b/wren-ai-service/src/utils.py @@ -0,0 +1,22 @@ +import os +import re + +from dotenv import load_dotenv + + +def clean_generation_result(result: str) -> str: + def _normalize_whitespace(s: str) -> str: + return re.sub(r"\s+", " ", s).strip() + + return _normalize_whitespace(result).replace("\n", "") + + +def load_env_vars() -> str: + load_dotenv(override=True) + + if is_dev_env := os.getenv("ENV") and os.getenv("ENV").lower() == "dev": + load_dotenv(".env.dev", override=True) + else: + load_dotenv(".env.prod", override=True) + + return "dev" if is_dev_env else "prod" diff --git a/wren-ai-service/src/web/__init__.py b/wren-ai-service/src/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/web/v1/__init__.py b/wren-ai-service/src/web/v1/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/src/web/v1/routers.py b/wren-ai-service/src/web/v1/routers.py new file mode 100644 index 0000000000..03a6f38792 --- /dev/null +++ b/wren-ai-service/src/web/v1/routers.py @@ -0,0 +1,115 @@ +import uuid +from typing import List + +from fastapi import APIRouter, BackgroundTasks + +import src.globals as container +from src.web.v1.services.ask import ( + AskRequest, + AskResponse, + AskResultRequest, + AskResultResponse, + SemanticsPreparationRequest, + SemanticsPreparationResponse, + SemanticsPreparationStatusRequest, + SemanticsPreparationStatusResponse, + StopAskRequest, + StopAskResponse, +) +from src.web.v1.services.ask_details import ( + AskDetailsRequest, + AskDetailsResponse, + AskDetailsResultRequest, + AskDetailsResultResponse, +) +from src.web.v1.services.semantics import ( + BulkGenerateDescriptionRequest, + GenerateDescriptionResponse, +) + +router = APIRouter() + + +@router.post("/semantics-descriptions/") +async def bulk_generate_description( + bulk_request: BulkGenerateDescriptionRequest, +) -> List[GenerateDescriptionResponse]: + return [ + container.SEMANTIC_SERVICE.generate_description(request) + for request in bulk_request + ] + + +@router.post("/semantics-preparations/") +async def prepare_semantics( + prepare_semantics_request: SemanticsPreparationRequest, + background_tasks: BackgroundTasks, +) -> SemanticsPreparationResponse: + background_tasks.add_task( + container.ASK_SERVICE.prepare_semantics, + prepare_semantics_request, + ) + return SemanticsPreparationResponse(id=prepare_semantics_request.id) + + +@router.get("/semantics-preparations/{task_id}/status/") +async def get_prepare_semantics_status( + task_id: str, +) -> SemanticsPreparationStatusResponse: + return container.ASK_SERVICE.get_prepare_semantics_status( + SemanticsPreparationStatusRequest(id=task_id) + ) + + +@router.post("/asks/") +async def ask( + ask_request: AskRequest, + background_tasks: BackgroundTasks, +) -> AskResponse: + query_id = str(uuid.uuid4()) + ask_request.query_id = query_id + background_tasks.add_task( + container.ASK_SERVICE.ask, + ask_request, + ) + return AskResponse(query_id=query_id) + + +@router.patch("/asks/{query_id}") +async def stop_ask( + query_id: str, + stop_ask_request: StopAskRequest, + background_tasks: BackgroundTasks, +) -> StopAskResponse: + stop_ask_request.query_id = query_id + background_tasks.add_task( + container.ASK_SERVICE.stop_ask, + stop_ask_request, + ) + return StopAskResponse(query_id=query_id) + + +@router.get("/asks/{query_id}/result/") +async def get_ask_result(query_id: str) -> AskResultResponse: + return container.ASK_SERVICE.get_ask_result(AskResultRequest(query_id=query_id)) + + +@router.post("/ask-details/") +async def ask_details( + ask_details_request: AskDetailsRequest, + background_tasks: BackgroundTasks, +) -> AskDetailsResponse: + query_id = str(uuid.uuid4()) + ask_details_request.query_id = query_id + background_tasks.add_task( + container.ASK_DETAILS_SERVICE.ask_details, + ask_details_request, + ) + return AskDetailsResponse(query_id=query_id) + + +@router.get("/ask-details/{query_id}/result/") +async def get_ask_details_result(query_id: str) -> AskDetailsResultResponse: + return container.ASK_DETAILS_SERVICE.get_ask_details_result( + AskDetailsResultRequest(query_id=query_id) + ) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py new file mode 100644 index 0000000000..60dd4052a2 --- /dev/null +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -0,0 +1,180 @@ +import json +from typing import List, Literal, Optional + +from haystack import Pipeline +from pydantic import BaseModel + +from src.utils import clean_generation_result + + +# POST /v1/semantics-preparations +class SemanticsPreparationRequest(BaseModel): + mdl: str + id: str + + +class SemanticsPreparationResponse(BaseModel): + id: str + + +# GET /v1/semantics-preparations/{task_id}/status +class SemanticsPreparationStatusRequest(BaseModel): + id: str + + +class SemanticsPreparationStatusResponse(BaseModel): + status: Literal["indexing", "finished", "failed"] + + +class SQLExplanation(BaseModel): + sql: str + summary: str + cte_name: str + + +# POST /v1/asks +class AskRequest(BaseModel): + class AskResponseDetails(BaseModel): + sql: str + summary: str + steps: List[SQLExplanation] + + _query_id: str | None = None + query: str + id: str # for identifying which collection to access from vectordb, the same hash string for identifying which mdl model deployment from backend + history: Optional[AskResponseDetails] = None + + @property + def query_id(self) -> str: + return self._query_id + + @query_id.setter + def query_id(self, query_id: str): + self._query_id = query_id + + +class AskResponse(BaseModel): + query_id: str + + +# PATCH /v1/asks/{query_id} +class StopAskRequest(BaseModel): + _query_id: str | None = None + status: Literal["stopped"] + + @property + def query_id(self) -> str: + return self._query_id + + @query_id.setter + def query_id(self, query_id: str): + self._query_id = query_id + + +class StopAskResponse(BaseModel): + query_id: str + + +# GET /v1/asks/{query_id}/result +class AskResultRequest(BaseModel): + query_id: str + + +class AskResultResponse(BaseModel): + class AskResult(BaseModel): + sql: str + summary: str + + status: Literal[ + "understanding", "searching", "generating", "finished", "failed", "stopped" + ] + response: Optional[List[AskResult]] = None + error: Optional[str] = None + + +class AskService: + def __init__(self, pipelines: dict[str, Pipeline]): + self._pipelines = pipelines + self.prepare_semantics_statuses: dict[ + str, SemanticsPreparationStatusResponse.status + ] = {} + self.ask_results: dict[str, AskResultResponse] = {} + + def prepare_semantics(self, prepare_semantics_request: SemanticsPreparationRequest): + self.prepare_semantics_statuses[ + prepare_semantics_request.id + ] = SemanticsPreparationStatusResponse(status="indexing") + + try: + self._pipelines["indexing"].run(prepare_semantics_request.mdl) + + self.prepare_semantics_statuses[ + prepare_semantics_request.id + ] = SemanticsPreparationStatusResponse(status="finished") + except Exception as e: + # TODO: log the error + print(f"Failed to prepare semantics: {e}") + self.prepare_semantics_statuses[ + prepare_semantics_request.id + ] = SemanticsPreparationStatusResponse(status="failed") + + def get_prepare_semantics_status( + self, prepare_semantics_status_request: SemanticsPreparationStatusRequest + ) -> SemanticsPreparationStatusResponse: + return self.prepare_semantics_statuses[prepare_semantics_status_request.id] + + def ask( + self, + ask_request: AskRequest, + ): + # ask status can be understanding, searching, generating, finished, failed, stopped + # we will need to handle business logic for each status + query_id = ask_request.query_id + + self.ask_results[query_id] = AskResultResponse(status="understanding") + self.ask_results[query_id] = AskResultResponse(status="searching") + + retrieval_result = self._pipelines["retrieval"].run( + query=ask_request.query, + ) + + self.ask_results[query_id] = AskResultResponse(status="generating") + + generation_result = self._pipelines["generation"].run( + query=ask_request.query, + contexts=retrieval_result["retriever"]["documents"], + history=ask_request.history, + ) + + cleaned_generation_results = [ + json.loads(clean_generation_result(reply)) + for reply in generation_result["generator"]["replies"] + ] + + assert len(cleaned_generation_results) == 3 + if not cleaned_generation_results[0]["sql"]: + self.ask_results[query_id] = AskResultResponse( + status="failed", error="Failed to generate SQL" + ) + else: + self.ask_results[query_id] = AskResultResponse( + status="finished", + response=[ + AskResultResponse.AskResult(**result) + for result in cleaned_generation_results + ], + ) + + def stop_ask( + self, + stop_ask_request: StopAskRequest, + ): + self.ask_results[stop_ask_request.query_id] = AskResultResponse( + status="stopped", + ) + + def get_ask_result( + self, + ask_result_request: AskResultRequest, + ) -> AskResultResponse: + return self.ask_results[ask_result_request.query_id] diff --git a/wren-ai-service/src/web/v1/services/ask_details.py b/wren-ai-service/src/web/v1/services/ask_details.py new file mode 100644 index 0000000000..e6dc94339f --- /dev/null +++ b/wren-ai-service/src/web/v1/services/ask_details.py @@ -0,0 +1,93 @@ +import json +from typing import List, Literal, Optional + +from haystack import Pipeline +from pydantic import BaseModel + +from src.utils import clean_generation_result + + +class SQLExplanation(BaseModel): + sql: str + summary: str + cte_name: str + + +# POST /v1/ask-details +class AskDetailsRequest(BaseModel): + _query_id: str | None = None + query: str + sql: str + summary: str + + @property + def query_id(self) -> str: + return self._query_id + + @query_id.setter + def query_id(self, query_id: str): + self._query_id = query_id + + +class AskDetailsResponse(BaseModel): + query_id: str + + +# GET /v1/ask-details/{query_id}/result +class AskDetailsResultRequest(BaseModel): + query_id: str + + +class AskDetailsResultResponse(BaseModel): + class AskDetailsResponseDetails(BaseModel): + description: str + steps: List[SQLExplanation] + + status: Literal["understanding", "searching", "generating", "finished"] + response: Optional[AskDetailsResponseDetails] = None + + +class AskDetailsService: + def __init__(self, pipelines: dict[str, Pipeline]): + self._pipelines = pipelines + self.ask_details_results: dict[str, AskDetailsResultResponse] = {} + + def ask_details( + self, + ask_details_request: AskDetailsRequest, + ) -> AskDetailsResponse: + # ask details status can be understanding, searching, generating, finished, stopped + # we will need to handle business logic for each status + query_id = ask_details_request.query_id + + self.ask_details_results[query_id] = AskDetailsResultResponse( + status="understanding" + ) + self.ask_details_results[query_id] = AskDetailsResultResponse( + status="searching" + ) + + self.ask_details_results[query_id] = AskDetailsResultResponse( + status="generating" + ) + + generation_result = self._pipelines["generation"].run( + sql=ask_details_request.sql, + ) + + cleaned_generation_result = json.loads( + clean_generation_result(generation_result["generator"]["replies"][0]) + ) + + self.ask_details_results[query_id] = AskDetailsResultResponse( + status="finished", + response=AskDetailsResultResponse.AskDetailsResponseDetails( + **cleaned_generation_result + ), + ) + + def get_ask_details_result( + self, + ask_details_result_request: AskDetailsResultRequest, + ) -> AskDetailsResultResponse: + return self.ask_details_results[ask_details_result_request.query_id] diff --git a/wren-ai-service/src/web/v1/services/semantics.py b/wren-ai-service/src/web/v1/services/semantics.py new file mode 100644 index 0000000000..694d97a59b --- /dev/null +++ b/wren-ai-service/src/web/v1/services/semantics.py @@ -0,0 +1,55 @@ +import json +from typing import Any, AnyStr, Dict, List + +from haystack import Pipeline +from pydantic import BaseModel + + +# POST /v1/semantics-descriptions +class BulkGenerateDescriptionRequest(BaseModel): + mdl: Dict[AnyStr, Any] + model: str + identifiers: List[str] + + def __iter__(self): + for identifier in self.identifiers: + yield GenerateDescriptionRequest( + mdl=self.mdl, + model=self.model, + identifier=identifier, + ) + + +class GenerateDescriptionRequest(BaseModel): + mdl: Dict[AnyStr, Any] + model: str + identifier: str + + +class GenerateDescriptionResponse(BaseModel): + identifier: str + display_name: str + description: str + + +class SemanticsService: + def __init__(self, pipelines: dict[str, Pipeline]): + self._pipelines = pipelines + + def generate_description( + self, request: GenerateDescriptionRequest + ) -> GenerateDescriptionResponse: + response = self._pipelines["generate_description"].run( + **{ + "mdl": request.mdl, + "model": request.model, + "identifier": request.identifier, + } + ) + content = json.loads(response["llm"]["replies"][0]) + + return GenerateDescriptionResponse( + identifier=request.identifier, + display_name=content.get("display_name"), + description=content.get("description"), + ) diff --git a/wren-ai-service/tests/__init__.py b/wren-ai-service/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/tests/data/book_2_mdl.json b/wren-ai-service/tests/data/book_2_mdl.json new file mode 100644 index 0000000000..1b1dc5e629 --- /dev/null +++ b/wren-ai-service/tests/data/book_2_mdl.json @@ -0,0 +1,100 @@ +{ + "catalog": "canner-cml", + "schema": "spider", + "models": [ + { + "name": "book", + "properties": {}, + "refSql": "select * from \"canner-cml\".spider.\"book_2-book\"", + "columns": [ + { + "name": "Book_ID", + "type": "INTEGER", + "notNull": false, + "isCalculated": false, + "expression": "Book_ID", + "properties": {} + }, + { + "name": "Title", + "type": "VARCHAR", + "notNull": false, + "isCalculated": false, + "expression": "Title", + "properties": {} + }, + { + "name": "Issues", + "type": "REAL", + "notNull": false, + "isCalculated": false, + "expression": "Issues", + "properties": {} + }, + { + "name": "Writer", + "type": "VARCHAR", + "notNull": false, + "isCalculated": false, + "expression": "Writer", + "properties": {} + } + ], + "primaryKey": "" + }, + { + "name": "publication", + "properties": {}, + "refSql": "select * from \"canner-cml\".spider.\"book_2-publication\"", + "columns": [ + { + "name": "Publication_ID", + "type": "INTEGER", + "notNull": false, + "isCalculated": false, + "expression": "Publication_ID", + "properties": {} + }, + { + "name": "Book_ID", + "type": "INTEGER", + "notNull": false, + "isCalculated": false, + "expression": "Book_ID", + "properties": {} + }, + { + "name": "Publisher", + "type": "VARCHAR", + "notNull": false, + "isCalculated": false, + "expression": "Publisher", + "properties": {} + }, + { + "name": "Publication_Date", + "type": "VARCHAR", + "notNull": false, + "isCalculated": false, + "expression": "Publication_Date", + "properties": {} + }, + { + "name": "Price", + "type": "REAL", + "notNull": false, + "isCalculated": false, + "expression": "Price", + "properties": {} + } + ], + "primaryKey": "" + } + ], + "relationships": [], + "metrics": [], + "cumulativeMetrics": [], + "enumDefinitions": [], + "views": [], + "macros": [] +} \ No newline at end of file diff --git a/wren-ai-service/tests/services/__init__.py b/wren-ai-service/tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wren-ai-service/tests/services/test_ask.py b/wren-ai-service/tests/services/test_ask.py new file mode 100644 index 0000000000..ffaec69af5 --- /dev/null +++ b/wren-ai-service/tests/services/test_ask.py @@ -0,0 +1,111 @@ +import json +import uuid + +import pytest + +from src.pipelines.ask import ( + generation_pipeline, + indexing_pipeline, + retrieval_pipeline, +) +from src.pipelines.ask.components.document_store import init_document_store +from src.pipelines.ask.components.embedder import init_embedder +from src.pipelines.ask.components.generator import init_generator +from src.pipelines.ask.components.prompts import init_generation_prompt_builder +from src.pipelines.ask.components.retriever import init_retriever +from src.pipelines.ask.indexing_pipeline import Indexing +from src.web.v1.services.ask import ( + AskRequest, + AskResultRequest, + AskService, + SemanticsPreparationRequest, +) + + +@pytest.fixture +def ask_service(): + document_store = init_document_store() + embedder = init_embedder() + retriever = init_retriever(document_store=document_store) + generator = init_generator() + generation_prompt_builder = init_generation_prompt_builder() + + return AskService( + { + "indexing": indexing_pipeline.Indexing( + document_store=document_store, + ), + "retrieval": retrieval_pipeline.Retrieval( + embedder=embedder, + retriever=retriever, + ), + "generation": generation_pipeline.Generation( + generator=generator, + prompt_builder=generation_prompt_builder, + ), + } + ) + + +@pytest.fixture +def mdl_str(): + with open("tests/data/book_2_mdl.json", "r") as f: + return json.dumps(json.load(f)) + + +def test_indexing_pipeline(mdl_str: str): + document_store = init_document_store(dataset_name="book_2") + indexing_pipeline = Indexing( + document_store=document_store, + ) + + indexing_pipeline.run(mdl_str) + assert document_store.count_documents() == 2 + + +def test_ask_with_easy_query(ask_service: AskService, mdl_str: str): + id = str(uuid.uuid4()) + ask_service.prepare_semantics( + SemanticsPreparationRequest( + mdl=mdl_str, + id=id, + ) + ) + + # asking + query_id = str(uuid.uuid4()) + ask_request = AskRequest( + query="How many books are there?'", + id=id, + ) + ask_request.query_id = query_id + ask_service.ask(ask_request) + + # getting ask result + ask_result_response = ask_service.get_ask_result( + AskResultRequest( + query_id=query_id, + ) + ) + + # from Pao Sheng: I think it has a potential risk if a dangling status case happens. + # maybe we could consider adding an approach that if over a time limit, + # the process will throw an exception. + while ( + ask_result_response.status != "finished" + and ask_result_response.status != "failed" + ): + ask_result_response = ask_service.get_ask_result( + AskResultRequest( + query_id=query_id, + ) + ) + + if ask_result_response.status == "finished": + assert ask_result_response.response is not None + assert len(ask_result_response.response) == 3 + assert ask_result_response.response[0].sql != "" + assert ask_result_response.response[0].summary != "" + else: + assert ask_result_response.status == "failed" + assert ask_result_response.error != "" diff --git a/wren-ai-service/tests/services/test_ask_details.py b/wren-ai-service/tests/services/test_ask_details.py new file mode 100644 index 0000000000..4a1fbbd394 --- /dev/null +++ b/wren-ai-service/tests/services/test_ask_details.py @@ -0,0 +1,58 @@ +import uuid + +import pytest + +from src.pipelines.ask_details import generation_pipeline +from src.pipelines.ask_details.components.generator import init_generator +from src.web.v1.services.ask_details import ( + AskDetailsRequest, + AskDetailsResultRequest, + AskDetailsService, +) + + +@pytest.fixture +def ask_details_service(): + generator = init_generator() + + return AskDetailsService( + { + "generation": generation_pipeline.Generation( + generator=generator, + ), + } + ) + + +def test_ask_details_wit_easy_query(ask_details_service: AskDetailsService): + # asking details + query_id = str(uuid.uuid4()) + sql = "SELECT * FROM books" + ask_details_request = AskDetailsRequest( + query="How many books are there?'", + sql=sql, + summary="This is a summary", + ) + ask_details_request.query_id = query_id + ask_details_service.ask_details(ask_details_request) + + # getting ask details result + ask_details_result_response = ask_details_service.get_ask_details_result( + AskDetailsResultRequest( + query_id=query_id, + ) + ) + + while ask_details_result_response.status != "finished": + ask_details_result_response = ask_details_service.get_ask_details_result( + AskDetailsResultRequest( + query_id=query_id, + ) + ) + + if ask_details_result_response.status == "finished": + assert ask_details_result_response.response.description != "" + assert len(ask_details_result_response.response.steps) >= 1 + assert ask_details_result_response.response.steps[0].sql != "" + assert ask_details_result_response.response.steps[0].summary != "" + assert ask_details_result_response.response.steps[0].cte_name == "" diff --git a/wren-ai-service/tests/services/test_semantics.py b/wren-ai-service/tests/services/test_semantics.py new file mode 100644 index 0000000000..f5af4d474c --- /dev/null +++ b/wren-ai-service/tests/services/test_semantics.py @@ -0,0 +1,46 @@ +import pytest + +from src.pipelines.semantics import description +from src.web.v1.services.semantics import ( + GenerateDescriptionRequest, + SemanticsService, +) + + +@pytest.fixture +def semantics_service(): + return SemanticsService( + pipelines={ + "generate_description": description.Generation(), + } + ) + + +def test_generate_description(semantics_service: SemanticsService): + actual = semantics_service.generate_description( + GenerateDescriptionRequest( + mdl={ + "name": "all_star", + "properties": {}, + "refsql": 'select * from "canner-cml".spider."baseball_1-all_star"', + "columns": [ + { + "name": "player_id", + "type": "varchar", + "notnull": False, + "iscalculated": False, + "expression": "player_id", + "properties": {}, + } + ], + "primarykey": "", + }, + model="all_star", + identifier="column@player_id", + ) + ) + + assert actual is not None + assert actual.identifier == "column@player_id" + assert actual.display_name is not None and actual.display_name != "" + assert actual.description is not None and actual.description != "" diff --git a/wren-ai-service/tests/test_main.py b/wren-ai-service/tests/test_main.py new file mode 100644 index 0000000000..7e38b7a277 --- /dev/null +++ b/wren-ai-service/tests/test_main.py @@ -0,0 +1,237 @@ +import json +import uuid + +import pytest +from fastapi.testclient import TestClient + +from src.__main__ import app + + +@pytest.fixture +def mdl_str(): + with open("tests/data/book_2_mdl.json", "r") as f: + return json.dumps(json.load(f)) + + +def test_semantics_description(): + # using TestClient as a context manager would trigger startup/shutdown events as well as lifespans. + with TestClient(app) as client: + response = client.post( + url="/v1/semantics-descriptions/", + json={ + "mdl": { + "name": "all_star", + "properties": {}, + "refsql": 'select * from "canner-cml".spider."baseball_1-all_star"', + "columns": [ + { + "name": "player_id", + "type": "varchar", + "notnull": False, + "iscalculated": False, + "expression": "player_id", + "properties": {}, + } + ], + "primarykey": "", + }, + "model": "all_star", + "identifiers": ["column@player_id"], + }, + ) + + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0]["identifier"] == "column@player_id" + assert ( + response.json()[0]["display_name"] is not None + and response.json()[0]["display_name"] != "" + ) + assert ( + response.json()[0]["description"] is not None + and response.json()[0]["description"] != "" + ) + + +def test_semantics_preparations(mdl_str: str): + with TestClient(app) as client: + semantics_preperation_id = str(uuid.uuid4()) + + response = client.post( + url="/v1/semantics-preparations/", + json={ + "mdl": mdl_str, + "id": semantics_preperation_id, + }, + ) + + assert response.status_code == 200 + assert response.json()["id"] == semantics_preperation_id + + status = "indexing" + + while status == "indexing": + response = client.get( + url=f"/v1/semantics-preparations/{semantics_preperation_id}/status/" + ) + + assert response.status_code == 200 + assert response.json()["status"] in ["indexing", "finished", "failed"] + status = response.json()["status"] + + assert status == "finished" + + +def test_asks(mdl_str: str): + with TestClient(app) as client: + semantics_preperation_id = str(uuid.uuid4()) + + response = client.post( + url="/v1/semantics-preparations/", + json={ + "mdl": mdl_str, + "id": semantics_preperation_id, + }, + ) + + status = "indexing" + while status != "finished": + response = client.get( + url=f"/v1/semantics-preparations/{semantics_preperation_id}/status/" + ) + status = response.json()["status"] + + response = client.post( + url="/v1/asks", + json={ + "query": "How many books are there?", + "id": semantics_preperation_id, + }, + ) + + assert response.status_code == 200 + assert response.json()["query_id"] != "" + + query_id = response.json()["query_id"] + + response = client.get(url=f"/v1/asks/{query_id}/result/") + while response.json()["status"] != "finished": + response = client.get(url=f"/v1/asks/{query_id}/result/") + + assert response.status_code == 200 + assert response.json()["status"] == "finished" + assert len(response.json()["response"]) == 3 + for r in response.json()["response"]: + assert r["sql"] is not None and r["sql"] != "" + assert r["summary"] is not None and r["summary"] != "" + + +def test_stop_asks(mdl_str: str): + with TestClient(app) as client: + semantics_preperation_id = str(uuid.uuid4()) + + response = client.post( + url="/v1/semantics-preparations/", + json={ + "mdl": mdl_str, + "id": semantics_preperation_id, + }, + ) + + status = "indexing" + while status != "finished": + response = client.get( + url=f"/v1/semantics-preparations/{semantics_preperation_id}/status/" + ) + status = response.json()["status"] + + response = client.post( + url="/v1/asks", + json={ + "query": "How many books are there?", + "id": semantics_preperation_id, + }, + ) + + query_id = response.json()["query_id"] + + response = client.patch( + url=f"/v1/asks/{query_id}", + json={ + "status": "stopped", + }, + ) + + assert response.status_code == 200 + assert response.json()["query_id"] == query_id + + response = client.get(url=f"/v1/asks/{query_id}/result/") + while response.json()["status"] != "stopped": + response = client.get(url=f"/v1/asks/{query_id}/result/") + + assert response.status_code == 200 + assert response.json()["status"] == "stopped" + + +def test_ask_details(mdl_str: str): + with TestClient(app) as client: + semantics_preperation_id = str(uuid.uuid4()) + + response = client.post( + url="/v1/semantics-preparations/", + json={ + "mdl": mdl_str, + "id": semantics_preperation_id, + }, + ) + + status = "indexing" + while status != "finished": + response = client.get( + url=f"/v1/semantics-preparations/{semantics_preperation_id}/status/" + ) + status = response.json()["status"] + + query = "How many books are there?" + response = client.post( + url="/v1/asks", + json={ + "query": query, + "id": semantics_preperation_id, + }, + ) + + query_id = response.json()["query_id"] + + response = client.get(url=f"/v1/asks/{query_id}/result/") + while response.json()["status"] != "finished": + response = client.get(url=f"/v1/asks/{query_id}/result/") + + sql = response.json()["response"][0]["sql"] + summary = response.json()["response"][0]["summary"] + + response = client.post( + url="/v1/ask-details/", + json={ + "query": query, + "sql": sql, + "summary": summary, + }, + ) + + assert response.status_code == 200 + assert response.json()["query_id"] != "" + + query_id = response.json()["query_id"] + response = client.get(url=f"/v1/ask-details/{query_id}/result/") + while response.json()["status"] != "finished": + response = client.get(url=f"/v1/ask-details/{query_id}/result/") + + assert response.status_code == 200 + assert response.json()["status"] == "finished" + assert response.json()["response"]["description"] != "" + assert len(response.json()["response"]["steps"]) >= 1 + + for step in response.json()["response"]["steps"]: + assert step["sql"] != "" + assert step["summary"] != "" From 7e1cb23ffb724293d5ba67f0b6662f3185808ec6 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Wed, 13 Mar 2024 15:57:56 +0800 Subject: [PATCH 2/3] modify the CI execution path --- .github/workflows/ai-service-ci.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ai-service-ci.yaml b/.github/workflows/ai-service-ci.yaml index a9bfbe04f4..8858556d7f 100644 --- a/.github/workflows/ai-service-ci.yaml +++ b/.github/workflows/ai-service-ci.yaml @@ -12,6 +12,10 @@ on: permissions: contents: read +defaults: + run: + working-directory: wren-ai-service + jobs: ci: strategy: @@ -35,7 +39,7 @@ jobs: - name: Run Qdrant run: docker run -p 6333:6333 -p 6334:6334 -d --name qdrant qdrant/qdrant:v1.7.4 - name: Test with pytest - run: poetry run pytest -s + run: poetry run pytest env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} ENV: dev From 39cfe63ace73a2018621377f6265a4463809b688 Mon Sep 17 00:00:00 2001 From: Pao-Sheng Wang Date: Wed, 13 Mar 2024 16:28:21 +0800 Subject: [PATCH 3/3] update README for wren ai service --- wren-ai-service/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/README.md b/wren-ai-service/README.md index e5349bafed..7861ef9dd0 100644 --- a/wren-ai-service/README.md +++ b/wren-ai-service/README.md @@ -1,4 +1,4 @@ -## Introduction +# AI Service of WrenAI ## Environment Setup