forked from hackintoshrao/sqlglot
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathfinal_testing.py
More file actions
370 lines (334 loc) · 14.7 KB
/
final_testing.py
File metadata and controls
370 lines (334 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
from typing import List, Dict, Any, Optional
import sqlglot
from sqlglot.expressions import (
Expression,
Select,
Column,
Table,
Alias,
With,
CTE,
Join,
Limit,
Literal,
Star,
Order,
Ordered,
And,
Or,
EQ,
GT,
LTE,
GTE,
Paren,
Subquery,
Exists,
Window,
)
from collections import defaultdict
import logging
# Configure logging for debugging purposes
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
def extract_sql_components_per_table_with_alias(
expressions: List[Expression],
) -> List[Dict[str, Any]]:
"""
Extracts SQL components (tables, columns, where_columns, limits) from parsed SQL expressions,
excluding derived tables and CTEs, while associating LIMIT clauses with the specific tables
involved in their respective SELECT statements.
Args:
expressions (List[Expression]): Parsed SQL expressions from sqlglot.parse().
Returns:
List[Dict[str, Any]]: A list of dictionaries, each representing a table with its associated columns,
where_columns, and limits.
"""
components = []
cte_names = set()
# Helper function to find or create a table entry
def get_or_create_table_entry(table_name: str) -> Dict[str, Any]:
# Use case-insensitive comparison for table names
table_entry = next(
(item for item in components if item["table"].lower() == table_name.lower()), None
)
if not table_entry:
table_entry = {"table": table_name, "columns": [], "where_columns": [], "limits": []}
components.append(table_entry)
return table_entry
# Recursive function to process SELECT nodes
def process_select(select_node: Select, parent_alias_mapping: Dict[str, str]):
"""
Processes a SELECT node, extracts tables, columns, where_columns, and limits.
Args:
select_node (Select): The SELECT node to process.
parent_alias_mapping (Dict[str, str]): The alias mapping from the parent scope.
"""
# Create a new alias mapping for the current scope, starting with parent mappings
alias_mapping = parent_alias_mapping.copy()
# Collect tables from FROM clause
current_select_tables = set()
from_clause = select_node.args.get("from")
if from_clause:
for table in from_clause.find_all(Table):
# Extract table name correctly
table_name = table.this.name if table.this and hasattr(table.this, "name") else None
if not table_name:
continue
# Skip if the table is a CTE
if table_name.lower() in cte_names:
continue
# Extract alias if present
alias = None
if isinstance(table.parent, Alias):
alias = table.parent.alias
elif table.alias:
alias = table.alias
if alias:
alias_mapping[alias] = table_name
if table_name:
current_select_tables.add(table_name)
get_or_create_table_entry(table_name)
# Collect tables from JOIN clauses
for join in select_node.find_all(Join):
joined_table = join.this
if isinstance(joined_table, Table):
# Extract table name correctly
table_name = (
joined_table.this.name
if joined_table.this and hasattr(joined_table.this, "name")
else None
)
if not table_name:
continue
# Skip if the table is a CTE
if table_name.lower() in cte_names:
continue
# Extract alias if present
alias = None
if isinstance(joined_table.parent, Alias):
alias = joined_table.parent.alias
elif joined_table.alias:
alias = joined_table.alias
if alias:
alias_mapping[alias] = table_name
if table_name:
current_select_tables.add(table_name)
get_or_create_table_entry(table_name)
# Extract columns from SELECT expressions
for expr_col in select_node.expressions:
# Perform a deep walk on each SELECT expression to find Columns and Stars
for node in expr_col.walk():
if isinstance(node, Column):
column_name = node.name
table_alias = node.table
if table_alias:
actual_table = alias_mapping.get(table_alias, table_alias)
table_entry = next(
(
item
for item in components
if item["table"].lower() == actual_table.lower()
),
None,
)
if table_entry and column_name:
if column_name not in table_entry["columns"]:
table_entry["columns"].append(column_name)
else:
logger.warning(
f"Column '{column_name}' has alias '{table_alias}' which does not match any table."
)
else:
if current_select_tables:
for table in current_select_tables:
table_entry = next(
(
item
for item in components
if item["table"].lower() == table.lower()
),
None,
)
if table_entry and column_name:
if column_name not in table_entry["columns"]:
table_entry["columns"].append(column_name)
else:
logger.warning(
f"Column '{column_name}' has no table alias and no tables found in SELECT."
)
elif isinstance(node, Star):
# Handle wildcard '*' and 'table_alias.*'
table_alias = None
if isinstance(node.parent, Table):
# Case: SELECT table_alias.*
table_alias = node.parent.alias_or_name
elif isinstance(node.parent, Alias):
# Case: SELECT table_alias.* AS alias
table_alias = node.parent.alias_or_name
elif hasattr(node.parent, "alias_or_name"):
# General case: check if parent has alias_or_name
table_alias = node.parent.alias_or_name
if table_alias:
actual_table = alias_mapping.get(table_alias, table_alias)
table_entry = next(
(
item
for item in components
if item["table"].lower() == actual_table.lower()
),
None,
)
if table_entry:
if "*" not in table_entry["columns"]:
table_entry["columns"].append("*")
else:
# Unqualified '*', associate with all current SELECT tables
if current_select_tables:
for table in current_select_tables:
table_entry = next(
(
item
for item in components
if item["table"].lower() == table.lower()
),
None,
)
if table_entry:
if "*" not in table_entry["columns"]:
table_entry["columns"].append("*")
else:
logger.warning(
"Unqualified '*' found but no tables are associated with the current SELECT."
)
# Extract WHERE columns
where_clause = select_node.args.get("where")
if where_clause:
for condition in where_clause.find_all(Column):
column_name = condition.name
table_alias = condition.table
if table_alias:
actual_table = alias_mapping.get(table_alias, table_alias)
table_entry = next(
(
item
for item in components
if item["table"].lower() == actual_table.lower()
),
None,
)
if table_entry and column_name:
if column_name not in table_entry["where_columns"]:
table_entry["where_columns"].append(column_name)
else:
logger.warning(
f"WHERE condition column '{column_name}' has alias '{table_alias}' which does not match any table."
)
else:
if current_select_tables:
for table in current_select_tables:
table_entry = next(
(
item
for item in components
if item["table"].lower() == table.lower()
),
None,
)
if table_entry and column_name:
if column_name not in table_entry["where_columns"]:
table_entry["where_columns"].append(column_name)
else:
logger.warning(
f"WHERE condition column '{column_name}' has no table alias and no tables found in SELECT."
)
# Extract LIMIT clauses, associating with current SELECT's tables
limit_node = select_node.args.get("limit")
if limit_node:
limit_value = limit_node.this
if isinstance(limit_value, Literal):
try:
limit_num = int(limit_value.this)
except ValueError:
limit_num = limit_value.this # Keep as is if not an integer
if current_select_tables:
for table in current_select_tables:
table_entry = next(
(item for item in components if item["table"].lower() == table.lower()),
None,
)
if table_entry:
if limit_num not in table_entry["limits"]:
table_entry["limits"].append(limit_num)
else:
logger.warning(
f"LIMIT '{limit_num}' found but no tables are associated with the current SELECT."
)
else:
logger.warning(f"LIMIT value is not a Literal: {limit_value}")
# Process nested SELECTs (e.g., subqueries in WHERE)
for subquery in select_node.find_all(Subquery):
# The subquery has its own SELECT node
sub_select = subquery.this
if isinstance(sub_select, Select):
process_select(sub_select, alias_mapping)
# Process nested SELECTs in EXISTS or other constructs
for exists in select_node.find_all(Exists):
exists_subquery = exists.this
if isinstance(exists_subquery, Select):
process_select(exists_subquery, alias_mapping)
# First Pass: Collect all CTE names
for expr in expressions:
with_clause = expr.args.get("with")
if with_clause:
for cte in with_clause.find_all(CTE):
cte_name = cte.alias_or_name
if cte_name:
cte_names.add(cte_name.lower()) # Use lowercase for consistent comparison
# Second Pass: Traverse and extract tables, excluding CTEs and Derived Tables
for expr in expressions:
# Iterate over all top-level Select nodes
for select_node in expr.find_all(Select):
# Determine if this SELECT is part of a CTE
is_cte = False
parent = select_node.parent
while parent:
if isinstance(parent, CTE):
is_cte = True
break
parent = parent.parent
# If it's a CTE's SELECT, process it to extract base tables but don't add the CTE itself
if is_cte:
# Initialize an empty alias mapping for CTE's internal scope
process_select(select_node, parent_alias_mapping={})
else:
# For main SELECT, initialize with an empty alias mapping
process_select(select_node, parent_alias_mapping={})
# Post-process to remove duplicates within each table entry
for entry in components:
entry["columns"] = sorted(list(set(entry["columns"])))
entry["where_columns"] = sorted(list(set(entry["where_columns"])))
entry["limits"] = sorted(list(set(entry["limits"])))
return components
if __name__ == "__main__":
sql = """
SELECT
e.employee_id,
e.full_name,
e.salary,
e.department_id
FROM employees e
WHERE e.salary > (
SELECT AVG(e2.salaries)
FROM employees e2
WHERE e2.department_id = e.department_id
);
"""
# Parse the SQL query
parsed = sqlglot.parse(sql, read="snowflake", error_level=None)
# Extract components per table with alias handling
components = extract_sql_components_per_table_with_alias(parsed)
# Display the result
from pprint import pprint
for c in components:
pprint(c)
print("\n")