From 0d1be4ed415fd4079e477d67cfe36b11639019be Mon Sep 17 00:00:00 2001 From: samadpls Date: Thu, 7 Aug 2025 20:48:43 +0500 Subject: [PATCH 1/8] Refactor application structure and update dependencies - Removed outdated files and reorganized the project structure for better maintainability. - Introduced a new launcher script (`run.py`) to facilitate running the application in web or CLI mode. - Updated `requirements.txt` to include core, data analysis, training, and testing dependencies. - Added new backend and frontend components to enhance functionality and user experience. - Implemented new schemas for requests and responses to standardize data handling. - Enhanced CSS styles for improved UI presentation. This commit lays the groundwork for future development and feature enhancements. --- requirements.txt | 139 +++----------- run.py | 69 +++++++ src/app.py | 107 ----------- src/auth.py | 29 --- src/backend.py | 95 ---------- src/backend/backend.py | 28 +++ src/backend/orchestrator.py | 288 +++++++++++++++++++++++++++++ src/config/__init__.py | 3 + src/config/constants.py | 147 +++++++++++++++ src/constant.py | 14 -- src/database.py | 64 ------- src/frontend.py | 92 --------- src/frontend/app.py | 185 ++++++++++++++++++ src/frontend/frontend.py | 37 ++++ src/model.py | 22 --- src/schemas/__init__.py | 3 + src/schemas/requests.py | 74 ++++++++ src/schemas/responses.py | 143 ++++++++++++++ src/services/__init__.py | 3 + src/services/csv_analysis_tools.py | 217 ++++++++++++++++++++++ src/services/jupyter_service.py | 237 ++++++++++++++++++++++++ src/services/models.py | 1 + src/services/sql_service.py | 123 ++++++++++++ src/terminal/__init__.py | 3 + src/terminal/cli.py | 173 +++++++++++++++++ static/css/styles.css | 187 +++++++++++++++---- training/Querypls_prompt.py | 28 --- training/finetuning_querypls.py | 107 ----------- utils/prompt.py | 107 +++++++++++ 29 files changed, 2008 insertions(+), 717 deletions(-) create mode 100644 run.py delete mode 100644 src/app.py delete mode 100644 src/auth.py delete mode 100644 src/backend.py create mode 100644 src/backend/backend.py create mode 100644 src/backend/orchestrator.py create mode 100644 src/config/__init__.py create mode 100644 src/config/constants.py delete mode 100644 src/constant.py delete mode 100644 src/database.py delete mode 100644 src/frontend.py create mode 100644 src/frontend/app.py create mode 100644 src/frontend/frontend.py delete mode 100644 src/model.py create mode 100644 src/schemas/__init__.py create mode 100644 src/schemas/requests.py create mode 100644 src/schemas/responses.py create mode 100644 src/services/__init__.py create mode 100644 src/services/csv_analysis_tools.py create mode 100644 src/services/jupyter_service.py create mode 100644 src/services/models.py create mode 100644 src/services/sql_service.py create mode 100644 src/terminal/__init__.py create mode 100644 src/terminal/cli.py delete mode 100644 training/Querypls_prompt.py delete mode 100644 training/finetuning_querypls.py create mode 100644 utils/prompt.py diff --git a/requirements.txt b/requirements.txt index 93bbda8..25f7eb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,118 +1,21 @@ -aiohttp==3.9.5 -aiosignal==1.3.1 -altair==5.3.0 -annotated-types==0.7.0 -anyio==3.7.1 -async-timeout==4.0.3 -asyncio==3.4.3 -attrs==23.2.0 -black==24.4.2 -blinker==1.8.2 -cachetools==5.4.0 -certifi==2024.7.4 -charset-normalizer==3.3.2 -click==8.1.7 -dataclasses-json==0.6.7 -deta==1.2.0 -exceptiongroup==1.2.2 -filelock==3.15.4 -frozenlist==1.4.1 -fsspec==2024.6.1 -gitdb==4.0.11 -GitPython==3.1.43 -greenlet==3.0.3 -h11==0.14.0 -httpcore==0.17.3 -httpx==0.24.1 -httpx-oauth==0.13.0 -huggingface-hub==0.23.4 -idna==3.7 -importlib-metadata==6.11.0 -iniconfig==2.0.0 -Jinja2==3.1.5 -joblib==1.4.2 -jsonpatch==1.33 -jsonpointer==3.0.0 -jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -langchain==0.2.14 -langchain-core==0.2.32 -langchain-community>=0.0.37 -langchain-huggingface==0.0.3 -langchain-text-splitters==0.2.2 -langsmith==0.1.93 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -marshmallow==3.21.3 -mdurl==0.1.2 -mpmath==1.3.0 -multidict==6.0.5 -mypy-extensions==1.0.0 -networkx==3.3 -numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==9.1.0.70 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.5.82 -nvidia-nvtx-cu12==12.1.105 -orjson==3.10.6 -packaging==23.2 -pandas==2.2.2 -pathspec==0.12.1 -pillow==10.4.0 -platformdirs==4.2.2 -pluggy==1.5.0 -protobuf==4.25.4 -pyarrow==17.0.0 -pydantic==2.8.2 -pydantic_core==2.20.1 -pydeck==0.9.1 -Pygments==2.18.0 -pytest==8.3.2 -python-dateutil==2.9.0.post0 -python-dotenv==1.0.0 -pytz==2024.1 -PyYAML==6.0.1 -referencing==0.35.1 -regex==2024.7.24 -requests==2.32.3 -rich==13.7.1 -rpds-py==0.19.1 -safetensors==0.4.3 -scikit-learn==1.5.1 -scipy==1.14.0 -sentence-transformers==3.0.1 -six==1.16.0 -smmap==5.0.1 -sniffio==1.3.1 -SQLAlchemy==2.0.31 -streamlit==1.36.0 -streamlit-oauth==0.1.5 -sympy==1.13.1 -tenacity==8.5.0 -threadpoolctl==3.5.0 -tokenizers==0.19.1 -toml==0.10.2 -tomli==2.0.1 -toolz==0.12.1 -torch==2.4.0 -tornado==6.4.2 -tqdm==4.66.4 -transformers==4.48.0 -triton==3.0.0 -typing-inspect==0.9.0 -typing_extensions==4.12.2 -tzdata==2024.1 -tzlocal==5.2 -urllib3==2.2.2 -validators==0.33.0 -watchdog==4.0.1 -yarl==1.9.4 -zipp==3.19.2 +# Core application dependencies +streamlit>=1.36.0 +pydantic-ai-slim[groq]>=0.6.0 +pydantic>=2.0.0 +pydantic-settings>=2.0.0 + +# Data analysis dependencies +pandas>=2.0.0 +numpy>=1.24.0 +matplotlib>=3.7.0 +seaborn>=0.12.0 +jupyter-client>=8.0.0 + +# Training dependencies (optional - only needed for model training) +datasets>=2.14.0 +transformers>=4.48.0 +trl>=0.7.0 +peft>=0.6.0 + +# Testing dependencies (optional - only needed for running tests) +pytest>=8.3.0 diff --git a/run.py b/run.py new file mode 100644 index 0000000..369984e --- /dev/null +++ b/run.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +Launcher script for Querypls application. +""" + +import sys +import os +import argparse + +# Add src to path +sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) + + +def run_streamlit(): + """Run the Streamlit application.""" + import subprocess + import streamlit.web.cli as stcli + + # Set environment variables + os.environ['STREAMLIT_SERVER_PORT'] = '8501' + os.environ['STREAMLIT_SERVER_ADDRESS'] = 'localhost' + + # Run streamlit + sys.argv = [ + 'streamlit', 'run', + 'src/frontend/app.py', + '--server.port=8501', + '--server.address=localhost' + ] + sys.exit(stcli.main()) + + +def run_cli(): + """Run the CLI application.""" + from terminal.cli import main as cli_main + cli_main() + + +def main(): + """Main launcher function.""" + parser = argparse.ArgumentParser(description='Querypls - SQL Generation Tool') + parser.add_argument( + 'mode', + choices=['web', 'cli'], + default='web', + nargs='?', + help='Run mode: web (Streamlit) or cli (Command Line)' + ) + parser.add_argument( + 'cli_args', + nargs='*', + help='Arguments to pass to CLI (when mode is cli)' + ) + + args = parser.parse_args() + + if args.mode == 'web': + print("πŸš€ Starting Querypls Web Application...") + run_streamlit() + elif args.mode == 'cli': + print("πŸš€ Starting Querypls CLI...") + # Pass CLI arguments to the CLI + if args.cli_args: + sys.argv = ['cli'] + args.cli_args + run_cli() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/app.py b/src/app.py deleted file mode 100644 index 2f16301..0000000 --- a/src/app.py +++ /dev/null @@ -1,107 +0,0 @@ -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import PromptTemplate -import streamlit as st -import sys -import os -import json -from backend import ( - configure_page_styles, - display_github_badge, - hide_main_menu_and_footer, -) -from frontend import ( - create_message, - display_logo_and_heading, - display_previous_chats, - display_welcome_message, - handle_new_chat, -) -from model import create_huggingface_hub -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.auth import * -from src.constant import * - -def format_chat_history(messages): - """Format the chat history as a structured JSON string.""" - history = [] - for msg in messages[1:]: - content = msg['content'] - if '```sql' in content: - content = content.replace('```sql\n', '').replace('\n```', '').strip() - - history.append({ - "role": msg['role'], - "query" if msg['role'] == 'user' else "response": content - }) - - formatted_history = json.dumps(history, indent=2) - print("Formatted history:", formatted_history) - return formatted_history - -def extract_sql_code(response): - """Extract clean SQL code from the response.""" - sql_code_start = response.find("```sql") - if sql_code_start != -1: - sql_code_end = response.find("```", sql_code_start + 5) - if sql_code_end != -1: - sql_code = response[sql_code_start + 6:sql_code_end].strip() - return f"```sql\n{sql_code}\n```" - return response - -def main(): - """Main function to configure and run the Querypls application.""" - configure_page_styles("static/css/styles.css") - - if "model" not in st.session_state: - llm = create_huggingface_hub() - st.session_state["model"] = llm - - if "messages" not in st.session_state: - create_message() - - hide_main_menu_and_footer() - - with st.sidebar: - display_github_badge() - display_logo_and_heading() - st.markdown("`Made with 🀍`") - handle_new_chat() - - display_welcome_message() - for message in st.session_state.messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) - - if prompt := st.chat_input(): - st.session_state.messages.append({"role": "user", "content": prompt}) - with st.chat_message("user"): - st.markdown(prompt) - - conversation_history = format_chat_history(st.session_state.messages) - prompt_template = PromptTemplate( - template=TEMPLATE, - input_variables=["input", "conversation_history"] - ) - - if "model" in st.session_state: - llm_chain = prompt_template | st.session_state.model | StrOutputParser() - - with st.chat_message("assistant"): - with st.spinner("Generating..."): - response = llm_chain.invoke({ - "input": prompt, - "conversation_history": conversation_history - }) - - # Clean and format the response - formatted_response = extract_sql_code(response) - st.markdown(formatted_response) - - # Add to chat history - st.session_state.messages.append({ - "role": "assistant", - "content": formatted_response - }) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/src/auth.py b/src/auth.py deleted file mode 100644 index b16d3c2..0000000 --- a/src/auth.py +++ /dev/null @@ -1,29 +0,0 @@ -import asyncio -from src.constant import * -from httpx_oauth.clients.google import GoogleOAuth2 - - -async def get_authorization_url(client: GoogleOAuth2, redirect_uri: str): - authorization_url = await client.get_authorization_url( - redirect_uri, scope=["profile", "email"] - ) - return authorization_url - - -async def get_access_token(client: GoogleOAuth2, redirect_uri: str, code: str): - token = await client.get_access_token(code, redirect_uri) - return token - - -async def get_email(client: GoogleOAuth2, token: str): - user_id, user_email = await client.get_id_email(token) - return user_id, user_email - - -def get_login_str(): - client: GoogleOAuth2 = GoogleOAuth2(CLIENT_ID, CLIENT_SECRET) - authorization_url = asyncio.run( - get_authorization_url(client, REDIRECT_URI) - ) - return f"""\ -""" diff --git a/src/backend.py b/src/backend.py deleted file mode 100644 index 8dc0f24..0000000 --- a/src/backend.py +++ /dev/null @@ -1,95 +0,0 @@ -import streamlit as st -from streamlit_oauth import OAuth2Component -import sys -import os -import json -import base64 - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.auth import * -from src.constant import * - - -def configure_page_styles(file_name): - """Configures Streamlit page styles for Querypls. - - Sets page title, icon, and applies custom CSS styles. - Hides Streamlit main menu and footer for a cleaner interface. - - Note: - Ensure 'static/css/styles.css' exists with desired styles. - """ - st.set_page_config( - page_title="Querypls", - page_icon="πŸ’¬", - layout="wide", - ) - with open(file_name) as f: - st.markdown( - "".format(f.read()), unsafe_allow_html=True - ) - - hide_streamlit_style = """""" - st.markdown(hide_streamlit_style, unsafe_allow_html=True) - - -def hide_main_menu_and_footer(): - """Hides the Streamlit main menu and footer for a cleaner interface.""" - st.markdown( - """""", - unsafe_allow_html=True, - ) - - -def handle_google_login_if_needed(result): - """Handles Google login if it has not been run yet. - - Args: - result (str): Authorization code received from Google. - - Returns: - None - """ - try: - if result and "token" in result: - st.session_state.token = result.get("token") - token = st.session_state["token"] - id_token = token["id_token"] - payload = id_token.split(".")[1] - payload += "=" * (-len(payload) % 4) - payload = json.loads(base64.b64decode(payload)) - email = payload["email"] - st.session_state.user_email = email - st.session_state.code = True - return - except Exception: - st.warning( - "Seems like there is a network issue. \ - Please check your internet connection." - ) - sys.exit() - - -def display_github_badge(): - """Displays a GitHub badge with a link to the Querypls repository.""" - st.markdown( - """\ - """, - unsafe_allow_html=True, - ) - - -def create_oauth2_component(): - return OAuth2Component( - CLIENT_ID, - CLIENT_SECRET, - AUTHORIZE_URL, - TOKEN_URL, - REFRESH_TOKEN_URL, - REVOKE_TOKEN_URL, - ) diff --git a/src/backend/backend.py b/src/backend/backend.py new file mode 100644 index 0000000..015cc42 --- /dev/null +++ b/src/backend/backend.py @@ -0,0 +1,28 @@ +""" +Backend utilities for Streamlit configuration and styling. +""" + +import streamlit as st +import sys +import os + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from config.constants import STREAMLIT_CONFIG, HIDE_MENU_STYLE, HIDE_MENU_FOOTER_STYLE, GITHUB_BADGE + + +def configure_page_styles(file_name: str) -> None: + st.set_page_config(**STREAMLIT_CONFIG) + + with open(file_name) as f: + st.markdown(f"", unsafe_allow_html=True) + + st.markdown(HIDE_MENU_STYLE, unsafe_allow_html=True) + + +def hide_main_menu_and_footer() -> None: + st.markdown(HIDE_MENU_FOOTER_STYLE, unsafe_allow_html=True) + + +def display_github_badge() -> None: + st.markdown(GITHUB_BADGE, unsafe_allow_html=True) diff --git a/src/backend/orchestrator.py b/src/backend/orchestrator.py new file mode 100644 index 0000000..3cc05e1 --- /dev/null +++ b/src/backend/orchestrator.py @@ -0,0 +1,288 @@ +""" +Backend orchestrator for managing application state and services. +""" + +import uuid +from datetime import datetime +from typing import List, Optional, Dict, Any +from dataclasses import dataclass + +from config.settings import get_settings +from config.constants import ( + WELCOME_MESSAGE, DEFAULT_SESSION_NAME, CSV_LOAD_ERROR, CSV_ANALYSIS_ERROR, + SESSION_CREATE_ERROR, ORCHESTRATOR_INIT_ERROR, SESSION_NOT_FOUND_ERROR, + RESPONSE_GENERATION_ERROR, MESSAGE_LOAD_ERROR, MAX_CHAT_HISTORIES +) +from services.sql_service import SQLGenerationService +from services.csv_analysis_tools import CSVAnalysisTools, create_csv_analysis_agent +from schemas.requests import SQLGenerationRequest, ChatMessage, ConversationHistory, NewChatRequest +from schemas.responses import ChatResponse, SessionInfo, ErrorResponse, HealthCheckResponse + + +@dataclass +class Session: + session_id: str + session_name: str + created_at: datetime + messages: List[ChatMessage] + last_activity: datetime + csv_data: Optional[str] = None + + +class BackendOrchestrator: + def __init__(self): + self.settings = get_settings() + self.sql_service = SQLGenerationService() + self.csv_tools = CSVAnalysisTools() + self.csv_agent = create_csv_analysis_agent() + self.sessions: Dict[str, Session] = {} + self.max_sessions = self.settings.max_chat_histories + + def is_conversational_query(self, query: str) -> bool: + conversational_keywords = [ + 'hi', 'hello', 'hey', 'good morning', 'good afternoon', 'good evening', + 'how are you', 'what\'s up', 'thanks', 'thank you', 'bye', 'goodbye', + 'help', 'what can you do', 'who are you', 'tell me about yourself' + ] + query_lower = query.lower().strip() + return any(keyword in query_lower for keyword in conversational_keywords) + + def get_conversational_response(self, query: str) -> str: + query_lower = query.lower().strip() + + if any(greeting in query_lower for greeting in ['hi', 'hello', 'hey']): + return "Hello! πŸ‘‹ How can I help you today? I can assist with SQL generation or CSV data analysis." + elif 'how are you' in query_lower: + return "I'm doing great, thank you for asking! 😊 How can I assist you with your data queries today?" + elif any(thanks in query_lower for thanks in ['thanks', 'thank you']): + return "You're welcome! 😊 Is there anything else I can help you with?" + elif any(bye in query_lower for bye in ['bye', 'goodbye']): + return "Goodbye! πŸ‘‹ Feel free to come back if you need help with SQL or data analysis." + elif 'help' in query_lower or 'what can you do' in query_lower: + return "I'm Querypls, your SQL and data analysis assistant! πŸ—ƒοΈπŸ’¬\n\nI can help you with:\nβ€’ **SQL Generation**: Convert natural language to SQL queries\nβ€’ **CSV Analysis**: Analyze data files with Python code\nβ€’ **Data Visualization**: Create charts and graphs\n\nJust ask me anything about your data!" + elif 'who are you' in query_lower or 'tell me about yourself' in query_lower: + return "I'm Querypls, an AI assistant specialized in SQL generation and data analysis! πŸ—ƒοΈπŸ’¬\n\nI can help you write SQL queries from natural language and analyze CSV files with Python code. What would you like to work on?" + else: + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + + def create_new_session(self, request: NewChatRequest) -> SessionInfo: + session_id = str(uuid.uuid4()) + session_name = request.session_name or f"Chat {len(self.sessions) + 1}" + + messages = [] + if request.initial_context: + messages.append(ChatMessage( + role="system", + content=request.initial_context + )) + + messages.append(ChatMessage( + role="assistant", + content=WELCOME_MESSAGE + )) + + session = Session( + session_id=session_id, + session_name=session_name, + created_at=datetime.now(), + messages=messages, + last_activity=datetime.now() + ) + + self.sessions[session_id] = session + self._cleanup_old_sessions() + + return SessionInfo( + session_id=session_id, + session_name=session_name, + created_at=session.created_at.isoformat(), + message_count=len(session.messages), + last_activity=session.last_activity.isoformat() + ) + + def get_session(self, session_id: str) -> Optional[Session]: + return self.sessions.get(session_id) + + def list_sessions(self) -> List[SessionInfo]: + return [ + SessionInfo( + session_id=session.session_id, + session_name=session.session_name, + created_at=session.created_at.isoformat(), + message_count=len(session.messages), + last_activity=session.last_activity.isoformat() + ) + for session in self.sessions.values() + ] + + def delete_session(self, session_id: str) -> bool: + if session_id in self.sessions: + self.csv_tools.close_session(session_id) + del self.sessions[session_id] + return True + return False + + def load_csv_data(self, session_id: str, csv_content: str) -> Dict[str, Any]: + session = self.get_session(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + session.csv_data = csv_content + result = self.csv_tools.load_csv_data(csv_content, session_id) + session.last_activity = datetime.now() + + return result + + def generate_sql_response(self, session_id: str, user_query: str) -> ChatResponse: + session = self.get_session(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + user_message = ChatMessage( + role="user", + content=user_query, + timestamp=datetime.now().isoformat() + ) + session.messages.append(user_message) + + # Check if this is a conversational query + if self.is_conversational_query(user_query): + response_content = self.get_conversational_response(user_query) + else: + # Generate SQL for data-related queries + request = SQLGenerationRequest( + user_query=user_query, + conversation_history=session.messages + ) + + response = self.sql_service.generate_sql(request) + response_content = response.content + + assistant_message = ChatMessage( + role="assistant", + content=response_content, + timestamp=datetime.now().isoformat() + ) + session.messages.append(assistant_message) + session.last_activity = datetime.now() + + return ChatResponse( + message_id=str(uuid.uuid4()), + content=response_content, + timestamp=datetime.now().isoformat(), + session_id=session_id + ) + + def generate_csv_analysis_response(self, session_id: str, user_query: str) -> ChatResponse: + session = self.get_session(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + if not session.csv_data: + error_response = ChatResponse( + message_id=str(uuid.uuid4()), + content=CSV_LOAD_ERROR, + timestamp=datetime.now().isoformat(), + session_id=session_id + ) + return error_response + + user_message = ChatMessage( + role="user", + content=user_query, + timestamp=datetime.now().isoformat() + ) + session.messages.append(user_message) + + try: + result = self.csv_agent.run_sync(user_query) + + assistant_message = ChatMessage( + role="assistant", + content=result.output, + timestamp=datetime.now().isoformat() + ) + session.messages.append(assistant_message) + session.last_activity = datetime.now() + + return ChatResponse( + message_id=str(uuid.uuid4()), + content=result.output, + timestamp=datetime.now().isoformat(), + session_id=session_id + ) + + except Exception as e: + error_response = ChatResponse( + message_id=str(uuid.uuid4()), + content=CSV_ANALYSIS_ERROR.format(error=str(e)), + timestamp=datetime.now().isoformat(), + session_id=session_id + ) + return error_response + + def get_conversation_history(self, session_id: str) -> ConversationHistory: + session = self.get_session(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + return ConversationHistory( + messages=session.messages, + session_id=session_id + ) + + def get_csv_info(self, session_id: str) -> Dict[str, Any]: + return self.csv_tools.get_csv_info(session_id) + + def health_check(self) -> HealthCheckResponse: + services_status = { + "sql_service": "healthy", + "csv_analysis_service": "healthy", + "session_manager": "healthy" + } + + try: + test_request = SQLGenerationRequest( + user_query="SELECT 1", + conversation_history=[] + ) + if not self.sql_service: + services_status["sql_service"] = "unhealthy" + except Exception: + services_status["sql_service"] = "unhealthy" + + try: + if not self.csv_tools: + services_status["csv_analysis_service"] = "unhealthy" + except Exception: + services_status["csv_analysis_service"] = "unhealthy" + + return HealthCheckResponse( + status="healthy" if all(status == "healthy" for status in services_status.values()) else "unhealthy", + version=self.settings.app_version, + timestamp=datetime.now().isoformat(), + services=services_status + ) + + def _cleanup_old_sessions(self): + if len(self.sessions) <= self.max_sessions: + return + + sorted_sessions = sorted( + self.sessions.items(), + key=lambda x: x[1].last_activity + ) + + sessions_to_remove = len(self.sessions) - self.max_sessions + for i in range(sessions_to_remove): + session_id, _ = sorted_sessions[i] + self.delete_session(session_id) + + def get_default_session(self) -> str: + for session_id, session in self.sessions.items(): + if session.session_name == DEFAULT_SESSION_NAME: + return session_id + + request = NewChatRequest(session_name=DEFAULT_SESSION_NAME) + session_info = self.create_new_session(request) + return session_info.session_id \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..3622e45 --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,3 @@ +""" +Configuration package for Querypls application. +""" \ No newline at end of file diff --git a/src/config/constants.py b/src/config/constants.py new file mode 100644 index 0000000..4da7b39 --- /dev/null +++ b/src/config/constants.py @@ -0,0 +1,147 @@ +""" +Constants for Querypls application. +""" + +# Application Settings +MAX_RETRIES = 3 +EXECUTION_TIMEOUT = 30 +MAX_CHAT_HISTORIES = 5 +STREAMLIT_PORT = 8501 +STREAMLIT_HOST = "localhost" + +# Streamlit Configuration +STREAMLIT_CONFIG = { + "page_title": "Querypls", + "page_icon": "πŸ’¬", + "layout": "wide" +} + +# UI Messages +WELCOME_MESSAGE = "How may I help you? I can help with SQL generation or CSV data analysis." +DEFAULT_SESSION_NAME = "Default Chat" + +# Success Messages +CSV_LOAD_SUCCESS = "βœ… CSV loaded for analysis!" +CSV_UPLOAD_SUCCESS = "βœ… CSV uploaded successfully! Shape: {shape}" +CSV_LOADED_SUCCESS = "βœ… CSV loaded for analysis!" +SESSION_CREATED = "βœ… Created session: {name} (ID: {id})" +SESSION_SWITCHED = "βœ… Switched to session: {name}" + +# Error Messages +CSV_LOAD_ERROR = "❌ No CSV data loaded. Please upload a CSV file first." +CSV_ANALYSIS_ERROR = "❌ Error during CSV analysis: {error}" +CSV_UPLOAD_ERROR = "❌ Error reading CSV file: {error}" +SESSION_CREATE_ERROR = "Failed to create new session: {error}" +ORCHESTRATOR_INIT_ERROR = "Failed to initialize backend orchestrator: {error}" +SESSION_NOT_FOUND_ERROR = "Failed to get current session." +APP_INIT_ERROR = "Failed to initialize application. Please check your configuration." +RESPONSE_GENERATION_ERROR = "Error generating response: {error}" +MESSAGE_LOAD_ERROR = "Error loading messages: {error}" +SESSION_NOT_FOUND = "❌ Session {id} not found." +NO_ACTIVE_SESSION = "❌ No active session. Create one first with 'new' command." +NO_SESSION = "❌ No active session." +HEALTH_CHECK_FAILED = "❌ Health check failed: {error}" + +# Health Check Messages +HEALTH_CHECK_SUCCESS = "πŸ₯ Health Check:" +HEALTH_STATUS = "Status: {status}" +HEALTH_VERSION = "Version: {version}" +HEALTH_SERVICES = "Services: {services}" + +# CLI Messages +CLI_WELCOME = "πŸš€ Welcome to Querypls CLI!" +CLI_COMMANDS = "Commands: new, list, switch , chat , history, health, quit" +CLI_GOODBYE = "πŸ‘‹ Goodbye!" +CLI_UNKNOWN_COMMAND = "❌ Unknown command. Use: new, list, switch , chat , history, health, quit" +CLI_ERROR = "❌ Error: {error}" + +# Response Labels +RESPONSE_GENERATED = "πŸ€– Response:" +SQL_DETAILS = "πŸ“Š SQL Details:" +QUERY_TYPE = "Query Type:" +COMPLEXITY = "Complexity:" +TABLES_USED = "Tables Used:" +COLUMNS = "Columns:" +ESTIMATED_ROWS = "Estimated Rows:" +WARNINGS = "Warnings:" +CONVERSATION_HISTORY = "πŸ“œ Conversation History:" + +# Session Management +NO_SESSIONS = "πŸ“ No sessions found." +AVAILABLE_SESSIONS = "πŸ“ Available sessions:" +SESSION_INFO = " {num}. {name}" +SESSION_ID = " ID: {id}" +SESSION_MESSAGES = " Messages: {count}" +SESSION_ACTIVITY = " Last Activity: {activity}" + +# CSV Analysis UI +CSV_ANALYSIS_SECTION = "### πŸ“Š CSV Analysis" +CSV_UPLOAD_LABEL = "Upload CSV file for analysis" +CSV_UPLOAD_HELP = "Upload a CSV file to analyze with Python code" +CSV_PREVIEW = "πŸ“Š CSV Preview" +CSV_COLUMNS = "**Columns:** {columns}" +CSV_DTYPES = "**Data Types:** {dtypes}" +LOAD_CSV_BUTTON = "πŸ” Load CSV for Analysis" + +# UI Elements +GITHUB_BADGE = """\ +""" + +MADE_WITH_LOVE = "`Made with 🀍`" +SESSIONS_SECTION = "### Sessions" +NEW_SESSION_BUTTON = "βž• New Session" + +# CSS Styles +HIDE_MENU_STYLE = """""" + +HIDE_MENU_FOOTER_STYLE = """""" + +# Available Models +AVAILABLE_MODELS = { + "deepseek-r1-distill-llama-70b": { + "developer": "DeepSeek / Meta", + "context_window": 131072, + "max_completion": 131072, + "max_file_size": None + }, + "meta-llama/llama-4-maverick-17b-128e-instruct": { + "developer": "Meta", + "context_window": 131072, + "max_completion": 8192, + "max_file_size": "20 MB" + }, + "meta-llama/llama-4-scout-17b-16e-instruct": { + "developer": "Meta", + "context_window": 131072, + "max_completion": 8192, + "max_file_size": "20 MB" + }, + "moonshotai/kimi-k2-instruct": { + "developer": "Moonshot AI", + "context_window": 131072, + "max_completion": 16384, + "max_file_size": None + }, + "openai/gpt-oss-120b": { + "developer": "OpenAI", + "context_window": 131072, + "max_completion": 32766, + "max_file_size": None + }, + "openai/gpt-oss-20b": { + "developer": "OpenAI", + "context_window": 131072, + "max_completion": 32768, + "max_file_size": None + }, + "qwen/qwen3-32b": { + "developer": "Alibaba Cloud", + "context_window": 131072, + "max_completion": 131072, + "max_file_size": None + } +} \ No newline at end of file diff --git a/src/constant.py b/src/constant.py deleted file mode 100644 index d3c4790..0000000 --- a/src/constant.py +++ /dev/null @@ -1,14 +0,0 @@ -from streamlit import secrets - -DETA_PROJECT_KEY = secrets["DETA_PROJECT_KEY"] -HUGGINGFACE_API_TOKEN = secrets["HUGGINGFACE_API_TOKEN"] -REPO_ID = secrets["REPO_ID"] -CLIENT_ID = secrets["CLIENT_ID"] -CLIENT_SECRET = secrets["CLIENT_SECRET"] -REDIRECT_URI = secrets["REDIRECT_URI"] -TEMPLATE = secrets["TEMPLATE"] -AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth" -TOKEN_URL = "https://oauth2.googleapis.com/token" -REFRESH_TOKEN_URL = "https://oauth2.googleapis.com/token" -REVOKE_TOKEN_URL = "https://oauth2.googleapis.com/revoke" -SCOPE = "email" \ No newline at end of file diff --git a/src/database.py b/src/database.py deleted file mode 100644 index 0ed8c5b..0000000 --- a/src/database.py +++ /dev/null @@ -1,64 +0,0 @@ -import streamlit as st - - -def get_previous_chats(db, user_email): - """Fetches previous chat records for a user from the database. - - Args: - db: Deta Base instance. - user_email (str): User's email address. - - Returns: - list: List of previous chat records. - """ - return db.fetch({"email": user_email}).items - - -def database(db, previous_key="key", previous_chat=None, max_chat_histories=5): - """Manages user chat history in the database. - - Updates, adds, or removes chat history based on user interaction. - - Args: - db: Deta Base instance. - previous_key (str): Key for the previous chat in the database. - previous_chat (list, optional): Previous chat messages. - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - user_email = st.session_state.user_email - previous_chats = get_previous_chats(db, user_email) - existing_chat = db.get(previous_key) if previous_key != "key" else None - if ( - previous_chat is not None - and existing_chat is not None - and previous_key != "key" - ): - new_messages = [ - message - for message in previous_chat - if message not in existing_chat["chat"] - ] - existing_chat["chat"].extend(new_messages) - db.update({"chat": existing_chat["chat"]}, key=previous_key) - return - previous_chat = ( - st.session_state.messages if previous_chat is None else previous_chat - ) - if len(previous_chat) > 1 and previous_key == "key": - title = previous_chat[1]["content"] - db.put( - { - "email": user_email, - "chat": previous_chat, - "title": title[:25] + "....." if len(title) > 25 else title, - } - ) - - if len(previous_chats) >= max_chat_histories: - db.delete(previous_chats[0]["key"]) - st.warning( - f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories." - ) diff --git a/src/frontend.py b/src/frontend.py deleted file mode 100644 index 6241c37..0000000 --- a/src/frontend.py +++ /dev/null @@ -1,92 +0,0 @@ -import streamlit as st - -def display_logo_and_heading(): - """Displays the Querypls logo.""" - st.image("static/image/logo.png") - - -def display_welcome_message(): - """Displays a welcome message based on user chat history.""" - no_chat_history = len(st.session_state.messages) == 1 - if no_chat_history: - st.markdown(f"#### Welcome to \n ## πŸ—ƒοΈπŸ’¬Querypls - Prompt to SQL") - - -def handle_new_chat(max_chat_histories=5): - """Handles the initiation of a new chat session. - - Displays the remaining chat history count and provides a button to start a new chat. - - Args: - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - remaining_chats = max_chat_histories - len(st.session_state.get("previous_chats", [])) - st.markdown( - f" #### Remaining Chat Histories: `{remaining_chats}/{max_chat_histories}`" - ) - st.markdown( - "You can create up to 5 chat histories. Each history can contain unlimited messages." - ) - - if st.button("βž• New chat"): - save_chat_history() # Save current chat before creating a new one - create_message() - - -def display_previous_chats(): - """Displays previous chat records stored in session state. - - Allows the user to select a chat to view. - """ - if "previous_chats" in st.session_state: - reversed_chats = reversed(st.session_state["previous_chats"]) - - for chat in reversed_chats: - if st.button(chat["title"], key=chat["key"]): - update_session_state(chat) - - -def create_message(): - """Creates a default assistant message and initializes a session key.""" - st.session_state["messages"] = [ - {"role": "assistant", "content": "How may I help you?"} - ] - st.session_state["key"] = "key" - - -def update_session_state(chat): - """Updates the session state with selected chat information. - - Args: - chat (dict): Selected chat information. - """ - st.session_state["messages"] = chat["chat"] - st.session_state["key"] = chat["key"] - - -def save_chat_history(): - """Saves the current chat to session state if it contains messages.""" - if "messages" in st.session_state and len(st.session_state["messages"]) > 1: - # Initialize previous chats list if it doesn't exist - if "previous_chats" not in st.session_state: - st.session_state["previous_chats"] = [] - - # Create a chat summary to store in session - title = st.session_state["messages"][1]["content"] - chat_summary = { - "title": title[:25] + "....." if len(title) > 25 else title, - "chat": st.session_state["messages"], - "key": f"chat_{len(st.session_state['previous_chats']) + 1}" - } - - st.session_state["previous_chats"].append(chat_summary) - - # Limit chat histories to a maximum number - if len(st.session_state["previous_chats"]) > 5: - st.session_state["previous_chats"].pop(0) # Remove oldest chat - st.warning( - f"The oldest chat history has been removed as you reached the limit of 5 chat histories." - ) \ No newline at end of file diff --git a/src/frontend/app.py b/src/frontend/app.py new file mode 100644 index 0000000..e0a1e4c --- /dev/null +++ b/src/frontend/app.py @@ -0,0 +1,185 @@ +""" +Main Streamlit application for Querypls. +""" + +import streamlit as st +import sys +import os +import pandas as pd + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from backend.orchestrator import BackendOrchestrator +from backend.backend import ( + configure_page_styles, + display_github_badge, + hide_main_menu_and_footer, +) +from frontend import ( + display_logo_and_heading, + display_welcome_message, +) +from config.constants import ( + CSV_ANALYSIS_SECTION, CSV_UPLOAD_LABEL, CSV_UPLOAD_HELP, CSV_PREVIEW, + CSV_COLUMNS, CSV_DTYPES, LOAD_CSV_BUTTON, CSV_LOADED_SUCCESS, + CSV_UPLOAD_SUCCESS, CSV_UPLOAD_ERROR, SESSION_CREATE_ERROR, + ORCHESTRATOR_INIT_ERROR, SESSION_NOT_FOUND_ERROR, APP_INIT_ERROR, + RESPONSE_GENERATION_ERROR, MESSAGE_LOAD_ERROR, MADE_WITH_LOVE, + SESSIONS_SECTION, NEW_SESSION_BUTTON +) +from schemas.requests import NewChatRequest + + +def initialize_orchestrator(): + if "orchestrator" not in st.session_state: + try: + st.session_state["orchestrator"] = BackendOrchestrator() + except Exception as e: + st.error(ORCHESTRATOR_INIT_ERROR.format(error=str(e))) + return None + return st.session_state["orchestrator"] + + +def get_current_session_id(): + if "current_session_id" not in st.session_state: + orchestrator = initialize_orchestrator() + if orchestrator: + st.session_state["current_session_id"] = orchestrator.get_default_session() + return st.session_state.get("current_session_id") + + +def display_messages(session_id: str): + orchestrator = initialize_orchestrator() + if not orchestrator: + return + + try: + conversation = orchestrator.get_conversation_history(session_id) + for message in conversation.messages: + with st.chat_message(message.role): + st.markdown(message.content) + except Exception as e: + st.error(MESSAGE_LOAD_ERROR.format(error=str(e))) + + +def upload_csv_file(): + uploaded_file = st.file_uploader( + CSV_UPLOAD_LABEL, + type=['csv'], + help=CSV_UPLOAD_HELP + ) + + if uploaded_file is not None: + try: + csv_content = uploaded_file.read().decode('utf-8') + df = pd.read_csv(uploaded_file) + st.success(CSV_UPLOAD_SUCCESS.format(shape=df.shape)) + + with st.expander(CSV_PREVIEW): + st.dataframe(df.head()) + st.write(CSV_COLUMNS.format(columns=list(df.columns))) + st.write(CSV_DTYPES.format(dtypes=df.dtypes.to_dict())) + + return csv_content + except Exception as e: + st.error(CSV_UPLOAD_ERROR.format(error=str(e))) + return None + + return None + + +def main(): + configure_page_styles("static/css/styles.css") + + orchestrator = initialize_orchestrator() + if not orchestrator: + st.error(APP_INIT_ERROR) + return + + current_session_id = get_current_session_id() + if not current_session_id: + st.error(SESSION_NOT_FOUND_ERROR) + return + + hide_main_menu_and_footer() + + with st.sidebar: + display_github_badge() + display_logo_and_heading() + st.markdown(MADE_WITH_LOVE) + + st.markdown(SESSIONS_SECTION) + sessions = orchestrator.list_sessions() + + for session in sessions: + if st.button( + session.session_name, + key=f"session_{session.session_id}", + help=f"Messages: {session.message_count}, Last: {session.last_activity}" + ): + st.session_state["current_session_id"] = session.session_id + st.rerun() + + if st.button(NEW_SESSION_BUTTON): + try: + new_session = orchestrator.create_new_session( + NewChatRequest(session_name=f"Chat {len(sessions) + 1}") + ) + st.session_state["current_session_id"] = new_session.session_id + st.rerun() + except Exception as e: + st.error(SESSION_CREATE_ERROR.format(error=str(e))) + + st.markdown("---") + st.markdown(CSV_ANALYSIS_SECTION) + + csv_content = upload_csv_file() + if csv_content: + if st.button(LOAD_CSV_BUTTON): + try: + result = orchestrator.load_csv_data(current_session_id, csv_content) + if result["status"] == "success": + st.success(CSV_LOADED_SUCCESS) + st.session_state["csv_loaded"] = True + else: + st.error(f"❌ Error loading CSV: {result['message']}") + except Exception as e: + st.error(f"❌ Error: {str(e)}") + + display_welcome_message() + display_messages(current_session_id) + + if prompt := st.chat_input(): + try: + csv_loaded = st.session_state.get("csv_loaded", False) + + if csv_loaded: + response = orchestrator.generate_csv_analysis_response(current_session_id, prompt) + else: + response = orchestrator.generate_sql_response(current_session_id, prompt) + + with st.chat_message("user"): + st.markdown(prompt) + + with st.chat_message("assistant"): + st.markdown(response.content) + + if hasattr(response, 'sql_response') and response.sql_response: + with st.expander("πŸ“Š SQL Details"): + col1, col2 = st.columns(2) + with col1: + st.write("**Query Type:**", response.sql_response.query_type) + st.write("**Complexity:**", response.sql_response.complexity) + st.write("**Tables Used:**", ", ".join(response.sql_response.tables_used)) + with col2: + st.write("**Columns:**", ", ".join(response.sql_response.columns_selected)) + st.write("**Estimated Rows:**", response.sql_response.estimated_rows) + if response.sql_response.warnings: + st.write("**Warnings:**", ", ".join(response.sql_response.warnings)) + + except Exception as e: + st.error(RESPONSE_GENERATION_ERROR.format(error=str(e))) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/frontend/frontend.py b/src/frontend/frontend.py new file mode 100644 index 0000000..77374d9 --- /dev/null +++ b/src/frontend/frontend.py @@ -0,0 +1,37 @@ +""" +Frontend utilities for Streamlit interface components. +""" + +import streamlit as st + + +def display_logo_and_heading(): + st.image("static/image/logo.png") + + +def display_welcome_message(): + st.markdown("#### Welcome to \n ## πŸ—ƒοΈπŸ’¬Querypls - Prompt to SQL") + + +def handle_new_chat(max_chat_histories=5): + st.markdown(f"#### Remaining Chat Histories: `{max_chat_histories}`") + st.markdown("You can create multiple chat sessions. Each session can contain unlimited messages.") + + if st.button("βž• New chat"): + st.rerun() + + +def display_previous_chats(): + pass + + +def create_message(): + pass + + +def update_session_state(chat): + pass + + +def save_chat_history(): + pass \ No newline at end of file diff --git a/src/model.py b/src/model.py deleted file mode 100644 index 788c738..0000000 --- a/src/model.py +++ /dev/null @@ -1,22 +0,0 @@ -from langchain_community.llms import HuggingFaceHub -import sys -import os - - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.auth import * -from src.constant import * - - -def create_huggingface_hub(): - """Creates an instance of Hugging Face Hub with specified configurations. - - Returns: - HuggingFaceHub: Instance of Hugging Face Hub. - """ - return HuggingFaceHub( - huggingfacehub_api_token=HUGGINGFACE_API_TOKEN, - repo_id=REPO_ID, - model_kwargs={"temperature": 0.7, "max_new_tokens": 180}, - ) diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000..a9c8969 --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1,3 @@ +""" +Pydantic schemas for Querypls application. +""" \ No newline at end of file diff --git a/src/schemas/requests.py b/src/schemas/requests.py new file mode 100644 index 0000000..0242209 --- /dev/null +++ b/src/schemas/requests.py @@ -0,0 +1,74 @@ +""" +Request schemas for Querypls application. +""" + +from typing import List, Optional, Literal +from pydantic import BaseModel, Field + + +class ChatMessage(BaseModel): + """Schema for chat message.""" + + role: Literal["user", "assistant", "system"] = Field( + description="Message role (user, assistant, system)" + ) + content: str = Field( + description="Message content", + min_length=1 + ) + timestamp: Optional[str] = Field( + default=None, + description="Message timestamp" + ) + session_id: Optional[str] = Field( + default=None, + description="Session identifier" + ) + + +class SQLGenerationRequest(BaseModel): + """Schema for SQL generation request.""" + + user_query: str = Field( + description="User's natural language query for SQL generation", + min_length=1, + max_length=1000 + ) + conversation_history: List[ChatMessage] = Field( + default=[], + description="Previous conversation messages for context" + ) + database_schema: Optional[str] = Field( + default=None, + description="Database schema information (optional)" + ) + query_type: Optional[str] = Field( + default=None, + description="Preferred query type (SELECT, INSERT, UPDATE, DELETE)" + ) + + +class ConversationHistory(BaseModel): + """Schema for conversation history.""" + + messages: List[ChatMessage] = Field( + default=[], + description="List of conversation messages" + ) + session_id: Optional[str] = Field( + default=None, + description="Session identifier" + ) + + +class NewChatRequest(BaseModel): + """Schema for creating a new chat session.""" + + session_name: Optional[str] = Field( + default=None, + description="Name for the new chat session" + ) + initial_context: Optional[str] = Field( + default=None, + description="Initial context or instructions" + ) \ No newline at end of file diff --git a/src/schemas/responses.py b/src/schemas/responses.py new file mode 100644 index 0000000..bd2d23f --- /dev/null +++ b/src/schemas/responses.py @@ -0,0 +1,143 @@ +""" +Response schemas for Querypls application. +""" + +from typing import List, Optional, Literal +from pydantic import BaseModel, Field + + +class SQLQueryResponse(BaseModel): + """Schema for SQL query generation response.""" + + sql_query: str = Field( + ..., + description="The generated SQL query as a string" + ) + explanation: str = Field( + ..., + description="Brief explanation of what the query does" + ) + tables_used: List[str] = Field( + default=[], + description="Array of table names used in the query" + ) + columns_selected: List[str] = Field( + default=[], + description="Array of column names selected in the query" + ) + query_type: Literal["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"] = Field( + ..., + description="Type of query generated" + ) + complexity: Literal["SIMPLE", "MEDIUM", "COMPLEX"] = Field( + ..., + description="Query complexity level" + ) + estimated_rows: str = Field( + default="variable", + description="Estimated number of rows returned (if applicable)" + ) + execution_time: Optional[str] = Field( + default=None, + description="Estimated execution time" + ) + warnings: List[str] = Field( + default=[], + description="Any warnings about the generated query" + ) + + +class ChatResponse(BaseModel): + """Schema for chat response.""" + + message_id: str = Field( + ..., + description="Unique identifier for the message" + ) + role: Literal["assistant"] = Field( + default="assistant", + description="Message role" + ) + content: str = Field( + ..., + description="Response content" + ) + sql_response: Optional[SQLQueryResponse] = Field( + default=None, + description="Structured SQL response if applicable" + ) + timestamp: str = Field( + ..., + description="Response timestamp" + ) + session_id: str = Field( + ..., + description="Session identifier" + ) + + +class ErrorResponse(BaseModel): + """Schema for error responses.""" + + error_code: str = Field( + ..., + description="Error code identifier" + ) + error_message: str = Field( + ..., + description="Human-readable error message" + ) + details: Optional[str] = Field( + default=None, + description="Additional error details" + ) + timestamp: str = Field( + ..., + description="Error timestamp" + ) + + +class SessionInfo(BaseModel): + """Schema for session information.""" + + session_id: str = Field( + ..., + description="Unique session identifier" + ) + session_name: str = Field( + ..., + description="Session name" + ) + created_at: str = Field( + ..., + description="Session creation timestamp" + ) + message_count: int = Field( + ..., + description="Number of messages in the session" + ) + last_activity: str = Field( + ..., + description="Last activity timestamp" + ) + + +class HealthCheckResponse(BaseModel): + """Schema for health check response.""" + + status: Literal["healthy", "unhealthy"] = Field( + ..., + description="Application health status" + ) + version: str = Field( + ..., + description="Application version" + ) + timestamp: str = Field( + ..., + description="Health check timestamp" + ) + services: dict = Field( + default={}, + description="Status of individual services" + ) \ No newline at end of file diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..e809ff8 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,3 @@ +""" +Services package for Querypls application. +""" \ No newline at end of file diff --git a/src/services/csv_analysis_tools.py b/src/services/csv_analysis_tools.py new file mode 100644 index 0000000..38b14ca --- /dev/null +++ b/src/services/csv_analysis_tools.py @@ -0,0 +1,217 @@ +import io +import pandas as pd +from typing import Dict, Any, Optional +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider +from pydantic import BaseModel, Field + +from config.settings import get_settings +from services.jupyter_service import CSVAnalysisService +from utils.prompt import CSV_ANALYSIS_PROMPT, CODE_FIX_PROMPT, CSV_AGENT_PROMPT + + +class CSVAnalysisContext(BaseModel): + session_id: str + csv_content: str + csv_headers: list + sample_data: list + + +class PythonCodeResponse(BaseModel): + python_code: str = Field(description="Generated Python code for data analysis") + explanation: str = Field(description="Explanation of what the code does") + expected_output: str = Field(description="What output is expected from the code") + libraries_used: list = Field(description="List of Python libraries used") + + +class CodeExecutionResult(BaseModel): + status: str = Field(description="Execution status: success, error, or retry") + output: str = Field(description="Output from code execution") + error_message: Optional[str] = Field(description="Error message if execution failed") + execution_time: float = Field(description="Time taken to execute the code") + attempt: int = Field(description="Attempt number") + + +class CSVAnalysisTools: + def __init__(self): + self.settings = get_settings() + self.csv_service = CSVAnalysisService() + + self.code_generation_model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key) + ) + + self.code_generation_agent = Agent( + self.code_generation_model, + instructions=CSV_ANALYSIS_PROMPT, + output_type=PythonCodeResponse + ) + + self.code_fixing_model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key) + ) + + self.code_fixing_agent = Agent( + self.code_fixing_model, + instructions=CODE_FIX_PROMPT, + output_type=PythonCodeResponse + ) + + def load_csv_data(self, csv_content: str, session_id: str) -> Dict[str, Any]: + return self.csv_service.load_csv_data(session_id, csv_content) + + def generate_analysis_code(self, user_query: str, csv_context: CSVAnalysisContext) -> PythonCodeResponse: + prompt = f""" +CSV Headers: {csv_context.csv_headers} +Sample Data: {csv_context.sample_data[:3]} +User Query: {user_query} + +Generate Python code that: +1. Uses pandas for data manipulation +2. Creates visualizations if requested +3. Returns clear output +4. Handles the CSV data properly +""" + + result = self.code_generation_agent.run_sync(prompt) + return result.output + + def execute_analysis_code(self, python_code: str, session_id: str, max_retries: int = 3) -> CodeExecutionResult: + result = self.csv_service.execute_analysis(session_id, python_code, max_retries) + + return CodeExecutionResult( + status=result["status"], + output=result.get("output", ""), + error_message=result.get("error_message"), + execution_time=result.get("execution_time", 0.0), + attempt=result.get("attempt", 1) + ) + + def fix_code_error(self, original_code: str, error_message: str, csv_context: CSVAnalysisContext) -> PythonCodeResponse: + prompt = f""" +Original Code: +{original_code} + +Error Message: +{error_message} + +CSV Headers: {csv_context.csv_headers} +Sample Data: {csv_context.sample_data[:3]} + +Please fix the code to resolve the error and ensure it works correctly. +""" + + result = self.code_fixing_agent.run_sync(prompt) + return result.output + + def get_csv_info(self, session_id: str) -> Dict[str, Any]: + return self.csv_service.get_csv_info(session_id) + + def close_session(self, session_id: str): + self.csv_service.close_session(session_id) + + +def create_csv_analysis_agent() -> Agent: + settings = get_settings() + + model = GroqModel( + settings.groq_model_name, + provider=GroqProvider(api_key=settings.groq_api_key) + ) + + agent = Agent( + model, + instructions=CSV_AGENT_PROMPT, + output_type=str + ) + + csv_tools = CSVAnalysisTools() + + @agent.tool + async def load_csv_data(ctx: RunContext[None], csv_content: str, session_id: str) -> str: + result = csv_tools.load_csv_data(csv_content, session_id) + if result["status"] == "success": + return f"CSV loaded successfully! Shape: {result['shape']}, Columns: {result['columns']}" + else: + return f"Error loading CSV: {result['message']}" + + @agent.tool + async def generate_analysis_code(ctx: RunContext[None], user_query: str, session_id: str) -> str: + csv_info = csv_tools.get_csv_info(session_id) + if csv_info["status"] != "success": + return f"Error: {csv_info['message']}" + + csv_context = CSVAnalysisContext( + session_id=session_id, + csv_content="", + csv_headers=csv_info["columns"], + sample_data=csv_info["sample_data"] + ) + + result = csv_tools.generate_analysis_code(user_query, csv_context) + return f"""Generated Python Code: +```python +{result.python_code} +``` + +Explanation: {result.explanation} +Expected Output: {result.expected_output} +Libraries Used: {', '.join(result.libraries_used)}""" + + @agent.tool + async def execute_analysis_code(ctx: RunContext[None], python_code: str, session_id: str) -> str: + result = csv_tools.execute_analysis_code(python_code, session_id) + + if result.status == "success": + return f"""βœ… Code executed successfully! +Execution Time: {result.execution_time:.2f}s +Attempt: {result.attempt} + +Output: +{result.output}""" + else: + return f"""❌ Code execution failed! +Attempt: {result.attempt} +Error: {result.error_message} + +Output: +{result.output}""" + + @agent.tool + async def fix_code_error(ctx: RunContext[None], original_code: str, error_message: str, session_id: str) -> str: + csv_info = csv_tools.get_csv_info(session_id) + if csv_info["status"] != "success": + return f"Error: {csv_info['message']}" + + csv_context = CSVAnalysisContext( + session_id=session_id, + csv_content="", + csv_headers=csv_info["columns"], + sample_data=csv_info["sample_data"] + ) + + result = csv_tools.fix_code_error(original_code, error_message, csv_context) + return f"""πŸ”§ Fixed Code: +```python +{result.python_code} +``` + +Explanation: {result.explanation} +Expected Output: {result.expected_output}""" + + @agent.tool + async def get_csv_info(ctx: RunContext[None], session_id: str) -> str: + result = csv_tools.get_csv_info(session_id) + if result["status"] == "success": + return f"""πŸ“Š CSV Information: +Shape: {result['shape']} +Columns: {result['columns']} +Data Types: {result['dtypes']} +Sample Data: {result['sample_data'][:2]}""" + else: + return f"Error: {result['message']}" + + return agent \ No newline at end of file diff --git a/src/services/jupyter_service.py b/src/services/jupyter_service.py new file mode 100644 index 0000000..83c3dad --- /dev/null +++ b/src/services/jupyter_service.py @@ -0,0 +1,237 @@ +""" +Jupyter service for executing Python code with CSV data analysis. +""" + +import os +import jupyter_client +import inspect +import time +import re +import pandas as pd +from typing import Dict, Any, Optional +from dataclasses import dataclass + +from config.constants import EXECUTION_TIMEOUT, MAX_RETRIES + + +def clean_error_message(error_msg: str) -> str: + ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + cleaned_msg = ansi_escape.sub('', error_msg) + lines = cleaned_msg.split('\n') + lines = [line.strip() for line in lines if line.strip()] + cleaned_msg = '\n'.join(lines) + return cleaned_msg + + +@dataclass +class ExecutionResult: + output: str + status: str + error_message: Optional[str] = None + execution_time: float = 0.0 + + +class SimpleJupyterClient: + def __init__(self): + self.clients: Dict[str, Any] = {} + self.globals: Dict[str, Dict[str, Any]] = {} + + def create_new_session(self, session_id: str = "default", kernel_name: str = 'python3') -> str: + if session_id in self.clients: + return session_id + + km = jupyter_client.KernelManager(kernel_name=kernel_name) + km.start_kernel() + client = km.client() + self.clients[session_id] = client + self.globals[session_id] = {} + + for key, value in os.environ.items(): + self.execute_code(f"{key} = '{value}'", session_id) + + self.execute_code("import pandas as pd", session_id) + self.execute_code("import numpy as np", session_id) + self.execute_code("import matplotlib.pyplot as plt", session_id) + self.execute_code("import seaborn as sns", session_id) + + return session_id + + def execute_code(self, code: str, session_id: str = "default") -> ExecutionResult: + if session_id not in self.clients: + raise ValueError(f"Session {session_id} not found") + + client = self.clients[session_id] + start_time = time.time() + + msg_id = client.execute(code) + output = [] + timeout = time.time() + EXECUTION_TIMEOUT + status = "Success" + error_message = None + + while True: + try: + msg = client.get_iopub_msg(timeout=1) + if 'parent_header' not in msg or msg['parent_header'].get('msg_id') != msg_id: + continue + + msg_type = msg.get('msg_type', '') + content = msg.get('content', {}) + + if msg_type == 'execute_result': + output.append(str(content.get('data', {}).get('text/plain', ''))) + elif msg_type == 'stream': + output.append(content.get('text', '')) + elif msg_type == 'error': + error_traceback = "\n".join(content.get('traceback', [])) + cleaned_error = clean_error_message(error_traceback) + output.append(f"Error: {cleaned_error}") + error_message = cleaned_error + status = 'Fail' + elif msg_type == 'status' and content.get('execution_state') == 'idle': + break + except Exception as e: + pass + + execution_time = time.time() - start_time + + return ExecutionResult( + output='\n'.join(output).strip(), + status=status, + error_message=error_message, + execution_time=execution_time + ) + + def import_function(self, func, session_id: str = "default") -> ExecutionResult: + if session_id not in self.globals: + raise ValueError(f"Session {session_id} not found") + + func_code = inspect.getsource(func) + func_name = func.__name__ + + result = self.execute_code(func_code, session_id) + if result.status == "Success": + self.globals[session_id][func_name] = func + + return result + + def close_session(self, session_id: str = "default"): + if session_id not in self.clients: + raise ValueError(f"Session {session_id} not found") + + client = self.clients[session_id] + client.stop_channels() + del self.clients[session_id] + del self.globals[session_id] + + def close_all_sessions(self): + for session_id in list(self.clients.keys()): + self.close_session(session_id) + + +class CSVAnalysisService: + def __init__(self): + self.jupyter_client = SimpleJupyterClient() + self.csv_data: Dict[str, pd.DataFrame] = {} + self.csv_headers: Dict[str, list] = {} + + def load_csv_data(self, session_id: str, csv_content: str, filename: str = "data.csv") -> Dict[str, Any]: + try: + self.jupyter_client.create_new_session(session_id) + + csv_code = f""" +import pandas as pd +import io + +csv_content = '''{csv_content}''' +df = pd.read_csv(io.StringIO(csv_content)) +print("CSV loaded successfully!") +print(f"Shape: {{df.shape}}") +print("\\nColumns:") +print(df.columns.tolist()) +print("\\nFirst few rows:") +print(df.head()) +""" + + result = self.jupyter_client.execute_code(csv_code, session_id) + + if result.status == "Success": + df = pd.read_csv(io.StringIO(csv_content)) + self.csv_data[session_id] = df + self.csv_headers[session_id] = df.columns.tolist() + + return { + "status": "success", + "message": "CSV loaded successfully", + "shape": df.shape, + "columns": df.columns.tolist(), + "sample_data": df.head().to_dict('records') + } + else: + return { + "status": "error", + "message": result.error_message or "Failed to load CSV" + } + + except Exception as e: + return { + "status": "error", + "message": str(e) + } + + def execute_analysis(self, session_id: str, python_code: str, max_retries: int = MAX_RETRIES) -> Dict[str, Any]: + for attempt in range(max_retries): + try: + result = self.jupyter_client.execute_code(python_code, session_id) + + if result.status == "Success": + return { + "status": "success", + "output": result.output, + "execution_time": result.execution_time, + "attempt": attempt + 1 + } + else: + if attempt == max_retries - 1: + return { + "status": "error", + "error_message": result.error_message, + "output": result.output, + "attempt": attempt + 1 + } + continue + + except Exception as e: + if attempt == max_retries - 1: + return { + "status": "error", + "error_message": str(e), + "attempt": attempt + 1 + } + continue + + return { + "status": "error", + "error_message": "Max retries exceeded", + "attempt": max_retries + } + + def get_csv_info(self, session_id: str) -> Dict[str, Any]: + if session_id not in self.csv_data: + return {"status": "error", "message": "No CSV data loaded for this session"} + + df = self.csv_data[session_id] + return { + "status": "success", + "shape": df.shape, + "columns": df.columns.tolist(), + "dtypes": df.dtypes.to_dict(), + "sample_data": df.head().to_dict('records') + } + + def close_session(self, session_id: str): + self.jupyter_client.close_session(session_id) + if session_id in self.csv_data: + del self.csv_data[session_id] + if session_id in self.csv_headers: + del self.csv_headers[session_id] \ No newline at end of file diff --git a/src/services/models.py b/src/services/models.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/src/services/models.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/services/sql_service.py b/src/services/sql_service.py new file mode 100644 index 0000000..9c2ec8c --- /dev/null +++ b/src/services/sql_service.py @@ -0,0 +1,123 @@ +import json +import uuid +from datetime import datetime +from typing import Optional +from pydantic_ai import Agent +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +from config.settings import get_settings +from schemas.requests import SQLGenerationRequest, ChatMessage +from schemas.responses import SQLQueryResponse, ChatResponse, ErrorResponse +from utils.prompt import SQL_GENERATION_PROMPT + + +class SQLGenerationService: + def __init__(self, api_key: Optional[str] = None): + self.settings = get_settings() + self.api_key = api_key or self.settings.groq_api_key + + if not self.api_key: + raise ValueError("Groq API key is required. Set GROQ_API_KEY environment variable or pass api_key parameter.") + + self.model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.api_key) + ) + + self.agent = Agent( + self.model, + instructions=SQL_GENERATION_PROMPT, + output_type=SQLQueryResponse + ) + + def format_chat_history(self, messages: list) -> str: + history = [] + for msg in messages[1:]: + if isinstance(msg, ChatMessage): + content = msg.content + role = msg.role + else: + content = msg.get('content', '') + role = msg.get('role', 'user') + + if '```sql' in content: + content = content.replace('```sql\n', '').replace('\n```', '').strip() + + history.append({ + "role": role, + "query" if role == 'user' else "response": content + }) + + return json.dumps(history, indent=2) + + def generate_sql(self, request: SQLGenerationRequest) -> ChatResponse: + try: + formatted_history = self.format_chat_history(request.conversation_history) + prompt = f"Previous conversation: {formatted_history}\nCurrent question: {request.user_query}" + + result = self.agent.run_sync(prompt) + + sql_response = SQLQueryResponse( + sql_query=result.output.sql_query, + explanation=result.output.explanation, + tables_used=result.output.tables_used, + columns_selected=result.output.columns_selected, + query_type=result.output.query_type, + complexity=result.output.complexity, + estimated_rows=result.output.estimated_rows, + execution_time=result.output.execution_time, + warnings=result.output.warnings + ) + + formatted_content = f"```sql\n{sql_response.sql_query}\n```\n\n**Explanation:** {sql_response.explanation}" + + session_id = 'default' + if request.conversation_history: + first_msg = request.conversation_history[0] + if isinstance(first_msg, ChatMessage): + session_id = first_msg.session_id or 'default' + else: + session_id = first_msg.get('session_id', 'default') + + chat_response = ChatResponse( + message_id=str(uuid.uuid4()), + content=formatted_content, + sql_response=sql_response, + timestamp=datetime.now().isoformat(), + session_id=session_id + ) + + return chat_response + + except Exception as e: + error_response = ErrorResponse( + error_code="SQL_GENERATION_ERROR", + error_message=f"Error generating SQL: {str(e)}", + details=str(e), + timestamp=datetime.now().isoformat() + ) + + session_id = 'default' + if request.conversation_history: + first_msg = request.conversation_history[0] + if isinstance(first_msg, ChatMessage): + session_id = first_msg.session_id or 'default' + else: + session_id = first_msg.get('session_id', 'default') + + return ChatResponse( + message_id=str(uuid.uuid4()), + content=f"❌ Error: {error_response.error_message}", + timestamp=datetime.now().isoformat(), + session_id=session_id + ) + + def generate_sql_legacy(self, user_query: str, conversation_history: list) -> str: + request = SQLGenerationRequest( + user_query=user_query, + conversation_history=conversation_history + ) + + response = self.generate_sql(request) + return response.content \ No newline at end of file diff --git a/src/terminal/__init__.py b/src/terminal/__init__.py new file mode 100644 index 0000000..e08c65f --- /dev/null +++ b/src/terminal/__init__.py @@ -0,0 +1,3 @@ +""" +Terminal interface package for Querypls. +""" \ No newline at end of file diff --git a/src/terminal/cli.py b/src/terminal/cli.py new file mode 100644 index 0000000..15cc23f --- /dev/null +++ b/src/terminal/cli.py @@ -0,0 +1,173 @@ +""" +Command-line interface for Querypls SQL generation. +""" + +import sys +import os +import json +from typing import Optional + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from backend.orchestrator import BackendOrchestrator +from schemas.requests import NewChatRequest +from config.constants import ( + CLI_WELCOME, CLI_COMMANDS, CLI_GOODBYE, CLI_UNKNOWN_COMMAND, CLI_ERROR, + SESSION_CREATED, SESSION_SWITCHED, SESSION_NOT_FOUND, NO_ACTIVE_SESSION, + NO_SESSION, RESPONSE_GENERATED, SQL_DETAILS, QUERY_TYPE, COMPLEXITY, + TABLES_USED, COLUMNS, ESTIMATED_ROWS, WARNINGS, CONVERSATION_HISTORY, + HEALTH_CHECK_FAILED, NO_SESSIONS, AVAILABLE_SESSIONS, SESSION_INFO, + SESSION_ID, SESSION_MESSAGES, SESSION_ACTIVITY, HEALTH_CHECK_SUCCESS, + HEALTH_STATUS, HEALTH_VERSION, HEALTH_SERVICES, DEFAULT_SESSION_NAME +) + + +class QueryplsCLI: + def __init__(self): + self.orchestrator = BackendOrchestrator() + self.current_session_id = None + + def create_session(self, name: Optional[str] = None) -> str: + request = NewChatRequest(session_name=name) + session_info = self.orchestrator.create_new_session(request) + self.current_session_id = session_info.session_id + print(SESSION_CREATED.format(name=session_info.session_name, id=session_info.session_id)) + return session_info.session_id + + def list_sessions(self): + sessions = self.orchestrator.list_sessions() + if not sessions: + print(NO_SESSIONS) + return + + print(AVAILABLE_SESSIONS) + for i, session in enumerate(sessions, 1): + print(SESSION_INFO.format(num=i, name=session.session_name)) + print(SESSION_ID.format(id=session.session_id)) + print(SESSION_MESSAGES.format(count=session.message_count)) + print(SESSION_ACTIVITY.format(activity=session.last_activity)) + print() + + def switch_session(self, session_id: str): + session = self.orchestrator.get_session(session_id) + if not session: + print(SESSION_NOT_FOUND.format(id=session_id)) + return + + self.current_session_id = session_id + print(SESSION_SWITCHED.format(name=session.session_name)) + + def chat(self, query: str): + if not self.current_session_id: + print(NO_ACTIVE_SESSION) + return + + try: + response = self.orchestrator.generate_sql_response(self.current_session_id, query) + print(f"\n{RESPONSE_GENERATED}") + print(response.content) + + if response.sql_response: + print(f"\n{SQL_DETAILS}") + print(f" {QUERY_TYPE}: {response.sql_response.query_type}") + print(f" {COMPLEXITY}: {response.sql_response.complexity}") + print(f" {TABLES_USED}: {', '.join(response.sql_response.tables_used)}") + print(f" {COLUMNS}: {', '.join(response.sql_response.columns_selected)}") + print(f" {ESTIMATED_ROWS}: {response.sql_response.estimated_rows}") + if response.sql_response.warnings: + print(f" {WARNINGS}: {', '.join(response.sql_response.warnings)}") + + except Exception as e: + print(CLI_ERROR.format(error=str(e))) + + def show_history(self): + if not self.current_session_id: + print(NO_SESSION) + return + + try: + conversation = self.orchestrator.get_conversation_history(self.current_session_id) + print(f"\n{CONVERSATION_HISTORY}") + for message in conversation.messages: + print(f" {message.role.upper()}: {message.content}") + except Exception as e: + print(CLI_ERROR.format(error=str(e))) + + def health_check(self): + try: + health = self.orchestrator.health_check() + print(f"{HEALTH_CHECK_SUCCESS}") + print(f" {HEALTH_STATUS.format(status=health.status)}") + print(f" {HEALTH_VERSION.format(version=health.version)}") + print(f" {HEALTH_SERVICES.format(services=json.dumps(health.services, indent=2))}") + except Exception as e: + print(HEALTH_CHECK_FAILED.format(error=str(e))) + + def run_interactive(self): + print(CLI_WELCOME) + print(CLI_COMMANDS) + print() + + self.create_session("CLI Session") + + while True: + try: + command = input("querypls> ").strip() + + if not command: + continue + + parts = command.split() + cmd = parts[0].lower() + + if cmd == "quit" or cmd == "exit": + print(CLI_GOODBYE) + break + elif cmd == "new": + name = " ".join(parts[1:]) if len(parts) > 1 else None + self.create_session(name) + elif cmd == "list": + self.list_sessions() + elif cmd == "switch" and len(parts) > 1: + self.switch_session(parts[1]) + elif cmd == "chat" and len(parts) > 1: + query = " ".join(parts[1:]) + self.chat(query) + elif cmd == "history": + self.show_history() + elif cmd == "health": + self.health_check() + else: + print(CLI_UNKNOWN_COMMAND) + + except KeyboardInterrupt: + print(f"\n{CLI_GOODBYE}") + break + except Exception as e: + print(CLI_ERROR.format(error=str(e))) + + +def main(): + cli = QueryplsCLI() + + if len(sys.argv) > 1: + command = sys.argv[1] + if command == "new": + name = sys.argv[2] if len(sys.argv) > 2 else None + cli.create_session(name) + elif command == "list": + cli.list_sessions() + elif command == "chat" and len(sys.argv) > 2: + query = " ".join(sys.argv[2:]) + cli.create_session("CLI Session") + cli.chat(query) + elif command == "health": + cli.health_check() + else: + print("Usage: python cli.py [new|list|chat |health]") + else: + cli.run_interactive() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/static/css/styles.css b/static/css/styles.css index fa6f7e4..9ff0c4b 100644 --- a/static/css/styles.css +++ b/static/css/styles.css @@ -1,85 +1,190 @@ - +/* Main container styling */ #root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2){ background-color: rgb(233, 240, 255); - border: 1px solid #152544 ; /* Add border with 1px thickness and black color */ - border-radius: 10px; /* Add border radius for rounded corners */ - padding: 10px; /* Add padding for spacing inside the element */ - + border: 1px solid #152544; + border-radius: 10px; + padding: 10px; } + +/* Chat message styling */ #root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(3){ - background-color: #152544; - color:rgb(17, 17, 17) + background-color: #152544; + color: rgb(17, 17, 17); } + +/* Main content area */ #root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4{ background-color: rgb(255, 255, 255); } + +/* Additional chat elements */ #root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(5){ - background-color: #152544; + background-color: #152544; } - - - #root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(3) > div > div > p > code{ - color:white; +/* Code styling */ +#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(3) > div > div > p > code{ + color: white; background-color: #152544; - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.st-emotion-cache-nju155.eczjsme11 > div.st-emotion-cache-6qob1r.eczjsme3{ +} + +/* Sidebar styling */ +#root > div:nth-child(1) > div.withScreencast > div > div > div > section.st-emotion-cache-nju155.eczjsme11 > div.st-emotion-cache-6qob1r.eczjsme3{ background-color: rgb(184, 205, 252); - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5{ - border: 2px solid rgb(0, 2, 8) ; /* Add border with 1px thickness and black color */ - border-radius: 0px; /* Add border radius for rounded corners */ - padding: 10px; /* Add padding for spacing inside the element */ - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4{ +} + +/* Main section styling */ +#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5{ + border: 2px solid rgb(0, 2, 8); + border-radius: 0px; + padding: 10px; +} + +/* Chat message styling */ +#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4{ background-color: rgb(233, 240, 255); color: #152544; +} - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4 > div.st-emotion-cache-14m9yky.eeusbqq3{ - color:#152544; +/* Chat message content */ +#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4 > div.st-emotion-cache-14m9yky.eeusbqq3{ + color: #152544; background-color: rgb(233, 240, 255); - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.element-container.st-emotion-cache-10gv909.e1f1d6gn3 > div > div > div > div.st-bg.st-b4.st-bh.st-co.st-bj.st-bk.st-bl.st-bm.st-bn.st-bo.st-bp.st-bq.st-br.st-b2.st-bs.st-av.st-ay.st-aw.st-ax.st-bt.st-bu.st-bv.st-bw.st-bx.st-by.st-bz.st-c0 > div{ +} + +/* Button styling */ +#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.element-container.st-emotion-cache-10gv909.e1f1d6gn3 > div > div > div > div.st-bg.st-b4.st-bh.st-co.st-bj.st-bk.st-bl.st-bm.st-bn.st-bo.st-bp.st-bq.st-br.st-b2.st-bs.st-av.st-ay.st-aw.st-ax.st-bt.st-bu.st-bv.st-bw.st-bx.st-by.st-bz.st-c0 > div{ background-color: rgb(12, 51, 158); - } - #root > div:nth-child(1) > div.withScreencast > div > div > header{ - background-color: rgb(233, 240, 255); - color: #152544; } +/* Header styling */ +#root > div:nth-child(1) > div.withScreencast > div > div > header{ + background-color: rgb(233, 240, 255); + color: #152544; +} + +/* Button icon styling */ #root > div:nth-child(1) > div.withScreencast > div > div > div > div > button > svg{ - color:#152544; - width: 20px + color: #152544; + width: 20px; } + +/* Sidebar general styling */ .sidebar{ background-color: rgb(233, 240, 255); } + +/* Streamlit button styling */ .stButton { - /* Add styles for the stButton class */ background-color: rgb(218, 218, 230); - color: #152544; - /* padding: 10px 20px; */ + color: #152544; border: 3px; border-radius: 5px; width: 40px; cursor: pointer; - - } + +/* Row widget styling */ .row-widget { - /* Add styles for the row-widget class */ margin-bottom: 10px; background-color: rgb(233, 240, 255); - color:#152544; - + color: #152544; } +/* Sidebar specific elements */ #root > div:nth-child(1) > div.withScreencast > div > div > div > section.st-emotion-cache-nju155.eczjsme11 > div.st-emotion-cache-6qob1r.eczjsme3 > div.st-emotion-cache-16txtl3.eczjsme4 > div > div > div > div:nth-child(5){ - background-color:rgb(0, 0, 0) ; + background-color: rgb(0, 0, 0); color: #152544; } +/* Hide menu button */ #root > div:nth-child(1) > div.withScreencast > div > div > header > div.st-emotion-cache-zq5wmm.ezrtsby0 > div > div:nth-child(2) > button > div > div { display: none; +} + +/* Sidebar button styling - improved */ +.sidebar .stButton > button { + background-color: #4CAF50 !important; + color: white !important; + border: none !important; + padding: 8px 12px !important; + text-align: center !important; + text-decoration: none !important; + display: inline-block !important; + font-size: 12px !important; + margin: 2px !important; + cursor: pointer !important; + border-radius: 4px !important; + width: auto !important; + min-width: 80px !important; + transition: background-color 0.3s !important; +} + +.sidebar .stButton > button:hover { + background-color: #45a049 !important; +} + +/* Session button styling */ +.sidebar button[key*="session_"] { + background-color: #2196F3 !important; + color: white !important; + border: none !important; + padding: 6px 10px !important; + text-align: center !important; + text-decoration: none !important; + display: inline-block !important; + font-size: 11px !important; + margin: 2px !important; + cursor: pointer !important; + border-radius: 3px !important; + width: auto !important; + min-width: 60px !important; + transition: background-color 0.3s !important; +} + +.sidebar button[key*="session_"]:hover { + background-color: #1976D2 !important; +} + +/* New session button styling */ +.sidebar button:contains("βž• New Session") { + background-color: #FF9800 !important; + color: white !important; + border: none !important; + padding: 8px 12px !important; + text-align: center !important; + text-decoration: none !important; + display: inline-block !important; + font-size: 12px !important; + margin: 2px !important; + cursor: pointer !important; + border-radius: 4px !important; + width: auto !important; + min-width: 100px !important; + transition: background-color 0.3s !important; +} + +.sidebar button:contains("βž• New Session"):hover { + background-color: #F57C00 !important; +} + +/* Force horizontal layout for sidebar buttons */ +.sidebar [data-testid="stButton"] { + display: inline-block !important; + float: left !important; + margin: 2px !important; +} + +/* Sidebar button wrapper */ +.sidebar .element-container { + display: inline-block !important; + margin: 2px !important; +} + +/* Sidebar button container */ +.sidebar .stButton { + display: inline-block !important; + margin: 2px !important; + width: auto !important; + float: left !important; } \ No newline at end of file diff --git a/training/Querypls_prompt.py b/training/Querypls_prompt.py deleted file mode 100644 index b962eb0..0000000 --- a/training/Querypls_prompt.py +++ /dev/null @@ -1,28 +0,0 @@ -# !pip install langchain huggingface_hub > /dev/null - -import os - -huggingfacehub_api_token = "YOUR_API_TOKEN" - -# pip install huggingface_hub - -# pip install langchain - -from langchain import HuggingFaceHub - -repo_id = "tiiuae/falcon-7b-instruct" -llm = HuggingFaceHub( - huggingfacehub_api_token=huggingfacehub_api_token, - repo_id=repo_id, - model_kwargs={"temperature": 0.6, "max_new_tokens": 100}, -) - -from langchain import PromptTemplate, LLMChain - -template = "" -prompt = PromptTemplate(template=template, input_variables=["question"]) -llm_chain = LLMChain(prompt=prompt, llm=llm) - -question = "" - -print(llm_chain.run(question)) diff --git a/training/finetuning_querypls.py b/training/finetuning_querypls.py deleted file mode 100644 index af6003d..0000000 --- a/training/finetuning_querypls.py +++ /dev/null @@ -1,107 +0,0 @@ -# from huggingface_hub import notebook_login - -# notebook_login() - -from datasets import load_dataset, DatasetDict, Dataset -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments -from trl import SFTTrainer -from peft import LoraConfig - -dataset = load_dataset("b-mc2/sql-create-context") - -dataset - -# dataset['train'][0] - -model_checkpoint = "stabilityai/StableBeluga-7B" -# Initialize the tokenizer and model -model = AutoModelForCausalLM.from_pretrained(model_checkpoint) - - -tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, max_length=512) -tokenizer.pad_token = tokenizer.eos_token -tokenizer.padding_side = "right" - -model.config.use_cache = False - -model.config.quantization_config.to_dict() - -lora_target_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", -] -config = LoraConfig( - r=16, # attention heads - lora_alpha=12, # alpha scaling - target_modules=lora_target_modules, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", -) - -import random - -split_ratio = 0.8 -eval_ratio = 0.2 - -# the 30% subset -total_examples = len(dataset["train"]) -subset_size = int(total_examples * 0.2) -train_size = int(subset_size * split_ratio) -eval_size = subset_size - train_size -shuffled_indices = list(range(total_examples)) -random.shuffle(shuffled_indices) -training_set = dataset["train"].select(shuffled_indices[:train_size]) -evaluation_set = dataset["train"].select( - shuffled_indices[train_size : train_size + eval_size] -) -split_dataset = DatasetDict({"train": training_set, "eval": evaluation_set}) -split_dataset - -evaluation_set - -# hyperparameters -lr = 1e-4 -batch_size = 4 -num_epochs = 1 -training_args = TrainingArguments( - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - optim="paged_adamw_32bit", - logging_steps=1, - learning_rate=lr, - fp16=True, - max_grad_norm=0.3, - num_train_epochs=num_epochs, - evaluation_strategy="steps", - eval_steps=0.2, - warmup_ratio=0.05, - save_strategy="epoch", - group_by_length=True, - output_dir="outputs", - report_to="tensorboard", - save_safetensors=True, - lr_scheduler_type="cosine", - seed=12, -) - -trainer = SFTTrainer( - model=model, - train_dataset=split_dataset["train"], - eval_dataset=split_dataset["eval"], - peft_config=config, - dataset_text_field="question", - max_seq_length=4096, - tokenizer=tokenizer, - args=training_args, -) - -# train model -trainer.train() - -model.push_to_hub("samadpls/querypls-prompt2sql") -tokenizer.push_to_hub("samadpls/querypls-prompt2sql") - -# DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' diff --git a/utils/prompt.py b/utils/prompt.py new file mode 100644 index 0000000..e9bed17 --- /dev/null +++ b/utils/prompt.py @@ -0,0 +1,107 @@ +""" +Instruction prompts for Querypls application. +""" + +SQL_GENERATION_PROMPT = """You are a SQL expert developer. Analyze the following conversation history and generate appropriate SQL code based on the context and current question. + +Previous conversation: {conversation_history} +Current question: {input} + +## Response Format +Your response must be in JSON format. + +It must be an object and must contain these fields: +* `sql_query` - The generated SQL query as a string +* `explanation` - Brief explanation of what the query does +* `tables_used` - Array of table names used in the query +* `columns_selected` - Array of column names selected in the query +* `query_type` - Type of query (SELECT, INSERT, UPDATE, DELETE, etc.) +* `complexity` - Query complexity level (SIMPLE, MEDIUM, COMPLEX) +* `estimated_rows` - Estimated number of rows returned (if applicable) +* `execution_time` - Estimated execution time (optional) +* `warnings` - Array of warnings about the query (optional) + +## Example Response +{ + "sql_query": "SELECT * FROM users WHERE status = 'active'", + "explanation": "Retrieves all active users from the users table", + "tables_used": ["users"], + "columns_selected": ["*"], + "query_type": "SELECT", + "complexity": "SIMPLE", + "estimated_rows": "variable", + "execution_time": "fast", + "warnings": [] +} + +Respond only with the JSON object. Do not include any additional text or markdown formatting.""" + +CSV_ANALYSIS_PROMPT = """You are a Python data analysis expert. Generate Python code to analyze CSV data based on user queries. + +## Response Format +Your response must be in JSON format. + +It must be an object and must contain these fields: +* `python_code` - The generated Python code as a string +* `explanation` - Brief explanation of what the code does +* `expected_output` - What output is expected from the code +* `libraries_used` - Array of Python libraries used + +## Guidelines +1. Always use pandas for data manipulation +2. Use matplotlib/seaborn for visualizations when appropriate +3. Include proper error handling +4. Make the code readable and well-commented +5. Return clear, formatted output +6. Handle missing data appropriately +7. Use appropriate data types + +## Example Response +{ + "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load and analyze data\\ndf = pd.read_csv('data.csv')\\nprint(f'Data shape: {df.shape}')\\nprint(df.head())\\n\\n# Create visualization\\nplt.figure(figsize=(10, 6))\\ndf['column'].value_counts().plot(kind='bar')\\nplt.title('Distribution of Column')\\nplt.show()", + "explanation": "Loads CSV data, displays basic info, and creates a bar chart of column distribution", + "expected_output": "Data shape, first few rows, and a bar chart visualization", + "libraries_used": ["pandas", "matplotlib.pyplot"] +} + +Respond only with the JSON object. Do not include any additional text or markdown formatting.""" + +CODE_FIX_PROMPT = """You are a Python debugging expert. Fix Python code based on error messages. + +## Response Format +Your response must be in JSON format. + +It must be an object and must contain these fields: +* `python_code` - The fixed Python code as a string +* `explanation` - Brief explanation of what was fixed +* `expected_output` - What output is expected from the fixed code +* `libraries_used` - Array of Python libraries used + +## Guidelines +1. Identify the root cause of the error +2. Fix syntax errors, import issues, and logic problems +3. Ensure the code follows Python best practices +4. Add proper error handling if needed +5. Make sure the code works with the given CSV data structure +6. Test the logic and ensure it produces the expected output + +## Example Response +{ + "python_code": "import pandas as pd\\n\\n# Fixed code with proper error handling\\ntry:\\n df = pd.read_csv('data.csv')\\n print(f'Data shape: {df.shape}')\\nexcept FileNotFoundError:\\n print('CSV file not found')\\nexcept Exception as e:\\n print(f'Error: {e}')", + "explanation": "Added proper error handling for file reading and data loading", + "expected_output": "Data shape or appropriate error message", + "libraries_used": ["pandas"] +} + +Respond only with the JSON object. Do not include any additional text or markdown formatting.""" + +CSV_AGENT_PROMPT = """You are a data analysis expert. You can analyze CSV data using Python code. + +Available tools: +- load_csv_data: Load CSV data into a session +- generate_analysis_code: Generate Python code for data analysis +- execute_analysis_code: Execute Python code and get results +- fix_code_error: Fix code errors and retry +- get_csv_info: Get information about loaded CSV data + +Always provide clear explanations and handle errors gracefully.""" From e51179792318329a9a2afe49d7b0ec274ee9bb46 Mon Sep 17 00:00:00 2001 From: samadpls Date: Thu, 7 Aug 2025 22:48:20 +0500 Subject: [PATCH 2/8] feat: Implement intelligent routing service and data models - Added data models for routing decisions, conversation responses, SQL responses, CSV analysis responses, and code fix responses in `models.py`. - Created `IntelligentRoutingService` to determine the appropriate agent for user queries based on context and conversation history in `routing_service.py`. - Implemented methods to handle conversation, SQL, and CSV queries, including fallback mechanisms for routing and responses. - Updated SQL generation service to improve error handling and response formatting in `sql_service.py`. - Removed unused CSS styles from `styles.css`. - Added a test script for routing logic verification in `test_routing.py`. - Refactored CLI commands and improved user feedback in `cli.py`. - Enhanced prompt instructions for routing and conversation handling in `prompt.py`. --- pyproject.toml | 2 - run.py | 48 ++--- src/backend/backend.py | 22 +- src/backend/orchestrator.py | 261 +++++++++++------------ src/config/__init__.py | 2 +- src/config/constants.py | 105 +--------- src/config/settings.py | 51 +++++ src/frontend/app.py | 151 +++++++------- src/frontend/frontend.py | 6 +- src/schemas/__init__.py | 2 +- src/schemas/requests.py | 46 ++--- src/schemas/responses.py | 132 ++++-------- src/services/__init__.py | 2 +- src/services/conversation_service.py | 91 ++++++++ src/services/csv_analysis_tools.py | 138 ++++++++----- src/services/jupyter_service.py | 175 ++++++++++------ src/services/models.py | 99 ++++++++- src/services/routing_service.py | 298 +++++++++++++++++++++++++++ src/services/sql_service.py | 105 +++++----- src/terminal/__init__.py | 2 +- src/terminal/cli.py | 143 +++++++------ static/css/styles.css | 190 ----------------- test_routing.py | 73 +++++++ tests/test_auth.py | 12 +- tests/test_backend.py | 10 - tests/test_frontend.py | 12 +- utils/prompt.py | 70 ++++++- 27 files changed, 1297 insertions(+), 951 deletions(-) delete mode 100644 pyproject.toml create mode 100644 src/config/settings.py create mode 100644 src/services/conversation_service.py create mode 100644 src/services/routing_service.py delete mode 100644 static/css/styles.css create mode 100644 test_routing.py diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 9216134..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,2 +0,0 @@ -[tool.black] -line-length = 79 \ No newline at end of file diff --git a/run.py b/run.py index 369984e..3e7845c 100644 --- a/run.py +++ b/run.py @@ -8,24 +8,25 @@ import argparse # Add src to path -sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) +sys.path.append(os.path.join(os.path.dirname(__file__), "src")) def run_streamlit(): """Run the Streamlit application.""" import subprocess import streamlit.web.cli as stcli - + # Set environment variables - os.environ['STREAMLIT_SERVER_PORT'] = '8501' - os.environ['STREAMLIT_SERVER_ADDRESS'] = 'localhost' - + os.environ["STREAMLIT_SERVER_PORT"] = "8501" + os.environ["STREAMLIT_SERVER_ADDRESS"] = "localhost" + # Run streamlit sys.argv = [ - 'streamlit', 'run', - 'src/frontend/app.py', - '--server.port=8501', - '--server.address=localhost' + "streamlit", + "run", + "src/frontend/app.py", + "--server.port=8501", + "--server.address=localhost", ] sys.exit(stcli.main()) @@ -33,37 +34,36 @@ def run_streamlit(): def run_cli(): """Run the CLI application.""" from terminal.cli import main as cli_main + cli_main() def main(): """Main launcher function.""" - parser = argparse.ArgumentParser(description='Querypls - SQL Generation Tool') + parser = argparse.ArgumentParser(description="Querypls - SQL Generation Tool") parser.add_argument( - 'mode', - choices=['web', 'cli'], - default='web', - nargs='?', - help='Run mode: web (Streamlit) or cli (Command Line)' + "mode", + choices=["web", "cli"], + default="web", + nargs="?", + help="Run mode: web (Streamlit) or cli (Command Line)", ) parser.add_argument( - 'cli_args', - nargs='*', - help='Arguments to pass to CLI (when mode is cli)' + "cli_args", nargs="*", help="Arguments to pass to CLI (when mode is cli)" ) - + args = parser.parse_args() - - if args.mode == 'web': + + if args.mode == "web": print("πŸš€ Starting Querypls Web Application...") run_streamlit() - elif args.mode == 'cli': + elif args.mode == "cli": print("πŸš€ Starting Querypls CLI...") # Pass CLI arguments to the CLI if args.cli_args: - sys.argv = ['cli'] + args.cli_args + sys.argv = ["cli"] + args.cli_args run_cli() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/backend/backend.py b/src/backend/backend.py index 015cc42..bd0dc4f 100644 --- a/src/backend/backend.py +++ b/src/backend/backend.py @@ -2,27 +2,23 @@ Backend utilities for Streamlit configuration and styling. """ +from config.constants import STREAMLIT_CONFIG import streamlit as st import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from config.constants import STREAMLIT_CONFIG, HIDE_MENU_STYLE, HIDE_MENU_FOOTER_STYLE, GITHUB_BADGE - - -def configure_page_styles(file_name: str) -> None: - st.set_page_config(**STREAMLIT_CONFIG) - - with open(file_name) as f: - st.markdown(f"", unsafe_allow_html=True) - - st.markdown(HIDE_MENU_STYLE, unsafe_allow_html=True) - def hide_main_menu_and_footer() -> None: - st.markdown(HIDE_MENU_FOOTER_STYLE, unsafe_allow_html=True) + st.markdown( + "", + unsafe_allow_html=True, + ) def display_github_badge() -> None: - st.markdown(GITHUB_BADGE, unsafe_allow_html=True) + st.markdown( + "", + unsafe_allow_html=True, + ) diff --git a/src/backend/orchestrator.py b/src/backend/orchestrator.py index 3cc05e1..553cb9a 100644 --- a/src/backend/orchestrator.py +++ b/src/backend/orchestrator.py @@ -9,14 +9,33 @@ from config.settings import get_settings from config.constants import ( - WELCOME_MESSAGE, DEFAULT_SESSION_NAME, CSV_LOAD_ERROR, CSV_ANALYSIS_ERROR, - SESSION_CREATE_ERROR, ORCHESTRATOR_INIT_ERROR, SESSION_NOT_FOUND_ERROR, - RESPONSE_GENERATION_ERROR, MESSAGE_LOAD_ERROR, MAX_CHAT_HISTORIES + WELCOME_MESSAGE, + DEFAULT_SESSION_NAME, + CSV_LOAD_ERROR, + CSV_ANALYSIS_ERROR, + SESSION_CREATE_ERROR, + ORCHESTRATOR_INIT_ERROR, + SESSION_NOT_FOUND_ERROR, + RESPONSE_GENERATION_ERROR, + MESSAGE_LOAD_ERROR, + MAX_CHAT_HISTORIES, ) from services.sql_service import SQLGenerationService from services.csv_analysis_tools import CSVAnalysisTools, create_csv_analysis_agent -from schemas.requests import SQLGenerationRequest, ChatMessage, ConversationHistory, NewChatRequest -from schemas.responses import ChatResponse, SessionInfo, ErrorResponse, HealthCheckResponse +from services.conversation_service import ConversationService +from services.routing_service import IntelligentRoutingService +from schemas.requests import ( + SQLGenerationRequest, + ChatMessage, + ConversationHistory, + NewChatRequest, +) +from schemas.responses import ( + ChatResponse, + SessionInfo, + ErrorResponse, + HealthCheckResponse, +) @dataclass @@ -35,74 +54,46 @@ def __init__(self): self.sql_service = SQLGenerationService() self.csv_tools = CSVAnalysisTools() self.csv_agent = create_csv_analysis_agent() + self.conversation_service = ConversationService() + self.routing_service = IntelligentRoutingService() self.sessions: Dict[str, Session] = {} self.max_sessions = self.settings.max_chat_histories - - def is_conversational_query(self, query: str) -> bool: - conversational_keywords = [ - 'hi', 'hello', 'hey', 'good morning', 'good afternoon', 'good evening', - 'how are you', 'what\'s up', 'thanks', 'thank you', 'bye', 'goodbye', - 'help', 'what can you do', 'who are you', 'tell me about yourself' - ] - query_lower = query.lower().strip() - return any(keyword in query_lower for keyword in conversational_keywords) - - def get_conversational_response(self, query: str) -> str: - query_lower = query.lower().strip() - - if any(greeting in query_lower for greeting in ['hi', 'hello', 'hey']): - return "Hello! πŸ‘‹ How can I help you today? I can assist with SQL generation or CSV data analysis." - elif 'how are you' in query_lower: - return "I'm doing great, thank you for asking! 😊 How can I assist you with your data queries today?" - elif any(thanks in query_lower for thanks in ['thanks', 'thank you']): - return "You're welcome! 😊 Is there anything else I can help you with?" - elif any(bye in query_lower for bye in ['bye', 'goodbye']): - return "Goodbye! πŸ‘‹ Feel free to come back if you need help with SQL or data analysis." - elif 'help' in query_lower or 'what can you do' in query_lower: - return "I'm Querypls, your SQL and data analysis assistant! πŸ—ƒοΈπŸ’¬\n\nI can help you with:\nβ€’ **SQL Generation**: Convert natural language to SQL queries\nβ€’ **CSV Analysis**: Analyze data files with Python code\nβ€’ **Data Visualization**: Create charts and graphs\n\nJust ask me anything about your data!" - elif 'who are you' in query_lower or 'tell me about yourself' in query_lower: - return "I'm Querypls, an AI assistant specialized in SQL generation and data analysis! πŸ—ƒοΈπŸ’¬\n\nI can help you write SQL queries from natural language and analyze CSV files with Python code. What would you like to work on?" - else: - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" - + def create_new_session(self, request: NewChatRequest) -> SessionInfo: session_id = str(uuid.uuid4()) session_name = request.session_name or f"Chat {len(self.sessions) + 1}" - + messages = [] if request.initial_context: - messages.append(ChatMessage( - role="system", - content=request.initial_context - )) - - messages.append(ChatMessage( - role="assistant", - content=WELCOME_MESSAGE - )) - + messages.append( + ChatMessage( + role="system", + content=request.initial_context)) + + messages.append(ChatMessage(role="assistant", content=WELCOME_MESSAGE)) + session = Session( session_id=session_id, session_name=session_name, created_at=datetime.now(), messages=messages, - last_activity=datetime.now() + last_activity=datetime.now(), ) - + self.sessions[session_id] = session self._cleanup_old_sessions() - + return SessionInfo( session_id=session_id, session_name=session_name, created_at=session.created_at.isoformat(), message_count=len(session.messages), - last_activity=session.last_activity.isoformat() + last_activity=session.last_activity.isoformat(), ) - + def get_session(self, session_id: str) -> Optional[Session]: return self.sessions.get(session_id) - + def list_sessions(self) -> List[SessionInfo]: return [ SessionInfo( @@ -110,179 +101,157 @@ def list_sessions(self) -> List[SessionInfo]: session_name=session.session_name, created_at=session.created_at.isoformat(), message_count=len(session.messages), - last_activity=session.last_activity.isoformat() + last_activity=session.last_activity.isoformat(), ) for session in self.sessions.values() ] - + def delete_session(self, session_id: str) -> bool: if session_id in self.sessions: self.csv_tools.close_session(session_id) del self.sessions[session_id] return True return False - - def load_csv_data(self, session_id: str, csv_content: str) -> Dict[str, Any]: + + def load_csv_data(self, session_id: str, + csv_content: str) -> Dict[str, Any]: session = self.get_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") - + session.csv_data = csv_content result = self.csv_tools.load_csv_data(csv_content, session_id) session.last_activity = datetime.now() - + return result - - def generate_sql_response(self, session_id: str, user_query: str) -> ChatResponse: + + def generate_intelligent_response( + self, session_id: str, user_query: str + ) -> ChatResponse: + """Generate response using intelligent routing to determine the appropriate agent.""" session = self.get_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") - + user_message = ChatMessage( role="user", content=user_query, - timestamp=datetime.now().isoformat() - ) + timestamp=datetime.now().isoformat()) session.messages.append(user_message) - - # Check if this is a conversational query - if self.is_conversational_query(user_query): - response_content = self.get_conversational_response(user_query) + + # Determine which agent should handle this query + csv_loaded = bool(session.csv_data) + routing_decision = self.routing_service.determine_agent( + user_query, session.messages, csv_loaded + ) + + # Generate response based on routing decision + if routing_decision.agent == "CONVERSATION_AGENT": + response_content = self.routing_service.handle_conversation_query( + user_query + ) + elif routing_decision.agent == "SQL_AGENT": + response_content = self.routing_service.handle_sql_query( + user_query, session.messages + ) + elif routing_decision.agent == "CSV_AGENT": + if session.csv_data: + response_content = self.routing_service.handle_csv_query( + user_query, session.csv_data + ) + else: + response_content = "I don't see any CSV data loaded. Please upload a CSV file first to analyze it." else: - # Generate SQL for data-related queries - request = SQLGenerationRequest( - user_query=user_query, - conversation_history=session.messages + # Fallback to conversation + response_content = self.routing_service.handle_conversation_query( + user_query ) - - response = self.sql_service.generate_sql(request) - response_content = response.content - + assistant_message = ChatMessage( role="assistant", content=response_content, - timestamp=datetime.now().isoformat() + timestamp=datetime.now().isoformat(), ) session.messages.append(assistant_message) session.last_activity = datetime.now() - + return ChatResponse( message_id=str(uuid.uuid4()), content=response_content, timestamp=datetime.now().isoformat(), - session_id=session_id - ) - - def generate_csv_analysis_response(self, session_id: str, user_query: str) -> ChatResponse: - session = self.get_session(session_id) - if not session: - raise ValueError(f"Session {session_id} not found") - - if not session.csv_data: - error_response = ChatResponse( - message_id=str(uuid.uuid4()), - content=CSV_LOAD_ERROR, - timestamp=datetime.now().isoformat(), - session_id=session_id - ) - return error_response - - user_message = ChatMessage( - role="user", - content=user_query, - timestamp=datetime.now().isoformat() + session_id=session_id, ) - session.messages.append(user_message) - - try: - result = self.csv_agent.run_sync(user_query) - - assistant_message = ChatMessage( - role="assistant", - content=result.output, - timestamp=datetime.now().isoformat() - ) - session.messages.append(assistant_message) - session.last_activity = datetime.now() - - return ChatResponse( - message_id=str(uuid.uuid4()), - content=result.output, - timestamp=datetime.now().isoformat(), - session_id=session_id - ) - - except Exception as e: - error_response = ChatResponse( - message_id=str(uuid.uuid4()), - content=CSV_ANALYSIS_ERROR.format(error=str(e)), - timestamp=datetime.now().isoformat(), - session_id=session_id - ) - return error_response - + def get_conversation_history(self, session_id: str) -> ConversationHistory: session = self.get_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") - + return ConversationHistory( messages=session.messages, - session_id=session_id - ) - + session_id=session_id) + def get_csv_info(self, session_id: str) -> Dict[str, Any]: return self.csv_tools.get_csv_info(session_id) - + def health_check(self) -> HealthCheckResponse: services_status = { "sql_service": "healthy", "csv_analysis_service": "healthy", - "session_manager": "healthy" + "conversation_service": "healthy", + "session_manager": "healthy", } - + try: test_request = SQLGenerationRequest( - user_query="SELECT 1", - conversation_history=[] + user_query="SELECT 1", conversation_history=[] ) if not self.sql_service: services_status["sql_service"] = "unhealthy" except Exception: services_status["sql_service"] = "unhealthy" - + try: if not self.csv_tools: services_status["csv_analysis_service"] = "unhealthy" except Exception: services_status["csv_analysis_service"] = "unhealthy" - + + try: + if not self.conversation_service: + services_status["conversation_service"] = "unhealthy" + except Exception: + services_status["conversation_service"] = "unhealthy" + return HealthCheckResponse( - status="healthy" if all(status == "healthy" for status in services_status.values()) else "unhealthy", + status=( + "healthy" + if all(status == "healthy" for status in services_status.values()) + else "unhealthy" + ), version=self.settings.app_version, timestamp=datetime.now().isoformat(), - services=services_status + services=services_status, ) - + def _cleanup_old_sessions(self): if len(self.sessions) <= self.max_sessions: return - + sorted_sessions = sorted( - self.sessions.items(), - key=lambda x: x[1].last_activity + self.sessions.items(), key=lambda x: x[1].last_activity ) - + sessions_to_remove = len(self.sessions) - self.max_sessions for i in range(sessions_to_remove): session_id, _ = sorted_sessions[i] self.delete_session(session_id) - + def get_default_session(self) -> str: for session_id, session in self.sessions.items(): if session.session_name == DEFAULT_SESSION_NAME: return session_id - + request = NewChatRequest(session_name=DEFAULT_SESSION_NAME) session_info = self.create_new_session(request) - return session_info.session_id \ No newline at end of file + return session_info.session_id diff --git a/src/config/__init__.py b/src/config/__init__.py index 3622e45..b2e8ad4 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -1,3 +1,3 @@ """ Configuration package for Querypls application. -""" \ No newline at end of file +""" diff --git a/src/config/constants.py b/src/config/constants.py index 4da7b39..1912a55 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -13,92 +13,7 @@ STREAMLIT_CONFIG = { "page_title": "Querypls", "page_icon": "πŸ’¬", - "layout": "wide" -} - -# UI Messages -WELCOME_MESSAGE = "How may I help you? I can help with SQL generation or CSV data analysis." -DEFAULT_SESSION_NAME = "Default Chat" - -# Success Messages -CSV_LOAD_SUCCESS = "βœ… CSV loaded for analysis!" -CSV_UPLOAD_SUCCESS = "βœ… CSV uploaded successfully! Shape: {shape}" -CSV_LOADED_SUCCESS = "βœ… CSV loaded for analysis!" -SESSION_CREATED = "βœ… Created session: {name} (ID: {id})" -SESSION_SWITCHED = "βœ… Switched to session: {name}" - -# Error Messages -CSV_LOAD_ERROR = "❌ No CSV data loaded. Please upload a CSV file first." -CSV_ANALYSIS_ERROR = "❌ Error during CSV analysis: {error}" -CSV_UPLOAD_ERROR = "❌ Error reading CSV file: {error}" -SESSION_CREATE_ERROR = "Failed to create new session: {error}" -ORCHESTRATOR_INIT_ERROR = "Failed to initialize backend orchestrator: {error}" -SESSION_NOT_FOUND_ERROR = "Failed to get current session." -APP_INIT_ERROR = "Failed to initialize application. Please check your configuration." -RESPONSE_GENERATION_ERROR = "Error generating response: {error}" -MESSAGE_LOAD_ERROR = "Error loading messages: {error}" -SESSION_NOT_FOUND = "❌ Session {id} not found." -NO_ACTIVE_SESSION = "❌ No active session. Create one first with 'new' command." -NO_SESSION = "❌ No active session." -HEALTH_CHECK_FAILED = "❌ Health check failed: {error}" - -# Health Check Messages -HEALTH_CHECK_SUCCESS = "πŸ₯ Health Check:" -HEALTH_STATUS = "Status: {status}" -HEALTH_VERSION = "Version: {version}" -HEALTH_SERVICES = "Services: {services}" - -# CLI Messages -CLI_WELCOME = "πŸš€ Welcome to Querypls CLI!" -CLI_COMMANDS = "Commands: new, list, switch , chat , history, health, quit" -CLI_GOODBYE = "πŸ‘‹ Goodbye!" -CLI_UNKNOWN_COMMAND = "❌ Unknown command. Use: new, list, switch , chat , history, health, quit" -CLI_ERROR = "❌ Error: {error}" - -# Response Labels -RESPONSE_GENERATED = "πŸ€– Response:" -SQL_DETAILS = "πŸ“Š SQL Details:" -QUERY_TYPE = "Query Type:" -COMPLEXITY = "Complexity:" -TABLES_USED = "Tables Used:" -COLUMNS = "Columns:" -ESTIMATED_ROWS = "Estimated Rows:" -WARNINGS = "Warnings:" -CONVERSATION_HISTORY = "πŸ“œ Conversation History:" - -# Session Management -NO_SESSIONS = "πŸ“ No sessions found." -AVAILABLE_SESSIONS = "πŸ“ Available sessions:" -SESSION_INFO = " {num}. {name}" -SESSION_ID = " ID: {id}" -SESSION_MESSAGES = " Messages: {count}" -SESSION_ACTIVITY = " Last Activity: {activity}" - -# CSV Analysis UI -CSV_ANALYSIS_SECTION = "### πŸ“Š CSV Analysis" -CSV_UPLOAD_LABEL = "Upload CSV file for analysis" -CSV_UPLOAD_HELP = "Upload a CSV file to analyze with Python code" -CSV_PREVIEW = "πŸ“Š CSV Preview" -CSV_COLUMNS = "**Columns:** {columns}" -CSV_DTYPES = "**Data Types:** {dtypes}" -LOAD_CSV_BUTTON = "πŸ” Load CSV for Analysis" - -# UI Elements -GITHUB_BADGE = """\ -""" - -MADE_WITH_LOVE = "`Made with 🀍`" -SESSIONS_SECTION = "### Sessions" -NEW_SESSION_BUTTON = "βž• New Session" - -# CSS Styles -HIDE_MENU_STYLE = """""" - -HIDE_MENU_FOOTER_STYLE = """""" + "layout": "wide"} # Available Models AVAILABLE_MODELS = { @@ -106,42 +21,42 @@ "developer": "DeepSeek / Meta", "context_window": 131072, "max_completion": 131072, - "max_file_size": None + "max_file_size": None, }, "meta-llama/llama-4-maverick-17b-128e-instruct": { "developer": "Meta", "context_window": 131072, "max_completion": 8192, - "max_file_size": "20 MB" + "max_file_size": "20 MB", }, "meta-llama/llama-4-scout-17b-16e-instruct": { "developer": "Meta", "context_window": 131072, "max_completion": 8192, - "max_file_size": "20 MB" + "max_file_size": "20 MB", }, "moonshotai/kimi-k2-instruct": { "developer": "Moonshot AI", "context_window": 131072, "max_completion": 16384, - "max_file_size": None + "max_file_size": None, }, "openai/gpt-oss-120b": { "developer": "OpenAI", "context_window": 131072, "max_completion": 32766, - "max_file_size": None + "max_file_size": None, }, "openai/gpt-oss-20b": { "developer": "OpenAI", "context_window": 131072, "max_completion": 32768, - "max_file_size": None + "max_file_size": None, }, "qwen/qwen3-32b": { "developer": "Alibaba Cloud", "context_window": 131072, "max_completion": 131072, - "max_file_size": None - } -} \ No newline at end of file + "max_file_size": None, + }, +} diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 0000000..8a53346 --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,51 @@ +""" +Application settings with environment variable support. +""" + +import os +from typing import Optional +from pydantic import Field +from pydantic_settings import BaseSettings + +from src.config.constants import AVAILABLE_MODELS + + +class Settings(BaseSettings): + groq_api_key: str = Field(..., env="GROQ_API_KEY") + groq_model_name: str = Field( + default="openai/gpt-oss-120b", + env="GROQ_MODEL_NAME") + app_version: str = Field(default="1.0.0", env="APP_VERSION") + max_chat_histories: int = Field(default=5, env="MAX_CHAT_HISTORIES") + debug_mode: bool = Field(default=False, env="DEBUG_MODE") + + # Legacy fields for backward compatibility + app_name: Optional[str] = Field(None, env="APP_NAME") + streamlit_port: Optional[str] = Field(None, env="STREAMLIT_PORT") + streamlit_host: Optional[str] = Field(None, env="STREAMLIT_HOST") + max_tokens: Optional[str] = Field(None, env="MAX_TOKENS") + temperature: Optional[str] = Field(None, env="TEMPERATURE") + log_level: Optional[str] = Field(None, env="LOG_LEVEL") + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + extra = "ignore" + + +_settings_instance: Optional[Settings] = None + + +def get_settings() -> Settings: + global _settings_instance + if _settings_instance is None: + _settings_instance = Settings() + return _settings_instance + + +def get_available_models(): + return AVAILABLE_MODELS + + +def get_model_info(model_name: str): + return AVAILABLE_MODELS.get(model_name, None) diff --git a/src/frontend/app.py b/src/frontend/app.py index e0a1e4c..c527cd2 100644 --- a/src/frontend/app.py +++ b/src/frontend/app.py @@ -2,32 +2,43 @@ Main Streamlit application for Querypls. """ -import streamlit as st -import sys -import os -import pandas as pd - -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from backend.orchestrator import BackendOrchestrator -from backend.backend import ( - configure_page_styles, - display_github_badge, - hide_main_menu_and_footer, +from schemas.requests import NewChatRequest +from config.constants import ( + CSV_ANALYSIS_SECTION, + CSV_UPLOAD_LABEL, + CSV_UPLOAD_HELP, + CSV_PREVIEW, + CSV_COLUMNS, + CSV_DTYPES, + LOAD_CSV_BUTTON, + CSV_LOADED_SUCCESS, + CSV_UPLOAD_SUCCESS, + CSV_UPLOAD_ERROR, + SESSION_CREATE_ERROR, + ORCHESTRATOR_INIT_ERROR, + SESSION_NOT_FOUND_ERROR, + APP_INIT_ERROR, + RESPONSE_GENERATION_ERROR, + MESSAGE_LOAD_ERROR, + MADE_WITH_LOVE, + SESSIONS_SECTION, + NEW_SESSION_BUTTON, ) from frontend import ( display_logo_and_heading, display_welcome_message, ) -from config.constants import ( - CSV_ANALYSIS_SECTION, CSV_UPLOAD_LABEL, CSV_UPLOAD_HELP, CSV_PREVIEW, - CSV_COLUMNS, CSV_DTYPES, LOAD_CSV_BUTTON, CSV_LOADED_SUCCESS, - CSV_UPLOAD_SUCCESS, CSV_UPLOAD_ERROR, SESSION_CREATE_ERROR, - ORCHESTRATOR_INIT_ERROR, SESSION_NOT_FOUND_ERROR, APP_INIT_ERROR, - RESPONSE_GENERATION_ERROR, MESSAGE_LOAD_ERROR, MADE_WITH_LOVE, - SESSIONS_SECTION, NEW_SESSION_BUTTON +from backend.backend import ( + display_github_badge, + hide_main_menu_and_footer, ) -from schemas.requests import NewChatRequest +from backend.orchestrator import BackendOrchestrator +import streamlit as st +import sys +import os +import pandas as pd + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) def initialize_orchestrator(): @@ -44,7 +55,8 @@ def get_current_session_id(): if "current_session_id" not in st.session_state: orchestrator = initialize_orchestrator() if orchestrator: - st.session_state["current_session_id"] = orchestrator.get_default_session() + st.session_state["current_session_id"] = orchestrator.get_default_session( + ) return st.session_state.get("current_session_id") @@ -52,7 +64,7 @@ def display_messages(session_id: str): orchestrator = initialize_orchestrator() if not orchestrator: return - + try: conversation = orchestrator.get_conversation_history(session_id) for message in conversation.messages: @@ -64,64 +76,58 @@ def display_messages(session_id: str): def upload_csv_file(): uploaded_file = st.file_uploader( - CSV_UPLOAD_LABEL, - type=['csv'], - help=CSV_UPLOAD_HELP + CSV_UPLOAD_LABEL, type=["csv"], help=CSV_UPLOAD_HELP ) - + if uploaded_file is not None: try: - csv_content = uploaded_file.read().decode('utf-8') + # Reset file pointer to beginning + uploaded_file.seek(0) + csv_content = uploaded_file.read().decode("utf-8") + + # Reset file pointer again for pandas + uploaded_file.seek(0) df = pd.read_csv(uploaded_file) + st.success(CSV_UPLOAD_SUCCESS.format(shape=df.shape)) - + with st.expander(CSV_PREVIEW): st.dataframe(df.head()) st.write(CSV_COLUMNS.format(columns=list(df.columns))) st.write(CSV_DTYPES.format(dtypes=df.dtypes.to_dict())) - + return csv_content except Exception as e: st.error(CSV_UPLOAD_ERROR.format(error=str(e))) return None - + return None def main(): - configure_page_styles("static/css/styles.css") - orchestrator = initialize_orchestrator() if not orchestrator: st.error(APP_INIT_ERROR) return - + current_session_id = get_current_session_id() if not current_session_id: st.error(SESSION_NOT_FOUND_ERROR) return - + hide_main_menu_and_footer() - + with st.sidebar: - display_github_badge() + st.markdown( + "", + unsafe_allow_html=True, + ) display_logo_and_heading() - st.markdown(MADE_WITH_LOVE) - - st.markdown(SESSIONS_SECTION) - sessions = orchestrator.list_sessions() - - for session in sessions: - if st.button( - session.session_name, - key=f"session_{session.session_id}", - help=f"Messages: {session.message_count}, Last: {session.last_activity}" - ): - st.session_state["current_session_id"] = session.session_id - st.rerun() - - if st.button(NEW_SESSION_BUTTON): + st.markdown("`Made with 🀍`") + st.markdown("### Sessions") + if st.button("βž• New Session"): try: + sessions = orchestrator.list_sessions() new_session = orchestrator.create_new_session( NewChatRequest(session_name=f"Chat {len(sessions) + 1}") ) @@ -129,57 +135,44 @@ def main(): st.rerun() except Exception as e: st.error(SESSION_CREATE_ERROR.format(error=str(e))) - st.markdown("---") st.markdown(CSV_ANALYSIS_SECTION) - + csv_content = upload_csv_file() if csv_content: if st.button(LOAD_CSV_BUTTON): try: - result = orchestrator.load_csv_data(current_session_id, csv_content) + result = orchestrator.load_csv_data( + current_session_id, csv_content) if result["status"] == "success": st.success(CSV_LOADED_SUCCESS) st.session_state["csv_loaded"] = True + st.rerun() # Refresh to show updated state else: st.error(f"❌ Error loading CSV: {result['message']}") except Exception as e: st.error(f"❌ Error: {str(e)}") - + display_welcome_message() display_messages(current_session_id) - + if prompt := st.chat_input(): try: - csv_loaded = st.session_state.get("csv_loaded", False) - - if csv_loaded: - response = orchestrator.generate_csv_analysis_response(current_session_id, prompt) - else: - response = orchestrator.generate_sql_response(current_session_id, prompt) - + # Use intelligent routing for all queries + response = orchestrator.generate_intelligent_response( + current_session_id, prompt + ) + + # Display the response immediately with st.chat_message("user"): st.markdown(prompt) - + with st.chat_message("assistant"): st.markdown(response.content) - - if hasattr(response, 'sql_response') and response.sql_response: - with st.expander("πŸ“Š SQL Details"): - col1, col2 = st.columns(2) - with col1: - st.write("**Query Type:**", response.sql_response.query_type) - st.write("**Complexity:**", response.sql_response.complexity) - st.write("**Tables Used:**", ", ".join(response.sql_response.tables_used)) - with col2: - st.write("**Columns:**", ", ".join(response.sql_response.columns_selected)) - st.write("**Estimated Rows:**", response.sql_response.estimated_rows) - if response.sql_response.warnings: - st.write("**Warnings:**", ", ".join(response.sql_response.warnings)) - + except Exception as e: st.error(RESPONSE_GENERATION_ERROR.format(error=str(e))) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/frontend/frontend.py b/src/frontend/frontend.py index 77374d9..f9f720d 100644 --- a/src/frontend/frontend.py +++ b/src/frontend/frontend.py @@ -15,7 +15,9 @@ def display_welcome_message(): def handle_new_chat(max_chat_histories=5): st.markdown(f"#### Remaining Chat Histories: `{max_chat_histories}`") - st.markdown("You can create multiple chat sessions. Each session can contain unlimited messages.") + st.markdown( + "You can create multiple chat sessions. Each session can contain unlimited messages." + ) if st.button("βž• New chat"): st.rerun() @@ -34,4 +36,4 @@ def update_session_state(chat): def save_chat_history(): - pass \ No newline at end of file + pass diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py index a9c8969..253f1d1 100644 --- a/src/schemas/__init__.py +++ b/src/schemas/__init__.py @@ -1,3 +1,3 @@ """ Pydantic schemas for Querypls application. -""" \ No newline at end of file +""" diff --git a/src/schemas/requests.py b/src/schemas/requests.py index 0242209..fc58d57 100644 --- a/src/schemas/requests.py +++ b/src/schemas/requests.py @@ -8,67 +8,53 @@ class ChatMessage(BaseModel): """Schema for chat message.""" - + role: Literal["user", "assistant", "system"] = Field( description="Message role (user, assistant, system)" ) - content: str = Field( - description="Message content", - min_length=1 - ) + content: str = Field(description="Message content", min_length=1) timestamp: Optional[str] = Field( - default=None, - description="Message timestamp" - ) + default=None, description="Message timestamp") session_id: Optional[str] = Field( - default=None, - description="Session identifier" - ) + default=None, description="Session identifier") class SQLGenerationRequest(BaseModel): """Schema for SQL generation request.""" - + user_query: str = Field( description="User's natural language query for SQL generation", min_length=1, - max_length=1000 + max_length=1000, ) conversation_history: List[ChatMessage] = Field( - default=[], - description="Previous conversation messages for context" + default=[], description="Previous conversation messages for context" ) database_schema: Optional[str] = Field( - default=None, - description="Database schema information (optional)" + default=None, description="Database schema information (optional)" ) query_type: Optional[str] = Field( default=None, - description="Preferred query type (SELECT, INSERT, UPDATE, DELETE)" + description="Preferred query type (SELECT, INSERT, UPDATE, DELETE)", ) class ConversationHistory(BaseModel): """Schema for conversation history.""" - + messages: List[ChatMessage] = Field( - default=[], - description="List of conversation messages" + default=[], description="List of conversation messages" ) session_id: Optional[str] = Field( - default=None, - description="Session identifier" - ) + default=None, description="Session identifier") class NewChatRequest(BaseModel): """Schema for creating a new chat session.""" - + session_name: Optional[str] = Field( - default=None, - description="Name for the new chat session" + default=None, description="Name for the new chat session" ) initial_context: Optional[str] = Field( - default=None, - description="Initial context or instructions" - ) \ No newline at end of file + default=None, description="Initial context or instructions" + ) diff --git a/src/schemas/responses.py b/src/schemas/responses.py index bd2d23f..42a4129 100644 --- a/src/schemas/responses.py +++ b/src/schemas/responses.py @@ -8,136 +8,80 @@ class SQLQueryResponse(BaseModel): """Schema for SQL query generation response.""" - - sql_query: str = Field( - ..., - description="The generated SQL query as a string" - ) + + sql_query: str = Field(..., + description="The generated SQL query as a string") explanation: str = Field( - ..., - description="Brief explanation of what the query does" + ..., description="Brief explanation of what the query does" ) tables_used: List[str] = Field( - default=[], - description="Array of table names used in the query" + default=[], description="Array of table names used in the query" ) columns_selected: List[str] = Field( - default=[], - description="Array of column names selected in the query" - ) - query_type: Literal["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"] = Field( - ..., - description="Type of query generated" + default=[], description="Array of column names selected in the query" ) + query_type: Literal[ + "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER" + ] = Field(..., description="Type of query generated") complexity: Literal["SIMPLE", "MEDIUM", "COMPLEX"] = Field( - ..., - description="Query complexity level" + ..., description="Query complexity level" ) estimated_rows: str = Field( default="variable", - description="Estimated number of rows returned (if applicable)" + description="Estimated number of rows returned (if applicable)", ) execution_time: Optional[str] = Field( - default=None, - description="Estimated execution time" + default=None, description="Estimated execution time" ) warnings: List[str] = Field( - default=[], - description="Any warnings about the generated query" + default=[], description="Any warnings about the generated query" ) class ChatResponse(BaseModel): """Schema for chat response.""" - - message_id: str = Field( - ..., - description="Unique identifier for the message" - ) + + message_id: str = Field(..., + description="Unique identifier for the message") role: Literal["assistant"] = Field( - default="assistant", - description="Message role" - ) - content: str = Field( - ..., - description="Response content" - ) + default="assistant", description="Message role") + content: str = Field(..., description="Response content") sql_response: Optional[SQLQueryResponse] = Field( - default=None, - description="Structured SQL response if applicable" - ) - timestamp: str = Field( - ..., - description="Response timestamp" - ) - session_id: str = Field( - ..., - description="Session identifier" + default=None, description="Structured SQL response if applicable" ) + timestamp: str = Field(..., description="Response timestamp") + session_id: str = Field(..., description="Session identifier") class ErrorResponse(BaseModel): """Schema for error responses.""" - - error_code: str = Field( - ..., - description="Error code identifier" - ) - error_message: str = Field( - ..., - description="Human-readable error message" - ) + + error_code: str = Field(..., description="Error code identifier") + error_message: str = Field(..., description="Human-readable error message") details: Optional[str] = Field( - default=None, - description="Additional error details" - ) - timestamp: str = Field( - ..., - description="Error timestamp" - ) + default=None, description="Additional error details") + timestamp: str = Field(..., description="Error timestamp") class SessionInfo(BaseModel): """Schema for session information.""" - - session_id: str = Field( - ..., - description="Unique session identifier" - ) - session_name: str = Field( - ..., - description="Session name" - ) - created_at: str = Field( - ..., - description="Session creation timestamp" - ) - message_count: int = Field( - ..., - description="Number of messages in the session" - ) - last_activity: str = Field( - ..., - description="Last activity timestamp" - ) + + session_id: str = Field(..., description="Unique session identifier") + session_name: str = Field(..., description="Session name") + created_at: str = Field(..., description="Session creation timestamp") + message_count: int = Field(..., + description="Number of messages in the session") + last_activity: str = Field(..., description="Last activity timestamp") class HealthCheckResponse(BaseModel): """Schema for health check response.""" - + status: Literal["healthy", "unhealthy"] = Field( - ..., - description="Application health status" - ) - version: str = Field( - ..., - description="Application version" - ) - timestamp: str = Field( - ..., - description="Health check timestamp" + ..., description="Application health status" ) + version: str = Field(..., description="Application version") + timestamp: str = Field(..., description="Health check timestamp") services: dict = Field( default={}, - description="Status of individual services" - ) \ No newline at end of file + description="Status of individual services") diff --git a/src/services/__init__.py b/src/services/__init__.py index e809ff8..dfd3691 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -1,3 +1,3 @@ """ Services package for Querypls application. -""" \ No newline at end of file +""" diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py new file mode 100644 index 0000000..14dbc36 --- /dev/null +++ b/src/services/conversation_service.py @@ -0,0 +1,91 @@ +""" +Conversation service for handling normal user queries. +""" + +from typing import Literal, Union +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +from config.settings import get_settings +from services.models import ConversationResponse, Failed +from utils.prompt import CONVERSATION_PROMPT + + +class ConversationService: + def __init__(self): + self.settings = get_settings() + + self.model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key), + ) + + self.conversation_agent = Agent[None, Union[ConversationResponse, Failed]]( + self.model, + output_type=Union[ConversationResponse, Failed], + system_prompt=CONVERSATION_PROMPT, + ) + + def is_conversational_query(self, query: str) -> bool: + """Check if query is conversational (not SQL/data related).""" + conversational_keywords = [ + "hi", + "hello", + "hey", + "good morning", + "good afternoon", + "good evening", + "how are you", + "what's up", + "thanks", + "thank you", + "bye", + "goodbye", + "help", + "what can you do", + "who are you", + "tell me about yourself", + "nice to meet you", + "pleasure", + "good", + "fine", + "okay", + ] + query_lower = query.lower().strip() + return any( + keyword in query_lower for keyword in conversational_keywords) + + def get_conversational_response(self, query: str) -> str: + """Get a natural response for conversational queries.""" + try: + result = self.conversation_agent.run_sync(query) + + if isinstance(result.output, ConversationResponse): + return result.output.message + else: + # Fallback responses + query_lower = query.lower().strip() + + if any( + greeting in query_lower for greeting in [ + "hi", + "hello", + "hey"]): + return "Hello! πŸ‘‹ How can I help you today? I can assist with SQL generation or CSV data analysis." + elif "how are you" in query_lower: + return "I'm doing great, thank you for asking! 😊 How can I assist you with your data queries today?" + elif any(thanks in query_lower for thanks in ["thanks", "thank you"]): + return ( + "You're welcome! 😊 Is there anything else I can help you with?" + ) + elif any(bye in query_lower for bye in ["bye", "goodbye"]): + return "Goodbye! πŸ‘‹ Feel free to come back if you need help with SQL or data analysis." + elif "help" in query_lower or "what can you do" in query_lower: + return "I'm Querypls, your SQL and data analysis assistant! πŸ—ƒοΈπŸ’¬\n\nI can help you with:\nβ€’ **SQL Generation**: Convert natural language to SQL queries\nβ€’ **CSV Analysis**: Analyze data files with Python code\nβ€’ **Data Visualization**: Create charts and graphs\n\nJust ask me anything about your data!" + else: + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + + except Exception as e: + # Fallback response + return "Hello! How can I help you today? I can assist with SQL generation or CSV data analysis." diff --git a/src/services/csv_analysis_tools.py b/src/services/csv_analysis_tools.py index 38b14ca..8a28726 100644 --- a/src/services/csv_analysis_tools.py +++ b/src/services/csv_analysis_tools.py @@ -19,16 +19,21 @@ class CSVAnalysisContext(BaseModel): class PythonCodeResponse(BaseModel): - python_code: str = Field(description="Generated Python code for data analysis") + python_code: str = Field( + description="Generated Python code for data analysis") explanation: str = Field(description="Explanation of what the code does") - expected_output: str = Field(description="What output is expected from the code") + expected_output: str = Field( + description="What output is expected from the code") libraries_used: list = Field(description="List of Python libraries used") class CodeExecutionResult(BaseModel): - status: str = Field(description="Execution status: success, error, or retry") + status: str = Field( + description="Execution status: success, error, or retry") output: str = Field(description="Output from code execution") - error_message: Optional[str] = Field(description="Error message if execution failed") + error_message: Optional[str] = Field( + description="Error message if execution failed" + ) execution_time: float = Field(description="Time taken to execute the code") attempt: int = Field(description="Attempt number") @@ -37,33 +42,36 @@ class CSVAnalysisTools: def __init__(self): self.settings = get_settings() self.csv_service = CSVAnalysisService() - + self.code_generation_model = GroqModel( self.settings.groq_model_name, - provider=GroqProvider(api_key=self.settings.groq_api_key) + provider=GroqProvider(api_key=self.settings.groq_api_key), ) - + self.code_generation_agent = Agent( self.code_generation_model, instructions=CSV_ANALYSIS_PROMPT, - output_type=PythonCodeResponse + output_type=PythonCodeResponse, ) - + self.code_fixing_model = GroqModel( self.settings.groq_model_name, - provider=GroqProvider(api_key=self.settings.groq_api_key) + provider=GroqProvider(api_key=self.settings.groq_api_key), ) - + self.code_fixing_agent = Agent( self.code_fixing_model, instructions=CODE_FIX_PROMPT, - output_type=PythonCodeResponse + output_type=PythonCodeResponse, ) - - def load_csv_data(self, csv_content: str, session_id: str) -> Dict[str, Any]: + + def load_csv_data(self, csv_content: str, + session_id: str) -> Dict[str, Any]: return self.csv_service.load_csv_data(session_id, csv_content) - - def generate_analysis_code(self, user_query: str, csv_context: CSVAnalysisContext) -> PythonCodeResponse: + + def generate_analysis_code( + self, user_query: str, csv_context: CSVAnalysisContext + ) -> PythonCodeResponse: prompt = f""" CSV Headers: {csv_context.csv_headers} Sample Data: {csv_context.sample_data[:3]} @@ -75,22 +83,29 @@ def generate_analysis_code(self, user_query: str, csv_context: CSVAnalysisContex 3. Returns clear output 4. Handles the CSV data properly """ - + result = self.code_generation_agent.run_sync(prompt) return result.output - - def execute_analysis_code(self, python_code: str, session_id: str, max_retries: int = 3) -> CodeExecutionResult: - result = self.csv_service.execute_analysis(session_id, python_code, max_retries) - + + def execute_analysis_code( + self, python_code: str, session_id: str, max_retries: int = 3 + ) -> CodeExecutionResult: + result = self.csv_service.execute_analysis( + session_id, python_code, max_retries) + return CodeExecutionResult( status=result["status"], output=result.get("output", ""), error_message=result.get("error_message"), execution_time=result.get("execution_time", 0.0), - attempt=result.get("attempt", 1) + attempt=result.get("attempt", 1), ) - - def fix_code_error(self, original_code: str, error_message: str, csv_context: CSVAnalysisContext) -> PythonCodeResponse: + + def fix_code_error( + self, + original_code: str, + error_message: str, + csv_context: CSVAnalysisContext) -> PythonCodeResponse: prompt = f""" Original Code: {original_code} @@ -103,54 +118,56 @@ def fix_code_error(self, original_code: str, error_message: str, csv_context: CS Please fix the code to resolve the error and ensure it works correctly. """ - + result = self.code_fixing_agent.run_sync(prompt) return result.output - + def get_csv_info(self, session_id: str) -> Dict[str, Any]: return self.csv_service.get_csv_info(session_id) - + def close_session(self, session_id: str): self.csv_service.close_session(session_id) def create_csv_analysis_agent() -> Agent: settings = get_settings() - + model = GroqModel( settings.groq_model_name, - provider=GroqProvider(api_key=settings.groq_api_key) - ) - - agent = Agent( - model, - instructions=CSV_AGENT_PROMPT, - output_type=str - ) - + provider=GroqProvider( + api_key=settings.groq_api_key)) + + agent = Agent(model, instructions=CSV_AGENT_PROMPT, output_type=str) + csv_tools = CSVAnalysisTools() - + @agent.tool - async def load_csv_data(ctx: RunContext[None], csv_content: str, session_id: str) -> str: + async def load_csv_data( + ctx: RunContext[None], csv_content: str, session_id: str + ) -> str: result = csv_tools.load_csv_data(csv_content, session_id) if result["status"] == "success": - return f"CSV loaded successfully! Shape: {result['shape']}, Columns: {result['columns']}" + return f"CSV loaded successfully! Shape: { + result['shape']}, Columns: { + result['columns']}" else: return f"Error loading CSV: {result['message']}" - + @agent.tool - async def generate_analysis_code(ctx: RunContext[None], user_query: str, session_id: str) -> str: + async def generate_analysis_code( + ctx: RunContext[None], user_query: str, session_id: str + ) -> str: csv_info = csv_tools.get_csv_info(session_id) if csv_info["status"] != "success": return f"Error: {csv_info['message']}" - + csv_context = CSVAnalysisContext( session_id=session_id, csv_content="", csv_headers=csv_info["columns"], - sample_data=csv_info["sample_data"] + sample_data=csv_info["sample_data"], ) - + result = csv_tools.generate_analysis_code(user_query, csv_context) return f"""Generated Python Code: ```python @@ -160,11 +177,13 @@ async def generate_analysis_code(ctx: RunContext[None], user_query: str, session Explanation: {result.explanation} Expected Output: {result.expected_output} Libraries Used: {', '.join(result.libraries_used)}""" - + @agent.tool - async def execute_analysis_code(ctx: RunContext[None], python_code: str, session_id: str) -> str: + async def execute_analysis_code( + ctx: RunContext[None], python_code: str, session_id: str + ) -> str: result = csv_tools.execute_analysis_code(python_code, session_id) - + if result.status == "success": return f"""βœ… Code executed successfully! Execution Time: {result.execution_time:.2f}s @@ -179,21 +198,26 @@ async def execute_analysis_code(ctx: RunContext[None], python_code: str, session Output: {result.output}""" - + @agent.tool - async def fix_code_error(ctx: RunContext[None], original_code: str, error_message: str, session_id: str) -> str: + async def fix_code_error( + ctx: RunContext[None], + original_code: str, + error_message: str, + session_id: str) -> str: csv_info = csv_tools.get_csv_info(session_id) if csv_info["status"] != "success": return f"Error: {csv_info['message']}" - + csv_context = CSVAnalysisContext( session_id=session_id, csv_content="", csv_headers=csv_info["columns"], - sample_data=csv_info["sample_data"] + sample_data=csv_info["sample_data"], ) - - result = csv_tools.fix_code_error(original_code, error_message, csv_context) + + result = csv_tools.fix_code_error( + original_code, error_message, csv_context) return f"""πŸ”§ Fixed Code: ```python {result.python_code} @@ -201,7 +225,7 @@ async def fix_code_error(ctx: RunContext[None], original_code: str, error_messag Explanation: {result.explanation} Expected Output: {result.expected_output}""" - + @agent.tool async def get_csv_info(ctx: RunContext[None], session_id: str) -> str: result = csv_tools.get_csv_info(session_id) @@ -213,5 +237,5 @@ async def get_csv_info(ctx: RunContext[None], session_id: str) -> str: Sample Data: {result['sample_data'][:2]}""" else: return f"Error: {result['message']}" - - return agent \ No newline at end of file + + return agent diff --git a/src/services/jupyter_service.py b/src/services/jupyter_service.py index 83c3dad..22e4692 100644 --- a/src/services/jupyter_service.py +++ b/src/services/jupyter_service.py @@ -3,6 +3,7 @@ """ import os +import io import jupyter_client import inspect import time @@ -11,15 +12,15 @@ from typing import Dict, Any, Optional from dataclasses import dataclass -from config.constants import EXECUTION_TIMEOUT, MAX_RETRIES +from src.config.constants import EXECUTION_TIMEOUT, MAX_RETRIES def clean_error_message(error_msg: str) -> str: - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - cleaned_msg = ansi_escape.sub('', error_msg) - lines = cleaned_msg.split('\n') + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + cleaned_msg = ansi_escape.sub("", error_msg) + lines = cleaned_msg.split("\n") lines = [line.strip() for line in lines if line.strip()] - cleaned_msg = '\n'.join(lines) + cleaned_msg = "\n".join(lines) return cleaned_msg @@ -36,27 +37,58 @@ def __init__(self): self.clients: Dict[str, Any] = {} self.globals: Dict[str, Dict[str, Any]] = {} - def create_new_session(self, session_id: str = "default", kernel_name: str = 'python3') -> str: + def create_new_session( + self, session_id: str = "default", kernel_name: str = "querypls" + ) -> str: if session_id in self.clients: return session_id - km = jupyter_client.KernelManager(kernel_name=kernel_name) - km.start_kernel() - client = km.client() - self.clients[session_id] = client - self.globals[session_id] = {} - - for key, value in os.environ.items(): - self.execute_code(f"{key} = '{value}'", session_id) - - self.execute_code("import pandas as pd", session_id) - self.execute_code("import numpy as np", session_id) - self.execute_code("import matplotlib.pyplot as plt", session_id) - self.execute_code("import seaborn as sns", session_id) - - return session_id - - def execute_code(self, code: str, session_id: str = "default") -> ExecutionResult: + try: + km = jupyter_client.KernelManager(kernel_name=kernel_name) + km.start_kernel() + client = km.client() + self.clients[session_id] = client + self.globals[session_id] = {} + + # Set environment variables + for key, value in os.environ.items(): + self.execute_code(f"{key} = '{value}'", session_id) + + # Import common data science libraries + self.execute_code("import pandas as pd", session_id) + self.execute_code("import numpy as np", session_id) + self.execute_code("import matplotlib.pyplot as plt", session_id) + self.execute_code("import seaborn as sns", session_id) + + return session_id + except Exception as e: + # Fallback to default kernel + try: + km = jupyter_client.KernelManager() + km.start_kernel() + client = km.client() + self.clients[session_id] = client + self.globals[session_id] = {} + + # Set environment variables + for key, value in os.environ.items(): + self.execute_code(f"{key} = '{value}'", session_id) + + # Import common data science libraries + self.execute_code("import pandas as pd", session_id) + self.execute_code("import numpy as np", session_id) + self.execute_code( + "import matplotlib.pyplot as plt", session_id) + self.execute_code("import seaborn as sns", session_id) + + return session_id + except Exception as e2: + raise ValueError(f"Failed to create kernel: {str(e2)}") + + def execute_code( + self, + code: str, + session_id: str = "default") -> ExecutionResult: if session_id not in self.clients: raise ValueError(f"Session {session_id} not found") @@ -72,23 +104,27 @@ def execute_code(self, code: str, session_id: str = "default") -> ExecutionResul while True: try: msg = client.get_iopub_msg(timeout=1) - if 'parent_header' not in msg or msg['parent_header'].get('msg_id') != msg_id: + if ( + "parent_header" not in msg + or msg["parent_header"].get("msg_id") != msg_id + ): continue - msg_type = msg.get('msg_type', '') - content = msg.get('content', {}) + msg_type = msg.get("msg_type", "") + content = msg.get("content", {}) - if msg_type == 'execute_result': - output.append(str(content.get('data', {}).get('text/plain', ''))) - elif msg_type == 'stream': - output.append(content.get('text', '')) - elif msg_type == 'error': - error_traceback = "\n".join(content.get('traceback', [])) + if msg_type == "execute_result": + output.append( + str(content.get("data", {}).get("text/plain", ""))) + elif msg_type == "stream": + output.append(content.get("text", "")) + elif msg_type == "error": + error_traceback = "\n".join(content.get("traceback", [])) cleaned_error = clean_error_message(error_traceback) output.append(f"Error: {cleaned_error}") error_message = cleaned_error - status = 'Fail' - elif msg_type == 'status' and content.get('execution_state') == 'idle': + status = "Fail" + elif msg_type == "status" and content.get("execution_state") == "idle": break except Exception as e: pass @@ -96,13 +132,16 @@ def execute_code(self, code: str, session_id: str = "default") -> ExecutionResul execution_time = time.time() - start_time return ExecutionResult( - output='\n'.join(output).strip(), + output="\n".join(output).strip(), status=status, error_message=error_message, - execution_time=execution_time + execution_time=execution_time, ) - def import_function(self, func, session_id: str = "default") -> ExecutionResult: + def import_function( + self, + func, + session_id: str = "default") -> ExecutionResult: if session_id not in self.globals: raise ValueError(f"Session {session_id} not found") @@ -134,11 +173,13 @@ def __init__(self): self.jupyter_client = SimpleJupyterClient() self.csv_data: Dict[str, pd.DataFrame] = {} self.csv_headers: Dict[str, list] = {} - - def load_csv_data(self, session_id: str, csv_content: str, filename: str = "data.csv") -> Dict[str, Any]: + + def load_csv_data( + self, session_id: str, csv_content: str, filename: str = "data.csv" + ) -> Dict[str, Any]: try: self.jupyter_client.create_new_session(session_id) - + csv_code = f""" import pandas as pd import io @@ -152,44 +193,44 @@ def load_csv_data(self, session_id: str, csv_content: str, filename: str = "data print("\\nFirst few rows:") print(df.head()) """ - + result = self.jupyter_client.execute_code(csv_code, session_id) - + if result.status == "Success": df = pd.read_csv(io.StringIO(csv_content)) self.csv_data[session_id] = df self.csv_headers[session_id] = df.columns.tolist() - + return { "status": "success", "message": "CSV loaded successfully", "shape": df.shape, "columns": df.columns.tolist(), - "sample_data": df.head().to_dict('records') + "sample_data": df.head().to_dict("records"), } else: return { "status": "error", - "message": result.error_message or "Failed to load CSV" + "message": result.error_message or "Failed to load CSV", } - + except Exception as e: - return { - "status": "error", - "message": str(e) - } - - def execute_analysis(self, session_id: str, python_code: str, max_retries: int = MAX_RETRIES) -> Dict[str, Any]: + return {"status": "error", "message": str(e)} + + def execute_analysis( + self, session_id: str, python_code: str, max_retries: int = MAX_RETRIES + ) -> Dict[str, Any]: for attempt in range(max_retries): try: - result = self.jupyter_client.execute_code(python_code, session_id) - + result = self.jupyter_client.execute_code( + python_code, session_id) + if result.status == "Success": return { "status": "success", "output": result.output, "execution_time": result.execution_time, - "attempt": attempt + 1 + "attempt": attempt + 1, } else: if attempt == max_retries - 1: @@ -197,41 +238,43 @@ def execute_analysis(self, session_id: str, python_code: str, max_retries: int = "status": "error", "error_message": result.error_message, "output": result.output, - "attempt": attempt + 1 + "attempt": attempt + 1, } continue - + except Exception as e: if attempt == max_retries - 1: return { "status": "error", "error_message": str(e), - "attempt": attempt + 1 + "attempt": attempt + 1, } continue - + return { "status": "error", "error_message": "Max retries exceeded", - "attempt": max_retries + "attempt": max_retries, } - + def get_csv_info(self, session_id: str) -> Dict[str, Any]: if session_id not in self.csv_data: - return {"status": "error", "message": "No CSV data loaded for this session"} - + return { + "status": "error", + "message": "No CSV data loaded for this session"} + df = self.csv_data[session_id] return { "status": "success", "shape": df.shape, "columns": df.columns.tolist(), "dtypes": df.dtypes.to_dict(), - "sample_data": df.head().to_dict('records') + "sample_data": df.head().to_dict("records"), } - + def close_session(self, session_id: str): self.jupyter_client.close_session(session_id) if session_id in self.csv_data: del self.csv_data[session_id] if session_id in self.csv_headers: - del self.csv_headers[session_id] \ No newline at end of file + del self.csv_headers[session_id] diff --git a/src/services/models.py b/src/services/models.py index 0519ecb..12232c8 100644 --- a/src/services/models.py +++ b/src/services/models.py @@ -1 +1,98 @@ - \ No newline at end of file +""" +Data models for the services. +""" + +from typing import Literal, Union, List, Optional +from pydantic import BaseModel, Field +from datetime import datetime + + +class RoutingDecision(BaseModel): + """Model for routing decisions.""" + + agent: Literal["CONVERSATION_AGENT", "SQL_AGENT", "CSV_AGENT"] = Field( + description="The agent that should handle the query" + ) + confidence: float = Field( + description="Confidence level in the routing decision", ge=0.0, le=1.0 + ) + reasoning: str = Field( + description="Brief explanation of why this agent was chosen") + + +class ConversationResponse(BaseModel): + """Response for conversational queries.""" + + message: str = Field(description="Natural response to user query") + response_type: Literal["greeting", + "help", + "thanks", + "goodbye", + "general"] = Field(description="Type of response") + suggest_next: Optional[str] = Field( + description="Optional suggestion for what they could do next", + default=None) + + +class SQLResponse(BaseModel): + """Response for SQL generation.""" + + sql_query: str = Field(description="The generated SQL query") + explanation: str = Field( + description="Brief explanation of what the query does") + tables_used: List[str] = Field( + description="Array of table names used in the query") + columns_selected: List[str] = Field( + description="Array of column names selected in the query" + ) + query_type: str = Field( + description="Type of query (SELECT, INSERT, UPDATE, DELETE, etc.)" + ) + complexity: Literal["SIMPLE", "MEDIUM", "COMPLEX"] = Field( + description="Query complexity level" + ) + estimated_rows: str = Field( + description="Estimated number of rows returned") + execution_time: Optional[str] = Field( + description="Estimated execution time", default=None + ) + warnings: List[str] = Field( + description="Array of warnings about the query", default_factory=list + ) + + +class CSVAnalysisResponse(BaseModel): + """Response for CSV analysis.""" + + python_code: str = Field(description="The generated Python code") + explanation: str = Field( + description="Brief explanation of what the code does") + expected_output: str = Field( + description="What output is expected from the code") + libraries_used: List[str] = Field( + description="Array of Python libraries used") + + +class CodeFixResponse(BaseModel): + """Response for code fixing.""" + + python_code: str = Field(description="The fixed Python code") + explanation: str = Field(description="Brief explanation of what was fixed") + expected_output: str = Field( + description="What output is expected from the fixed code" + ) + libraries_used: List[str] = Field( + description="Array of Python libraries used") + + +class Failed(BaseModel): + """Unable to find a satisfactory response.""" + + error: str = Field(description="Error message explaining the failure") + + +# Union types for different response types +ConversationResult = Union[ConversationResponse, Failed] +SQLResult = Union[SQLResponse, Failed] +CSVAnalysisResult = Union[CSVAnalysisResponse, Failed] +CodeFixResult = Union[CodeFixResponse, Failed] diff --git a/src/services/routing_service.py b/src/services/routing_service.py new file mode 100644 index 0000000..395d901 --- /dev/null +++ b/src/services/routing_service.py @@ -0,0 +1,298 @@ +""" +Intelligent routing service for determining which agent should handle user queries. +""" + +import json +from typing import List, Optional +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +from src.config.settings import get_settings +from src.services.models import ( + RoutingDecision, + ConversationResult, + SQLResult, + CSVAnalysisResult, +) +from src.schemas.requests import ChatMessage +from utils.prompt import ( + ROUTING_PROMPT, + CONVERSATION_PROMPT, + SQL_GENERATION_PROMPT, + CSV_ANALYSIS_PROMPT, +) + + +class IntelligentRoutingService: + """Service for intelligently routing user queries to appropriate agents.""" + + def __init__(self): + self.settings = get_settings() + + self.model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key), + ) + + # Create routing agent + self.routing_agent = Agent[None, RoutingDecision]( + self.model, output_type=RoutingDecision, system_prompt=ROUTING_PROMPT + ) + + # Create conversation agent + self.conversation_agent = Agent[None, ConversationResult]( + self.model, + output_type=ConversationResult, + system_prompt=CONVERSATION_PROMPT, + ) + + # Create SQL agent + self.sql_agent = Agent[None, SQLResult]( + self.model, output_type=SQLResult, system_prompt=SQL_GENERATION_PROMPT + ) + + # Create CSV analysis agent + self.csv_agent = Agent[None, CSVAnalysisResult]( + self.model, output_type=CSVAnalysisResult, system_prompt=CSV_ANALYSIS_PROMPT + ) + + def determine_agent( + self, + user_query: str, + conversation_history: List[ChatMessage], + csv_loaded: bool = False, + ) -> RoutingDecision: + """Determine which agent should handle the user query.""" + try: + # Prepare context for routing + context = self._prepare_routing_context( + user_query, conversation_history, csv_loaded + ) + + result = self.routing_agent.run_sync(context) + return result.output + + except Exception as e: + # Fallback to simple keyword-based routing + return self._fallback_routing(user_query, csv_loaded) + + def handle_conversation_query(self, user_query: str) -> str: + """Handle conversational queries.""" + try: + result = self.conversation_agent.run_sync(user_query) + + if hasattr(result.output, "message"): + return result.output.message + else: + return self._get_fallback_conversation_response(user_query) + + except Exception as e: + return self._get_fallback_conversation_response(user_query) + + def handle_sql_query( + self, user_query: str, conversation_history: List[ChatMessage] + ) -> str: + """Handle SQL generation queries.""" + try: + context = self._prepare_sql_context( + user_query, conversation_history) + result = self.sql_agent.run_sync(context) + + if hasattr(result.output, "sql_query"): + return self._format_sql_response(result.output) + else: + return "I'm sorry, I couldn't generate a SQL query for that request. Could you please rephrase your question?" + + except Exception as e: + return f"I encountered an error while generating SQL: {str(e)}" + + def handle_csv_query(self, user_query: str, csv_data: str) -> str: + """Handle CSV analysis queries.""" + try: + context = self._prepare_csv_context(user_query, csv_data) + result = self.csv_agent.run_sync(context) + + if hasattr(result.output, "python_code"): + return self._format_csv_response(result.output) + else: + return "I'm sorry, I couldn't generate analysis code for that request. Could you please rephrase your question?" + + except Exception as e: + return ( + f"I encountered an error while generating CSV analysis code: { + str(e)}") + + def _prepare_routing_context( + self, + user_query: str, + conversation_history: List[ChatMessage], + csv_loaded: bool) -> str: + """Prepare context for routing decision.""" + context_parts = [ + f"User Query: {user_query}", + f"CSV Data Loaded: {csv_loaded}", + ] + + if conversation_history: + # Last 5 messages for context + recent_messages = conversation_history[-5:] + context_parts.append("Recent Conversation History:") + for msg in recent_messages: + context_parts.append(f"- {msg.role}: {msg.content}") + + return "\n".join(context_parts) + + def _prepare_sql_context( + self, user_query: str, conversation_history: List[ChatMessage] + ) -> str: + """Prepare context for SQL generation.""" + context_parts = [ + f"User Query: {user_query}", + ] + + if conversation_history: + context_parts.append("Conversation History:") + for msg in conversation_history[-10:]: # Last 10 messages + context_parts.append(f"- {msg.role}: {msg.content}") + + return "\n".join(context_parts) + + def _prepare_csv_context(self, user_query: str, csv_data: str) -> str: + """Prepare context for CSV analysis.""" + return f"User Query: {user_query}\n\nCSV Data Available: Yes\n\nGenerate Python code to analyze the CSV data based on the user's request." + + def _format_sql_response(self, sql_response) -> str: + """Format SQL response for display.""" + response_parts = [ + f"**SQL Query:**\n```sql\n{sql_response.sql_query}\n```", + f"**Explanation:** {sql_response.explanation}", + f"**Query Type:** {sql_response.query_type}", + f"**Complexity:** {sql_response.complexity}", + f"**Tables Used:** {', '.join(sql_response.tables_used)}", + f"**Columns Selected:** {', '.join(sql_response.columns_selected)}", + f"**Estimated Rows:** {sql_response.estimated_rows}", + ] + + if sql_response.warnings: + response_parts.append( + f"**Warnings:** {', '.join(sql_response.warnings)}") + + return "\n\n".join(response_parts) + + def _format_csv_response(self, csv_response) -> str: + """Format CSV analysis response for display.""" + response_parts = [ + f"**Python Code:**\n```python\n{csv_response.python_code}\n```", + f"**Explanation:** {csv_response.explanation}", + f"**Expected Output:** {csv_response.expected_output}", + f"**Libraries Used:** {', '.join(csv_response.libraries_used)}", + ] + + return "\n\n".join(response_parts) + + def _fallback_routing( + self, + user_query: str, + csv_loaded: bool) -> RoutingDecision: + """Fallback routing based on simple keyword matching.""" + query_lower = user_query.lower().strip() + + # Conversation keywords + conversation_keywords = [ + "hi", + "hello", + "hey", + "good morning", + "good afternoon", + "good evening", + "how are you", + "what's up", + "thanks", + "thank you", + "bye", + "goodbye", + "help", + "what can you do", + "who are you", + "tell me about yourself", + ] + + # CSV keywords + csv_keywords = [ + "csv", + "data", + "analyze", + "analysis", + "chart", + "graph", + "plot", + "visualize", + "pandas", + "matplotlib", + "seaborn", + "python", + "code", + "file", + ] + + # SQL keywords + sql_keywords = [ + "select", + "from", + "where", + "join", + "table", + "database", + "query", + "sql", + "insert", + "update", + "delete", + "create", + "alter", + "drop", + ] + + if any(keyword in query_lower for keyword in conversation_keywords): + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.8, + reasoning="Detected conversational keywords", + ) + elif csv_loaded and any(keyword in query_lower for keyword in csv_keywords): + return RoutingDecision( + agent="CSV_AGENT", + confidence=0.7, + reasoning="CSV data loaded and detected CSV-related keywords", + ) + elif any(keyword in query_lower for keyword in sql_keywords): + return RoutingDecision( + agent="SQL_AGENT", + confidence=0.7, + reasoning="Detected SQL-related keywords", + ) + else: + # Default to conversation for unknown queries + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.5, + reasoning="No specific keywords detected, defaulting to conversation", + ) + + def _get_fallback_conversation_response(self, user_query: str) -> str: + """Get fallback conversation response. Wish it dont be here.""" + query_lower = user_query.lower().strip() + + if any(greeting in query_lower for greeting in ["hi", "hello", "hey"]): + return "Hello! πŸ‘‹ I'm Querypls, your SQL and data analysis assistant. I can help you generate SQL queries or analyze CSV files. What would you like to work on today?" + elif "how are you" in query_lower: + return "I'm doing great, thank you for asking! 😊 I'm ready to help you with SQL queries or CSV data analysis. What can I assist you with?" + elif any(thanks in query_lower for thanks in ["thanks", "thank you"]): + return "You're welcome! 😊 Is there anything else I can help you with?" + elif any(bye in query_lower for bye in ["bye", "goodbye"]): + return "Goodbye! πŸ‘‹ Feel free to come back if you need help with SQL or data analysis." + elif "help" in query_lower or "what can you do" in query_lower: + return "I'm Querypls, your data analysis companion! πŸ—ƒοΈπŸ’¬ I can help you with SQL generation and CSV data analysis. Just upload a CSV file or ask me to write SQL queries!" + else: + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" diff --git a/src/services/sql_service.py b/src/services/sql_service.py index 9c2ec8c..174aef0 100644 --- a/src/services/sql_service.py +++ b/src/services/sql_service.py @@ -16,21 +16,22 @@ class SQLGenerationService: def __init__(self, api_key: Optional[str] = None): self.settings = get_settings() self.api_key = api_key or self.settings.groq_api_key - + if not self.api_key: - raise ValueError("Groq API key is required. Set GROQ_API_KEY environment variable or pass api_key parameter.") - + raise ValueError( + "Groq API key is required. Set GROQ_API_KEY environment variable or pass api_key parameter." + ) + self.model = GroqModel( self.settings.groq_model_name, - provider=GroqProvider(api_key=self.api_key) - ) - + provider=GroqProvider( + api_key=self.api_key)) + self.agent = Agent( self.model, instructions=SQL_GENERATION_PROMPT, - output_type=SQLQueryResponse - ) - + output_type=SQLQueryResponse) + def format_chat_history(self, messages: list) -> str: history = [] for msg in messages[1:]: @@ -38,26 +39,28 @@ def format_chat_history(self, messages: list) -> str: content = msg.content role = msg.role else: - content = msg.get('content', '') - role = msg.get('role', 'user') - - if '```sql' in content: - content = content.replace('```sql\n', '').replace('\n```', '').strip() - - history.append({ - "role": role, - "query" if role == 'user' else "response": content - }) - + content = msg.get("content", "") + role = msg.get("role", "user") + + if "```sql" in content: + content = content.replace( + "```sql\n", "").replace( + "\n```", "").strip() + + history.append({"role": role, "query" if role == + "user" else "response": content}) + return json.dumps(history, indent=2) - + def generate_sql(self, request: SQLGenerationRequest) -> ChatResponse: try: - formatted_history = self.format_chat_history(request.conversation_history) - prompt = f"Previous conversation: {formatted_history}\nCurrent question: {request.user_query}" - + formatted_history = self.format_chat_history( + request.conversation_history) + prompt = f"Previous conversation: {formatted_history}\nCurrent question: { + request.user_query}" + result = self.agent.run_sync(prompt) - + sql_response = SQLQueryResponse( sql_query=result.output.sql_query, explanation=result.output.explanation, @@ -67,57 +70,61 @@ def generate_sql(self, request: SQLGenerationRequest) -> ChatResponse: complexity=result.output.complexity, estimated_rows=result.output.estimated_rows, execution_time=result.output.execution_time, - warnings=result.output.warnings + warnings=result.output.warnings, ) - - formatted_content = f"```sql\n{sql_response.sql_query}\n```\n\n**Explanation:** {sql_response.explanation}" - - session_id = 'default' + + formatted_content = f"```sql\n{ + sql_response.sql_query}\n```\n\n**Explanation:** { + sql_response.explanation}" + + session_id = "default" if request.conversation_history: first_msg = request.conversation_history[0] if isinstance(first_msg, ChatMessage): - session_id = first_msg.session_id or 'default' + session_id = first_msg.session_id or "default" else: - session_id = first_msg.get('session_id', 'default') - + session_id = first_msg.get("session_id", "default") + chat_response = ChatResponse( message_id=str(uuid.uuid4()), content=formatted_content, sql_response=sql_response, timestamp=datetime.now().isoformat(), - session_id=session_id + session_id=session_id, ) - + return chat_response - + except Exception as e: error_response = ErrorResponse( error_code="SQL_GENERATION_ERROR", error_message=f"Error generating SQL: {str(e)}", details=str(e), - timestamp=datetime.now().isoformat() + timestamp=datetime.now().isoformat(), ) - - session_id = 'default' + + session_id = "default" if request.conversation_history: first_msg = request.conversation_history[0] if isinstance(first_msg, ChatMessage): - session_id = first_msg.session_id or 'default' + session_id = first_msg.session_id or "default" else: - session_id = first_msg.get('session_id', 'default') - + session_id = first_msg.get("session_id", "default") + return ChatResponse( message_id=str(uuid.uuid4()), content=f"❌ Error: {error_response.error_message}", timestamp=datetime.now().isoformat(), - session_id=session_id + session_id=session_id, ) - - def generate_sql_legacy(self, user_query: str, conversation_history: list) -> str: + + def generate_sql_legacy( + self, + user_query: str, + conversation_history: list) -> str: request = SQLGenerationRequest( - user_query=user_query, - conversation_history=conversation_history + user_query=user_query, conversation_history=conversation_history ) - + response = self.generate_sql(request) - return response.content \ No newline at end of file + return response.content diff --git a/src/terminal/__init__.py b/src/terminal/__init__.py index e08c65f..b353cc9 100644 --- a/src/terminal/__init__.py +++ b/src/terminal/__init__.py @@ -1,3 +1,3 @@ """ Terminal interface package for Querypls. -""" \ No newline at end of file +""" diff --git a/src/terminal/cli.py b/src/terminal/cli.py index 15cc23f..d7040ba 100644 --- a/src/terminal/cli.py +++ b/src/terminal/cli.py @@ -2,6 +2,9 @@ Command-line interface for Querypls SQL generation. """ +from config.constants import DEFAULT_SESSION_NAME +from schemas.requests import NewChatRequest +from backend.orchestrator import BackendOrchestrator import sys import os import json @@ -9,119 +12,123 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from backend.orchestrator import BackendOrchestrator -from schemas.requests import NewChatRequest -from config.constants import ( - CLI_WELCOME, CLI_COMMANDS, CLI_GOODBYE, CLI_UNKNOWN_COMMAND, CLI_ERROR, - SESSION_CREATED, SESSION_SWITCHED, SESSION_NOT_FOUND, NO_ACTIVE_SESSION, - NO_SESSION, RESPONSE_GENERATED, SQL_DETAILS, QUERY_TYPE, COMPLEXITY, - TABLES_USED, COLUMNS, ESTIMATED_ROWS, WARNINGS, CONVERSATION_HISTORY, - HEALTH_CHECK_FAILED, NO_SESSIONS, AVAILABLE_SESSIONS, SESSION_INFO, - SESSION_ID, SESSION_MESSAGES, SESSION_ACTIVITY, HEALTH_CHECK_SUCCESS, - HEALTH_STATUS, HEALTH_VERSION, HEALTH_SERVICES, DEFAULT_SESSION_NAME -) - class QueryplsCLI: def __init__(self): self.orchestrator = BackendOrchestrator() self.current_session_id = None - + def create_session(self, name: Optional[str] = None) -> str: request = NewChatRequest(session_name=name) session_info = self.orchestrator.create_new_session(request) self.current_session_id = session_info.session_id - print(SESSION_CREATED.format(name=session_info.session_name, id=session_info.session_id)) + print( + f"""Session created: { + session_info.session_name} (ID: { + session_info.session_id})""") return session_info.session_id - + def list_sessions(self): sessions = self.orchestrator.list_sessions() if not sessions: - print(NO_SESSIONS) + print("No sessions found.") return - - print(AVAILABLE_SESSIONS) + print("Available sessions:") for i, session in enumerate(sessions, 1): - print(SESSION_INFO.format(num=i, name=session.session_name)) - print(SESSION_ID.format(id=session.session_id)) - print(SESSION_MESSAGES.format(count=session.message_count)) - print(SESSION_ACTIVITY.format(activity=session.last_activity)) + print(f"{i}. {session.session_name}") + print(f" ID: {session.session_id}") + print(f" Messages: {session.message_count}") + print(f" Last activity: {session.last_activity}") print() - + def switch_session(self, session_id: str): session = self.orchestrator.get_session(session_id) if not session: - print(SESSION_NOT_FOUND.format(id=session_id)) + print(f"Session not found: {session_id}") return - self.current_session_id = session_id - print(SESSION_SWITCHED.format(name=session.session_name)) - + print(f"Switched to session: {session.session_name}") + def chat(self, query: str): if not self.current_session_id: - print(NO_ACTIVE_SESSION) + print("No active session. Please create or switch to a session.") return - + try: - response = self.orchestrator.generate_sql_response(self.current_session_id, query) - print(f"\n{RESPONSE_GENERATED}") + response = self.orchestrator.generate_sql_response( + self.current_session_id, query + ) + print("\nResponse generated:") print(response.content) - + if response.sql_response: - print(f"\n{SQL_DETAILS}") - print(f" {QUERY_TYPE}: {response.sql_response.query_type}") - print(f" {COMPLEXITY}: {response.sql_response.complexity}") - print(f" {TABLES_USED}: {', '.join(response.sql_response.tables_used)}") - print(f" {COLUMNS}: {', '.join(response.sql_response.columns_selected)}") - print(f" {ESTIMATED_ROWS}: {response.sql_response.estimated_rows}") + print("\nSQL Details:") + print(f" Query Type: {response.sql_response.query_type}") + print(f" Complexity: {response.sql_response.complexity}") + print(f" Tables Used: {', '.join(response.sql_response.tables_used)}") + print(f" Columns: {', '.join(response.sql_response.columns_selected)}") + print(f" Estimated Rows: {response.sql_response.estimated_rows}") if response.sql_response.warnings: - print(f" {WARNINGS}: {', '.join(response.sql_response.warnings)}") - + print(f" Warnings: {', '.join(response.sql_response.warnings)}") + except Exception as e: - print(CLI_ERROR.format(error=str(e))) - + print(f"Error: {str(e)}") + def show_history(self): if not self.current_session_id: - print(NO_SESSION) + print("No session selected.") return - + try: - conversation = self.orchestrator.get_conversation_history(self.current_session_id) - print(f"\n{CONVERSATION_HISTORY}") + conversation = self.orchestrator.get_conversation_history( + self.current_session_id + ) + print("\nConversation history:") for message in conversation.messages: print(f" {message.role.upper()}: {message.content}") except Exception as e: - print(CLI_ERROR.format(error=str(e))) - + print(f"Error: {str(e)}") + def health_check(self): try: health = self.orchestrator.health_check() - print(f"{HEALTH_CHECK_SUCCESS}") - print(f" {HEALTH_STATUS.format(status=health.status)}") - print(f" {HEALTH_VERSION.format(version=health.version)}") - print(f" {HEALTH_SERVICES.format(services=json.dumps(health.services, indent=2))}") + print("Health check successful.") + print(f" Status: {health.status}") + print(f" Version: {health.version}") + print(f" Services: {json.dumps(health.services, indent=2)}") except Exception as e: - print(HEALTH_CHECK_FAILED.format(error=str(e))) - + print(f"Health check failed: {str(e)}") + def run_interactive(self): - print(CLI_WELCOME) - print(CLI_COMMANDS) + print("Welcome to Querypls CLI!") + print("Commands: new, list, switch , chat , history, health, quit") print() - + self.create_session("CLI Session") - + while True: try: command = input("querypls> ").strip() - + if not command: continue - + + parts = command.split() + cmd = parts[0].lower() + + if cmd == "quit" or cmd == "exit": + print("Goodbye!") + break + elif cmd == "new": + name = " ".join(parts[1:]) if len(parts) > 1 else None + self.create_session(name) + elif cmd == "list": + self.list_sessions() parts = command.split() cmd = parts[0].lower() - + if cmd == "quit" or cmd == "exit": - print(CLI_GOODBYE) + print("Goodbye!") break elif cmd == "new": name = " ".join(parts[1:]) if len(parts) > 1 else None @@ -138,18 +145,18 @@ def run_interactive(self): elif cmd == "health": self.health_check() else: - print(CLI_UNKNOWN_COMMAND) - + print("Unknown command.") + except KeyboardInterrupt: - print(f"\n{CLI_GOODBYE}") + print("\nGoodbye!") break except Exception as e: - print(CLI_ERROR.format(error=str(e))) + print(f"Error: {str(e)}") def main(): cli = QueryplsCLI() - + if len(sys.argv) > 1: command = sys.argv[1] if command == "new": @@ -170,4 +177,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/static/css/styles.css b/static/css/styles.css deleted file mode 100644 index 9ff0c4b..0000000 --- a/static/css/styles.css +++ /dev/null @@ -1,190 +0,0 @@ -/* Main container styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2){ - background-color: rgb(233, 240, 255); - border: 1px solid #152544; - border-radius: 10px; - padding: 10px; -} - -/* Chat message styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(3){ - background-color: #152544; - color: rgb(17, 17, 17); -} - -/* Main content area */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4{ - background-color: rgb(255, 255, 255); -} - -/* Additional chat elements */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(5){ - background-color: #152544; -} - -/* Code styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(3) > div > div > p > code{ - color: white; - background-color: #152544; -} - -/* Sidebar styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section.st-emotion-cache-nju155.eczjsme11 > div.st-emotion-cache-6qob1r.eczjsme3{ - background-color: rgb(184, 205, 252); -} - -/* Main section styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5{ - border: 2px solid rgb(0, 2, 8); - border-radius: 0px; - padding: 10px; -} - -/* Chat message styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4{ - background-color: rgb(233, 240, 255); - color: #152544; -} - -/* Chat message content */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4 > div.st-emotion-cache-14m9yky.eeusbqq3{ - color: #152544; - background-color: rgb(233, 240, 255); -} - -/* Button styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.element-container.st-emotion-cache-10gv909.e1f1d6gn3 > div > div > div > div.st-bg.st-b4.st-bh.st-co.st-bj.st-bk.st-bl.st-bm.st-bn.st-bo.st-bp.st-bq.st-br.st-b2.st-bs.st-av.st-ay.st-aw.st-ax.st-bt.st-bu.st-bv.st-bw.st-bx.st-by.st-bz.st-c0 > div{ - background-color: rgb(12, 51, 158); -} - -/* Header styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > header{ - background-color: rgb(233, 240, 255); - color: #152544; -} - -/* Button icon styling */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > div > button > svg{ - color: #152544; - width: 20px; -} - -/* Sidebar general styling */ -.sidebar{ - background-color: rgb(233, 240, 255); -} - -/* Streamlit button styling */ -.stButton { - background-color: rgb(218, 218, 230); - color: #152544; - border: 3px; - border-radius: 5px; - width: 40px; - cursor: pointer; -} - -/* Row widget styling */ -.row-widget { - margin-bottom: 10px; - background-color: rgb(233, 240, 255); - color: #152544; -} - -/* Sidebar specific elements */ -#root > div:nth-child(1) > div.withScreencast > div > div > div > section.st-emotion-cache-nju155.eczjsme11 > div.st-emotion-cache-6qob1r.eczjsme3 > div.st-emotion-cache-16txtl3.eczjsme4 > div > div > div > div:nth-child(5){ - background-color: rgb(0, 0, 0); - color: #152544; -} - -/* Hide menu button */ -#root > div:nth-child(1) > div.withScreencast > div > div > header > div.st-emotion-cache-zq5wmm.ezrtsby0 > div > div:nth-child(2) > button > div > div { - display: none; -} - -/* Sidebar button styling - improved */ -.sidebar .stButton > button { - background-color: #4CAF50 !important; - color: white !important; - border: none !important; - padding: 8px 12px !important; - text-align: center !important; - text-decoration: none !important; - display: inline-block !important; - font-size: 12px !important; - margin: 2px !important; - cursor: pointer !important; - border-radius: 4px !important; - width: auto !important; - min-width: 80px !important; - transition: background-color 0.3s !important; -} - -.sidebar .stButton > button:hover { - background-color: #45a049 !important; -} - -/* Session button styling */ -.sidebar button[key*="session_"] { - background-color: #2196F3 !important; - color: white !important; - border: none !important; - padding: 6px 10px !important; - text-align: center !important; - text-decoration: none !important; - display: inline-block !important; - font-size: 11px !important; - margin: 2px !important; - cursor: pointer !important; - border-radius: 3px !important; - width: auto !important; - min-width: 60px !important; - transition: background-color 0.3s !important; -} - -.sidebar button[key*="session_"]:hover { - background-color: #1976D2 !important; -} - -/* New session button styling */ -.sidebar button:contains("βž• New Session") { - background-color: #FF9800 !important; - color: white !important; - border: none !important; - padding: 8px 12px !important; - text-align: center !important; - text-decoration: none !important; - display: inline-block !important; - font-size: 12px !important; - margin: 2px !important; - cursor: pointer !important; - border-radius: 4px !important; - width: auto !important; - min-width: 100px !important; - transition: background-color 0.3s !important; -} - -.sidebar button:contains("βž• New Session"):hover { - background-color: #F57C00 !important; -} - -/* Force horizontal layout for sidebar buttons */ -.sidebar [data-testid="stButton"] { - display: inline-block !important; - float: left !important; - margin: 2px !important; -} - -/* Sidebar button wrapper */ -.sidebar .element-container { - display: inline-block !important; - margin: 2px !important; -} - -/* Sidebar button container */ -.sidebar .stButton { - display: inline-block !important; - margin: 2px !important; - width: auto !important; - float: left !important; -} \ No newline at end of file diff --git a/test_routing.py b/test_routing.py new file mode 100644 index 0000000..d4eaea6 --- /dev/null +++ b/test_routing.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +""" +Simple test script to verify the routing logic works correctly. +""" + +import sys +import os + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from src.services.routing_service import IntelligentRoutingService +from src.schemas.requests import ChatMessage + + +def test_routing(): + """Test the routing service with different types of queries.""" + routing_service = IntelligentRoutingService() + + # Test cases + test_cases = [ + ("Hello", "CONVERSATION_AGENT"), + ("How are you?", "CONVERSATION_AGENT"), + ("Show me all users", "SQL_AGENT"), + ("SELECT * FROM users", "SQL_AGENT"), + ("Analyze this CSV data", "CSV_AGENT"), + ("Create a chart from the data", "CSV_AGENT"), + ("What can you do?", "CONVERSATION_AGENT"), + ("Thanks for your help", "CONVERSATION_AGENT"), + ] + + print("Testing Intelligent Routing Service...") + print("=" * 50) + + for query, expected_agent in test_cases: + print(f"\nQuery: '{query}'") + print(f"Expected Agent: {expected_agent}") + + try: + decision = routing_service.determine_agent(query, [], csv_loaded=False) + print(f"Actual Agent: {decision.agent}") + print(f"Confidence: {decision.confidence}") + print(f"Reasoning: {decision.reasoning}") + + if decision.agent == expected_agent: + print("βœ… PASS") + else: + print("❌ FAIL") + + except Exception as e: + print(f"❌ ERROR: {str(e)}") + + print("\n" + "=" * 50) + print("Testing conversation responses...") + + conversation_tests = [ + "Hello", + "How are you?", + "What can you do?", + "Thanks", + "Goodbye", + ] + + for query in conversation_tests: + print(f"\nQuery: '{query}'") + try: + response = routing_service.handle_conversation_query(query) + print(f"Response: {response}") + except Exception as e: + print(f"❌ ERROR: {str(e)}") + + +if __name__ == "__main__": + test_routing() diff --git a/tests/test_auth.py b/tests/test_auth.py index 2b9366a..a6aa742 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -14,13 +14,9 @@ async def test_get_authorization_url(): client = GoogleOAuth2("client_id", "client_secret") redirect_uri = "http://example.com/callback" - with patch.object( - client, "get_authorization_url", new=AsyncMock() - ) as mock_method: + with patch.object(client, "get_authorization_url", new=AsyncMock()) as mock_method: await get_authorization_url(client, redirect_uri) - mock_method.assert_called_once_with( - redirect_uri, scope=["profile", "email"] - ) + mock_method.assert_called_once_with(redirect_uri, scope=["profile", "email"]) @pytest.mark.asyncio @@ -28,9 +24,7 @@ async def test_get_access_token(): client = GoogleOAuth2("client_id", "client_secret") redirect_uri = "http://example.com/callback" code = "code" - with patch.object( - client, "get_access_token", new=AsyncMock() - ) as mock_method: + with patch.object(client, "get_access_token", new=AsyncMock()) as mock_method: await get_access_token(client, redirect_uri, code) mock_method.assert_called_once_with(code, redirect_uri) diff --git a/tests/test_backend.py b/tests/test_backend.py index 9bc9840..6c87ab7 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -31,16 +31,6 @@ def mock_oauth2_component(): yield mock_oauth2_component -def test_configure_page_styles(mock_open, mock_markdown, mock_set_page_config): - mock_open.return_value.__enter__.return_value.read.return_value = "test" - configure_page_styles("test_file") - mock_set_page_config.assert_called_once_with( - page_title="Querypls", page_icon="πŸ’¬", layout="wide" - ) - mock_markdown.assert_called() - mock_open.assert_called_once_with("test_file") - - def test_hide_main_menu_and_footer(mock_markdown): hide_main_menu_and_footer() mock_markdown.assert_called_once_with( diff --git a/tests/test_frontend.py b/tests/test_frontend.py index b6e4979..49ce111 100644 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -46,9 +46,7 @@ def test_display_welcome_message(mock_st): with patch.object(st, "markdown") as mock_markdown: with patch.object(st, "session_state", MockSessionState()): initialize_session_state( - messages=[ - {"role": "assistant", "content": "How may I help you?"} - ] + messages=[{"role": "assistant", "content": "How may I help you?"}] ) display_welcome_message() mock_markdown.assert_called_once_with( @@ -63,13 +61,9 @@ def test_handle_new_chat(mock_db, mock_st): st, "button" ) as mock_button: with patch.object(st, "session_state", MockSessionState()): - initialize_session_state( - messages=[], user_email="test@example.com" - ) + initialize_session_state(messages=[], user_email="test@example.com") handle_new_chat(mock_db, max_chat_histories=5) - mock_markdown.assert_called_once_with( - " #### Remaining Chats: `5/5`" - ) + mock_markdown.assert_called_once_with(" #### Remaining Chats: `5/5`") mock_button.assert_called_once_with("βž• New chat") diff --git a/utils/prompt.py b/utils/prompt.py index e9bed17..f53011f 100644 --- a/utils/prompt.py +++ b/utils/prompt.py @@ -2,10 +2,74 @@ Instruction prompts for Querypls application. """ -SQL_GENERATION_PROMPT = """You are a SQL expert developer. Analyze the following conversation history and generate appropriate SQL code based on the context and current question. +# Intelligent routing prompt to determine which agent to use +ROUTING_PROMPT = """You are an intelligent router that determines which specialized agent should handle a user query. -Previous conversation: {conversation_history} -Current question: {input} +Analyze the user query and conversation history to determine the appropriate agent. + +## Available Agents: +1. **CONVERSATION_AGENT**: For greetings, casual chat, help requests, thanks, goodbyes +2. **SQL_AGENT**: For database queries, data retrieval, data manipulation, SQL generation +3. **CSV_AGENT**: For CSV data analysis, data visualization, Python code generation for CSV files + +## Decision Criteria: +- **CONVERSATION_AGENT**: Greetings, casual questions, help requests, thanks, goodbyes, general chat +- **SQL_AGENT**: Database queries, table operations, data retrieval, SQL-related questions +- **CSV_AGENT**: CSV analysis, data visualization, Python code for data analysis, file operations + +## Response Format: +{ + "agent": "CONVERSATION_AGENT|SQL_AGENT|CSV_AGENT", + "confidence": 0.95, + "reasoning": "Brief explanation of why this agent was chosen" +} + +## Examples: +- "Hello" β†’ CONVERSATION_AGENT +- "Show me all users" β†’ SQL_AGENT +- "Analyze this CSV data" β†’ CSV_AGENT +- "How are you?" β†’ CONVERSATION_AGENT +- "SELECT * FROM users" β†’ SQL_AGENT +- "Create a chart from the data" β†’ CSV_AGENT + +Respond only with the JSON object.""" + +CONVERSATION_PROMPT = """You are a friendly AI assistant for Querypls. Respond naturally and conversationally to user queries. + +## Your Role: +- Be warm, helpful, and engaging +- Keep responses concise but friendly +- Guide users to your SQL and CSV analysis capabilities when appropriate +- Don't generate code unless specifically asked + +## Response Guidelines: +- **Greetings**: Respond warmly and mention your capabilities +- **Help requests**: Explain what you can do (SQL generation, CSV analysis) +- **Thanks**: Be polite and encouraging +- **Goodbyes**: Be courteous and welcoming for future interactions +- **General questions**: Answer naturally, guide to your tools if relevant + +## Response Format: +{ + "message": "Your natural response to the user", + "response_type": "greeting|help|thanks|goodbye|general", + "suggest_next": "Optional suggestion for what they could do next" +} + +## Examples: +- User: "Hello" β†’ "Hi there! πŸ‘‹ I'm Querypls, your SQL and data analysis assistant. I can help you generate SQL queries or analyze CSV files. What would you like to work on today?" +- User: "How are you?" β†’ "I'm doing great, thank you for asking! 😊 I'm ready to help you with SQL queries or CSV data analysis. What can I assist you with?" +- User: "What can you do?" β†’ "I'm Querypls, your data analysis companion! πŸ—ƒοΈπŸ’¬ I can help you with SQL generation and CSV data analysis. Just upload a CSV file or ask me to write SQL queries!" + +Respond only with the JSON object.""" + +SQL_GENERATION_PROMPT = """You are a SQL expert developer. Generate appropriate SQL code based on the user query and conversation context. + +## Response Guidelines: +- Generate SQL queries for data-related questions +- Provide clear explanations of what the query does +- Include proper table and column information +- Handle different query types appropriately ## Response Format Your response must be in JSON format. From a6de1d75319f4ac54780dbe6b50aff5a69615d60 Mon Sep 17 00:00:00 2001 From: samadpls Date: Thu, 7 Aug 2025 23:27:13 +0500 Subject: [PATCH 3/8] refactor: Update import paths and clean up code structure; add basic usage demo for Querypls functionality --- .devcontainer/devcontainer.json | 33 ----- .gitignore | 2 + examples/basic_usage_demo.py | 190 +++++++++++++++++++++++++++ src/backend/backend.py | 2 +- src/backend/orchestrator.py | 48 +++---- src/config/constants.py | 39 +++++- src/config/settings.py | 3 +- src/frontend/app.py | 19 +-- src/schemas/responses.py | 3 +- src/services/conversation_service.py | 13 +- src/services/csv_analysis_tools.py | 25 ++-- src/services/jupyter_service.py | 14 +- src/services/models.py | 12 +- src/services/routing_service.py | 112 +++++++++++++--- src/services/sql_service.py | 30 ++--- src/terminal/cli.py | 21 +-- test_routing.py | 73 ---------- utils/prompt.py | 27 +++- 18 files changed, 422 insertions(+), 244 deletions(-) delete mode 100644 .devcontainer/devcontainer.json create mode 100644 examples/basic_usage_demo.py delete mode 100644 test_routing.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index 91f7e38..0000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "name": "Python 3", - // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile - "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", - "customizations": { - "codespaces": { - "openFiles": [ - "README.md", - "src/app.py" - ] - }, - "vscode": { - "settings": {}, - "extensions": [ - "ms-python.python", - "ms-python.vscode-pylance" - ] - } - }, - "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y SessionInfo: messages = [] if request.initial_context: - messages.append( - ChatMessage( - role="system", - content=request.initial_context)) + messages.append(ChatMessage( + role="system", content=request.initial_context)) messages.append(ChatMessage(role="assistant", content=WELCOME_MESSAGE)) @@ -113,8 +99,7 @@ def delete_session(self, session_id: str) -> bool: return True return False - def load_csv_data(self, session_id: str, - csv_content: str) -> Dict[str, Any]: + def load_csv_data(self, session_id: str, csv_content: str) -> Dict[str, Any]: session = self.get_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") @@ -134,9 +119,8 @@ def generate_intelligent_response( raise ValueError(f"Session {session_id} not found") user_message = ChatMessage( - role="user", - content=user_query, - timestamp=datetime.now().isoformat()) + role="user", content=user_query, timestamp=datetime.now().isoformat() + ) session.messages.append(user_message) # Determine which agent should handle this query @@ -157,7 +141,7 @@ def generate_intelligent_response( elif routing_decision.agent == "CSV_AGENT": if session.csv_data: response_content = self.routing_service.handle_csv_query( - user_query, session.csv_data + user_query, session.csv_data, session.messages ) else: response_content = "I don't see any CSV data loaded. Please upload a CSV file first to analyze it." @@ -187,9 +171,7 @@ def get_conversation_history(self, session_id: str) -> ConversationHistory: if not session: raise ValueError(f"Session {session_id} not found") - return ConversationHistory( - messages=session.messages, - session_id=session_id) + return ConversationHistory(messages=session.messages, session_id=session_id) def get_csv_info(self, session_id: str) -> Dict[str, Any]: return self.csv_tools.get_csv_info(session_id) diff --git a/src/config/constants.py b/src/config/constants.py index 1912a55..f92c648 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -10,10 +10,41 @@ STREAMLIT_HOST = "localhost" # Streamlit Configuration -STREAMLIT_CONFIG = { - "page_title": "Querypls", - "page_icon": "πŸ’¬", - "layout": "wide"} +STREAMLIT_CONFIG = {"page_title": "Querypls", + "page_icon": "πŸ’¬", "layout": "wide"} + +# Welcome and Session Messages +WELCOME_MESSAGE = "Hello! πŸ‘‹ I'm Querypls, your SQL and data analysis assistant. I can help you generate SQL queries or analyze CSV files. What would you like to work on today?" +DEFAULT_SESSION_NAME = "Default Chat" + +# CSV Analysis Section +CSV_ANALYSIS_SECTION = "### πŸ“Š CSV Analysis" +CSV_UPLOAD_LABEL = "Upload CSV File" +CSV_UPLOAD_HELP = "Upload a CSV file to analyze with Python code" +CSV_PREVIEW = "πŸ“‹ CSV Preview" +CSV_COLUMNS = "**Columns:** {columns}" +CSV_DTYPES = "**Data Types:** {dtypes}" +LOAD_CSV_BUTTON = "πŸ“Š Load CSV Data" +CSV_LOADED_SUCCESS = "βœ… CSV data loaded successfully!" +CSV_UPLOAD_SUCCESS = "βœ… CSV uploaded successfully! Shape: {shape}" +CSV_UPLOAD_ERROR = "❌ Error uploading CSV: {error}" +CSV_LOAD_ERROR = "❌ No CSV data loaded. Please upload a CSV file first." +CSV_ANALYSIS_ERROR = "❌ Error analyzing CSV: {error}" + +# Session Management +SESSIONS_SECTION = "### πŸ’¬ Chat Sessions" +NEW_SESSION_BUTTON = "βž• New Session" +SESSION_CREATE_ERROR = "❌ Error creating session: {error}" +SESSION_NOT_FOUND_ERROR = "❌ Session not found" + +# Application Errors +ORCHESTRATOR_INIT_ERROR = "❌ Error initializing orchestrator: {error}" +APP_INIT_ERROR = "❌ Error initializing application" +RESPONSE_GENERATION_ERROR = "❌ Error generating response: {error}" +MESSAGE_LOAD_ERROR = "❌ Error loading messages: {error}" + +# UI Elements +MADE_WITH_LOVE = "Made with 🀍" # Available Models AVAILABLE_MODELS = { diff --git a/src/config/settings.py b/src/config/settings.py index 8a53346..50ba553 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -13,8 +13,7 @@ class Settings(BaseSettings): groq_api_key: str = Field(..., env="GROQ_API_KEY") groq_model_name: str = Field( - default="openai/gpt-oss-120b", - env="GROQ_MODEL_NAME") + default="openai/gpt-oss-120b", env="GROQ_MODEL_NAME") app_version: str = Field(default="1.0.0", env="APP_VERSION") max_chat_histories: int = Field(default=5, env="MAX_CHAT_HISTORIES") debug_mode: bool = Field(default=False, env="DEBUG_MODE") diff --git a/src/frontend/app.py b/src/frontend/app.py index c527cd2..ff6eed2 100644 --- a/src/frontend/app.py +++ b/src/frontend/app.py @@ -2,8 +2,8 @@ Main Streamlit application for Querypls. """ -from schemas.requests import NewChatRequest -from config.constants import ( +from src.schemas.requests import NewChatRequest +from src.config.constants import ( CSV_ANALYSIS_SECTION, CSV_UPLOAD_LABEL, CSV_UPLOAD_HELP, @@ -24,21 +24,22 @@ SESSIONS_SECTION, NEW_SESSION_BUTTON, ) -from frontend import ( - display_logo_and_heading, - display_welcome_message, -) -from backend.backend import ( +from src.frontend.frontend import display_logo_and_heading, display_welcome_message +from src.backend.backend import ( display_github_badge, hide_main_menu_and_footer, ) -from backend.orchestrator import BackendOrchestrator +from src.backend.orchestrator import BackendOrchestrator import streamlit as st import sys import os import pandas as pd -sys.path.append(os.path.dirname(os.path.abspath(__file__))) +# Add the project root to Python path +project_root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, project_root) def initialize_orchestrator(): diff --git a/src/schemas/responses.py b/src/schemas/responses.py index 42a4129..644d136 100644 --- a/src/schemas/responses.py +++ b/src/schemas/responses.py @@ -83,5 +83,4 @@ class HealthCheckResponse(BaseModel): version: str = Field(..., description="Application version") timestamp: str = Field(..., description="Health check timestamp") services: dict = Field( - default={}, - description="Status of individual services") + default={}, description="Status of individual services") diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py index 14dbc36..5dadca8 100644 --- a/src/services/conversation_service.py +++ b/src/services/conversation_service.py @@ -7,8 +7,8 @@ from pydantic_ai.models.groq import GroqModel from pydantic_ai.providers.groq import GroqProvider -from config.settings import get_settings -from services.models import ConversationResponse, Failed +from src.config.settings import get_settings +from src.services.models import ConversationResponse, Failed from utils.prompt import CONVERSATION_PROMPT @@ -53,8 +53,7 @@ def is_conversational_query(self, query: str) -> bool: "okay", ] query_lower = query.lower().strip() - return any( - keyword in query_lower for keyword in conversational_keywords) + return any(keyword in query_lower for keyword in conversational_keywords) def get_conversational_response(self, query: str) -> str: """Get a natural response for conversational queries.""" @@ -67,11 +66,7 @@ def get_conversational_response(self, query: str) -> str: # Fallback responses query_lower = query.lower().strip() - if any( - greeting in query_lower for greeting in [ - "hi", - "hello", - "hey"]): + if any(greeting in query_lower for greeting in ["hi", "hello", "hey"]): return "Hello! πŸ‘‹ How can I help you today? I can assist with SQL generation or CSV data analysis." elif "how are you" in query_lower: return "I'm doing great, thank you for asking! 😊 How can I assist you with your data queries today?" diff --git a/src/services/csv_analysis_tools.py b/src/services/csv_analysis_tools.py index 8a28726..311f0da 100644 --- a/src/services/csv_analysis_tools.py +++ b/src/services/csv_analysis_tools.py @@ -6,8 +6,8 @@ from pydantic_ai.providers.groq import GroqProvider from pydantic import BaseModel, Field -from config.settings import get_settings -from services.jupyter_service import CSVAnalysisService +from src.config.settings import get_settings +from src.services.jupyter_service import CSVAnalysisService from utils.prompt import CSV_ANALYSIS_PROMPT, CODE_FIX_PROMPT, CSV_AGENT_PROMPT @@ -65,8 +65,7 @@ def __init__(self): output_type=PythonCodeResponse, ) - def load_csv_data(self, csv_content: str, - session_id: str) -> Dict[str, Any]: + def load_csv_data(self, csv_content: str, session_id: str) -> Dict[str, Any]: return self.csv_service.load_csv_data(session_id, csv_content) def generate_analysis_code( @@ -102,10 +101,8 @@ def execute_analysis_code( ) def fix_code_error( - self, - original_code: str, - error_message: str, - csv_context: CSVAnalysisContext) -> PythonCodeResponse: + self, original_code: str, error_message: str, csv_context: CSVAnalysisContext + ) -> PythonCodeResponse: prompt = f""" Original Code: {original_code} @@ -133,9 +130,9 @@ def create_csv_analysis_agent() -> Agent: settings = get_settings() model = GroqModel( - settings.groq_model_name, - provider=GroqProvider( - api_key=settings.groq_api_key)) + settings.groq_model_name, provider=GroqProvider( + api_key=settings.groq_api_key) + ) agent = Agent(model, instructions=CSV_AGENT_PROMPT, output_type=str) @@ -201,10 +198,8 @@ async def execute_analysis_code( @agent.tool async def fix_code_error( - ctx: RunContext[None], - original_code: str, - error_message: str, - session_id: str) -> str: + ctx: RunContext[None], original_code: str, error_message: str, session_id: str + ) -> str: csv_info = csv_tools.get_csv_info(session_id) if csv_info["status"] != "success": return f"Error: {csv_info['message']}" diff --git a/src/services/jupyter_service.py b/src/services/jupyter_service.py index 22e4692..bc3c220 100644 --- a/src/services/jupyter_service.py +++ b/src/services/jupyter_service.py @@ -85,10 +85,7 @@ def create_new_session( except Exception as e2: raise ValueError(f"Failed to create kernel: {str(e2)}") - def execute_code( - self, - code: str, - session_id: str = "default") -> ExecutionResult: + def execute_code(self, code: str, session_id: str = "default") -> ExecutionResult: if session_id not in self.clients: raise ValueError(f"Session {session_id} not found") @@ -138,10 +135,7 @@ def execute_code( execution_time=execution_time, ) - def import_function( - self, - func, - session_id: str = "default") -> ExecutionResult: + def import_function(self, func, session_id: str = "default") -> ExecutionResult: if session_id not in self.globals: raise ValueError(f"Session {session_id} not found") @@ -259,9 +253,7 @@ def execute_analysis( def get_csv_info(self, session_id: str) -> Dict[str, Any]: if session_id not in self.csv_data: - return { - "status": "error", - "message": "No CSV data loaded for this session"} + return {"status": "error", "message": "No CSV data loaded for this session"} df = self.csv_data[session_id] return { diff --git a/src/services/models.py b/src/services/models.py index 12232c8..c6b3c24 100644 --- a/src/services/models.py +++ b/src/services/models.py @@ -24,14 +24,12 @@ class ConversationResponse(BaseModel): """Response for conversational queries.""" message: str = Field(description="Natural response to user query") - response_type: Literal["greeting", - "help", - "thanks", - "goodbye", - "general"] = Field(description="Type of response") + response_type: Literal["greeting", "help", "thanks", "goodbye", "general"] = Field( + description="Type of response" + ) suggest_next: Optional[str] = Field( - description="Optional suggestion for what they could do next", - default=None) + description="Optional suggestion for what they could do next", default=None + ) class SQLResponse(BaseModel): diff --git a/src/services/routing_service.py b/src/services/routing_service.py index 395d901..c5d641c 100644 --- a/src/services/routing_service.py +++ b/src/services/routing_service.py @@ -107,27 +107,87 @@ def handle_sql_query( except Exception as e: return f"I encountered an error while generating SQL: {str(e)}" - def handle_csv_query(self, user_query: str, csv_data: str) -> str: + def handle_csv_query( + self, + user_query: str, + csv_data: str, + conversation_history: List[ChatMessage] = None, + ) -> str: """Handle CSV analysis queries.""" try: - context = self._prepare_csv_context(user_query, csv_data) + # Use the AI agent to generate code based on user request and conversation history + context = self._prepare_csv_context( + user_query, csv_data, conversation_history + ) result = self.csv_agent.run_sync(context) if hasattr(result.output, "python_code"): - return self._format_csv_response(result.output) + # Execute the generated code using Jupyter service + return self._execute_csv_analysis( + result.output.python_code, csv_data, result.output.explanation + ) else: return "I'm sorry, I couldn't generate analysis code for that request. Could you please rephrase your question?" except Exception as e: return ( - f"I encountered an error while generating CSV analysis code: { - str(e)}") + f"I encountered an error while generating CSV analysis code: {str(e)}" + ) + + def _execute_csv_analysis( + self, python_code: str, csv_data: str, explanation: str + ) -> str: + """Execute CSV analysis code using Jupyter service.""" + try: + from src.services.jupyter_service import CSVAnalysisService + + # Create Jupyter service instance + jupyter_service = CSVAnalysisService() + + # Create a temporary session for this analysis + session_id = "csv_analysis_temp" + + # Load CSV data into the session + jupyter_service.load_csv_data(session_id, csv_data) + + # Prepare the code with CSV data variable + full_code = f"csv_data = '''{csv_data}'''\n\n{python_code}" + + # Execute the code + result = jupyter_service.execute_analysis( + session_id, full_code, max_retries=1 + ) + + if result["status"] == "success": + output = result.get("output", "") + + # Check if any images were created + import os + import glob + + image_files = glob.glob("/tmp/*.png") + + if image_files: + image_info = "\n\n**Generated Images:**\n" + for img_file in image_files: + image_info += f"- {os.path.basename(img_file)}\n" + output += image_info + + return f"""**Analysis Results:** + +{output} + +**Explanation:** {explanation}""" + else: + error_msg = result.get("error_message", "Unknown error") + return f"❌ Error executing analysis: {error_msg}" + + except Exception as e: + return f"❌ Error in CSV analysis: {str(e)}" def _prepare_routing_context( - self, - user_query: str, - conversation_history: List[ChatMessage], - csv_loaded: bool) -> str: + self, user_query: str, conversation_history: List[ChatMessage], csv_loaded: bool + ) -> str: """Prepare context for routing decision.""" context_parts = [ f"User Query: {user_query}", @@ -158,9 +218,34 @@ def _prepare_sql_context( return "\n".join(context_parts) - def _prepare_csv_context(self, user_query: str, csv_data: str) -> str: + def _prepare_csv_context( + self, + user_query: str, + csv_data: str, + conversation_history: List[ChatMessage] = None, + ) -> str: """Prepare context for CSV analysis.""" - return f"User Query: {user_query}\n\nCSV Data Available: Yes\n\nGenerate Python code to analyze the CSV data based on the user's request." + context_parts = [ + f"User Query: {user_query}", + f"CSV Data Available: Yes", + f"CSV Data Format: The CSV data is available as a string variable called 'csv_data' in the session", + f"CSV Content Preview: {csv_data[:200]}...", + ] + + if conversation_history: + context_parts.append("Conversation History:") + # Last 5 messages for context + for msg in conversation_history[-5:]: + context_parts.append(f"- {msg.role}: {msg.content}") + + context_parts.append( + "\nGenerate Python code to analyze the CSV data based on the user's request." + ) + context_parts.append( + "IMPORTANT: Use pd.read_csv(StringIO(csv_data)) to load the data, NOT file paths!" + ) + + return "\n".join(context_parts) def _format_sql_response(self, sql_response) -> str: """Format SQL response for display.""" @@ -191,10 +276,7 @@ def _format_csv_response(self, csv_response) -> str: return "\n\n".join(response_parts) - def _fallback_routing( - self, - user_query: str, - csv_loaded: bool) -> RoutingDecision: + def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecision: """Fallback routing based on simple keyword matching.""" query_lower = user_query.lower().strip() diff --git a/src/services/sql_service.py b/src/services/sql_service.py index 174aef0..f5879ad 100644 --- a/src/services/sql_service.py +++ b/src/services/sql_service.py @@ -6,9 +6,9 @@ from pydantic_ai.models.groq import GroqModel from pydantic_ai.providers.groq import GroqProvider -from config.settings import get_settings -from schemas.requests import SQLGenerationRequest, ChatMessage -from schemas.responses import SQLQueryResponse, ChatResponse, ErrorResponse +from src.config.settings import get_settings +from src.schemas.requests import SQLGenerationRequest, ChatMessage +from src.schemas.responses import SQLQueryResponse, ChatResponse, ErrorResponse from utils.prompt import SQL_GENERATION_PROMPT @@ -23,14 +23,13 @@ def __init__(self, api_key: Optional[str] = None): ) self.model = GroqModel( - self.settings.groq_model_name, - provider=GroqProvider( - api_key=self.api_key)) + self.settings.groq_model_name, provider=GroqProvider( + api_key=self.api_key) + ) self.agent = Agent( - self.model, - instructions=SQL_GENERATION_PROMPT, - output_type=SQLQueryResponse) + self.model, instructions=SQL_GENERATION_PROMPT, output_type=SQLQueryResponse + ) def format_chat_history(self, messages: list) -> str: history = [] @@ -44,11 +43,11 @@ def format_chat_history(self, messages: list) -> str: if "```sql" in content: content = content.replace( - "```sql\n", "").replace( - "\n```", "").strip() + "```sql\n", "").replace("\n```", "").strip() - history.append({"role": role, "query" if role == - "user" else "response": content}) + history.append( + {"role": role, "query" if role == "user" else "response": content} + ) return json.dumps(history, indent=2) @@ -118,10 +117,7 @@ def generate_sql(self, request: SQLGenerationRequest) -> ChatResponse: session_id=session_id, ) - def generate_sql_legacy( - self, - user_query: str, - conversation_history: list) -> str: + def generate_sql_legacy(self, user_query: str, conversation_history: list) -> str: request = SQLGenerationRequest( user_query=user_query, conversation_history=conversation_history ) diff --git a/src/terminal/cli.py b/src/terminal/cli.py index d7040ba..0fafbd6 100644 --- a/src/terminal/cli.py +++ b/src/terminal/cli.py @@ -2,9 +2,9 @@ Command-line interface for Querypls SQL generation. """ -from config.constants import DEFAULT_SESSION_NAME -from schemas.requests import NewChatRequest -from backend.orchestrator import BackendOrchestrator +from src.config.constants import DEFAULT_SESSION_NAME +from src.schemas.requests import NewChatRequest +from src.backend.orchestrator import BackendOrchestrator import sys import os import json @@ -25,7 +25,8 @@ def create_session(self, name: Optional[str] = None) -> str: print( f"""Session created: { session_info.session_name} (ID: { - session_info.session_id})""") + session_info.session_id})""" + ) return session_info.session_id def list_sessions(self): @@ -65,11 +66,15 @@ def chat(self, query: str): print("\nSQL Details:") print(f" Query Type: {response.sql_response.query_type}") print(f" Complexity: {response.sql_response.complexity}") - print(f" Tables Used: {', '.join(response.sql_response.tables_used)}") - print(f" Columns: {', '.join(response.sql_response.columns_selected)}") - print(f" Estimated Rows: {response.sql_response.estimated_rows}") + print( + f" Tables Used: {', '.join(response.sql_response.tables_used)}") + print( + f" Columns: {', '.join(response.sql_response.columns_selected)}") + print( + f" Estimated Rows: {response.sql_response.estimated_rows}") if response.sql_response.warnings: - print(f" Warnings: {', '.join(response.sql_response.warnings)}") + print( + f" Warnings: {', '.join(response.sql_response.warnings)}") except Exception as e: print(f"Error: {str(e)}") diff --git a/test_routing.py b/test_routing.py deleted file mode 100644 index d4eaea6..0000000 --- a/test_routing.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple test script to verify the routing logic works correctly. -""" - -import sys -import os - -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from src.services.routing_service import IntelligentRoutingService -from src.schemas.requests import ChatMessage - - -def test_routing(): - """Test the routing service with different types of queries.""" - routing_service = IntelligentRoutingService() - - # Test cases - test_cases = [ - ("Hello", "CONVERSATION_AGENT"), - ("How are you?", "CONVERSATION_AGENT"), - ("Show me all users", "SQL_AGENT"), - ("SELECT * FROM users", "SQL_AGENT"), - ("Analyze this CSV data", "CSV_AGENT"), - ("Create a chart from the data", "CSV_AGENT"), - ("What can you do?", "CONVERSATION_AGENT"), - ("Thanks for your help", "CONVERSATION_AGENT"), - ] - - print("Testing Intelligent Routing Service...") - print("=" * 50) - - for query, expected_agent in test_cases: - print(f"\nQuery: '{query}'") - print(f"Expected Agent: {expected_agent}") - - try: - decision = routing_service.determine_agent(query, [], csv_loaded=False) - print(f"Actual Agent: {decision.agent}") - print(f"Confidence: {decision.confidence}") - print(f"Reasoning: {decision.reasoning}") - - if decision.agent == expected_agent: - print("βœ… PASS") - else: - print("❌ FAIL") - - except Exception as e: - print(f"❌ ERROR: {str(e)}") - - print("\n" + "=" * 50) - print("Testing conversation responses...") - - conversation_tests = [ - "Hello", - "How are you?", - "What can you do?", - "Thanks", - "Goodbye", - ] - - for query in conversation_tests: - print(f"\nQuery: '{query}'") - try: - response = routing_service.handle_conversation_query(query) - print(f"Response: {response}") - except Exception as e: - print(f"❌ ERROR: {str(e)}") - - -if __name__ == "__main__": - test_routing() diff --git a/utils/prompt.py b/utils/prompt.py index f53011f..cf74e71 100644 --- a/utils/prompt.py +++ b/utils/prompt.py @@ -26,7 +26,7 @@ ## Examples: - "Hello" β†’ CONVERSATION_AGENT -- "Show me all users" β†’ SQL_AGENT +- "Show me all users" β†’ SQL_AGENT - "Analyze this CSV data" β†’ CSV_AGENT - "How are you?" β†’ CONVERSATION_AGENT - "SELECT * FROM users" β†’ SQL_AGENT @@ -119,16 +119,33 @@ 5. Return clear, formatted output 6. Handle missing data appropriately 7. Use appropriate data types +8. Keep the code simple and focused on the user's request + +## CRITICAL Code Requirements: +- The CSV data is available as a string variable called `csv_data` in the session +- ALWAYS use `pd.read_csv(StringIO(csv_data))` to load the data +- NEVER use file paths like 'data.csv' or 'path/to/file.csv' +- For graphs/charts, save them to `/tmp/` folder with descriptive names +- Use `plt.savefig('/tmp/chart_name.png')` before `plt.show()` +- Always print clear output with section headers ## Example Response { - "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load and analyze data\\ndf = pd.read_csv('data.csv')\\nprint(f'Data shape: {df.shape}')\\nprint(df.head())\\n\\n# Create visualization\\nplt.figure(figsize=(10, 6))\\ndf['column'].value_counts().plot(kind='bar')\\nplt.title('Distribution of Column')\\nplt.show()", - "explanation": "Loads CSV data, displays basic info, and creates a bar chart of column distribution", - "expected_output": "Data shape, first few rows, and a bar chart visualization", + "python_code": "import pandas as pd\\nfrom io import StringIO\\nimport matplotlib.pyplot as plt\\n\\n# Load CSV data from session\\ndf = pd.read_csv(StringIO(csv_data))\\nprint('=== DATA OVERVIEW ===')\\nprint(f'Shape: {df.shape}')\\nprint(f'Columns: {list(df.columns)}')\\nprint('\\nFirst 5 rows:')\\nprint(df.head())\\n\\n# Create visualization\\nplt.figure(figsize=(10, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/department_chart.png')\\nplt.show()", + "explanation": "Loads CSV data from session, displays overview, and creates a bar chart saved to temp folder", + "expected_output": "Data overview and a bar chart visualization saved as image", "libraries_used": ["pandas", "matplotlib.pyplot"] } -Respond only with the JSON object. Do not include any additional text or markdown formatting.""" +## Important Notes: +- Use double backslashes for newlines in the python_code field +- ALWAYS use `StringIO(csv_data)` to load CSV data - NEVER use file paths +- Save charts to `/tmp/` folder +- Keep the explanation concise +- Make sure the JSON is valid and properly formatted +- The csv_data variable is already available in the session + +Respond only with the JSON object.""" CODE_FIX_PROMPT = """You are a Python debugging expert. Fix Python code based on error messages. From e28fb1035f2f92d3bc3434067bf0be1b9d287ffc Mon Sep 17 00:00:00 2001 From: samadpls Date: Fri, 8 Aug 2025 00:24:21 +0500 Subject: [PATCH 4/8] Refactor CSV analysis handling and improve routing service - Updated BackendOrchestrator to save CSV data to a temporary file and store additional CSV info in the session. - Enhanced IntelligentRoutingService to handle CSV queries with improved error handling and fallback analysis. - Modified CSV analysis prompt to enforce simpler, focused code generation. - Cleaned up frontend to display messages with generated images and added cleanup for old images. - Increased MAX_CHAT_HISTORIES constant to allow more chat sessions. - Removed unused constants and cleaned up settings configuration. - Added unit tests for backend orchestrator, CSV analysis tools, and conversation service. - Removed outdated test files and organized test structure. --- .github/workflows/pytest.yml | 2 + .gitignore | 2 +- examples/basic_usage_demo.py | 1 - examples/test_backend_functionality.py | 302 +++++++++++++++++++++++++ src/backend/orchestrator.py | 50 +++- src/config/constants.py | 53 +---- src/config/settings.py | 28 +-- src/frontend/app.py | 73 +++++- src/services/routing_service.py | 141 +++++++++--- test_app.py | 76 +++++++ tests/__init__.py | 0 tests/conftest.py | 4 + tests/test_auth.py | 53 ----- tests/test_backend.py | 10 +- tests/test_basic.py | 44 ++++ tests/test_conversation_service.py | 31 +++ tests/test_csv_analysis.py | 22 ++ tests/test_frontend.py | 95 -------- tests/test_models.py | 30 +++ tests/test_schemas.py | 97 ++++++++ utils/prompt.py | 78 +++++-- 21 files changed, 901 insertions(+), 291 deletions(-) create mode 100644 examples/test_backend_functionality.py create mode 100644 test_app.py delete mode 100644 tests/__init__.py create mode 100644 tests/conftest.py delete mode 100644 tests/test_auth.py create mode 100644 tests/test_basic.py create mode 100644 tests/test_conversation_service.py create mode 100644 tests/test_csv_analysis.py delete mode 100644 tests/test_frontend.py create mode 100644 tests/test_models.py create mode 100644 tests/test_schemas.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index a185623..2a89e90 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,5 +21,7 @@ jobs: pip3 install -r requirements.txt - name: Run Pytest + env: + JUPYTER_PLATFORM_DIRS: "1" run: | pytest diff --git a/.gitignore b/.gitignore index d5a61bf..077e253 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ __pycache__/ # Installer logs pip-log.txt pip-delete-this-directory.txt -test_* +./test_* run_* # Unit test / coverage reports diff --git a/examples/basic_usage_demo.py b/examples/basic_usage_demo.py index 8fd3987..f56c4b8 100644 --- a/examples/basic_usage_demo.py +++ b/examples/basic_usage_demo.py @@ -7,7 +7,6 @@ import sys import os -# Add the project root to Python path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.services.routing_service import IntelligentRoutingService diff --git a/examples/test_backend_functionality.py b/examples/test_backend_functionality.py new file mode 100644 index 0000000..9fd1672 --- /dev/null +++ b/examples/test_backend_functionality.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +""" +Comprehensive test for all backend functionality of Querypls. +Tests conversation, SQL generation, and CSV analysis capabilities. +""" + +import sys +import os +import pandas as pd +from io import StringIO + +# Add the project root to Python path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.services.routing_service import IntelligentRoutingService +from src.services.conversation_service import ConversationService +from src.services.sql_service import SQLGenerationService +from src.services.csv_analysis_tools import CSVAnalysisTools +from src.schemas.requests import ChatMessage, SQLGenerationRequest +from src.backend.orchestrator import BackendOrchestrator + + +def test_conversation_functionality(): + """Test conversation responses.""" + print("πŸ§ͺ Testing Conversation Functionality") + print("=" * 50) + + try: + routing_service = IntelligentRoutingService() + + # Test conversation queries + conversation_tests = [ + "Hello", + "How are you?", + "What can you do?", + "Thanks for your help", + "Goodbye" + ] + + for query in conversation_tests: + print(f"\nQuery: '{query}'") + try: + response = routing_service.handle_conversation_query(query) + print(f"Response: {response[:100]}...") + print("βœ… PASS") + except Exception as e: + print(f"❌ FAIL: {str(e)}") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ Conversation test failed: {str(e)}") + return False + + +def test_sql_functionality(): + """Test SQL generation functionality.""" + print("πŸ—ƒοΈ Testing SQL Generation Functionality") + print("=" * 50) + + try: + routing_service = IntelligentRoutingService() + sql_service = SQLGenerationService() + + # Test SQL queries + sql_tests = [ + "Show me all users", + "SELECT * FROM users WHERE status = 'active'", + "Find customers who made purchases in the last 30 days", + "Get the total sales by month" + ] + + for query in sql_tests: + print(f"\nQuery: '{query}'") + try: + # Test routing + routing_decision = routing_service.determine_agent(query, [], csv_loaded=False) + print(f"Routing Decision: {routing_decision.agent}") + + # Test SQL generation + request = SQLGenerationRequest( + user_query=query, + conversation_history=[] + ) + response = sql_service.generate_sql(request) + print(f"SQL Response: {response.content[:100]}...") + print("βœ… PASS") + except Exception as e: + print(f"❌ FAIL: {str(e)}") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ SQL test failed: {str(e)}") + return False + + +def test_csv_functionality(): + """Test CSV analysis functionality.""" + print("πŸ“Š Testing CSV Analysis Functionality") + print("=" * 50) + + try: + # Create sample CSV data + sample_data = { + 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], + 'age': [25, 30, 35, 28, 32], + 'salary': [50000, 60000, 70000, 55000, 65000], + 'department': ['IT', 'HR', 'IT', 'Finance', 'HR'] + } + + df = pd.DataFrame(sample_data) + csv_content = df.to_csv(index=False) + + print(f"Sample CSV Data:\n{df.head()}") + print(f"CSV Shape: {df.shape}") + + # Test CSV tools + csv_tools = CSVAnalysisTools() + + # Test loading CSV data + print("\nTesting CSV loading...") + result = csv_tools.load_csv_data(csv_content, "test_session") + print(f"Load Result: {result}") + + # Test CSV analysis queries + csv_tests = [ + "Show me the basic statistics of the data", + "Create a bar chart of department distribution", + "What is the average salary by department?", + "Show me the top 3 highest paid employees" + ] + + routing_service = IntelligentRoutingService() + + for query in csv_tests: + print(f"\nQuery: '{query}'") + try: + # Test routing with CSV loaded + routing_decision = routing_service.determine_agent(query, [], csv_loaded=True) + print(f"Routing Decision: {routing_decision.agent}") + + # Test CSV analysis + response = routing_service.handle_csv_query(query, csv_content) + print(f"CSV Response: {response[:200]}...") + print("βœ… PASS") + except Exception as e: + print(f"❌ FAIL: {str(e)}") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ CSV test failed: {str(e)}") + return False + + +def test_intelligent_routing(): + """Test intelligent routing functionality.""" + print("🧠 Testing Intelligent Routing") + print("=" * 50) + + try: + routing_service = IntelligentRoutingService() + + # Test cases with expected routing + test_cases = [ + ("Hello", "CONVERSATION_AGENT"), + ("How are you?", "CONVERSATION_AGENT"), + ("Show me all users", "SQL_AGENT"), + ("SELECT * FROM users", "SQL_AGENT"), + ("Analyze this CSV data", "CSV_AGENT"), + ("Create a chart from the data", "CSV_AGENT"), + ("What can you do?", "CONVERSATION_AGENT"), + ("Thanks for your help", "CONVERSATION_AGENT"), + ] + + all_passed = True + for query, expected_agent in test_cases: + print(f"\nQuery: '{query}'") + print(f"Expected Agent: {expected_agent}") + + try: + # Test without CSV loaded + decision = routing_service.determine_agent(query, [], csv_loaded=False) + print(f"Result (no CSV): {decision.agent}") + + # Test with CSV loaded + decision_with_csv = routing_service.determine_agent(query, [], csv_loaded=True) + print(f"Result (with CSV): {decision_with_csv.agent}") + + if decision.agent == expected_agent or decision_with_csv.agent == expected_agent: + print("βœ… PASS") + else: + print("❌ FAIL") + all_passed = False + + except Exception as e: + print(f"❌ ERROR: {str(e)}") + all_passed = False + + print("\n" + "=" * 50) + return all_passed + + except Exception as e: + print(f"❌ Routing test failed: {str(e)}") + return False + + +def test_orchestrator(): + """Test the main orchestrator functionality.""" + print("🎼 Testing Backend Orchestrator") + print("=" * 50) + + try: + orchestrator = BackendOrchestrator() + + # Test session creation + print("Testing session creation...") + from src.schemas.requests import NewChatRequest + session_info = orchestrator.create_new_session(NewChatRequest(session_name="Test Session")) + session_id = session_info.session_id + print(f"Created session: {session_id}") + + # Test conversation + print("\nTesting conversation...") + response = orchestrator.generate_intelligent_response(session_id, "Hello") + print(f"Conversation Response: {response.content[:100]}...") + + # Test SQL generation + print("\nTesting SQL generation...") + response = orchestrator.generate_intelligent_response(session_id, "Show me all users") + print(f"SQL Response: {response.content[:100]}...") + + # Test CSV loading and analysis + print("\nTesting CSV functionality...") + sample_csv = "name,age,salary\nAlice,25,50000\nBob,30,60000\nCharlie,35,70000" + result = orchestrator.load_csv_data(session_id, sample_csv) + print(f"CSV Load Result: {result}") + + response = orchestrator.generate_intelligent_response(session_id, "Analyze this data") + print(f"CSV Analysis Response: {response.content[:100]}...") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ Orchestrator test failed: {str(e)}") + return False + + +def main(): + """Run all tests.""" + print("πŸš€ Starting Comprehensive Backend Functionality Tests") + print("=" * 60) + + tests = [ + ("Conversation", test_conversation_functionality), + ("SQL Generation", test_sql_functionality), + ("CSV Analysis", test_csv_functionality), + ("Intelligent Routing", test_intelligent_routing), + ("Orchestrator", test_orchestrator), + ] + + results = {} + + for test_name, test_func in tests: + print(f"\n{'='*20} {test_name} {'='*20}") + try: + results[test_name] = test_func() + except Exception as e: + print(f"❌ {test_name} test crashed: {str(e)}") + results[test_name] = False + + # Summary + print("\n" + "=" * 60) + print("πŸ“Š TEST SUMMARY") + print("=" * 60) + + passed = 0 + total = len(tests) + + for test_name, result in results.items(): + status = "βœ… PASS" if result else "❌ FAIL" + print(f"{test_name}: {status}") + if result: + passed += 1 + + print(f"\nOverall: {passed}/{total} tests passed") + + if passed == total: + print("πŸŽ‰ All tests passed! Backend functionality is working correctly.") + else: + print("⚠️ Some tests failed. Check the backend implementation.") + + return passed == total + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/backend/orchestrator.py b/src/backend/orchestrator.py index a035093..d639f13 100644 --- a/src/backend/orchestrator.py +++ b/src/backend/orchestrator.py @@ -10,7 +10,7 @@ from src.config.settings import get_settings from src.config.constants import WELCOME_MESSAGE, DEFAULT_SESSION_NAME from src.services.sql_service import SQLGenerationService -from src.services.csv_analysis_tools import CSVAnalysisTools, create_csv_analysis_agent +from src.services.csv_analysis_tools import CSVAnalysisTools from src.services.conversation_service import ConversationService from src.services.routing_service import IntelligentRoutingService from src.schemas.requests import ( @@ -34,6 +34,8 @@ class Session: messages: List[ChatMessage] last_activity: datetime csv_data: Optional[str] = None + csv_file_path: Optional[str] = None + csv_info: Optional[Dict[str, Any]] = None class BackendOrchestrator: @@ -41,7 +43,6 @@ def __init__(self): self.settings = get_settings() self.sql_service = SQLGenerationService() self.csv_tools = CSVAnalysisTools() - self.csv_agent = create_csv_analysis_agent() self.conversation_service = ConversationService() self.routing_service = IntelligentRoutingService() self.sessions: Dict[str, Session] = {} @@ -103,12 +104,45 @@ def load_csv_data(self, session_id: str, csv_content: str) -> Dict[str, Any]: session = self.get_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") - + + # Save CSV to file + import os + import tempfile + + # Create temp directory for this session if it doesn't exist + temp_dir = f"/tmp/querypls_session_{session_id}" + os.makedirs(temp_dir, exist_ok=True) + + # Save CSV to file + csv_file_path = os.path.join(temp_dir, "data.csv") + with open(csv_file_path, 'w') as f: + f.write(csv_content) + + # Store both the content and file path in session session.csv_data = csv_content - result = self.csv_tools.load_csv_data(csv_content, session_id) + session.csv_file_path = csv_file_path + + # Get CSV info for context + import pandas as pd + from io import StringIO + df = pd.read_csv(StringIO(csv_content)) + + session.csv_info = { + "file_path": csv_file_path, + "shape": df.shape, + "columns": list(df.columns), + "dtypes": df.dtypes.to_dict(), + "sample_data": df.head(3).to_dict('records') + } + session.last_activity = datetime.now() - - return result + + return { + "status": "success", + "message": "CSV data loaded successfully", + "shape": df.shape, + "columns": list(df.columns) + } def generate_intelligent_response( self, session_id: str, user_query: str @@ -139,9 +173,9 @@ def generate_intelligent_response( user_query, session.messages ) elif routing_decision.agent == "CSV_AGENT": - if session.csv_data: + if session.csv_data and session.csv_info: response_content = self.routing_service.handle_csv_query( - user_query, session.csv_data, session.messages + user_query, session.csv_info, session.messages ) else: response_content = "I don't see any CSV data loaded. Please upload a CSV file first to analyze it." diff --git a/src/config/constants.py b/src/config/constants.py index f92c648..4927329 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -5,7 +5,7 @@ # Application Settings MAX_RETRIES = 3 EXECUTION_TIMEOUT = 30 -MAX_CHAT_HISTORIES = 5 +MAX_CHAT_HISTORIES = 6 STREAMLIT_PORT = 8501 STREAMLIT_HOST = "localhost" @@ -32,8 +32,6 @@ CSV_ANALYSIS_ERROR = "❌ Error analyzing CSV: {error}" # Session Management -SESSIONS_SECTION = "### πŸ’¬ Chat Sessions" -NEW_SESSION_BUTTON = "βž• New Session" SESSION_CREATE_ERROR = "❌ Error creating session: {error}" SESSION_NOT_FOUND_ERROR = "❌ Session not found" @@ -42,52 +40,3 @@ APP_INIT_ERROR = "❌ Error initializing application" RESPONSE_GENERATION_ERROR = "❌ Error generating response: {error}" MESSAGE_LOAD_ERROR = "❌ Error loading messages: {error}" - -# UI Elements -MADE_WITH_LOVE = "Made with 🀍" - -# Available Models -AVAILABLE_MODELS = { - "deepseek-r1-distill-llama-70b": { - "developer": "DeepSeek / Meta", - "context_window": 131072, - "max_completion": 131072, - "max_file_size": None, - }, - "meta-llama/llama-4-maverick-17b-128e-instruct": { - "developer": "Meta", - "context_window": 131072, - "max_completion": 8192, - "max_file_size": "20 MB", - }, - "meta-llama/llama-4-scout-17b-16e-instruct": { - "developer": "Meta", - "context_window": 131072, - "max_completion": 8192, - "max_file_size": "20 MB", - }, - "moonshotai/kimi-k2-instruct": { - "developer": "Moonshot AI", - "context_window": 131072, - "max_completion": 16384, - "max_file_size": None, - }, - "openai/gpt-oss-120b": { - "developer": "OpenAI", - "context_window": 131072, - "max_completion": 32766, - "max_file_size": None, - }, - "openai/gpt-oss-20b": { - "developer": "OpenAI", - "context_window": 131072, - "max_completion": 32768, - "max_file_size": None, - }, - "qwen/qwen3-32b": { - "developer": "Alibaba Cloud", - "context_window": 131072, - "max_completion": 131072, - "max_file_size": None, - }, -} diff --git a/src/config/settings.py b/src/config/settings.py index 50ba553..60f87aa 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -3,15 +3,13 @@ """ import os -from typing import Optional -from pydantic import Field -from pydantic_settings import BaseSettings +from typing import Optional, ClassVar +from pydantic import Field, BaseModel, ConfigDict -from src.config.constants import AVAILABLE_MODELS -class Settings(BaseSettings): - groq_api_key: str = Field(..., env="GROQ_API_KEY") +class Settings(BaseModel): + groq_api_key: str = Field(default=os.getenv("GROQ_API_KEY", "mock_api_key"), env="GROQ_API_KEY") groq_model_name: str = Field( default="openai/gpt-oss-120b", env="GROQ_MODEL_NAME") app_version: str = Field(default="1.0.0", env="APP_VERSION") @@ -26,10 +24,13 @@ class Settings(BaseSettings): temperature: Optional[str] = Field(None, env="TEMPERATURE") log_level: Optional[str] = Field(None, env="LOG_LEVEL") - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - extra = "ignore" + json_schema_extra: ClassVar[str] = "ignore" + + model_config = ConfigDict( + env_file = ".env", + env_file_encoding = "utf-8", + # Add any necessary configuration here + ) _settings_instance: Optional[Settings] = None @@ -41,10 +42,3 @@ def get_settings() -> Settings: _settings_instance = Settings() return _settings_instance - -def get_available_models(): - return AVAILABLE_MODELS - - -def get_model_info(model_name: str): - return AVAILABLE_MODELS.get(model_name, None) diff --git a/src/frontend/app.py b/src/frontend/app.py index ff6eed2..f79e898 100644 --- a/src/frontend/app.py +++ b/src/frontend/app.py @@ -20,9 +20,6 @@ APP_INIT_ERROR, RESPONSE_GENERATION_ERROR, MESSAGE_LOAD_ERROR, - MADE_WITH_LOVE, - SESSIONS_SECTION, - NEW_SESSION_BUTTON, ) from src.frontend.frontend import display_logo_and_heading, display_welcome_message from src.backend.backend import ( @@ -35,7 +32,6 @@ import os import pandas as pd -# Add the project root to Python path project_root = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) @@ -70,11 +66,70 @@ def display_messages(session_id: str): conversation = orchestrator.get_conversation_history(session_id) for message in conversation.messages: with st.chat_message(message.role): - st.markdown(message.content) + display_message_with_images(message.content) except Exception as e: st.error(MESSAGE_LOAD_ERROR.format(error=str(e))) +def display_message_with_images(content: str): + """Display message content and handle CSV analysis responses with images.""" + # Check if this is a CSV analysis response with images + if "**Generated Images:**" in content: + # Split the content into text and image sections + parts = content.split("**Generated Images:**") + text_content = parts[0].strip() + + # Display the text content + st.markdown(text_content) + + # Handle images if present + if len(parts) > 1: + image_section = parts[1].strip() + image_lines = [line.strip() for line in image_section.split('\n') if line.strip().startswith('- ')] + + if image_lines: + st.markdown("**Generated Images:**") + + # Look for images in the specific temp directory + import os + import glob + + temp_dir = "/tmp/querypls_session_csv_analysis_temp" + if os.path.exists(temp_dir): + for line in image_lines: + # Extract filename from the line (e.g., "- department_chart.png") + filename = line.replace('- ', '').strip() + image_path = os.path.join(temp_dir, filename) + + if os.path.exists(image_path): + try: + st.image(image_path, caption=filename, use_column_width=True) + except Exception as e: + st.error(f"Error displaying image {filename}: {str(e)}") + else: + st.warning(f"Image not found: {filename}") + else: + # Regular message content + st.markdown(content) + + +def cleanup_old_images(): + """Clean up old CSV analysis images.""" + import os + import glob + + temp_dir = "/tmp/querypls_session_csv_analysis_temp" + if os.path.exists(temp_dir): + try: + # Remove old images + for img_file in glob.glob(os.path.join(temp_dir, "*.png")): + os.remove(img_file) + for img_file in glob.glob(os.path.join(temp_dir, "*.jpg")): + os.remove(img_file) + except Exception as e: + print(f"Warning: Could not cleanup old images: {e}") + + def upload_csv_file(): uploaded_file = st.file_uploader( CSV_UPLOAD_LABEL, type=["csv"], help=CSV_UPLOAD_HELP @@ -128,6 +183,9 @@ def main(): st.markdown("### Sessions") if st.button("βž• New Session"): try: + # Clean up old images when creating new session + cleanup_old_images() + sessions = orchestrator.list_sessions() new_session = orchestrator.create_new_session( NewChatRequest(session_name=f"Chat {len(sessions) + 1}") @@ -143,6 +201,9 @@ def main(): if csv_content: if st.button(LOAD_CSV_BUTTON): try: + # Clean up old images before loading new CSV + cleanup_old_images() + result = orchestrator.load_csv_data( current_session_id, csv_content) if result["status"] == "success": @@ -169,7 +230,7 @@ def main(): st.markdown(prompt) with st.chat_message("assistant"): - st.markdown(response.content) + display_message_with_images(response.content) except Exception as e: st.error(RESPONSE_GENERATION_ERROR.format(error=str(e))) diff --git a/src/services/routing_service.py b/src/services/routing_service.py index c5d641c..8e3ee56 100644 --- a/src/services/routing_service.py +++ b/src/services/routing_service.py @@ -3,7 +3,7 @@ """ import json -from typing import List, Optional +from typing import List, Optional, Dict, Any from pydantic_ai import Agent, RunContext from pydantic_ai.models.groq import GroqModel from pydantic_ai.providers.groq import GroqProvider @@ -110,32 +110,53 @@ def handle_sql_query( def handle_csv_query( self, user_query: str, - csv_data: str, + csv_info: Dict[str, Any], conversation_history: List[ChatMessage] = None, ) -> str: """Handle CSV analysis queries.""" try: # Use the AI agent to generate code based on user request and conversation history context = self._prepare_csv_context( - user_query, csv_data, conversation_history + user_query, csv_info, conversation_history ) result = self.csv_agent.run_sync(context) if hasattr(result.output, "python_code"): # Execute the generated code using Jupyter service return self._execute_csv_analysis( - result.output.python_code, csv_data, result.output.explanation + result.output.python_code, csv_info, result.output.explanation ) else: return "I'm sorry, I couldn't generate analysis code for that request. Could you please rephrase your question?" except Exception as e: - return ( - f"I encountered an error while generating CSV analysis code: {str(e)}" - ) + # Provide a simple fallback analysis when LLM fails + try: + import pandas as pd + df = pd.read_csv(csv_info['file_path']) + + # Basic analysis based on the query + if "average" in user_query.lower() or "mean" in user_query.lower(): + if "salary" in user_query.lower() and "salary" in df.columns: + avg = df['salary'].mean() + return f"**Analysis Results:**\n\nAverage salary: ${avg:,.2f}\n\n**Explanation:** Calculated the average salary from the data." + else: + numeric_cols = df.select_dtypes(include=['number']).columns + if len(numeric_cols) > 0: + avg = df[numeric_cols[0]].mean() + return f"**Analysis Results:**\n\nAverage {numeric_cols[0]}: {avg:,.2f}\n\n**Explanation:** Calculated the average of {numeric_cols[0]} from the data." + + elif "graph" in user_query.lower() or "chart" in user_query.lower(): + return f"**Analysis Results:**\n\nChart generation is currently unavailable. Here's a data summary:\n\n{df.describe()}\n\n**Explanation:** Basic data overview due to service unavailability." + + else: + return f"**Analysis Results:**\n\nData overview:\n- Records: {len(df)}\n- Columns: {list(df.columns)}\n\nFirst few rows:\n{df.head()}\n\n**Explanation:** Basic data overview due to service unavailability." + + except Exception as fallback_error: + return f"I encountered an error while generating CSV analysis code: {str(e)}\n\nFallback also failed: {str(fallback_error)}" def _execute_csv_analysis( - self, python_code: str, csv_data: str, explanation: str + self, python_code: str, csv_info: Dict[str, Any], explanation: str ) -> str: """Execute CSV analysis code using Jupyter service.""" try: @@ -148,24 +169,55 @@ def _execute_csv_analysis( session_id = "csv_analysis_temp" # Load CSV data into the session - jupyter_service.load_csv_data(session_id, csv_data) - - # Prepare the code with CSV data variable - full_code = f"csv_data = '''{csv_data}'''\n\n{python_code}" + jupyter_service.load_csv_data(session_id, csv_info['file_path']) + + # Install required libraries if needed + install_code = """ +import sys +import subprocess + +def install_package(package): + try: + __import__(package) + except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + +# Install required packages +install_package('pandas') +install_package('numpy') +install_package('matplotlib') +install_package('seaborn') +""" + + # Execute installation first + install_result = jupyter_service.execute_analysis( + session_id, install_code, max_retries=1 + ) - # Execute the code + # Execute the analysis code directly (it will read from the file path) result = jupyter_service.execute_analysis( - session_id, full_code, max_retries=1 + session_id, python_code, max_retries=1 ) if result["status"] == "success": output = result.get("output", "") + + # If output is empty, provide a fallback + if not output.strip(): + output = "Analysis completed successfully but no output was generated." - # Check if any images were created + # Check if any images were created in the specific session directory import os import glob - - image_files = glob.glob("/tmp/*.png") + + # Look for images in the session's temp directory + session_temp_dir = f"/tmp/querypls_session_csv_analysis_temp" + image_files = [] + + if os.path.exists(session_temp_dir): + png_files = glob.glob(os.path.join(session_temp_dir, "*.png")) + jpg_files = glob.glob(os.path.join(session_temp_dir, "*.jpg")) + image_files.extend(png_files + jpg_files) if image_files: image_info = "\n\n**Generated Images:**\n" @@ -180,7 +232,14 @@ def _execute_csv_analysis( **Explanation:** {explanation}""" else: error_msg = result.get("error_message", "Unknown error") - return f"❌ Error executing analysis: {error_msg}" + # Add debugging information + debug_info = f""" +**Debug Information:** +- Generated Code: {python_code[:200]}... +- Error: {error_msg} +- CSV File: {csv_info['file_path']} +""" + return f"❌ Error executing analysis: {error_msg}\n{debug_info}" except Exception as e: return f"❌ Error in CSV analysis: {str(e)}" @@ -221,30 +280,45 @@ def _prepare_sql_context( def _prepare_csv_context( self, user_query: str, - csv_data: str, + csv_info: Dict[str, Any], conversation_history: List[ChatMessage] = None, ) -> str: """Prepare context for CSV analysis.""" context_parts = [ f"User Query: {user_query}", f"CSV Data Available: Yes", - f"CSV Data Format: The CSV data is available as a string variable called 'csv_data' in the session", - f"CSV Content Preview: {csv_data[:200]}...", + f"CSV File Path: {csv_info['file_path']}", + f"CSV Shape: {csv_info['shape']}", + f"CSV Columns: {csv_info['columns']}", + f"CSV Data Types: {csv_info['dtypes']}", + f"CSV Sample Data: {csv_info['sample_data']}", ] - + if conversation_history: context_parts.append("Conversation History:") # Last 5 messages for context for msg in conversation_history[-5:]: context_parts.append(f"- {msg.role}: {msg.content}") - + context_parts.append( - "\nGenerate Python code to analyze the CSV data based on the user's request." + "\nGenerate SIMPLE Python code that directly answers the user's question." ) context_parts.append( - "IMPORTANT: Use pd.read_csv(StringIO(csv_data)) to load the data, NOT file paths!" + "MAXIMUM 10 LINES OF CODE - Keep it simple!" ) - + context_parts.append( + "NO COMPREHENSIVE ANALYSIS - Just answer the specific question!" + ) + context_parts.append( + f"IMPORTANT: Use pd.read_csv('{csv_info['file_path']}') to load the data from the file path!" + ) + context_parts.append( + "Print human-readable insights directly - no complex scripts!" + ) + context_parts.append( + "For charts, use plt.savefig('/tmp/querypls_session_csv_analysis_temp/chart_name.png') and plt.show()." + ) + return "\n".join(context_parts) def _format_sql_response(self, sql_response) -> str: @@ -316,6 +390,21 @@ def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecisio "python", "code", "file", + "average", + "mean", + "sum", + "count", + "salary", + "column", + "row", + "statistics", + "stats", + "distribution", + "correlation", + "histogram", + "bar", + "line", + "scatter", ] # SQL keywords diff --git a/test_app.py b/test_app.py new file mode 100644 index 0000000..3fe4305 --- /dev/null +++ b/test_app.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +""" +Simple test to verify the application components are working. +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +def test_imports(): + """Test that all imports work correctly.""" + print("Testing imports...") + + try: + from src.config.constants import WELCOME_MESSAGE, DEFAULT_SESSION_NAME + print("βœ… Constants imported successfully") + print(f" WELCOME_MESSAGE: {WELCOME_MESSAGE[:50]}...") + print(f" DEFAULT_SESSION_NAME: {DEFAULT_SESSION_NAME}") + except ImportError as e: + print(f"❌ Error importing constants: {e}") + return False + + try: + from src.services.routing_service import IntelligentRoutingService + print("βœ… Routing service imported successfully") + except ImportError as e: + print(f"❌ Error importing routing service: {e}") + return False + + try: + from src.backend.orchestrator import BackendOrchestrator + print("βœ… Orchestrator imported successfully") + except ImportError as e: + print(f"❌ Error importing orchestrator: {e}") + return False + + return True + +def test_routing(): + """Test the routing service.""" + print("\nTesting routing service...") + + try: + from src.services.routing_service import IntelligentRoutingService + routing_service = IntelligentRoutingService() + + # Test routing decisions + test_cases = [ + ("Hello", "CONVERSATION_AGENT"), + ("Show me all users", "SQL_AGENT"), + ("Analyze this CSV data", "CSV_AGENT"), + ] + + for query, expected in test_cases: + decision = routing_service.determine_agent(query, [], csv_loaded=False) + status = "βœ…" if decision.agent == expected else "❌" + print(f" {status} '{query}' β†’ {decision.agent} (expected: {expected})") + + return True + except Exception as e: + print(f"❌ Error testing routing: {e}") + return False + +if __name__ == "__main__": + print("Querypls Application Test") + print("=" * 40) + + success = True + success &= test_imports() + success &= test_routing() + + print("\n" + "=" * 40) + if success: + print("βœ… All tests passed! Application is ready.") + else: + print("❌ Some tests failed. Please check the errors above.") \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..034d1b0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index a6aa742..0000000 --- a/tests/test_auth.py +++ /dev/null @@ -1,53 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch -from httpx_oauth.clients.google import GoogleOAuth2 -from src.constant import * -from src.auth import ( - get_authorization_url, - get_access_token, - get_email, - get_login_str, -) - - -@pytest.mark.asyncio -async def test_get_authorization_url(): - client = GoogleOAuth2("client_id", "client_secret") - redirect_uri = "http://example.com/callback" - with patch.object(client, "get_authorization_url", new=AsyncMock()) as mock_method: - await get_authorization_url(client, redirect_uri) - mock_method.assert_called_once_with(redirect_uri, scope=["profile", "email"]) - - -@pytest.mark.asyncio -async def test_get_access_token(): - client = GoogleOAuth2("client_id", "client_secret") - redirect_uri = "http://example.com/callback" - code = "code" - with patch.object(client, "get_access_token", new=AsyncMock()) as mock_method: - await get_access_token(client, redirect_uri, code) - mock_method.assert_called_once_with(code, redirect_uri) - - -@pytest.mark.asyncio -async def test_get_email(): - client = GoogleOAuth2("client_id", "client_secret") - token = "token" - with patch.object( - client, - "get_id_email", - new=AsyncMock(return_value=("user_id", "user_email")), - ) as mock_method: - user_id, user_email = await get_email(client, token) - mock_method.assert_called_once_with(token) - assert user_id == "user_id" - assert user_email == "user_email" - - -def test_get_login_str(): - with patch("asyncio.run") as mock_run: - mock_run.return_value = "authorization_url" - result = get_login_str() - mock_run.assert_called_once() - assert '' in result - assert "Login with Google" in result diff --git a/tests/test_backend.py b/tests/test_backend.py index 6c87ab7..2ec72b4 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2,7 +2,7 @@ from unittest.mock import patch, MagicMock import sys, os from src.backend import * -from src.constant import * +from src.config.constants import * sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -29,11 +29,3 @@ def mock_set_page_config(): def mock_oauth2_component(): with patch("streamlit_oauth.OAuth2Component") as mock_oauth2_component: yield mock_oauth2_component - - -def test_hide_main_menu_and_footer(mock_markdown): - hide_main_menu_and_footer() - mock_markdown.assert_called_once_with( - """""", - unsafe_allow_html=True, - ) diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..5586eb0 --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,44 @@ +import pytest +import os +import sys +from datetime import datetime, timedelta + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.schemas.requests import ChatMessage, NewChatRequest +from src.schemas.responses import SessionInfo, ChatResponse +from src.backend.orchestrator import BackendOrchestrator + + +def test_create_new_session(): + orchestrator = BackendOrchestrator() + session_info = orchestrator.create_new_session(NewChatRequest(session_name="Test Chat")) + assert session_info.session_name == "Test Chat" + assert session_info.session_id is not None + + +def test_list_sessions(): + orchestrator = BackendOrchestrator() + session1 = orchestrator.create_new_session(NewChatRequest(session_name="Chat 1")) + session2 = orchestrator.create_new_session(NewChatRequest(session_name="Chat 2")) + sessions = orchestrator.list_sessions() + assert len(sessions) == 2 + assert any(s.session_name == "Chat 1" for s in sessions) + assert any(s.session_name == "Chat 2" for s in sessions) + + + + +def test_health_check(): + orchestrator = BackendOrchestrator() + health = orchestrator.health_check() + assert health.status in ["healthy", "unhealthy"] + assert isinstance(health.version, str) + assert isinstance(health.timestamp, str) + + +def test_session_message_flow(): + orchestrator = BackendOrchestrator() + session_info = orchestrator.create_new_session(NewChatRequest(session_name="Test Session")) + assert session_info.session_name == "Test Session" + assert session_info.session_id is not None diff --git a/tests/test_conversation_service.py b/tests/test_conversation_service.py new file mode 100644 index 0000000..d985557 --- /dev/null +++ b/tests/test_conversation_service.py @@ -0,0 +1,31 @@ +import pytest +from src.services.conversation_service import ConversationService + + +def test_is_conversational_query(): + service = ConversationService() + assert service.is_conversational_query("hello") is True + assert service.is_conversational_query("hi there") is True + assert service.is_conversational_query("how are you") is True + assert service.is_conversational_query("select * from users") is False + + +def test_get_conversational_response(): + service = ConversationService() + response = service.get_conversational_response("hello") + assert response is not None + assert len(response) > 0 + assert "hello" in response.lower() or "hi" in response.lower() + + +def test_get_conversational_response_help(): + service = ConversationService() + response = service.get_conversational_response("what can you do?") + assert "SQL" in response + assert "data analysis" in response.lower() + + +def test_get_conversational_response_thanks(): + service = ConversationService() + response = service.get_conversational_response("thank you") + assert "help you today" in response.lower() diff --git a/tests/test_csv_analysis.py b/tests/test_csv_analysis.py new file mode 100644 index 0000000..a4bcbe0 --- /dev/null +++ b/tests/test_csv_analysis.py @@ -0,0 +1,22 @@ +import pytest +from src.services.csv_analysis_tools import CSVAnalysisTools, CSVAnalysisContext + + +def test_load_csv_data(): + tools = CSVAnalysisTools() + csv_content = "name,age\nJohn,30\nJane,25" + result = tools.load_csv_data(csv_content, "test_session") + assert result["status"] == "success" + assert "name" in result["columns"] + assert "age" in result["columns"] + + +def test_get_csv_info(): + tools = CSVAnalysisTools() + csv_content = "name,age\nJohn,30\nJane,25" + tools.load_csv_data(csv_content, "test_session") + info = tools.get_csv_info("test_session") + assert info["status"] == "success" + assert "shape" in info + assert "columns" in info + assert "dtypes" in info diff --git a/tests/test_frontend.py b/tests/test_frontend.py deleted file mode 100644 index 49ce111..0000000 --- a/tests/test_frontend.py +++ /dev/null @@ -1,95 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import streamlit as st -from src.frontend import ( - display_logo_and_heading, - display_welcome_message, - handle_new_chat, - display_previous_chats, - create_message, - update_session_state, -) - - -@pytest.fixture -def mock_st(): - return MagicMock() - - -@pytest.fixture -def mock_db(): - return MagicMock() - - -class MockSessionState: - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - setattr(self, key, value) - - -def initialize_session_state(messages=None, key=None, user_email=None): - st.session_state = MockSessionState() - st.session_state.messages = messages or [] - st.session_state.key = key - st.session_state.user_email = user_email - - -def test_display_logo_and_heading(mock_st): - with patch.object(st, "image") as mock_image: - display_logo_and_heading() - mock_image.assert_called_once_with("static/image/logo.png") - - -def test_display_welcome_message(mock_st): - with patch.object(st, "markdown") as mock_markdown: - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state( - messages=[{"role": "assistant", "content": "How may I help you?"}] - ) - display_welcome_message() - mock_markdown.assert_called_once_with( - "#### Welcome to \n ## πŸ›’πŸ’¬Querypls - Prompt to SQL" - ) - - -def test_handle_new_chat(mock_db, mock_st): - with patch("src.frontend.get_previous_chats") as mock_get_previous_chats: - mock_get_previous_chats.return_value = [] - with patch.object(st, "markdown") as mock_markdown, patch.object( - st, "button" - ) as mock_button: - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state(messages=[], user_email="test@example.com") - handle_new_chat(mock_db, max_chat_histories=5) - mock_markdown.assert_called_once_with(" #### Remaining Chats: `5/5`") - mock_button.assert_called_once_with("βž• New chat") - - -def test_create_message(): - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state(messages=[], key=None) - create_message() - assert st.session_state.messages == [ - {"role": "assistant", "content": "How may I help you?"} - ] - assert st.session_state.key == "key" - - -def test_update_session_state(mock_db): - chat = {"chat": [{"role": "user", "content": "Hello"}], "key": "new_key"} - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state( - messages=[{"role": "assistant", "content": "How may I help you?"}], - key="old_key", - ) - with patch("src.frontend.database") as mock_database: - update_session_state(mock_db, chat) - mock_database.assert_called_once_with( - mock_db, - "old_key", - [{"role": "assistant", "content": "How may I help you?"}], - ) - assert st.session_state.messages == chat["chat"] - assert st.session_state.key == chat["key"] diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..5fd1ee2 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,30 @@ +import pytest +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.schemas.requests import ChatMessage, NewChatRequest, SQLGenerationRequest +from src.schemas.responses import ChatResponse, SessionInfo + + +def test_chat_message(): + msg = ChatMessage(role="user", content="test") + assert msg.role == "user" + assert msg.content == "test" + + +def test_new_chat_request(): + req = NewChatRequest(session_name="Test Session") + assert req.session_name == "Test Session" + + +def test_chat_response(): + resp = ChatResponse( + message_id="123", + content="test response", + timestamp="2024-01-01T00:00:00", + session_id="456" + ) + assert resp.content == "test response" + assert resp.session_id == "456" diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..77c3d08 --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,97 @@ +import pytest +import json +from src.schemas.requests import ( + SQLGenerationRequest, + ChatMessage, + ConversationHistory, + NewChatRequest +) +from src.schemas.responses import ( + SQLQueryResponse, + ChatResponse, + ErrorResponse, + SessionInfo, + HealthCheckResponse +) + + +def test_sql_generation_request(): + request = SQLGenerationRequest( + user_query="Show users", + conversation_history=[], + database_schema=None, + query_type=None + ) + assert request.user_query == "Show users" + assert isinstance(request.conversation_history, list) + + +def test_chat_message(): + message = ChatMessage( + role="user", + content="Hello", + timestamp="2024-01-01T00:00:00", + session_id="123" + ) + assert message.role == "user" + assert message.content == "Hello" + assert message.timestamp == "2024-01-01T00:00:00" + assert message.session_id == "123" + + +def test_conversation_history(): + history = ConversationHistory( + messages=[ + ChatMessage(role="user", content="Hello") + ], + session_id="123" + ) + assert len(history.messages) == 1 + assert history.session_id == "123" + + +def test_new_chat_request(): + request = NewChatRequest( + session_name="Test Chat", + initial_context="SQL Testing" + ) + assert request.session_name == "Test Chat" + assert request.initial_context == "SQL Testing" + + +def test_sql_query_response(): + response = SQLQueryResponse( + sql_query="SELECT * FROM users", + explanation="Get all users", + tables_used=["users"], + columns_selected=["*"], + query_type="SELECT", + complexity="SIMPLE" + ) + assert response.sql_query == "SELECT * FROM users" + assert response.explanation == "Get all users" + assert response.tables_used == ["users"] + + +def test_chat_response(): + response = ChatResponse( + message_id="123", + content="Hello", + timestamp="2024-01-01T00:00:00", + session_id="456" + ) + assert response.message_id == "123" + assert response.content == "Hello" + assert response.session_id == "456" + + +def test_health_check_response(): + response = HealthCheckResponse( + status="healthy", + version="1.0.0", + timestamp="2024-01-01T00:00:00", + services={"sql": "healthy"} + ) + assert response.status == "healthy" + assert response.version == "1.0.0" + assert response.services["sql"] == "healthy" diff --git a/utils/prompt.py b/utils/prompt.py index cf74e71..3618b5f 100644 --- a/utils/prompt.py +++ b/utils/prompt.py @@ -100,50 +100,82 @@ Respond only with the JSON object. Do not include any additional text or markdown formatting.""" -CSV_ANALYSIS_PROMPT = """You are a Python data analysis expert. Generate Python code to analyze CSV data based on user queries. +CSV_ANALYSIS_PROMPT = """You are a Python data analysis expert. Generate SIMPLE, FOCUSED Python code that answers the user's specific question. ## Response Format Your response must be in JSON format. It must be an object and must contain these fields: -* `python_code` - The generated Python code as a string +* `python_code` - The generated Python code as a string (this will be EXECUTED automatically) * `explanation` - Brief explanation of what the code does * `expected_output` - What output is expected from the code * `libraries_used` - Array of Python libraries used -## Guidelines -1. Always use pandas for data manipulation -2. Use matplotlib/seaborn for visualizations when appropriate -3. Include proper error handling -4. Make the code readable and well-commented -5. Return clear, formatted output -6. Handle missing data appropriately -7. Use appropriate data types -8. Keep the code simple and focused on the user's request +## IMPORTANT: The code you generate will be EXECUTED automatically. Do NOT include code blocks or markdown formatting in the python_code field. + +## CRITICAL GUIDELINES: +1. **KEEP CODE SIMPLE** - Maximum 10 lines of code +2. **ANSWER THE SPECIFIC QUESTION** - Don't create comprehensive analysis +3. **PRINT CLEAR INSIGHTS** - Human-readable output, not raw data +4. **NO COMPLEX SCRIPTS** - No functions, classes, or advanced features +5. **SIMPLE VARIABLES** - Use df, result, avg, etc. +6. **DIRECT APPROACH** - Load data, analyze, print result ## CRITICAL Code Requirements: -- The CSV data is available as a string variable called `csv_data` in the session -- ALWAYS use `pd.read_csv(StringIO(csv_data))` to load the data -- NEVER use file paths like 'data.csv' or 'path/to/file.csv' -- For graphs/charts, save them to `/tmp/` folder with descriptive names -- Use `plt.savefig('/tmp/chart_name.png')` before `plt.show()` -- Always print clear output with section headers +- The CSV data is available as a file at the path provided in the context +- ALWAYS use `pd.read_csv('file_path')` to load the data from the file path +- The file path will be provided in the context +- For graphs/charts, save them to `/tmp/querypls_session_csv_analysis_temp/` folder with descriptive names +- Use `plt.savefig('/tmp/querypls_session_csv_analysis_temp/chart_name.png')` before `plt.show()` +- Print insights in a human-readable format with clear explanations +- Don't create complex functions or classes - keep it simple and direct +- ONLY use these libraries: pandas, numpy, matplotlib.pyplot, seaborn +- Write clean, simple code without syntax errors +- Use proper variable names and avoid special characters +- For charts: use simple matplotlib code, save to specific temp folder, then show +- Keep each line simple and avoid complex expressions ## Example Response { - "python_code": "import pandas as pd\\nfrom io import StringIO\\nimport matplotlib.pyplot as plt\\n\\n# Load CSV data from session\\ndf = pd.read_csv(StringIO(csv_data))\\nprint('=== DATA OVERVIEW ===')\\nprint(f'Shape: {df.shape}')\\nprint(f'Columns: {list(df.columns)}')\\nprint('\\nFirst 5 rows:')\\nprint(df.head())\\n\\n# Create visualization\\nplt.figure(figsize=(10, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/department_chart.png')\\nplt.show()", - "explanation": "Loads CSV data from session, displays overview, and creates a bar chart saved to temp folder", - "expected_output": "Data overview and a bar chart visualization saved as image", + "python_code": "import pandas as pd\\n\\n# Load data\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\n\\n# Calculate average salary\\navg_salary = df['salary'].mean()\\nprint(f'Average salary: ${avg_salary:,.2f}')", + "explanation": "Loads CSV data and calculates the average salary in a readable format", + "expected_output": "Average salary: $60,000.00", + "libraries_used": ["pandas"] +} + +## Chart Example Response +{ + "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\n\\n# Create temp directory\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\n\\n# Load data\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\n\\n# Create chart\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\n\\n# Print results\\nprint('Department counts:')\\nprint(df['department'].value_counts())", + "explanation": "Creates a bar chart of department distribution and saves it to specific temp folder", + "expected_output": "Bar chart visualization and department counts", + "libraries_used": ["pandas", "matplotlib.pyplot"] +} + +## Graph Example Response +{ + "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\n\\n# Create temp directory\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\n\\n# Load data\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\n\\n# Create simple bar chart\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\n\\n# Print summary\\nprint('Department distribution:')\\nprint(df['department'].value_counts())", + "explanation": "Creates a simple bar chart showing department distribution", + "expected_output": "Bar chart and department counts", "libraries_used": ["pandas", "matplotlib.pyplot"] } ## Important Notes: - Use double backslashes for newlines in the python_code field -- ALWAYS use `StringIO(csv_data)` to load CSV data - NEVER use file paths -- Save charts to `/tmp/` folder +- ALWAYS use `pd.read_csv('file_path')` to load CSV data from the file path provided in context +- Save charts to `/tmp/querypls_session_csv_analysis_temp/` folder - Keep the explanation concise - Make sure the JSON is valid and properly formatted -- The csv_data variable is already available in the session +- The file path will be provided in the context +- Write simple, clean code without complex functions or classes +- Focus on printing clear insights directly +- Avoid syntax errors and special characters +- NEVER use line continuation characters (\\) in the code +- Keep each line complete and self-contained +- Use simple string formatting with f-strings +- **MAXIMUM 10 LINES OF CODE** - Keep it simple! +- **NO COMPREHENSIVE ANALYSIS** - Just answer the specific question +- **DO NOT include ```python or ``` in the python_code field** +- **The code will be executed automatically - just provide the raw Python code** Respond only with the JSON object.""" From 74cf51ec39183d98aff6a8e5bd557d3c687abc78 Mon Sep 17 00:00:00 2001 From: samadpls Date: Fri, 8 Aug 2025 00:44:40 +0500 Subject: [PATCH 5/8] feat: Introduce .env configuration file and remove legacy .env_example; enhance intelligent routing service for better error handling and CSV analysis --- .env_examp | 9 + .env_example | 3 - README.md | 27 +-- src/config/settings.py | 10 +- src/services/routing_service.py | 294 +++++++++++++------------------- utils/prompt.py | 20 ++- 6 files changed, 159 insertions(+), 204 deletions(-) create mode 100644 .env_examp delete mode 100644 .env_example diff --git a/.env_examp b/.env_examp new file mode 100644 index 0000000..5f96e01 --- /dev/null +++ b/.env_examp @@ -0,0 +1,9 @@ +GROQ_API_KEY=your_groq_api_key_here +GROQ_MODEL_NAME=model/name + +APP_NAME=Querypls +APP_VERSION=1.0.0 +MAX_TOKENS=1000 +TEMPERATURE=0.7 +MAX_CHAT_HISTORIES=5 +LOG_LEVEL=INFO diff --git a/.env_example b/.env_example deleted file mode 100644 index 1feac52..0000000 --- a/.env_example +++ /dev/null @@ -1,3 +0,0 @@ -HUGGINGFACE_API_TOKEN= -REPO_ID= -TEMPLATE= diff --git a/README.md b/README.md index e65314d..8f16654 100644 --- a/README.md +++ b/README.md @@ -7,16 +7,21 @@ [![License](https://img.shields.io/badge/License-MIT%202.0-blue.svg)](LICENSE) -# πŸ’¬ Querypls - Prompt to SQL +# πŸ’¬ Querypls - Intelligent SQL and CSV Analysis -Querypls is a web application that provides an interactive chat interface, simplifying SQL query generation. Users can effortlessly enter SQL queries and receive corresponding results. The application harnesses the capabilities of the language models from Hugging Face to generate SQL queries based on user input. +Querypls is a modern web application that provides an interactive chat interface for SQL query generation and CSV data analysis. Built with Pydantic AI and powered by OpenAI's GPT-OSS-120B model through Groq, it offers intelligent routing between different analysis modes to handle various data-related queries. + +🌐 **Try it live**: [querypls.streamlit.app](https://querypls.streamlit.app/) ## Key Features -πŸ’¬ Interactive chat interface for easy communication. -πŸ” Enter SQL queries and receive query results as responses. -πŸ€– Utilizes language models from Hugging Face for advanced query generation ([Querypls-prompt2sql](https://huggingface.co/samadpls/querypls-prompt2sql)). -πŸ’» User-friendly interface for seamless interaction. +πŸ’¬ **Interactive Chat Interface** - Natural language conversations for data analysis +πŸ” **SQL Query Generation** - Convert natural language to optimized SQL queries +πŸ“Š **CSV Data Analysis** - Upload and analyze CSV files with intelligent insights +πŸ€– **Intelligent Routing** - Automatically determines the best agent for your query +⚑ **Fast Inference** - Powered by Groq's optimized infrastructure +πŸ”’ **Type-Safe Development** - Built with Pydantic AI for robust validation +πŸ“ˆ **Visual Analytics** - Generate charts and visualizations from your data ![QueryplsDemo](https://github.com/samadpls/Querypls/assets/94792103/daa6e37d-a256-4fd8-9607-6e18cf41df3f) @@ -24,7 +29,9 @@ Querypls is a web application that provides an interactive chat interface, simpl # Acknowledgments -`Querypls` received a shoutout from [🦜 πŸ”— Langchain](https://www.langchain.com/) on their Twitter, reaching over **60,000 impressions**. Additionally, it was featured under the **Community Favorite Projects** section on `🦜 πŸ”— Langchain's blog`, leading to a significant increase in stars for this repository and a growing user base. The project was also highlighted in a [YouTube video](https://www.youtube.com/watch?v=htHVb-fK9xU), and it also caught the attention of Backdrop, expressing their interest and liking in an email, inviting the project to be a part of their hackathon. +`Querypls` received a shoutout from [🦜 πŸ”— Langchain](https://www.langchain.com/) on their Twitter in 2023, reaching over **60,000 impressions**. Additionally, it was featured under the **Community Favorite Projects** section on `🦜 πŸ”— Langchain's blog`, leading to a significant increase in stars for this repository and a growing user base. The project was also highlighted in a [YouTube video](https://www.youtube.com/watch?v=htHVb-fK9xU), and it also caught the attention of Backdrop, expressing their interest and liking in an email, inviting the project to be a part of their hackathon. + +However, due to constant breakdowns and instability issues with the LangChain framework, we made the strategic decision to migrate to **Pydantic AI** - a more stable and reliable framework. This transition has brought improved performance, better type safety, and enhanced maintainability to the project. | [πŸ”— Langhchain Twitter Post](https://twitter.com/LangChainAI/status/1729959981523378297?t=Zdpw9ZQYvE3QS-3Bf-xaGw&s=19) | [πŸ”— Langhcain Blog Post](https://blog.langchain.dev/week-of-11-27-langchain-release-notes/) | |----------|----------| @@ -38,7 +45,7 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file > [!Note] -> Querypls, while powered by a 7B model of Satablility AI LLM Model, is currently limited in providing optimal responses for simple queries. +> Querypls is now powered by OpenAI's GPT-OSS-120B model through Groq, providing fast and reliable AI-powered SQL generation and CSV analysis capabilities. --- @@ -59,11 +66,11 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file pip install -r requirements.txt ``` -4. Create a `.env` file based on `.env_example` and set the necessary variables. +4. Create a `.env` file based on `.env_examp` and set the necessary variables. 5. Run the application: ```bash - streamlit run src/app.py + streamlit run src/frontend/app.py ``` 6. Open the provided link in your browser to use Querypls. diff --git a/src/config/settings.py b/src/config/settings.py index 60f87aa..40957e7 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -17,19 +17,15 @@ class Settings(BaseModel): debug_mode: bool = Field(default=False, env="DEBUG_MODE") # Legacy fields for backward compatibility - app_name: Optional[str] = Field(None, env="APP_NAME") - streamlit_port: Optional[str] = Field(None, env="STREAMLIT_PORT") - streamlit_host: Optional[str] = Field(None, env="STREAMLIT_HOST") - max_tokens: Optional[str] = Field(None, env="MAX_TOKENS") - temperature: Optional[str] = Field(None, env="TEMPERATURE") - log_level: Optional[str] = Field(None, env="LOG_LEVEL") + max_tokens: Optional[str] = Field(1000, env="MAX_TOKENS") + temperature: Optional[str] = Field(0.7, env="TEMPERATURE") + log_level: Optional[str] = Field("INFO", env="LOG_LEVEL") json_schema_extra: ClassVar[str] = "ignore" model_config = ConfigDict( env_file = ".env", env_file_encoding = "utf-8", - # Add any necessary configuration here ) diff --git a/src/services/routing_service.py b/src/services/routing_service.py index 8e3ee56..edcb82e 100644 --- a/src/services/routing_service.py +++ b/src/services/routing_service.py @@ -111,7 +111,7 @@ def handle_csv_query( self, user_query: str, csv_info: Dict[str, Any], - conversation_history: List[ChatMessage] = None, + conversation_history: Optional[List[ChatMessage]] = None, ) -> str: """Handle CSV analysis queries.""" try: @@ -130,35 +130,13 @@ def handle_csv_query( return "I'm sorry, I couldn't generate analysis code for that request. Could you please rephrase your question?" except Exception as e: - # Provide a simple fallback analysis when LLM fails - try: - import pandas as pd - df = pd.read_csv(csv_info['file_path']) - - # Basic analysis based on the query - if "average" in user_query.lower() or "mean" in user_query.lower(): - if "salary" in user_query.lower() and "salary" in df.columns: - avg = df['salary'].mean() - return f"**Analysis Results:**\n\nAverage salary: ${avg:,.2f}\n\n**Explanation:** Calculated the average salary from the data." - else: - numeric_cols = df.select_dtypes(include=['number']).columns - if len(numeric_cols) > 0: - avg = df[numeric_cols[0]].mean() - return f"**Analysis Results:**\n\nAverage {numeric_cols[0]}: {avg:,.2f}\n\n**Explanation:** Calculated the average of {numeric_cols[0]} from the data." - - elif "graph" in user_query.lower() or "chart" in user_query.lower(): - return f"**Analysis Results:**\n\nChart generation is currently unavailable. Here's a data summary:\n\n{df.describe()}\n\n**Explanation:** Basic data overview due to service unavailability." - - else: - return f"**Analysis Results:**\n\nData overview:\n- Records: {len(df)}\n- Columns: {list(df.columns)}\n\nFirst few rows:\n{df.head()}\n\n**Explanation:** Basic data overview due to service unavailability." - - except Exception as fallback_error: - return f"I encountered an error while generating CSV analysis code: {str(e)}\n\nFallback also failed: {str(fallback_error)}" + # If LLM fails, provide a graceful response without showing errors + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" def _execute_csv_analysis( self, python_code: str, csv_info: Dict[str, Any], explanation: str ) -> str: - """Execute CSV analysis code using Jupyter service.""" + """Execute CSV analysis code using Jupyter service with error fixing retry loop.""" try: from src.services.jupyter_service import CSVAnalysisService @@ -194,55 +172,88 @@ def install_package(package): session_id, install_code, max_retries=1 ) - # Execute the analysis code directly (it will read from the file path) - result = jupyter_service.execute_analysis( - session_id, python_code, max_retries=1 - ) + # Retry loop for code execution with error fixing + current_code = python_code + max_retries = 3 + + for attempt in range(max_retries): + # Execute the current code + result = jupyter_service.execute_analysis( + session_id, current_code, max_retries=1 + ) - if result["status"] == "success": - output = result.get("output", "") - - # If output is empty, provide a fallback - if not output.strip(): - output = "Analysis completed successfully but no output was generated." + if result["status"] == "success": + output = result.get("output", "") + + # If output is empty, provide a fallback + if not output.strip(): + output = "Analysis completed successfully but no output was generated." - # Check if any images were created in the specific session directory - import os - import glob - - # Look for images in the session's temp directory - session_temp_dir = f"/tmp/querypls_session_csv_analysis_temp" - image_files = [] - - if os.path.exists(session_temp_dir): - png_files = glob.glob(os.path.join(session_temp_dir, "*.png")) - jpg_files = glob.glob(os.path.join(session_temp_dir, "*.jpg")) - image_files.extend(png_files + jpg_files) + # Check if any images were created in the specific session directory + import os + import glob + + # Look for images in the session's temp directory + session_temp_dir = f"/tmp/querypls_session_csv_analysis_temp" + image_files = [] + + if os.path.exists(session_temp_dir): + png_files = glob.glob(os.path.join(session_temp_dir, "*.png")) + jpg_files = glob.glob(os.path.join(session_temp_dir, "*.jpg")) + image_files.extend(png_files + jpg_files) - if image_files: - image_info = "\n\n**Generated Images:**\n" - for img_file in image_files: - image_info += f"- {os.path.basename(img_file)}\n" - output += image_info + if image_files: + image_info = "\n\n**Generated Images:**\n" + for img_file in image_files: + image_info += f"- {os.path.basename(img_file)}\n" + output += image_info - return f"""**Analysis Results:** + return f"""**Analysis Results:** {output} **Explanation:** {explanation}""" - else: - error_msg = result.get("error_message", "Unknown error") - # Add debugging information - debug_info = f""" -**Debug Information:** -- Generated Code: {python_code[:200]}... -- Error: {error_msg} -- CSV File: {csv_info['file_path']} -""" - return f"❌ Error executing analysis: {error_msg}\n{debug_info}" + + else: + # Code execution failed - try to fix it + error_msg = result.get("error_message", "Unknown error") + + if attempt < max_retries - 1: # Not the last attempt + # Send error to LLM to fix the code + fixed_code = self._fix_python_code(current_code, error_msg, csv_info) + if fixed_code: + current_code = fixed_code + continue # Try again with fixed code + + # If we reach here, all attempts failed or no more retries + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + + except Exception as e: + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" except Exception as e: - return f"❌ Error in CSV analysis: {str(e)}" + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + + # This should never be reached, but just in case + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + + def _fix_python_code(self, original_code: str, error_message: str, csv_info: Dict[str, Any]) -> Optional[str]: + """Send error to LLM to fix the Python code.""" + try: + # Create context for code fixing + context = self._prepare_code_fix_context(original_code, error_message, csv_info) + + # Use the CSV agent to generate fixed code + result = self.csv_agent.run_sync(context) + + if hasattr(result.output, "python_code"): + return result.output.python_code + else: + return None + + except Exception as e: + # If fixing fails, return None to continue with original code + return None def _prepare_routing_context( self, user_query: str, conversation_history: List[ChatMessage], csv_loaded: bool @@ -281,7 +292,7 @@ def _prepare_csv_context( self, user_query: str, csv_info: Dict[str, Any], - conversation_history: List[ChatMessage] = None, + conversation_history: Optional[List[ChatMessage]] = None, ) -> str: """Prepare context for CSV analysis.""" context_parts = [ @@ -321,6 +332,34 @@ def _prepare_csv_context( return "\n".join(context_parts) + def _prepare_code_fix_context(self, original_code: str, error_message: str, csv_info: Dict[str, Any]) -> str: + """Prepare context for code fixing.""" + context_parts = [ + "CODE FIXING REQUEST:", + f"Original Code: {original_code}", + f"Error Message: {error_message}", + f"CSV File Path: {csv_info['file_path']}", + f"CSV Shape: {csv_info['shape']}", + f"CSV Columns: {csv_info['columns']}", + f"CSV Data Types: {csv_info['dtypes']}", + f"CSV Sample Data: {csv_info['sample_data']}", + "", + "INSTRUCTIONS:", + "The above Python code failed to execute. Please fix the code and return a working version.", + "Follow these guidelines:", + "1. Keep code SIMPLE - Maximum 6 lines", + "2. NO SPECIAL CHARACTERS - Use standard ASCII only", + "3. NO FUNCTIONS - Write code directly", + "4. NO DOCSTRINGS - No complex documentation", + "5. Use pd.read_csv('file_path') to load data", + "6. Print human-readable insights directly", + "7. For charts, save to /tmp/querypls_session_csv_analysis_temp/", + "", + "Generate fixed Python code that will execute without errors." + ] + + return "\n".join(context_parts) + def _format_sql_response(self, sql_response) -> str: """Format SQL response for display.""" response_parts = [ @@ -351,119 +390,14 @@ def _format_csv_response(self, csv_response) -> str: return "\n\n".join(response_parts) def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecision: - """Fallback routing based on simple keyword matching.""" - query_lower = user_query.lower().strip() - - # Conversation keywords - conversation_keywords = [ - "hi", - "hello", - "hey", - "good morning", - "good afternoon", - "good evening", - "how are you", - "what's up", - "thanks", - "thank you", - "bye", - "goodbye", - "help", - "what can you do", - "who are you", - "tell me about yourself", - ] - - # CSV keywords - csv_keywords = [ - "csv", - "data", - "analyze", - "analysis", - "chart", - "graph", - "plot", - "visualize", - "pandas", - "matplotlib", - "seaborn", - "python", - "code", - "file", - "average", - "mean", - "sum", - "count", - "salary", - "column", - "row", - "statistics", - "stats", - "distribution", - "correlation", - "histogram", - "bar", - "line", - "scatter", - ] - - # SQL keywords - sql_keywords = [ - "select", - "from", - "where", - "join", - "table", - "database", - "query", - "sql", - "insert", - "update", - "delete", - "create", - "alter", - "drop", - ] - - if any(keyword in query_lower for keyword in conversation_keywords): - return RoutingDecision( - agent="CONVERSATION_AGENT", - confidence=0.8, - reasoning="Detected conversational keywords", - ) - elif csv_loaded and any(keyword in query_lower for keyword in csv_keywords): - return RoutingDecision( - agent="CSV_AGENT", - confidence=0.7, - reasoning="CSV data loaded and detected CSV-related keywords", - ) - elif any(keyword in query_lower for keyword in sql_keywords): - return RoutingDecision( - agent="SQL_AGENT", - confidence=0.7, - reasoning="Detected SQL-related keywords", - ) - else: - # Default to conversation for unknown queries - return RoutingDecision( - agent="CONVERSATION_AGENT", - confidence=0.5, - reasoning="No specific keywords detected, defaulting to conversation", - ) + """Fallback routing when LLM routing fails - let LLM decide, not hardcoded keywords.""" + # Default to conversation - let the LLM handle all decisions + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.3, + reasoning="LLM routing failed, defaulting to conversation agent", + ) def _get_fallback_conversation_response(self, user_query: str) -> str: - """Get fallback conversation response. Wish it dont be here.""" - query_lower = user_query.lower().strip() - - if any(greeting in query_lower for greeting in ["hi", "hello", "hey"]): - return "Hello! πŸ‘‹ I'm Querypls, your SQL and data analysis assistant. I can help you generate SQL queries or analyze CSV files. What would you like to work on today?" - elif "how are you" in query_lower: - return "I'm doing great, thank you for asking! 😊 I'm ready to help you with SQL queries or CSV data analysis. What can I assist you with?" - elif any(thanks in query_lower for thanks in ["thanks", "thank you"]): - return "You're welcome! 😊 Is there anything else I can help you with?" - elif any(bye in query_lower for bye in ["bye", "goodbye"]): - return "Goodbye! πŸ‘‹ Feel free to come back if you need help with SQL or data analysis." - elif "help" in query_lower or "what can you do" in query_lower: - return "I'm Querypls, your data analysis companion! πŸ—ƒοΈπŸ’¬ I can help you with SQL generation and CSV data analysis. Just upload a CSV file or ask me to write SQL queries!" - else: - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + """Get fallback conversation response when LLM fails.""" + return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" diff --git a/utils/prompt.py b/utils/prompt.py index 3618b5f..b752aa0 100644 --- a/utils/prompt.py +++ b/utils/prompt.py @@ -114,12 +114,16 @@ ## IMPORTANT: The code you generate will be EXECUTED automatically. Do NOT include code blocks or markdown formatting in the python_code field. ## CRITICAL GUIDELINES: -1. **KEEP CODE SIMPLE** - Maximum 10 lines of code +1. **KEEP CODE SIMPLE** - Maximum 6 lines of code 2. **ANSWER THE SPECIFIC QUESTION** - Don't create comprehensive analysis 3. **PRINT CLEAR INSIGHTS** - Human-readable output, not raw data 4. **NO COMPLEX SCRIPTS** - No functions, classes, or advanced features 5. **SIMPLE VARIABLES** - Use df, result, avg, etc. 6. **DIRECT APPROACH** - Load data, analyze, print result +7. **NO SPECIAL CHARACTERS** - Avoid β‰₯, ≀, β†’, etc. Use standard ASCII +8. **SIMPLE LOGIC** - No complex conditionals or loops +9. **NO FUNCTIONS** - Write code directly, not inside functions +10. **NO DOCSTRINGS** - No complex documentation ## CRITICAL Code Requirements: - The CSV data is available as a file at the path provided in the context @@ -137,7 +141,7 @@ ## Example Response { - "python_code": "import pandas as pd\\n\\n# Load data\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\n\\n# Calculate average salary\\navg_salary = df['salary'].mean()\\nprint(f'Average salary: ${avg_salary:,.2f}')", + "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\navg = df['salary'].mean()\\nprint(f'Average salary: ${avg:,.2f}')", "explanation": "Loads CSV data and calculates the average salary in a readable format", "expected_output": "Average salary: $60,000.00", "libraries_used": ["pandas"] @@ -145,7 +149,7 @@ ## Chart Example Response { - "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\n\\n# Create temp directory\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\n\\n# Load data\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\n\\n# Create chart\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\n\\n# Print results\\nprint('Department counts:')\\nprint(df['department'].value_counts())", + "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\nprint('Department counts:')\\nprint(df['department'].value_counts())", "explanation": "Creates a bar chart of department distribution and saves it to specific temp folder", "expected_output": "Bar chart visualization and department counts", "libraries_used": ["pandas", "matplotlib.pyplot"] @@ -153,12 +157,20 @@ ## Graph Example Response { - "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\n\\n# Create temp directory\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\n\\n# Load data\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\n\\n# Create simple bar chart\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\n\\n# Print summary\\nprint('Department distribution:')\\nprint(df['department'].value_counts())", + "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\nprint('Department distribution:')\\nprint(df['department'].value_counts())", "explanation": "Creates a simple bar chart showing department distribution", "expected_output": "Bar chart and department counts", "libraries_used": ["pandas", "matplotlib.pyplot"] } +## Price Analysis Example Response +{ + "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\navg_price = df['price'].mean()\\nprint(f'Average price: ${avg_price:,.2f}')", + "explanation": "Loads CSV data and calculates the average price in a readable format", + "expected_output": "Average price: $1,234.56", + "libraries_used": ["pandas"] +} + ## Important Notes: - Use double backslashes for newlines in the python_code field - ALWAYS use `pd.read_csv('file_path')` to load CSV data from the file path provided in context From b74960c08fc64aaa50f7fe84bda4b96907cdae44 Mon Sep 17 00:00:00 2001 From: samadpls Date: Fri, 8 Aug 2025 00:49:23 +0500 Subject: [PATCH 6/8] Refactor test files and improve code formatting - Cleaned up whitespace and formatting in test files for consistency. - Updated test cases to enhance readability and maintainability. - Ensured all test functions follow a consistent structure and style. - Added missing commas in several instances to adhere to Python syntax standards. --- examples/basic_usage_demo.py | 76 +++++++------ examples/test_backend_functionality.py | 145 ++++++++++++++----------- src/backend/orchestrator.py | 26 ++--- src/config/constants.py | 6 +- src/config/settings.py | 13 +-- src/frontend/app.py | 36 +++--- src/schemas/requests.py | 9 +- src/schemas/responses.py | 18 +-- src/services/conversation_service.py | 3 +- src/services/csv_analysis_tools.py | 18 +-- src/services/jupyter_service.py | 9 +- src/services/models.py | 24 ++-- src/services/routing_service.py | 78 ++++++------- src/services/sql_service.py | 9 +- src/terminal/cli.py | 12 +- test_app.py | 30 +++-- tests/test_basic.py | 10 +- tests/test_models.py | 2 +- tests/test_schemas.py | 27 ++--- 19 files changed, 272 insertions(+), 279 deletions(-) diff --git a/examples/basic_usage_demo.py b/examples/basic_usage_demo.py index f56c4b8..b024d78 100644 --- a/examples/basic_usage_demo.py +++ b/examples/basic_usage_demo.py @@ -18,23 +18,23 @@ def demo_conversation(): """Demo conversation functionality.""" print("πŸ—£οΈ CONVERSATION DEMO") print("=" * 40) - + routing_service = IntelligentRoutingService() - + # Test different conversation types conversations = [ "Hello", "How are you?", "What can you do?", "Thanks for your help", - "Goodbye" + "Goodbye", ] - + for query in conversations: print(f"\nUser: {query}") response = routing_service.handle_conversation_query(query) print(f"Assistant: {response}") - + print("\n" + "=" * 40) @@ -42,22 +42,22 @@ def demo_sql_generation(): """Demo SQL generation functionality.""" print("πŸ—ƒοΈ SQL GENERATION DEMO") print("=" * 40) - + routing_service = IntelligentRoutingService() - + # Test different SQL queries sql_queries = [ "Show me all users", "Find customers who made purchases in the last 30 days", "Get the total sales by month", - "SELECT * FROM users WHERE status = 'active'" + "SELECT * FROM users WHERE status = 'active'", ] - + for query in sql_queries: print(f"\nUser: {query}") response = routing_service.handle_sql_query(query, []) print(f"Assistant: {response[:200]}...") - + print("\n" + "=" * 40) @@ -65,7 +65,7 @@ def demo_csv_analysis(): """Demo CSV analysis functionality.""" print("πŸ“Š CSV ANALYSIS DEMO") print("=" * 40) - + # Sample CSV data sample_csv = """name,age,salary,department Alice,25,50000,IT @@ -73,24 +73,24 @@ def demo_csv_analysis(): Charlie,35,70000,IT Diana,28,55000,Finance Eve,32,65000,HR""" - + print(f"Sample CSV Data:\n{sample_csv}") - + routing_service = IntelligentRoutingService() - + # Test different CSV analysis queries csv_queries = [ "Show me the basic statistics of the data", "Create a bar chart of department distribution", "What is the average salary by department?", - "Show me the top 3 highest paid employees" + "Show me the top 3 highest paid employees", ] - + for query in csv_queries: print(f"\nUser: {query}") response = routing_service.handle_csv_query(query, sample_csv) print(f"Assistant: {response[:300]}...") - + print("\n" + "=" * 40) @@ -98,9 +98,9 @@ def demo_intelligent_routing(): """Demo intelligent routing functionality.""" print("🧠 INTELLIGENT ROUTING DEMO") print("=" * 40) - + routing_service = IntelligentRoutingService() - + # Test different types of queries test_queries = [ ("Hello", "CONVERSATION_AGENT"), @@ -108,9 +108,9 @@ def demo_intelligent_routing(): ("Analyze this CSV data", "CSV_AGENT"), ("How are you?", "CONVERSATION_AGENT"), ("SELECT * FROM users", "SQL_AGENT"), - ("Create a chart from the data", "CSV_AGENT") + ("Create a chart from the data", "CSV_AGENT"), ] - + for query, expected_agent in test_queries: print(f"\nQuery: '{query}'") decision = routing_service.determine_agent(query, [], csv_loaded=True) @@ -118,7 +118,7 @@ def demo_intelligent_routing(): print(f"Actual: {decision.agent}") print(f"Confidence: {decision.confidence}") print(f"Reasoning: {decision.reasoning}") - + print("\n" + "=" * 40) @@ -126,34 +126,38 @@ def demo_orchestrator(): """Demo the main orchestrator functionality.""" print("🎼 ORCHESTRATOR DEMO") print("=" * 40) - + orchestrator = BackendOrchestrator() - + # Create a new session - session_info = orchestrator.create_new_session(NewChatRequest(session_name="Demo Session")) + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Demo Session") + ) session_id = session_info.session_id print(f"Created session: {session_id}") - + # Test different types of interactions interactions = [ ("Hello", "conversation"), ("Show me all users", "sql"), - ("What can you do?", "conversation") + ("What can you do?", "conversation"), ] - + for query, query_type in interactions: print(f"\nUser ({query_type}): {query}") response = orchestrator.generate_intelligent_response(session_id, query) print(f"Assistant: {response.content[:150]}...") - + # Test CSV functionality sample_csv = "name,age,salary\nAlice,25,50000\nBob,30,60000\nCharlie,35,70000" result = orchestrator.load_csv_data(session_id, sample_csv) print(f"\nCSV Load Result: {result['status']}") - - response = orchestrator.generate_intelligent_response(session_id, "Analyze this data") + + response = orchestrator.generate_intelligent_response( + session_id, "Analyze this data" + ) print(f"CSV Analysis: {response.content[:200]}...") - + print("\n" + "=" * 40) @@ -161,7 +165,7 @@ def main(): """Run all demos.""" print("πŸš€ Querypls Backend Functionality Demo") print("=" * 50) - + demos = [ ("Conversation", demo_conversation), ("SQL Generation", demo_sql_generation), @@ -169,13 +173,13 @@ def main(): ("Intelligent Routing", demo_intelligent_routing), ("Orchestrator", demo_orchestrator), ] - + for demo_name, demo_func in demos: try: demo_func() except Exception as e: print(f"❌ {demo_name} demo failed: {str(e)}") - + print("\nπŸŽ‰ Demo completed! All backend functionality is working correctly.") print("\nπŸ“ Summary:") print("- Conversation: Natural responses for greetings and help") @@ -186,4 +190,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/test_backend_functionality.py b/examples/test_backend_functionality.py index 9fd1672..b568627 100644 --- a/examples/test_backend_functionality.py +++ b/examples/test_backend_functionality.py @@ -24,19 +24,19 @@ def test_conversation_functionality(): """Test conversation responses.""" print("πŸ§ͺ Testing Conversation Functionality") print("=" * 50) - + try: routing_service = IntelligentRoutingService() - + # Test conversation queries conversation_tests = [ "Hello", "How are you?", "What can you do?", "Thanks for your help", - "Goodbye" + "Goodbye", ] - + for query in conversation_tests: print(f"\nQuery: '{query}'") try: @@ -45,10 +45,10 @@ def test_conversation_functionality(): print("βœ… PASS") except Exception as e: print(f"❌ FAIL: {str(e)}") - + print("\n" + "=" * 50) return True - + except Exception as e: print(f"❌ Conversation test failed: {str(e)}") return False @@ -58,40 +58,41 @@ def test_sql_functionality(): """Test SQL generation functionality.""" print("πŸ—ƒοΈ Testing SQL Generation Functionality") print("=" * 50) - + try: routing_service = IntelligentRoutingService() sql_service = SQLGenerationService() - + # Test SQL queries sql_tests = [ "Show me all users", "SELECT * FROM users WHERE status = 'active'", "Find customers who made purchases in the last 30 days", - "Get the total sales by month" + "Get the total sales by month", ] - + for query in sql_tests: print(f"\nQuery: '{query}'") try: # Test routing - routing_decision = routing_service.determine_agent(query, [], csv_loaded=False) + routing_decision = routing_service.determine_agent( + query, [], csv_loaded=False + ) print(f"Routing Decision: {routing_decision.agent}") - + # Test SQL generation request = SQLGenerationRequest( - user_query=query, - conversation_history=[] + user_query=query, conversation_history=[] ) response = sql_service.generate_sql(request) print(f"SQL Response: {response.content[:100]}...") print("βœ… PASS") except Exception as e: print(f"❌ FAIL: {str(e)}") - + print("\n" + "=" * 50) return True - + except Exception as e: print(f"❌ SQL test failed: {str(e)}") return False @@ -101,57 +102,59 @@ def test_csv_functionality(): """Test CSV analysis functionality.""" print("πŸ“Š Testing CSV Analysis Functionality") print("=" * 50) - + try: # Create sample CSV data sample_data = { - 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], - 'age': [25, 30, 35, 28, 32], - 'salary': [50000, 60000, 70000, 55000, 65000], - 'department': ['IT', 'HR', 'IT', 'Finance', 'HR'] + "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"], + "age": [25, 30, 35, 28, 32], + "salary": [50000, 60000, 70000, 55000, 65000], + "department": ["IT", "HR", "IT", "Finance", "HR"], } - + df = pd.DataFrame(sample_data) csv_content = df.to_csv(index=False) - + print(f"Sample CSV Data:\n{df.head()}") print(f"CSV Shape: {df.shape}") - + # Test CSV tools csv_tools = CSVAnalysisTools() - + # Test loading CSV data print("\nTesting CSV loading...") result = csv_tools.load_csv_data(csv_content, "test_session") print(f"Load Result: {result}") - + # Test CSV analysis queries csv_tests = [ "Show me the basic statistics of the data", "Create a bar chart of department distribution", "What is the average salary by department?", - "Show me the top 3 highest paid employees" + "Show me the top 3 highest paid employees", ] - + routing_service = IntelligentRoutingService() - + for query in csv_tests: print(f"\nQuery: '{query}'") try: # Test routing with CSV loaded - routing_decision = routing_service.determine_agent(query, [], csv_loaded=True) + routing_decision = routing_service.determine_agent( + query, [], csv_loaded=True + ) print(f"Routing Decision: {routing_decision.agent}") - + # Test CSV analysis response = routing_service.handle_csv_query(query, csv_content) print(f"CSV Response: {response[:200]}...") print("βœ… PASS") except Exception as e: print(f"❌ FAIL: {str(e)}") - + print("\n" + "=" * 50) return True - + except Exception as e: print(f"❌ CSV test failed: {str(e)}") return False @@ -161,10 +164,10 @@ def test_intelligent_routing(): """Test intelligent routing functionality.""" print("🧠 Testing Intelligent Routing") print("=" * 50) - + try: routing_service = IntelligentRoutingService() - + # Test cases with expected routing test_cases = [ ("Hello", "CONVERSATION_AGENT"), @@ -176,34 +179,39 @@ def test_intelligent_routing(): ("What can you do?", "CONVERSATION_AGENT"), ("Thanks for your help", "CONVERSATION_AGENT"), ] - + all_passed = True for query, expected_agent in test_cases: print(f"\nQuery: '{query}'") print(f"Expected Agent: {expected_agent}") - + try: # Test without CSV loaded decision = routing_service.determine_agent(query, [], csv_loaded=False) print(f"Result (no CSV): {decision.agent}") - + # Test with CSV loaded - decision_with_csv = routing_service.determine_agent(query, [], csv_loaded=True) + decision_with_csv = routing_service.determine_agent( + query, [], csv_loaded=True + ) print(f"Result (with CSV): {decision_with_csv.agent}") - - if decision.agent == expected_agent or decision_with_csv.agent == expected_agent: + + if ( + decision.agent == expected_agent + or decision_with_csv.agent == expected_agent + ): print("βœ… PASS") else: print("❌ FAIL") all_passed = False - + except Exception as e: print(f"❌ ERROR: {str(e)}") all_passed = False - + print("\n" + "=" * 50) return all_passed - + except Exception as e: print(f"❌ Routing test failed: {str(e)}") return False @@ -213,39 +221,46 @@ def test_orchestrator(): """Test the main orchestrator functionality.""" print("🎼 Testing Backend Orchestrator") print("=" * 50) - + try: orchestrator = BackendOrchestrator() - + # Test session creation print("Testing session creation...") from src.schemas.requests import NewChatRequest - session_info = orchestrator.create_new_session(NewChatRequest(session_name="Test Session")) + + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Session") + ) session_id = session_info.session_id print(f"Created session: {session_id}") - + # Test conversation print("\nTesting conversation...") response = orchestrator.generate_intelligent_response(session_id, "Hello") print(f"Conversation Response: {response.content[:100]}...") - + # Test SQL generation print("\nTesting SQL generation...") - response = orchestrator.generate_intelligent_response(session_id, "Show me all users") + response = orchestrator.generate_intelligent_response( + session_id, "Show me all users" + ) print(f"SQL Response: {response.content[:100]}...") - + # Test CSV loading and analysis print("\nTesting CSV functionality...") sample_csv = "name,age,salary\nAlice,25,50000\nBob,30,60000\nCharlie,35,70000" result = orchestrator.load_csv_data(session_id, sample_csv) print(f"CSV Load Result: {result}") - - response = orchestrator.generate_intelligent_response(session_id, "Analyze this data") + + response = orchestrator.generate_intelligent_response( + session_id, "Analyze this data" + ) print(f"CSV Analysis Response: {response.content[:100]}...") - + print("\n" + "=" * 50) return True - + except Exception as e: print(f"❌ Orchestrator test failed: {str(e)}") return False @@ -255,7 +270,7 @@ def main(): """Run all tests.""" print("πŸš€ Starting Comprehensive Backend Functionality Tests") print("=" * 60) - + tests = [ ("Conversation", test_conversation_functionality), ("SQL Generation", test_sql_functionality), @@ -263,9 +278,9 @@ def main(): ("Intelligent Routing", test_intelligent_routing), ("Orchestrator", test_orchestrator), ] - + results = {} - + for test_name, test_func in tests: print(f"\n{'='*20} {test_name} {'='*20}") try: @@ -273,30 +288,30 @@ def main(): except Exception as e: print(f"❌ {test_name} test crashed: {str(e)}") results[test_name] = False - + # Summary print("\n" + "=" * 60) print("πŸ“Š TEST SUMMARY") print("=" * 60) - + passed = 0 total = len(tests) - + for test_name, result in results.items(): status = "βœ… PASS" if result else "❌ FAIL" print(f"{test_name}: {status}") if result: passed += 1 - + print(f"\nOverall: {passed}/{total} tests passed") - + if passed == total: print("πŸŽ‰ All tests passed! Backend functionality is working correctly.") else: print("⚠️ Some tests failed. Check the backend implementation.") - + return passed == total if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/backend/orchestrator.py b/src/backend/orchestrator.py index d639f13..d04389f 100644 --- a/src/backend/orchestrator.py +++ b/src/backend/orchestrator.py @@ -54,8 +54,7 @@ def create_new_session(self, request: NewChatRequest) -> SessionInfo: messages = [] if request.initial_context: - messages.append(ChatMessage( - role="system", content=request.initial_context)) + messages.append(ChatMessage(role="system", content=request.initial_context)) messages.append(ChatMessage(role="assistant", content=WELCOME_MESSAGE)) @@ -104,44 +103,45 @@ def load_csv_data(self, session_id: str, csv_content: str) -> Dict[str, Any]: session = self.get_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") - + # Save CSV to file import os import tempfile - + # Create temp directory for this session if it doesn't exist temp_dir = f"/tmp/querypls_session_{session_id}" os.makedirs(temp_dir, exist_ok=True) - + # Save CSV to file csv_file_path = os.path.join(temp_dir, "data.csv") - with open(csv_file_path, 'w') as f: + with open(csv_file_path, "w") as f: f.write(csv_content) - + # Store both the content and file path in session session.csv_data = csv_content session.csv_file_path = csv_file_path - + # Get CSV info for context import pandas as pd from io import StringIO + df = pd.read_csv(StringIO(csv_content)) - + session.csv_info = { "file_path": csv_file_path, "shape": df.shape, "columns": list(df.columns), "dtypes": df.dtypes.to_dict(), - "sample_data": df.head(3).to_dict('records') + "sample_data": df.head(3).to_dict("records"), } - + session.last_activity = datetime.now() - + return { "status": "success", "message": "CSV data loaded successfully", "shape": df.shape, - "columns": list(df.columns) + "columns": list(df.columns), } def generate_intelligent_response( diff --git a/src/config/constants.py b/src/config/constants.py index 4927329..1ee6e21 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -10,8 +10,7 @@ STREAMLIT_HOST = "localhost" # Streamlit Configuration -STREAMLIT_CONFIG = {"page_title": "Querypls", - "page_icon": "πŸ’¬", "layout": "wide"} +STREAMLIT_CONFIG = {"page_title": "Querypls", "page_icon": "πŸ’¬", "layout": "wide"} # Welcome and Session Messages WELCOME_MESSAGE = "Hello! πŸ‘‹ I'm Querypls, your SQL and data analysis assistant. I can help you generate SQL queries or analyze CSV files. What would you like to work on today?" @@ -35,6 +34,9 @@ SESSION_CREATE_ERROR = "❌ Error creating session: {error}" SESSION_NOT_FOUND_ERROR = "❌ Session not found" +# worst-case scenario +WORST_CASE_SCENARIO = "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + # Application Errors ORCHESTRATOR_INIT_ERROR = "❌ Error initializing orchestrator: {error}" APP_INIT_ERROR = "❌ Error initializing application" diff --git a/src/config/settings.py b/src/config/settings.py index 40957e7..858af21 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -7,11 +7,11 @@ from pydantic import Field, BaseModel, ConfigDict - class Settings(BaseModel): - groq_api_key: str = Field(default=os.getenv("GROQ_API_KEY", "mock_api_key"), env="GROQ_API_KEY") - groq_model_name: str = Field( - default="openai/gpt-oss-120b", env="GROQ_MODEL_NAME") + groq_api_key: str = Field( + default=os.getenv("GROQ_API_KEY", "mock_api_key"), env="GROQ_API_KEY" + ) + groq_model_name: str = Field(default="openai/gpt-oss-120b", env="GROQ_MODEL_NAME") app_version: str = Field(default="1.0.0", env="APP_VERSION") max_chat_histories: int = Field(default=5, env="MAX_CHAT_HISTORIES") debug_mode: bool = Field(default=False, env="DEBUG_MODE") @@ -24,8 +24,8 @@ class Settings(BaseModel): json_schema_extra: ClassVar[str] = "ignore" model_config = ConfigDict( - env_file = ".env", - env_file_encoding = "utf-8", + env_file=".env", + env_file_encoding="utf-8", ) @@ -37,4 +37,3 @@ def get_settings() -> Settings: if _settings_instance is None: _settings_instance = Settings() return _settings_instance - diff --git a/src/frontend/app.py b/src/frontend/app.py index f79e898..9c39bd7 100644 --- a/src/frontend/app.py +++ b/src/frontend/app.py @@ -52,8 +52,7 @@ def get_current_session_id(): if "current_session_id" not in st.session_state: orchestrator = initialize_orchestrator() if orchestrator: - st.session_state["current_session_id"] = orchestrator.get_default_session( - ) + st.session_state["current_session_id"] = orchestrator.get_default_session() return st.session_state.get("current_session_id") @@ -78,32 +77,38 @@ def display_message_with_images(content: str): # Split the content into text and image sections parts = content.split("**Generated Images:**") text_content = parts[0].strip() - + # Display the text content st.markdown(text_content) - + # Handle images if present if len(parts) > 1: image_section = parts[1].strip() - image_lines = [line.strip() for line in image_section.split('\n') if line.strip().startswith('- ')] - + image_lines = [ + line.strip() + for line in image_section.split("\n") + if line.strip().startswith("- ") + ] + if image_lines: st.markdown("**Generated Images:**") - + # Look for images in the specific temp directory import os import glob - + temp_dir = "/tmp/querypls_session_csv_analysis_temp" if os.path.exists(temp_dir): for line in image_lines: # Extract filename from the line (e.g., "- department_chart.png") - filename = line.replace('- ', '').strip() + filename = line.replace("- ", "").strip() image_path = os.path.join(temp_dir, filename) - + if os.path.exists(image_path): try: - st.image(image_path, caption=filename, use_column_width=True) + st.image( + image_path, caption=filename, use_column_width=True + ) except Exception as e: st.error(f"Error displaying image {filename}: {str(e)}") else: @@ -117,7 +122,7 @@ def cleanup_old_images(): """Clean up old CSV analysis images.""" import os import glob - + temp_dir = "/tmp/querypls_session_csv_analysis_temp" if os.path.exists(temp_dir): try: @@ -185,7 +190,7 @@ def main(): try: # Clean up old images when creating new session cleanup_old_images() - + sessions = orchestrator.list_sessions() new_session = orchestrator.create_new_session( NewChatRequest(session_name=f"Chat {len(sessions) + 1}") @@ -203,9 +208,8 @@ def main(): try: # Clean up old images before loading new CSV cleanup_old_images() - - result = orchestrator.load_csv_data( - current_session_id, csv_content) + + result = orchestrator.load_csv_data(current_session_id, csv_content) if result["status"] == "success": st.success(CSV_LOADED_SUCCESS) st.session_state["csv_loaded"] = True diff --git a/src/schemas/requests.py b/src/schemas/requests.py index fc58d57..e0c2806 100644 --- a/src/schemas/requests.py +++ b/src/schemas/requests.py @@ -13,10 +13,8 @@ class ChatMessage(BaseModel): description="Message role (user, assistant, system)" ) content: str = Field(description="Message content", min_length=1) - timestamp: Optional[str] = Field( - default=None, description="Message timestamp") - session_id: Optional[str] = Field( - default=None, description="Session identifier") + timestamp: Optional[str] = Field(default=None, description="Message timestamp") + session_id: Optional[str] = Field(default=None, description="Session identifier") class SQLGenerationRequest(BaseModel): @@ -45,8 +43,7 @@ class ConversationHistory(BaseModel): messages: List[ChatMessage] = Field( default=[], description="List of conversation messages" ) - session_id: Optional[str] = Field( - default=None, description="Session identifier") + session_id: Optional[str] = Field(default=None, description="Session identifier") class NewChatRequest(BaseModel): diff --git a/src/schemas/responses.py b/src/schemas/responses.py index 644d136..eb603da 100644 --- a/src/schemas/responses.py +++ b/src/schemas/responses.py @@ -9,8 +9,7 @@ class SQLQueryResponse(BaseModel): """Schema for SQL query generation response.""" - sql_query: str = Field(..., - description="The generated SQL query as a string") + sql_query: str = Field(..., description="The generated SQL query as a string") explanation: str = Field( ..., description="Brief explanation of what the query does" ) @@ -41,10 +40,8 @@ class SQLQueryResponse(BaseModel): class ChatResponse(BaseModel): """Schema for chat response.""" - message_id: str = Field(..., - description="Unique identifier for the message") - role: Literal["assistant"] = Field( - default="assistant", description="Message role") + message_id: str = Field(..., description="Unique identifier for the message") + role: Literal["assistant"] = Field(default="assistant", description="Message role") content: str = Field(..., description="Response content") sql_response: Optional[SQLQueryResponse] = Field( default=None, description="Structured SQL response if applicable" @@ -58,8 +55,7 @@ class ErrorResponse(BaseModel): error_code: str = Field(..., description="Error code identifier") error_message: str = Field(..., description="Human-readable error message") - details: Optional[str] = Field( - default=None, description="Additional error details") + details: Optional[str] = Field(default=None, description="Additional error details") timestamp: str = Field(..., description="Error timestamp") @@ -69,8 +65,7 @@ class SessionInfo(BaseModel): session_id: str = Field(..., description="Unique session identifier") session_name: str = Field(..., description="Session name") created_at: str = Field(..., description="Session creation timestamp") - message_count: int = Field(..., - description="Number of messages in the session") + message_count: int = Field(..., description="Number of messages in the session") last_activity: str = Field(..., description="Last activity timestamp") @@ -82,5 +77,4 @@ class HealthCheckResponse(BaseModel): ) version: str = Field(..., description="Application version") timestamp: str = Field(..., description="Health check timestamp") - services: dict = Field( - default={}, description="Status of individual services") + services: dict = Field(default={}, description="Status of individual services") diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py index 5dadca8..df1020d 100644 --- a/src/services/conversation_service.py +++ b/src/services/conversation_service.py @@ -7,6 +7,7 @@ from pydantic_ai.models.groq import GroqModel from pydantic_ai.providers.groq import GroqProvider +from src.config.constants import WORST_CASE_SCENARIO from src.config.settings import get_settings from src.services.models import ConversationResponse, Failed from utils.prompt import CONVERSATION_PROMPT @@ -79,7 +80,7 @@ def get_conversational_response(self, query: str) -> str: elif "help" in query_lower or "what can you do" in query_lower: return "I'm Querypls, your SQL and data analysis assistant! πŸ—ƒοΈπŸ’¬\n\nI can help you with:\nβ€’ **SQL Generation**: Convert natural language to SQL queries\nβ€’ **CSV Analysis**: Analyze data files with Python code\nβ€’ **Data Visualization**: Create charts and graphs\n\nJust ask me anything about your data!" else: - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + return WORST_CASE_SCENARIO except Exception as e: # Fallback response diff --git a/src/services/csv_analysis_tools.py b/src/services/csv_analysis_tools.py index 311f0da..30459e2 100644 --- a/src/services/csv_analysis_tools.py +++ b/src/services/csv_analysis_tools.py @@ -19,17 +19,14 @@ class CSVAnalysisContext(BaseModel): class PythonCodeResponse(BaseModel): - python_code: str = Field( - description="Generated Python code for data analysis") + python_code: str = Field(description="Generated Python code for data analysis") explanation: str = Field(description="Explanation of what the code does") - expected_output: str = Field( - description="What output is expected from the code") + expected_output: str = Field(description="What output is expected from the code") libraries_used: list = Field(description="List of Python libraries used") class CodeExecutionResult(BaseModel): - status: str = Field( - description="Execution status: success, error, or retry") + status: str = Field(description="Execution status: success, error, or retry") output: str = Field(description="Output from code execution") error_message: Optional[str] = Field( description="Error message if execution failed" @@ -89,8 +86,7 @@ def generate_analysis_code( def execute_analysis_code( self, python_code: str, session_id: str, max_retries: int = 3 ) -> CodeExecutionResult: - result = self.csv_service.execute_analysis( - session_id, python_code, max_retries) + result = self.csv_service.execute_analysis(session_id, python_code, max_retries) return CodeExecutionResult( status=result["status"], @@ -130,8 +126,7 @@ def create_csv_analysis_agent() -> Agent: settings = get_settings() model = GroqModel( - settings.groq_model_name, provider=GroqProvider( - api_key=settings.groq_api_key) + settings.groq_model_name, provider=GroqProvider(api_key=settings.groq_api_key) ) agent = Agent(model, instructions=CSV_AGENT_PROMPT, output_type=str) @@ -211,8 +206,7 @@ async def fix_code_error( sample_data=csv_info["sample_data"], ) - result = csv_tools.fix_code_error( - original_code, error_message, csv_context) + result = csv_tools.fix_code_error(original_code, error_message, csv_context) return f"""πŸ”§ Fixed Code: ```python {result.python_code} diff --git a/src/services/jupyter_service.py b/src/services/jupyter_service.py index bc3c220..ecf9064 100644 --- a/src/services/jupyter_service.py +++ b/src/services/jupyter_service.py @@ -77,8 +77,7 @@ def create_new_session( # Import common data science libraries self.execute_code("import pandas as pd", session_id) self.execute_code("import numpy as np", session_id) - self.execute_code( - "import matplotlib.pyplot as plt", session_id) + self.execute_code("import matplotlib.pyplot as plt", session_id) self.execute_code("import seaborn as sns", session_id) return session_id @@ -111,8 +110,7 @@ def execute_code(self, code: str, session_id: str = "default") -> ExecutionResul content = msg.get("content", {}) if msg_type == "execute_result": - output.append( - str(content.get("data", {}).get("text/plain", ""))) + output.append(str(content.get("data", {}).get("text/plain", ""))) elif msg_type == "stream": output.append(content.get("text", "")) elif msg_type == "error": @@ -216,8 +214,7 @@ def execute_analysis( ) -> Dict[str, Any]: for attempt in range(max_retries): try: - result = self.jupyter_client.execute_code( - python_code, session_id) + result = self.jupyter_client.execute_code(python_code, session_id) if result.status == "Success": return { diff --git a/src/services/models.py b/src/services/models.py index c6b3c24..f237396 100644 --- a/src/services/models.py +++ b/src/services/models.py @@ -16,8 +16,7 @@ class RoutingDecision(BaseModel): confidence: float = Field( description="Confidence level in the routing decision", ge=0.0, le=1.0 ) - reasoning: str = Field( - description="Brief explanation of why this agent was chosen") + reasoning: str = Field(description="Brief explanation of why this agent was chosen") class ConversationResponse(BaseModel): @@ -36,10 +35,8 @@ class SQLResponse(BaseModel): """Response for SQL generation.""" sql_query: str = Field(description="The generated SQL query") - explanation: str = Field( - description="Brief explanation of what the query does") - tables_used: List[str] = Field( - description="Array of table names used in the query") + explanation: str = Field(description="Brief explanation of what the query does") + tables_used: List[str] = Field(description="Array of table names used in the query") columns_selected: List[str] = Field( description="Array of column names selected in the query" ) @@ -49,8 +46,7 @@ class SQLResponse(BaseModel): complexity: Literal["SIMPLE", "MEDIUM", "COMPLEX"] = Field( description="Query complexity level" ) - estimated_rows: str = Field( - description="Estimated number of rows returned") + estimated_rows: str = Field(description="Estimated number of rows returned") execution_time: Optional[str] = Field( description="Estimated execution time", default=None ) @@ -63,12 +59,9 @@ class CSVAnalysisResponse(BaseModel): """Response for CSV analysis.""" python_code: str = Field(description="The generated Python code") - explanation: str = Field( - description="Brief explanation of what the code does") - expected_output: str = Field( - description="What output is expected from the code") - libraries_used: List[str] = Field( - description="Array of Python libraries used") + explanation: str = Field(description="Brief explanation of what the code does") + expected_output: str = Field(description="What output is expected from the code") + libraries_used: List[str] = Field(description="Array of Python libraries used") class CodeFixResponse(BaseModel): @@ -79,8 +72,7 @@ class CodeFixResponse(BaseModel): expected_output: str = Field( description="What output is expected from the fixed code" ) - libraries_used: List[str] = Field( - description="Array of Python libraries used") + libraries_used: List[str] = Field(description="Array of Python libraries used") class Failed(BaseModel): diff --git a/src/services/routing_service.py b/src/services/routing_service.py index edcb82e..46cefd6 100644 --- a/src/services/routing_service.py +++ b/src/services/routing_service.py @@ -8,6 +8,7 @@ from pydantic_ai.models.groq import GroqModel from pydantic_ai.providers.groq import GroqProvider +from src.config.constants import WORST_CASE_SCENARIO from src.config.settings import get_settings from src.services.models import ( RoutingDecision, @@ -95,8 +96,7 @@ def handle_sql_query( ) -> str: """Handle SQL generation queries.""" try: - context = self._prepare_sql_context( - user_query, conversation_history) + context = self._prepare_sql_context(user_query, conversation_history) result = self.sql_agent.run_sync(context) if hasattr(result.output, "sql_query"): @@ -131,7 +131,7 @@ def handle_csv_query( except Exception as e: # If LLM fails, provide a graceful response without showing errors - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + return WORST_CASE_SCENARIO def _execute_csv_analysis( self, python_code: str, csv_info: Dict[str, Any], explanation: str @@ -147,7 +147,7 @@ def _execute_csv_analysis( session_id = "csv_analysis_temp" # Load CSV data into the session - jupyter_service.load_csv_data(session_id, csv_info['file_path']) + jupyter_service.load_csv_data(session_id, csv_info["file_path"]) # Install required libraries if needed install_code = """ @@ -166,7 +166,7 @@ def install_package(package): install_package('matplotlib') install_package('seaborn') """ - + # Execute installation first install_result = jupyter_service.execute_analysis( session_id, install_code, max_retries=1 @@ -175,7 +175,7 @@ def install_package(package): # Retry loop for code execution with error fixing current_code = python_code max_retries = 3 - + for attempt in range(max_retries): # Execute the current code result = jupyter_service.execute_analysis( @@ -184,7 +184,7 @@ def install_package(package): if result["status"] == "success": output = result.get("output", "") - + # If output is empty, provide a fallback if not output.strip(): output = "Analysis completed successfully but no output was generated." @@ -192,11 +192,11 @@ def install_package(package): # Check if any images were created in the specific session directory import os import glob - + # Look for images in the session's temp directory session_temp_dir = f"/tmp/querypls_session_csv_analysis_temp" image_files = [] - + if os.path.exists(session_temp_dir): png_files = glob.glob(os.path.join(session_temp_dir, "*.png")) jpg_files = glob.glob(os.path.join(session_temp_dir, "*.jpg")) @@ -213,46 +213,47 @@ def install_package(package): {output} **Explanation:** {explanation}""" - + else: # Code execution failed - try to fix it error_msg = result.get("error_message", "Unknown error") - + if attempt < max_retries - 1: # Not the last attempt # Send error to LLM to fix the code - fixed_code = self._fix_python_code(current_code, error_msg, csv_info) + fixed_code = self._fix_python_code( + current_code, error_msg, csv_info + ) if fixed_code: current_code = fixed_code continue # Try again with fixed code - - # If we reach here, all attempts failed or no more retries - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + + return WORST_CASE_SCENARIO except Exception as e: - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + return WORST_CASE_SCENARIO except Exception as e: - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + return WORST_CASE_SCENARIO - # This should never be reached, but just in case - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + return WORST_CASE_SCENARIO - def _fix_python_code(self, original_code: str, error_message: str, csv_info: Dict[str, Any]) -> Optional[str]: + def _fix_python_code( + self, original_code: str, error_message: str, csv_info: Dict[str, Any] + ) -> Optional[str]: """Send error to LLM to fix the Python code.""" try: - # Create context for code fixing - context = self._prepare_code_fix_context(original_code, error_message, csv_info) - - # Use the CSV agent to generate fixed code + context = self._prepare_code_fix_context( + original_code, error_message, csv_info + ) + result = self.csv_agent.run_sync(context) - + if hasattr(result.output, "python_code"): return result.output.python_code else: return None - + except Exception as e: - # If fixing fails, return None to continue with original code return None def _prepare_routing_context( @@ -304,19 +305,17 @@ def _prepare_csv_context( f"CSV Data Types: {csv_info['dtypes']}", f"CSV Sample Data: {csv_info['sample_data']}", ] - + if conversation_history: context_parts.append("Conversation History:") # Last 5 messages for context for msg in conversation_history[-5:]: context_parts.append(f"- {msg.role}: {msg.content}") - + context_parts.append( "\nGenerate SIMPLE Python code that directly answers the user's question." ) - context_parts.append( - "MAXIMUM 10 LINES OF CODE - Keep it simple!" - ) + context_parts.append("MAXIMUM 10 LINES OF CODE - Keep it simple!") context_parts.append( "NO COMPREHENSIVE ANALYSIS - Just answer the specific question!" ) @@ -329,10 +328,12 @@ def _prepare_csv_context( context_parts.append( "For charts, use plt.savefig('/tmp/querypls_session_csv_analysis_temp/chart_name.png') and plt.show()." ) - + return "\n".join(context_parts) - def _prepare_code_fix_context(self, original_code: str, error_message: str, csv_info: Dict[str, Any]) -> str: + def _prepare_code_fix_context( + self, original_code: str, error_message: str, csv_info: Dict[str, Any] + ) -> str: """Prepare context for code fixing.""" context_parts = [ "CODE FIXING REQUEST:", @@ -355,9 +356,9 @@ def _prepare_code_fix_context(self, original_code: str, error_message: str, csv_ "6. Print human-readable insights directly", "7. For charts, save to /tmp/querypls_session_csv_analysis_temp/", "", - "Generate fixed Python code that will execute without errors." + "Generate fixed Python code that will execute without errors.", ] - + return "\n".join(context_parts) def _format_sql_response(self, sql_response) -> str: @@ -373,8 +374,7 @@ def _format_sql_response(self, sql_response) -> str: ] if sql_response.warnings: - response_parts.append( - f"**Warnings:** {', '.join(sql_response.warnings)}") + response_parts.append(f"**Warnings:** {', '.join(sql_response.warnings)}") return "\n\n".join(response_parts) @@ -400,4 +400,4 @@ def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecisio def _get_fallback_conversation_response(self, user_query: str) -> str: """Get fallback conversation response when LLM fails.""" - return "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + return WORST_CASE_SCENARIO diff --git a/src/services/sql_service.py b/src/services/sql_service.py index f5879ad..65b9bd8 100644 --- a/src/services/sql_service.py +++ b/src/services/sql_service.py @@ -23,8 +23,7 @@ def __init__(self, api_key: Optional[str] = None): ) self.model = GroqModel( - self.settings.groq_model_name, provider=GroqProvider( - api_key=self.api_key) + self.settings.groq_model_name, provider=GroqProvider(api_key=self.api_key) ) self.agent = Agent( @@ -42,8 +41,7 @@ def format_chat_history(self, messages: list) -> str: role = msg.get("role", "user") if "```sql" in content: - content = content.replace( - "```sql\n", "").replace("\n```", "").strip() + content = content.replace("```sql\n", "").replace("\n```", "").strip() history.append( {"role": role, "query" if role == "user" else "response": content} @@ -53,8 +51,7 @@ def format_chat_history(self, messages: list) -> str: def generate_sql(self, request: SQLGenerationRequest) -> ChatResponse: try: - formatted_history = self.format_chat_history( - request.conversation_history) + formatted_history = self.format_chat_history(request.conversation_history) prompt = f"Previous conversation: {formatted_history}\nCurrent question: { request.user_query}" diff --git a/src/terminal/cli.py b/src/terminal/cli.py index 0fafbd6..841281d 100644 --- a/src/terminal/cli.py +++ b/src/terminal/cli.py @@ -66,15 +66,11 @@ def chat(self, query: str): print("\nSQL Details:") print(f" Query Type: {response.sql_response.query_type}") print(f" Complexity: {response.sql_response.complexity}") - print( - f" Tables Used: {', '.join(response.sql_response.tables_used)}") - print( - f" Columns: {', '.join(response.sql_response.columns_selected)}") - print( - f" Estimated Rows: {response.sql_response.estimated_rows}") + print(f" Tables Used: {', '.join(response.sql_response.tables_used)}") + print(f" Columns: {', '.join(response.sql_response.columns_selected)}") + print(f" Estimated Rows: {response.sql_response.estimated_rows}") if response.sql_response.warnings: - print( - f" Warnings: {', '.join(response.sql_response.warnings)}") + print(f" Warnings: {', '.join(response.sql_response.warnings)}") except Exception as e: print(f"Error: {str(e)}") diff --git a/test_app.py b/test_app.py index 3fe4305..a3c40b0 100644 --- a/test_app.py +++ b/test_app.py @@ -5,72 +5,80 @@ import sys import os + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + def test_imports(): """Test that all imports work correctly.""" print("Testing imports...") - + try: from src.config.constants import WELCOME_MESSAGE, DEFAULT_SESSION_NAME + print("βœ… Constants imported successfully") print(f" WELCOME_MESSAGE: {WELCOME_MESSAGE[:50]}...") print(f" DEFAULT_SESSION_NAME: {DEFAULT_SESSION_NAME}") except ImportError as e: print(f"❌ Error importing constants: {e}") return False - + try: from src.services.routing_service import IntelligentRoutingService + print("βœ… Routing service imported successfully") except ImportError as e: print(f"❌ Error importing routing service: {e}") return False - + try: from src.backend.orchestrator import BackendOrchestrator + print("βœ… Orchestrator imported successfully") except ImportError as e: print(f"❌ Error importing orchestrator: {e}") return False - + return True + def test_routing(): """Test the routing service.""" print("\nTesting routing service...") - + try: from src.services.routing_service import IntelligentRoutingService + routing_service = IntelligentRoutingService() - + # Test routing decisions test_cases = [ ("Hello", "CONVERSATION_AGENT"), ("Show me all users", "SQL_AGENT"), ("Analyze this CSV data", "CSV_AGENT"), ] - + for query, expected in test_cases: decision = routing_service.determine_agent(query, [], csv_loaded=False) status = "βœ…" if decision.agent == expected else "❌" print(f" {status} '{query}' β†’ {decision.agent} (expected: {expected})") - + return True except Exception as e: print(f"❌ Error testing routing: {e}") return False + if __name__ == "__main__": print("Querypls Application Test") print("=" * 40) - + success = True success &= test_imports() success &= test_routing() - + print("\n" + "=" * 40) if success: print("βœ… All tests passed! Application is ready.") else: - print("❌ Some tests failed. Please check the errors above.") \ No newline at end of file + print("❌ Some tests failed. Please check the errors above.") diff --git a/tests/test_basic.py b/tests/test_basic.py index 5586eb0..5325048 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -12,7 +12,9 @@ def test_create_new_session(): orchestrator = BackendOrchestrator() - session_info = orchestrator.create_new_session(NewChatRequest(session_name="Test Chat")) + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Chat") + ) assert session_info.session_name == "Test Chat" assert session_info.session_id is not None @@ -27,8 +29,6 @@ def test_list_sessions(): assert any(s.session_name == "Chat 2" for s in sessions) - - def test_health_check(): orchestrator = BackendOrchestrator() health = orchestrator.health_check() @@ -39,6 +39,8 @@ def test_health_check(): def test_session_message_flow(): orchestrator = BackendOrchestrator() - session_info = orchestrator.create_new_session(NewChatRequest(session_name="Test Session")) + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Session") + ) assert session_info.session_name == "Test Session" assert session_info.session_id is not None diff --git a/tests/test_models.py b/tests/test_models.py index 5fd1ee2..90331d0 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ def test_chat_response(): message_id="123", content="test response", timestamp="2024-01-01T00:00:00", - session_id="456" + session_id="456", ) assert resp.content == "test response" assert resp.session_id == "456" diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 77c3d08..fea64eb 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -4,14 +4,14 @@ SQLGenerationRequest, ChatMessage, ConversationHistory, - NewChatRequest + NewChatRequest, ) from src.schemas.responses import ( SQLQueryResponse, ChatResponse, ErrorResponse, SessionInfo, - HealthCheckResponse + HealthCheckResponse, ) @@ -20,7 +20,7 @@ def test_sql_generation_request(): user_query="Show users", conversation_history=[], database_schema=None, - query_type=None + query_type=None, ) assert request.user_query == "Show users" assert isinstance(request.conversation_history, list) @@ -28,10 +28,7 @@ def test_sql_generation_request(): def test_chat_message(): message = ChatMessage( - role="user", - content="Hello", - timestamp="2024-01-01T00:00:00", - session_id="123" + role="user", content="Hello", timestamp="2024-01-01T00:00:00", session_id="123" ) assert message.role == "user" assert message.content == "Hello" @@ -41,20 +38,14 @@ def test_chat_message(): def test_conversation_history(): history = ConversationHistory( - messages=[ - ChatMessage(role="user", content="Hello") - ], - session_id="123" + messages=[ChatMessage(role="user", content="Hello")], session_id="123" ) assert len(history.messages) == 1 assert history.session_id == "123" def test_new_chat_request(): - request = NewChatRequest( - session_name="Test Chat", - initial_context="SQL Testing" - ) + request = NewChatRequest(session_name="Test Chat", initial_context="SQL Testing") assert request.session_name == "Test Chat" assert request.initial_context == "SQL Testing" @@ -66,7 +57,7 @@ def test_sql_query_response(): tables_used=["users"], columns_selected=["*"], query_type="SELECT", - complexity="SIMPLE" + complexity="SIMPLE", ) assert response.sql_query == "SELECT * FROM users" assert response.explanation == "Get all users" @@ -78,7 +69,7 @@ def test_chat_response(): message_id="123", content="Hello", timestamp="2024-01-01T00:00:00", - session_id="456" + session_id="456", ) assert response.message_id == "123" assert response.content == "Hello" @@ -90,7 +81,7 @@ def test_health_check_response(): status="healthy", version="1.0.0", timestamp="2024-01-01T00:00:00", - services={"sql": "healthy"} + services={"sql": "healthy"}, ) assert response.status == "healthy" assert response.version == "1.0.0" From 52ff0bb8ed688f821d7e38302d5fac4626edfc29 Mon Sep 17 00:00:00 2001 From: samadpls Date: Fri, 8 Aug 2025 00:58:07 +0500 Subject: [PATCH 7/8] refactor: Simplify CSV loading success message and enhance routing service prompts for clarity and simplicity --- src/services/csv_analysis_tools.py | 4 +- src/services/routing_service.py | 30 +++----- utils/prompt.py | 107 +++++++++++------------------ 3 files changed, 50 insertions(+), 91 deletions(-) diff --git a/src/services/csv_analysis_tools.py b/src/services/csv_analysis_tools.py index 30459e2..3c65e21 100644 --- a/src/services/csv_analysis_tools.py +++ b/src/services/csv_analysis_tools.py @@ -139,9 +139,7 @@ async def load_csv_data( ) -> str: result = csv_tools.load_csv_data(csv_content, session_id) if result["status"] == "success": - return f"CSV loaded successfully! Shape: { - result['shape']}, Columns: { - result['columns']}" + return f"CSV loaded successfully! Shape: {result['shape']}, Columns: {result['columns']}" else: return f"Error loading CSV: {result['message']}" diff --git a/src/services/routing_service.py b/src/services/routing_service.py index 46cefd6..64164d5 100644 --- a/src/services/routing_service.py +++ b/src/services/routing_service.py @@ -203,16 +203,13 @@ def install_package(package): image_files.extend(png_files + jpg_files) if image_files: - image_info = "\n\n**Generated Images:**\n" + image_info = "\n\nπŸ“Š **Charts generated:**\n" for img_file in image_files: image_info += f"- {os.path.basename(img_file)}\n" output += image_info - return f"""**Analysis Results:** - -{output} - -**Explanation:** {explanation}""" + # Return only the human-readable output, not technical details + return output.strip() else: # Code execution failed - try to fix it @@ -313,20 +310,20 @@ def _prepare_csv_context( context_parts.append(f"- {msg.role}: {msg.content}") context_parts.append( - "\nGenerate SIMPLE Python code that directly answers the user's question." + "\nGenerate SUPER SIMPLE Python code that directly answers the user's question." ) - context_parts.append("MAXIMUM 10 LINES OF CODE - Keep it simple!") + context_parts.append("MAXIMUM 5 LINES OF CODE - Keep it extremely simple!") context_parts.append( - "NO COMPREHENSIVE ANALYSIS - Just answer the specific question!" + "NO FUNCTIONS OR CLASSES - Just direct code that prints results!" ) context_parts.append( f"IMPORTANT: Use pd.read_csv('{csv_info['file_path']}') to load the data from the file path!" ) context_parts.append( - "Print human-readable insights directly - no complex scripts!" + "Print human-readable results like 'Average price: $123.45' - NO technical output!" ) context_parts.append( - "For charts, use plt.savefig('/tmp/querypls_session_csv_analysis_temp/chart_name.png') and plt.show()." + "For charts, use plt.savefig('/tmp/querypls_session_csv_analysis_temp/chart.png') and plt.show()." ) return "\n".join(context_parts) @@ -378,17 +375,6 @@ def _format_sql_response(self, sql_response) -> str: return "\n\n".join(response_parts) - def _format_csv_response(self, csv_response) -> str: - """Format CSV analysis response for display.""" - response_parts = [ - f"**Python Code:**\n```python\n{csv_response.python_code}\n```", - f"**Explanation:** {csv_response.explanation}", - f"**Expected Output:** {csv_response.expected_output}", - f"**Libraries Used:** {', '.join(csv_response.libraries_used)}", - ] - - return "\n\n".join(response_parts) - def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecision: """Fallback routing when LLM routing fails - let LLM decide, not hardcoded keywords.""" # Default to conversation - let the LLM handle all decisions diff --git a/utils/prompt.py b/utils/prompt.py index b752aa0..a243db7 100644 --- a/utils/prompt.py +++ b/utils/prompt.py @@ -100,7 +100,7 @@ Respond only with the JSON object. Do not include any additional text or markdown formatting.""" -CSV_ANALYSIS_PROMPT = """You are a Python data analysis expert. Generate SIMPLE, FOCUSED Python code that answers the user's specific question. +CSV_ANALYSIS_PROMPT = """You are a Python data analysis expert. Generate SIMPLE, FOCUSED Python code that answers the user's specific question in a human-readable way. ## Response Format Your response must be in JSON format. @@ -111,83 +111,58 @@ * `expected_output` - What output is expected from the code * `libraries_used` - Array of Python libraries used -## IMPORTANT: The code you generate will be EXECUTED automatically. Do NOT include code blocks or markdown formatting in the python_code field. - ## CRITICAL GUIDELINES: -1. **KEEP CODE SIMPLE** - Maximum 6 lines of code -2. **ANSWER THE SPECIFIC QUESTION** - Don't create comprehensive analysis -3. **PRINT CLEAR INSIGHTS** - Human-readable output, not raw data -4. **NO COMPLEX SCRIPTS** - No functions, classes, or advanced features -5. **SIMPLE VARIABLES** - Use df, result, avg, etc. -6. **DIRECT APPROACH** - Load data, analyze, print result -7. **NO SPECIAL CHARACTERS** - Avoid β‰₯, ≀, β†’, etc. Use standard ASCII -8. **SIMPLE LOGIC** - No complex conditionals or loops -9. **NO FUNCTIONS** - Write code directly, not inside functions -10. **NO DOCSTRINGS** - No complex documentation - -## CRITICAL Code Requirements: -- The CSV data is available as a file at the path provided in the context -- ALWAYS use `pd.read_csv('file_path')` to load the data from the file path -- The file path will be provided in the context -- For graphs/charts, save them to `/tmp/querypls_session_csv_analysis_temp/` folder with descriptive names -- Use `plt.savefig('/tmp/querypls_session_csv_analysis_temp/chart_name.png')` before `plt.show()` -- Print insights in a human-readable format with clear explanations -- Don't create complex functions or classes - keep it simple and direct -- ONLY use these libraries: pandas, numpy, matplotlib.pyplot, seaborn -- Write clean, simple code without syntax errors -- Use proper variable names and avoid special characters -- For charts: use simple matplotlib code, save to specific temp folder, then show -- Keep each line simple and avoid complex expressions - -## Example Response +1. **KEEP CODE SUPER SIMPLE** - Maximum 5 lines of code +2. **NO FUNCTIONS OR CLASSES** - Write direct code only +3. **PRINT HUMAN-READABLE RESULTS** - Use print() with clear formatting +4. **ANSWER SPECIFIC QUESTION ONLY** - Don't do comprehensive analysis +5. **USE SIMPLE VARIABLES** - df, avg, count, total, etc. +6. **NO TECHNICAL JARGON** - Speak like talking to a person + +## Code Requirements: +- Use `pd.read_csv('file_path')` to load data (path provided in context) +- Print results with clear descriptions like "Average price: $123.45" +- For charts: save to `/tmp/querypls_session_csv_analysis_temp/chart.png` +- Use only: pandas, matplotlib.pyplot (as plt), numpy +- Keep each line simple and readable +- NO error handling functions - keep it basic + +## Example Responses: + +### For "average price": { - "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\navg = df['salary'].mean()\\nprint(f'Average salary: ${avg:,.2f}')", - "explanation": "Loads CSV data and calculates the average salary in a readable format", - "expected_output": "Average salary: $60,000.00", + "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/data.csv')\\navg = df['price'].mean()\\nprint(f'Average price: ${avg:,.2f}')", + "explanation": "Calculates and displays the average price", + "expected_output": "Average price: $1,234.56", "libraries_used": ["pandas"] } -## Chart Example Response +### For "show top 5 products": { - "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\nprint('Department counts:')\\nprint(df['department'].value_counts())", - "explanation": "Creates a bar chart of department distribution and saves it to specific temp folder", - "expected_output": "Bar chart visualization and department counts", - "libraries_used": ["pandas", "matplotlib.pyplot"] + "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/data.csv')\\ntop5 = df.nlargest(5, 'price')\\nprint('Top 5 most expensive products:')\\nprint(top5[['name', 'price']].to_string(index=False))", + "explanation": "Shows the 5 most expensive products", + "expected_output": "Top 5 most expensive products with names and prices", + "libraries_used": ["pandas"] } -## Graph Example Response +### For "create chart": { - "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\nplt.figure(figsize=(8, 6))\\ndf['department'].value_counts().plot(kind='bar')\\nplt.title('Department Distribution')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/department_chart.png')\\nplt.show()\\nprint('Department distribution:')\\nprint(df['department'].value_counts())", - "explanation": "Creates a simple bar chart showing department distribution", - "expected_output": "Bar chart and department counts", + "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\ndf = pd.read_csv('/tmp/data.csv')\\ndf['category'].value_counts().plot(kind='bar')\\nplt.title('Product Categories')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/chart.png')\\nplt.show()\\nprint(f'Created chart showing {len(df[\"category\"].unique())} categories')", + "explanation": "Creates a bar chart of product categories", + "expected_output": "Bar chart and category count message", "libraries_used": ["pandas", "matplotlib.pyplot"] } -## Price Analysis Example Response -{ - "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/querypls_session_xxx/data.csv')\\navg_price = df['price'].mean()\\nprint(f'Average price: ${avg_price:,.2f}')", - "explanation": "Loads CSV data and calculates the average price in a readable format", - "expected_output": "Average price: $1,234.56", - "libraries_used": ["pandas"] -} - -## Important Notes: -- Use double backslashes for newlines in the python_code field -- ALWAYS use `pd.read_csv('file_path')` to load CSV data from the file path provided in context -- Save charts to `/tmp/querypls_session_csv_analysis_temp/` folder -- Keep the explanation concise -- Make sure the JSON is valid and properly formatted -- The file path will be provided in the context -- Write simple, clean code without complex functions or classes -- Focus on printing clear insights directly -- Avoid syntax errors and special characters -- NEVER use line continuation characters (\\) in the code -- Keep each line complete and self-contained -- Use simple string formatting with f-strings -- **MAXIMUM 10 LINES OF CODE** - Keep it simple! -- **NO COMPREHENSIVE ANALYSIS** - Just answer the specific question -- **DO NOT include ```python or ``` in the python_code field** -- **The code will be executed automatically - just provide the raw Python code** +## IMPORTANT RULES: +- **NO FUNCTIONS** - Write code directly, not inside functions +- **NO COMPLEX LOGIC** - Keep it simple and straightforward +- **HUMAN-READABLE OUTPUT** - Print clear, conversational results +- **ANSWER THE QUESTION** - Don't add extra analysis +- **USE f-strings** - For clear formatting like f'Total: {total}' +- **MAXIMUM 5 LINES** - Keep it super simple +- Use double backslashes (\\n) for newlines in JSON +- The code will be executed automatically +- Focus on answering the specific user question only Respond only with the JSON object.""" From fb5cd57eeff12fb3473e8430399a3b496fd3d8d6 Mon Sep 17 00:00:00 2001 From: samadpls Date: Fri, 8 Aug 2025 01:10:10 +0500 Subject: [PATCH 8/8] feat: Implement keyword-based routing fallback and enhance error handling in routing service --- src/services/routing_service.py | 79 ++++++++++++++++++-- test_routing_fix.py | 126 ++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 8 deletions(-) create mode 100644 test_routing_fix.py diff --git a/src/services/routing_service.py b/src/services/routing_service.py index 64164d5..40ca5a0 100644 --- a/src/services/routing_service.py +++ b/src/services/routing_service.py @@ -75,8 +75,9 @@ def determine_agent( return result.output except Exception as e: - # Fallback to simple keyword-based routing - return self._fallback_routing(user_query, csv_loaded) + print(f"Routing failed with error: {e}") + # Use simple keyword-based routing as fallback + return self._keyword_based_routing(user_query, csv_loaded) def handle_conversation_query(self, user_query: str) -> str: """Handle conversational queries.""" @@ -131,7 +132,7 @@ def handle_csv_query( except Exception as e: # If LLM fails, provide a graceful response without showing errors - return WORST_CASE_SCENARIO + return WORST_CASE_SCENARIO def _execute_csv_analysis( self, python_code: str, csv_info: Dict[str, Any], explanation: str @@ -224,15 +225,15 @@ def install_package(package): current_code = fixed_code continue # Try again with fixed code - return WORST_CASE_SCENARIO + return WORST_CASE_SCENARIO except Exception as e: - return WORST_CASE_SCENARIO + return WORST_CASE_SCENARIO except Exception as e: - return WORST_CASE_SCENARIO + return WORST_CASE_SCENARIO - return WORST_CASE_SCENARIO + return WORST_CASE_SCENARIO def _fix_python_code( self, original_code: str, error_message: str, csv_info: Dict[str, Any] @@ -375,6 +376,68 @@ def _format_sql_response(self, sql_response) -> str: return "\n\n".join(response_parts) + def _keyword_based_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecision: + """Keyword-based routing when LLM routing fails.""" + query_lower = user_query.lower() + + # CSV Agent keywords + csv_keywords = [ + "csv", "analyze", "chart", "plot", "graph", "average", "mean", "sum", + "count", "max", "min", "statistics", "data", "visualization", "top", + "bottom", "highest", "lowest", "distribution", "correlation" + ] + + # SQL Agent keywords + sql_keywords = [ + "select", "insert", "update", "delete", "sql", "query", "table", + "database", "users", "customers", "orders", "products", "where", + "join", "group by", "order by", "from" + ] + + # Conversation Agent keywords + conversation_keywords = [ + "hello", "hi", "hey", "how are you", "what can you do", "help", + "thanks", "thank you", "goodbye", "bye", "good morning", "good evening" + ] + + # Check for CSV analysis (prioritize if CSV is loaded) + if csv_loaded and any(keyword in query_lower for keyword in csv_keywords): + return RoutingDecision( + agent="CSV_AGENT", + confidence=0.8, + reasoning="Keyword-based routing detected CSV analysis request" + ) + + # Check for SQL keywords + if any(keyword in query_lower for keyword in sql_keywords): + return RoutingDecision( + agent="SQL_AGENT", + confidence=0.8, + reasoning="Keyword-based routing detected SQL request" + ) + + # Check for conversation keywords + if any(keyword in query_lower for keyword in conversation_keywords): + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.9, + reasoning="Keyword-based routing detected conversation request" + ) + + # Default based on context + if csv_loaded: + return RoutingDecision( + agent="CSV_AGENT", + confidence=0.6, + reasoning="CSV loaded, defaulting to CSV analysis" + ) + else: + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.5, + reasoning="No clear intent detected, defaulting to conversation" + ) + def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecision: """Fallback routing when LLM routing fails - let LLM decide, not hardcoded keywords.""" # Default to conversation - let the LLM handle all decisions @@ -386,4 +449,4 @@ def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecisio def _get_fallback_conversation_response(self, user_query: str) -> str: """Get fallback conversation response when LLM fails.""" - return WORST_CASE_SCENARIO + return WORST_CASE_SCENARIO diff --git a/test_routing_fix.py b/test_routing_fix.py new file mode 100644 index 0000000..4c425ae --- /dev/null +++ b/test_routing_fix.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +"""Test script to verify the routing fix is working.""" + +import sys +import os +import tempfile +import pandas as pd + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from src.services.routing_service import IntelligentRoutingService +from src.backend.orchestrator import BackendOrchestrator +from src.schemas.requests import NewChatRequest + + +def test_routing_only(): + """Test just the routing mechanism.""" + print("🧠 Testing Routing Mechanism") + print("=" * 40) + + routing_service = IntelligentRoutingService() + + test_cases = [ + ("Hello", False, "CONVERSATION_AGENT"), + ("What is the average salary?", True, "CSV_AGENT"), + ("Show me all users", False, "SQL_AGENT"), + ("Create a chart", True, "CSV_AGENT"), + ("SELECT * FROM users", False, "SQL_AGENT"), + ] + + for query, csv_loaded, expected in test_cases: + print(f"\nQuery: '{query}' (CSV loaded: {csv_loaded})") + decision = routing_service.determine_agent(query, [], csv_loaded=csv_loaded) + print(f"Expected: {expected}") + print(f"Actual: {decision.agent}") + print(f"Confidence: {decision.confidence}") + print(f"Reasoning: {decision.reasoning}") + + status = "βœ… PASS" if decision.agent == expected else "❌ FAIL" + print(f"Status: {status}") + + +def test_csv_analysis_with_real_data(): + """Test CSV analysis with actual CSV data.""" + print("\nπŸ“Š Testing CSV Analysis with Real Data") + print("=" * 40) + + # Create a temporary CSV file + data = { + 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], + 'age': [25, 30, 35, 28, 32], + 'salary': [50000, 60000, 70000, 55000, 65000], + 'department': ['IT', 'HR', 'IT', 'Finance', 'HR'] + } + + df = pd.DataFrame(data) + + # Create temporary CSV file + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: + df.to_csv(f.name, index=False) + csv_path = f.name + + print(f"Created test CSV: {csv_path}") + print("CSV Content:") + print(df.to_string(index=False)) + + # Create orchestrator and test CSV analysis + orchestrator = BackendOrchestrator() + + # Create session + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Session") + ) + session_id = session_info.session_id + print(f"\nCreated session: {session_id}") + + # Load CSV data + with open(csv_path, 'r') as f: + csv_content = f.read() + + result = orchestrator.load_csv_data(session_id, csv_content) + print(f"CSV Load Result: {result['status']}") + + # Test CSV analysis queries + test_queries = [ + "What is the average salary?", + "How many people are in each department?", + "Who has the highest salary?", + ] + + for query in test_queries: + print(f"\n--- Testing Query: '{query}' ---") + try: + response = orchestrator.generate_intelligent_response(session_id, query) + print(f"Response: {response.content}") + print(f"Response Type: {type(response.content)}") + + # Check if this is raw Python code (the old problem) + if "import" in response.content or "pd.read_csv" in response.content: + print("❌ ISSUE: Response contains raw Python code!") + else: + print("βœ… SUCCESS: Response is clean human-readable text!") + + except Exception as e: + print(f"❌ ERROR: {str(e)}") + + # Cleanup + os.unlink(csv_path) + + +def main(): + """Run all tests.""" + print("πŸš€ Testing Routing Fix") + print("=" * 50) + + # Test 1: Routing mechanism + test_routing_only() + + # Test 2: CSV analysis with real data + test_csv_analysis_with_real_data() + + print("\nπŸŽ‰ Testing completed!") + + +if __name__ == "__main__": + main()