diff --git a/apis/utils/supported_functions_in_all_dialects.json b/apis/utils/supported_functions_in_all_dialects.json index 353e4ac7c8..d46d5af1ff 100644 --- a/apis/utils/supported_functions_in_all_dialects.json +++ b/apis/utils/supported_functions_in_all_dialects.json @@ -787,7 +787,7 @@ "LAST_DAY_OF_MONTH", "FORMAT_DATETIME", "COUNT_IF", - "ARRAY_INTERSECT" + "ARRAY_INTERSECT", "WIDTH_BUCKET", "RAND", "CORR", diff --git a/sqlglot/dialects/e6.py b/sqlglot/dialects/e6.py index 81a618509a..a77a47de59 100644 --- a/sqlglot/dialects/e6.py +++ b/sqlglot/dialects/e6.py @@ -1499,7 +1499,7 @@ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: ), "TRUNC": date_trunc_to_time, "TRIM": lambda self: self._parse_trim(), - "TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)), + "TYPEOF": lambda args: exp.Typeof(this=seq_get(args, 0)), "UNNEST": lambda args: exp.Explode(this=seq_get(args, 0)), # TODO:: I have removed the _parse_unnest_sql, was it really required # It was added due to some requirements before but those were asked to remove afterwards so it should not matter now @@ -1762,15 +1762,24 @@ def _last_day_sql(self: E6.Generator, expression: exp.LastDay) -> str: def extract_sql(self: E6.Generator, expression: exp.Extract | exp.DayOfYear) -> str: unit = expression.this.name unit_mapped = E6.UNIT_PART_MAPPING.get(f"'{unit.lower()}'", unit) - expression_sql = self.sql(expression, "expression") + date_expr = ( + expression.expression if isinstance(expression, exp.Extract) else expression.this + ) + if isinstance(expression, exp.DayOfYear): unit_mapped = "DOY" - date_expr = expression.this if isinstance( date_expr, (exp.TsOrDsToDate, exp.TsOrDsToTimestamp, exp.TsOrDsToTime) ): date_expr = exp.Cast(this=date_expr.this, to=exp.DataType(this="TIMESTAMP")) expression_sql = self.sql(date_expr) + if isinstance(expression, exp.Extract): + # For regular Extract operations, cast string literals and columns to TIMESTAMP for robust E6 syntax + if (isinstance(date_expr, exp.Literal) and date_expr.is_string) or isinstance( + date_expr, exp.Column + ): + date_expr = exp.Cast(this=date_expr, to=exp.DataType(this="DATE")) + expression_sql = self.sql(date_expr) extract_str = f"EXTRACT({unit_mapped} FROM {expression_sql})" return extract_str @@ -2209,7 +2218,7 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit): exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), exp.Array: array_sql, - exp.TypeOf: rename_func("TYPEOF"), + exp.Typeof: rename_func("TYPEOF"), exp.ArrayAgg: rename_func("ARRAY_AGG"), exp.ArrayConcat: rename_func("ARRAY_CONCAT"), exp.ArrayIntersect: rename_func("ARRAY_INTERSECT"), @@ -2295,6 +2304,9 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit): exp.Mod: lambda self, e: self.func("MOD", e.this, e.expression), exp.Nullif: rename_func("NULLIF"), exp.Pow: rename_func("POWER"), + exp.Quarter: lambda self, e: self.extract_sql( + exp.Extract(this=exp.Var(this="QUARTER"), expression=e.this) + ), exp.RegexpExtract: rename_func("REGEXP_EXTRACT"), exp.RegexpLike: lambda self, e: self.func("REGEXP_LIKE", e.this, e.expression), # here I handled replacement arg carefully because, sometimes if replacement arg is not provided/extracted then it is getting None there overriding in E6 diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 507ec28e90..ecabbc32fb 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -135,7 +135,7 @@ class Parser(Spark2.Parser): "DATEDIFF": _build_datediff, "DATE_DIFF": _build_datediff, "LISTAGG": exp.GroupConcat.from_arg_list, - "TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)), + "TYPEOF": lambda args: exp.Typeof(this=seq_get(args, 0)), "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime( diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a6c705ecd1..de986f3b14 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5546,10 +5546,6 @@ class ToArray(Func): pass -class TypeOf(Func): - arg_types = {"this": True} - - # https://materialize.com/docs/sql/types/list/ class List(Func): arg_types = {"expressions": False} diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 8a83023056..79b244d61c 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -945,6 +945,7 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: joins_ons[left_join_table.pop()].append(cond) + old_joins = {join.alias_or_name: join for join in joins} new_joins = {} query_from = query.args["from"] @@ -972,6 +973,7 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: len(only_old_joins) >= 1 ), "Cannot determine which table to use in the new FROM clause" + new_from_name = list(only_old_joins)[0] query.set("from", exp.From(this=old_joins[new_from_name].this)) @@ -1041,4 +1043,4 @@ def _inline_inherited_window(window: exp.Expression) -> None: for window in find_all_in_scope(expression, exp.Window): _inline_inherited_window(window) - return expression + return expression \ No newline at end of file diff --git a/tests/dialects/test_e6.py b/tests/dialects/test_e6.py index 4fb93e2703..3cd485d60f 100644 --- a/tests/dialects/test_e6.py +++ b/tests/dialects/test_e6.py @@ -580,10 +580,10 @@ def test_E6(self): ) self.validate_all( - "SELECT EXTRACT(fieldStr FROM date_expr)", + "SELECT EXTRACT(FIELDSTR FROM CAST(date_expr AS DATE))", read={ - "databricks": "SELECT DATE_PART(fieldStr, date_expr)", - "e6": "SELECT DATEPART(fieldStr, date_expr)", + "databricks": "SELECT DATE_PART(FIELDSTR, date_expr)", + "e6": "SELECT EXTRACT(FIELDSTR FROM CAST(date_expr AS DATE))", }, ) @@ -593,6 +593,18 @@ def test_E6(self): write={"databricks": "SELECT NOT A IS NULL"}, ) + self.validate_all( + "SELECT EXTRACT(QUARTER FROM CAST('2016-08-31' AS DATE))", + read={"databricks": "SELECT QUARTER('2016-08-31')"}, + ) + + self.validate_all( + "SELECT MIN(DATE) AS C1 FROM (SELECT DATE FROM cdr_adhoc_analysis.default.gr_3p_demand_ix_revenue WHERE (((mappedflag = 'mapped' AND parent_advertiser_name_clean = 'toronto-dominion bank (td bank group)') AND seller_defined = 'yes') AND COALESCE(YEAR(TO_DATE(DATE)), 0) = 2025) AND COALESCE(EXTRACT(QUARTER FROM CAST(DATE AS DATE)), 0) = 3) AS ITBL", + read={ + "databricks": "SELECT MIN(DATE) AS C1 FROM (SELECT DATE FROM cdr_adhoc_analysis.default.gr_3p_demand_ix_revenue WHERE (((mappedflag = 'mapped' AND parent_advertiser_name_clean = 'toronto-dominion bank (td bank group)') AND seller_defined = 'yes') AND COALESCE(YEAR(DATE), 0) = 2025) AND COALESCE(QUARTER(DATE), 0) = 3) AS ITBL" + }, + ) + self.validate_all( "SELECT A IS NULL", read={"databricks": "SELECT ISNULL(A)"}, @@ -1977,14 +1989,14 @@ def test_statistical_funcs(self): "databricks": "SELECT percentile_cont(0.50) WITHIN GROUP (ORDER BY col) FROM VALUES (0), (6), (6), (7), (9), (10) AS tab(col)" }, ) - + # Additional STDDEV tests from multiple dialects self.validate_all( "SELECT STDDEV(col) FROM (VALUES (1), (2), (3), (4)) AS tab(col)", read={ "databricks": "SELECT stddev(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", "snowflake": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", - "postgres": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)" + "postgres": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", }, ) self.validate_all( @@ -1992,26 +2004,26 @@ def test_statistical_funcs(self): read={ "databricks": "SELECT stddev_samp(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", "snowflake": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", - "postgres": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)" + "postgres": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", }, ) - + # COVAR_SAMP tests from multiple dialects self.validate_all( "SELECT COVAR_SAMP(x, y) FROM (VALUES (1, 10), (2, 20), (3, 30)) AS tab(x, y)", read={ "databricks": "SELECT covar_samp(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)", "snowflake": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)", - "postgres": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)" + "postgres": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)", }, ) - + # VARIANCE_SAMP tests from multiple dialects self.validate_all( "SELECT VARIANCE_SAMP(col) FROM (VALUES (1), (2), (3), (4), (5)) AS tab(col)", read={ "databricks": "SELECT variance_samp(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)", - "snowflake": "SELECT VARIANCE_SAMP(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)" + "snowflake": "SELECT VARIANCE_SAMP(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)", }, ) self.validate_all( @@ -2020,13 +2032,13 @@ def test_statistical_funcs(self): "databricks": "SELECT variance_samp(DISTINCT col) FROM VALUES (1), (2), (2), (3), (3), (3) AS tab(col)" }, ) - + # VAR_SAMP tests from multiple dialects self.validate_all( "SELECT VAR_SAMP(col) FROM (VALUES (1), (2), (3), (4)) AS tab(col)", read={ "databricks": "SELECT var_samp(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", - "snowflake": "SELECT VAR_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)" + "snowflake": "SELECT VAR_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", }, ) @@ -2350,7 +2362,7 @@ def test_group_by_all(self): "SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL", read={ "databricks": "SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL" - } + }, ) # GROUP BY ALL with CTE @@ -2358,7 +2370,7 @@ def test_group_by_all(self): """WITH products AS (SELECT 'Electronics' AS category, 'BrandA' AS brand, 100 AS price UNION ALL SELECT 'Electronics' AS category, 'BrandA' AS brand, 150 AS price) SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL""", read={ "databricks": """WITH products AS (SELECT 'Electronics' AS category, 'BrandA' AS brand, 100 AS price UNION ALL SELECT 'Electronics' AS category, 'BrandA' AS brand, 150 AS price) SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL""" - } + }, ) # GROUP BY ALL with ORDER BY @@ -2366,7 +2378,7 @@ def test_group_by_all(self): "SELECT department, COUNT(*) AS employee_count FROM employees GROUP BY ALL ORDER BY employee_count DESC", read={ "databricks": "SELECT department, COUNT(*) AS employee_count FROM employees GROUP BY ALL ORDER BY employee_count DESC" - } + }, ) # GROUP BY ALL with HAVING clause @@ -2374,7 +2386,7 @@ def test_group_by_all(self): "SELECT region, SUM(sales) AS total_sales FROM sales_data GROUP BY ALL HAVING SUM(sales) > 1000", read={ "databricks": "SELECT region, SUM(sales) AS total_sales FROM sales_data GROUP BY ALL HAVING SUM(sales) > 1000" - } + }, ) # GROUP BY ALL with multiple aggregations @@ -2382,7 +2394,7 @@ def test_group_by_all(self): "SELECT product_category, COUNT(*) AS item_count, AVG(price) AS avg_price, MAX(price) AS max_price FROM inventory GROUP BY ALL", read={ "databricks": "SELECT product_category, COUNT(*) AS item_count, AVG(price) AS avg_price, MAX(price) AS max_price FROM inventory GROUP BY ALL" - } + }, ) def test_keywords(self):