-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsql_ops.py
More file actions
113 lines (91 loc) · 4.29 KB
/
sql_ops.py
File metadata and controls
113 lines (91 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from app_state import app_state, Query
from utils import REMOVE_THINK_TAGS, IS_SAFE_QUERY
def generate_sql(user_request, max_attempts=3, groq_client=None):
"""
Generate a safe SQL query from user request. If an unsafe query is generated,
it will recursively try again up to max_attempts times.
Args:
user_request (str): The user's request for data
max_attempts (int): Maximum number of attempts to generate a safe query
groq_client: The Groq client instance to use for API calls
Returns:
str: A safe SQL query
Raises:
Exception: If unable to generate a safe query after max attempts
Exception: If clarification is needed (with the clarification question)
"""
app_state.set_response_mode("sql_generation")
app_state.add_to_chat_context("user", user_request)
if groq_client is None:
raise ValueError("groq_client must be provided")
attempt = 0
while attempt < max_attempts:
attempt += 1
# Generate the SQL query
response = groq_client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
*app_state.chat_context,
{"role": "user", "content": user_request},
],
temperature=0.5,
# max_tokens=4000, ### DO NOT remove this line.
)
resp = response.choices[0].message.content
if type(resp) == str:
# Check if clarification is needed
if resp.startswith("NEED_CLARIFICATION:"):
app_state.set_response_mode("info_clarification")
app_state.add_to_chat_context("assistant", resp)
raise Exception(resp[19:].strip()) # Remove the prefix and whitespace
sql = REMOVE_THINK_TAGS(resp)
else:
sql = resp
# Check if the query is safe and add to history
is_safe, error_message = IS_SAFE_QUERY(sql)
app_state.add_query_to_history(Query(sql=sql, is_safe=is_safe))
if is_safe:
app_state.add_to_chat_context("assistant", sql)
return sql
# If we're here, the query wasn't safe
print(f"Attempt {attempt}: Generated unsafe query. {error_message}")
if attempt == max_attempts:
error_msg = f"Failed to generate a safe query after {max_attempts} attempts. Last error: {error_message}"
app_state.add_to_chat_context("assistant", error_msg)
raise Exception(error_msg)
# Add safety requirement to the user request for next attempt
user_request = f"{user_request} (IMPORTANT: Generate only SELECT queries. ***DO NOT*** include any DDL operations like CREATE, DROP, ALTER, etc.)"
def analyze_query_output(query_output, groq_client=None):
app_state.set_response_mode("sql_analysis")
content = f"Here is the output of the SQL query: {query_output}"
app_state.add_to_chat_context("user", content)
if groq_client is None:
raise ValueError("groq_client must be provided")
if query_output is None or len(query_output) != 2:
return "No output from the query."
columns, results = query_output
if not columns or not results:
return "Query executed successfully but returned no data."
# Format the output in a more readable way
output = "\nQuery Results:\n" + "=" * 50 + "\n"
# Add column headers
header = " | ".join(str(col).upper() for col in columns)
output += header + "\n" + "=" * 50 + "\n"
# Add rows
for row in results:
row_str = " | ".join(str(val) if val is not None else "NULL" for val in row)
output += row_str + "\n"
output += "=" * 50 + "\n"
output += f"Total rows: {len(results)}\n"
# Perform actual analysis of the query output via LLM
completion = groq_client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
*app_state.chat_context,
{"role": "user", "content": f"Here is the output of the SQL query: {output}"},
],
)
analysis = completion.choices[0].message.content
# Return the analysis of the query output
app_state.add_to_chat_context("assistant", analysis)
return analysis