Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apis/utils/supported_functions_in_all_dialects.json
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@
"LAST_DAY_OF_MONTH",
"FORMAT_DATETIME",
"COUNT_IF",
"ARRAY_INTERSECT"
"ARRAY_INTERSECT",
"WIDTH_BUCKET",
"RAND",
"CORR",
Expand Down
20 changes: 16 additions & 4 deletions sqlglot/dialects/e6.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition:
),
"TRUNC": date_trunc_to_time,
"TRIM": lambda self: self._parse_trim(),
"TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)),
"TYPEOF": lambda args: exp.Typeof(this=seq_get(args, 0)),
"UNNEST": lambda args: exp.Explode(this=seq_get(args, 0)),
# TODO:: I have removed the _parse_unnest_sql, was it really required
# It was added due to some requirements before but those were asked to remove afterwards so it should not matter now
Expand Down Expand Up @@ -1762,15 +1762,24 @@ def _last_day_sql(self: E6.Generator, expression: exp.LastDay) -> str:
def extract_sql(self: E6.Generator, expression: exp.Extract | exp.DayOfYear) -> str:
unit = expression.this.name
unit_mapped = E6.UNIT_PART_MAPPING.get(f"'{unit.lower()}'", unit)
expression_sql = self.sql(expression, "expression")
date_expr = (
expression.expression if isinstance(expression, exp.Extract) else expression.this
)

if isinstance(expression, exp.DayOfYear):
unit_mapped = "DOY"
date_expr = expression.this
if isinstance(
date_expr, (exp.TsOrDsToDate, exp.TsOrDsToTimestamp, exp.TsOrDsToTime)
):
date_expr = exp.Cast(this=date_expr.this, to=exp.DataType(this="TIMESTAMP"))
expression_sql = self.sql(date_expr)
if isinstance(expression, exp.Extract):
# For regular Extract operations, cast string literals and columns to TIMESTAMP for robust E6 syntax
if (isinstance(date_expr, exp.Literal) and date_expr.is_string) or isinstance(
date_expr, exp.Column
):
date_expr = exp.Cast(this=date_expr, to=exp.DataType(this="DATE"))
expression_sql = self.sql(date_expr)

extract_str = f"EXTRACT({unit_mapped} FROM {expression_sql})"
return extract_str
Expand Down Expand Up @@ -2209,7 +2218,7 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit):
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: array_sql,
exp.TypeOf: rename_func("TYPEOF"),
exp.Typeof: rename_func("TYPEOF"),
exp.ArrayAgg: rename_func("ARRAY_AGG"),
exp.ArrayConcat: rename_func("ARRAY_CONCAT"),
exp.ArrayIntersect: rename_func("ARRAY_INTERSECT"),
Expand Down Expand Up @@ -2295,6 +2304,9 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit):
exp.Mod: lambda self, e: self.func("MOD", e.this, e.expression),
exp.Nullif: rename_func("NULLIF"),
exp.Pow: rename_func("POWER"),
exp.Quarter: lambda self, e: self.extract_sql(
exp.Extract(this=exp.Var(this="QUARTER"), expression=e.this)
),
exp.RegexpExtract: rename_func("REGEXP_EXTRACT"),
exp.RegexpLike: lambda self, e: self.func("REGEXP_LIKE", e.this, e.expression),
# here I handled replacement arg carefully because, sometimes if replacement arg is not provided/extracted then it is getting None there overriding in E6
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class Parser(Spark2.Parser):
"DATEDIFF": _build_datediff,
"DATE_DIFF": _build_datediff,
"LISTAGG": exp.GroupConcat.from_arg_list,
"TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)),
"TYPEOF": lambda args: exp.Typeof(this=seq_get(args, 0)),
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
"TIMESTAMP_MILLIS": lambda args: exp.UnixToTime(
Expand Down
4 changes: 0 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5546,10 +5546,6 @@ class ToArray(Func):
pass


class TypeOf(Func):
arg_types = {"this": True}


# https://materialize.com/docs/sql/types/list/
class List(Func):
arg_types = {"expressions": False}
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:

joins_ons[left_join_table.pop()].append(cond)


old_joins = {join.alias_or_name: join for join in joins}
new_joins = {}
query_from = query.args["from"]
Expand Down Expand Up @@ -972,6 +973,7 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
len(only_old_joins) >= 1
), "Cannot determine which table to use in the new FROM clause"


new_from_name = list(only_old_joins)[0]
query.set("from", exp.From(this=old_joins[new_from_name].this))

Expand Down Expand Up @@ -1041,4 +1043,4 @@ def _inline_inherited_window(window: exp.Expression) -> None:
for window in find_all_in_scope(expression, exp.Window):
_inline_inherited_window(window)

return expression
return expression
46 changes: 29 additions & 17 deletions tests/dialects/test_e6.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ def test_E6(self):
)

self.validate_all(
"SELECT EXTRACT(fieldStr FROM date_expr)",
"SELECT EXTRACT(FIELDSTR FROM CAST(date_expr AS DATE))",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NiranjGaurav Where is QUARTER in the query?

read={
"databricks": "SELECT DATE_PART(fieldStr, date_expr)",
"e6": "SELECT DATEPART(fieldStr, date_expr)",
"databricks": "SELECT DATE_PART(FIELDSTR, date_expr)",
"e6": "SELECT EXTRACT(FIELDSTR FROM CAST(date_expr AS DATE))",
},
)

Expand All @@ -593,6 +593,18 @@ def test_E6(self):
write={"databricks": "SELECT NOT A IS NULL"},
)

self.validate_all(
"SELECT EXTRACT(QUARTER FROM CAST('2016-08-31' AS DATE))",
read={"databricks": "SELECT QUARTER('2016-08-31')"},
)

self.validate_all(
"SELECT MIN(DATE) AS C1 FROM (SELECT DATE FROM cdr_adhoc_analysis.default.gr_3p_demand_ix_revenue WHERE (((mappedflag = 'mapped' AND parent_advertiser_name_clean = 'toronto-dominion bank (td bank group)') AND seller_defined = 'yes') AND COALESCE(YEAR(TO_DATE(DATE)), 0) = 2025) AND COALESCE(EXTRACT(QUARTER FROM CAST(DATE AS DATE)), 0) = 3) AS ITBL",
read={
"databricks": "SELECT MIN(DATE) AS C1 FROM (SELECT DATE FROM cdr_adhoc_analysis.default.gr_3p_demand_ix_revenue WHERE (((mappedflag = 'mapped' AND parent_advertiser_name_clean = 'toronto-dominion bank (td bank group)') AND seller_defined = 'yes') AND COALESCE(YEAR(DATE), 0) = 2025) AND COALESCE(QUARTER(DATE), 0) = 3) AS ITBL"
},
)

self.validate_all(
"SELECT A IS NULL",
read={"databricks": "SELECT ISNULL(A)"},
Expand Down Expand Up @@ -1977,41 +1989,41 @@ def test_statistical_funcs(self):
"databricks": "SELECT percentile_cont(0.50) WITHIN GROUP (ORDER BY col) FROM VALUES (0), (6), (6), (7), (9), (10) AS tab(col)"
},
)

# Additional STDDEV tests from multiple dialects
self.validate_all(
"SELECT STDDEV(col) FROM (VALUES (1), (2), (3), (4)) AS tab(col)",
read={
"databricks": "SELECT stddev(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
"snowflake": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
"postgres": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)"
"postgres": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
},
)
self.validate_all(
"SELECT STDDEV_SAMP(col) FROM (VALUES (1), (2), (3), (4)) AS tab(col)",
read={
"databricks": "SELECT stddev_samp(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
"snowflake": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
"postgres": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)"
"postgres": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
},
)

# COVAR_SAMP tests from multiple dialects
self.validate_all(
"SELECT COVAR_SAMP(x, y) FROM (VALUES (1, 10), (2, 20), (3, 30)) AS tab(x, y)",
read={
"databricks": "SELECT covar_samp(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)",
"snowflake": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)",
"postgres": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)"
"postgres": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)",
},
)

# VARIANCE_SAMP tests from multiple dialects
self.validate_all(
"SELECT VARIANCE_SAMP(col) FROM (VALUES (1), (2), (3), (4), (5)) AS tab(col)",
read={
"databricks": "SELECT variance_samp(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)",
"snowflake": "SELECT VARIANCE_SAMP(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)"
"snowflake": "SELECT VARIANCE_SAMP(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)",
},
)
self.validate_all(
Expand All @@ -2020,13 +2032,13 @@ def test_statistical_funcs(self):
"databricks": "SELECT variance_samp(DISTINCT col) FROM VALUES (1), (2), (2), (3), (3), (3) AS tab(col)"
},
)

# VAR_SAMP tests from multiple dialects
self.validate_all(
"SELECT VAR_SAMP(col) FROM (VALUES (1), (2), (3), (4)) AS tab(col)",
read={
"databricks": "SELECT var_samp(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
"snowflake": "SELECT VAR_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)"
"snowflake": "SELECT VAR_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)",
},
)

Expand Down Expand Up @@ -2350,39 +2362,39 @@ def test_group_by_all(self):
"SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL",
read={
"databricks": "SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL"
}
},
)

# GROUP BY ALL with CTE
self.validate_all(
"""WITH products AS (SELECT 'Electronics' AS category, 'BrandA' AS brand, 100 AS price UNION ALL SELECT 'Electronics' AS category, 'BrandA' AS brand, 150 AS price) SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL""",
read={
"databricks": """WITH products AS (SELECT 'Electronics' AS category, 'BrandA' AS brand, 100 AS price UNION ALL SELECT 'Electronics' AS category, 'BrandA' AS brand, 150 AS price) SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL"""
}
},
)

# GROUP BY ALL with ORDER BY
self.validate_all(
"SELECT department, COUNT(*) AS employee_count FROM employees GROUP BY ALL ORDER BY employee_count DESC",
read={
"databricks": "SELECT department, COUNT(*) AS employee_count FROM employees GROUP BY ALL ORDER BY employee_count DESC"
}
},
)

# GROUP BY ALL with HAVING clause
self.validate_all(
"SELECT region, SUM(sales) AS total_sales FROM sales_data GROUP BY ALL HAVING SUM(sales) > 1000",
read={
"databricks": "SELECT region, SUM(sales) AS total_sales FROM sales_data GROUP BY ALL HAVING SUM(sales) > 1000"
}
},
)

# GROUP BY ALL with multiple aggregations
self.validate_all(
"SELECT product_category, COUNT(*) AS item_count, AVG(price) AS avg_price, MAX(price) AS max_price FROM inventory GROUP BY ALL",
read={
"databricks": "SELECT product_category, COUNT(*) AS item_count, AVG(price) AS avg_price, MAX(price) AS max_price FROM inventory GROUP BY ALL"
}
},
)

def test_keywords(self):
Expand Down