From ce037700581d913962d36e3347ebe5dd68e314ac Mon Sep 17 00:00:00 2001 From: NiranjGaurav Date: Mon, 4 Aug 2025 15:43:53 +0530 Subject: [PATCH 1/2] Mapped the QUARTER to EXTRACT and also ran make check. I also fixed the comma missed in the json file. --- .../supported_functions_in_all_dialects.json | 4 +- sqlglot/dialects/e6.py | 16 ++++++- sqlglot/transforms.py | 24 +++++----- tests/dialects/test_e6.py | 46 ++++++++++++------- 4 files changed, 57 insertions(+), 33 deletions(-) diff --git a/apis/utils/supported_functions_in_all_dialects.json b/apis/utils/supported_functions_in_all_dialects.json index 790e5e9597..f6cf662416 100644 --- a/apis/utils/supported_functions_in_all_dialects.json +++ b/apis/utils/supported_functions_in_all_dialects.json @@ -787,7 +787,7 @@ "LAST_DAY_OF_MONTH", "FORMAT_DATETIME", "COUNT_IF", - "ARRAY_INTERSECT" + "ARRAY_INTERSECT", "WIDTH_BUCKET", "RAND", "CORR", @@ -795,7 +795,7 @@ "COVAR_SAMP", "VARIANCE_SAMP", "VAR_SAMP", - "URL_DECODE",, + "URL_DECODE", "TYPEOF", "TIMEDIFF", "INTERVAL" diff --git a/sqlglot/dialects/e6.py b/sqlglot/dialects/e6.py index 97106487ed..f02e138c0f 100644 --- a/sqlglot/dialects/e6.py +++ b/sqlglot/dialects/e6.py @@ -1757,15 +1757,24 @@ def _last_day_sql(self: E6.Generator, expression: exp.LastDay) -> str: def extract_sql(self: E6.Generator, expression: exp.Extract | exp.DayOfYear) -> str: unit = expression.this.name unit_mapped = E6.UNIT_PART_MAPPING.get(f"'{unit.lower()}'", unit) - expression_sql = self.sql(expression, "expression") + date_expr = ( + expression.expression if isinstance(expression, exp.Extract) else expression.this + ) + if isinstance(expression, exp.DayOfYear): unit_mapped = "DOY" - date_expr = expression.this if isinstance( date_expr, (exp.TsOrDsToDate, exp.TsOrDsToTimestamp, exp.TsOrDsToTime) ): date_expr = exp.Cast(this=date_expr.this, to=exp.DataType(this="TIMESTAMP")) expression_sql = self.sql(date_expr) + if isinstance(expression, exp.Extract): + # For regular Extract operations, cast string literals and columns to TIMESTAMP for robust E6 syntax + if (isinstance(date_expr, exp.Literal) and date_expr.is_string) or isinstance( + date_expr, exp.Column + ): + date_expr = exp.Cast(this=date_expr, to=exp.DataType(this="DATE")) + expression_sql = self.sql(date_expr) extract_str = f"EXTRACT({unit_mapped} FROM {expression_sql})" return extract_str @@ -2282,6 +2291,9 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit): exp.Mod: lambda self, e: self.func("MOD", e.this, e.expression), exp.Nullif: rename_func("NULLIF"), exp.Pow: rename_func("POWER"), + exp.Quarter: lambda self, e: self.extract_sql( + exp.Extract(this=exp.Var(this="QUARTER"), expression=e.this) + ), exp.RegexpExtract: rename_func("REGEXP_EXTRACT"), exp.RegexpLike: lambda self, e: self.func("REGEXP_LIKE", e.this, e.expression), # here I handled replacement arg carefully because, sometimes if replacement arg is not provided/extracted then it is getting None there overriding in E6 diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index e4385e6750..6d7c5aebe8 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -888,9 +888,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: continue predicate = column.find_ancestor(exp.Predicate, exp.Select) - assert isinstance( - predicate, exp.Binary - ), "Columns can only be marked with (+) when involved in a binary operation" + assert isinstance(predicate, exp.Binary), ( + "Columns can only be marked with (+) when involved in a binary operation" + ) predicate_parent = predicate.parent join_predicate = predicate.pop() @@ -902,9 +902,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") ] - assert not ( - left_columns and right_columns - ), "The (+) marker cannot appear in both sides of a binary predicate" + assert not (left_columns and right_columns), ( + "The (+) marker cannot appear in both sides of a binary predicate" + ) marked_column_tables = set() for col in left_columns or right_columns: @@ -914,9 +914,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: col.set("join_mark", False) marked_column_tables.add(table) - assert ( - len(marked_column_tables) == 1 - ), "Columns of only a single table can be marked with (+) in a given binary predicate" + assert len(marked_column_tables) == 1, ( + "Columns of only a single table can be marked with (+) in a given binary predicate" + ) # Add predicate if join already copied, or add join if it is new join_this = old_joins.get(col.table, query_from).this @@ -938,9 +938,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: only_old_join_sources = old_joins.keys() - new_joins.keys() if query_from.alias_or_name in new_joins: - assert ( - len(only_old_join_sources) >= 1 - ), "Cannot determine which table to use in the new FROM clause" + assert len(only_old_join_sources) >= 1, ( + "Cannot determine which table to use in the new FROM clause" + ) new_from_name = list(only_old_join_sources)[0] query.set("from", exp.From(this=old_joins.pop(new_from_name).this)) diff --git a/tests/dialects/test_e6.py b/tests/dialects/test_e6.py index 2128ed16a1..fee9e28a05 100644 --- a/tests/dialects/test_e6.py +++ b/tests/dialects/test_e6.py @@ -580,10 +580,10 @@ def test_E6(self): ) self.validate_all( - "SELECT EXTRACT(fieldStr FROM date_expr)", + "SELECT EXTRACT(FIELDSTR FROM CAST(date_expr AS DATE))", read={ - "databricks": "SELECT DATE_PART(fieldStr, date_expr)", - "e6": "SELECT DATEPART(fieldStr, date_expr)", + "databricks": "SELECT DATE_PART(FIELDSTR, date_expr)", + "e6": "SELECT EXTRACT(FIELDSTR FROM CAST(date_expr AS DATE))", }, ) @@ -593,6 +593,18 @@ def test_E6(self): write={"databricks": "SELECT NOT A IS NULL"}, ) + self.validate_all( + "SELECT EXTRACT(QUARTER FROM CAST('2016-08-31' AS DATE))", + read={"databricks": "SELECT QUARTER('2016-08-31')"}, + ) + + self.validate_all( + "SELECT MIN(DATE) AS C1 FROM (SELECT DATE FROM cdr_adhoc_analysis.default.gr_3p_demand_ix_revenue WHERE (((mappedflag = 'mapped' AND parent_advertiser_name_clean = 'toronto-dominion bank (td bank group)') AND seller_defined = 'yes') AND COALESCE(YEAR(TO_DATE(DATE)), 0) = 2025) AND COALESCE(EXTRACT(QUARTER FROM CAST(DATE AS DATE)), 0) = 3) AS ITBL", + read={ + "databricks": "SELECT MIN(DATE) AS C1 FROM (SELECT DATE FROM cdr_adhoc_analysis.default.gr_3p_demand_ix_revenue WHERE (((mappedflag = 'mapped' AND parent_advertiser_name_clean = 'toronto-dominion bank (td bank group)') AND seller_defined = 'yes') AND COALESCE(YEAR(DATE), 0) = 2025) AND COALESCE(QUARTER(DATE), 0) = 3) AS ITBL" + }, + ) + self.validate_all( "SELECT A IS NULL", read={"databricks": "SELECT ISNULL(A)"}, @@ -1977,14 +1989,14 @@ def test_statistical_funcs(self): "databricks": "SELECT percentile_cont(0.50) WITHIN GROUP (ORDER BY col) FROM VALUES (0), (6), (6), (7), (9), (10) AS tab(col)" }, ) - + # Additional STDDEV tests from multiple dialects self.validate_all( "SELECT STDDEV(col) FROM (VALUES (1), (2), (3), (4)) AS tab(col)", read={ "databricks": "SELECT stddev(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", "snowflake": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", - "postgres": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)" + "postgres": "SELECT STDDEV(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", }, ) self.validate_all( @@ -1992,26 +2004,26 @@ def test_statistical_funcs(self): read={ "databricks": "SELECT stddev_samp(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", "snowflake": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", - "postgres": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)" + "postgres": "SELECT STDDEV_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", }, ) - + # COVAR_SAMP tests from multiple dialects self.validate_all( "SELECT COVAR_SAMP(x, y) FROM (VALUES (1, 10), (2, 20), (3, 30)) AS tab(x, y)", read={ "databricks": "SELECT covar_samp(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)", "snowflake": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)", - "postgres": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)" + "postgres": "SELECT COVAR_SAMP(x, y) FROM VALUES (1, 10), (2, 20), (3, 30) AS tab(x, y)", }, ) - + # VARIANCE_SAMP tests from multiple dialects self.validate_all( "SELECT VARIANCE_SAMP(col) FROM (VALUES (1), (2), (3), (4), (5)) AS tab(col)", read={ "databricks": "SELECT variance_samp(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)", - "snowflake": "SELECT VARIANCE_SAMP(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)" + "snowflake": "SELECT VARIANCE_SAMP(col) FROM VALUES (1), (2), (3), (4), (5) AS tab(col)", }, ) self.validate_all( @@ -2020,13 +2032,13 @@ def test_statistical_funcs(self): "databricks": "SELECT variance_samp(DISTINCT col) FROM VALUES (1), (2), (2), (3), (3), (3) AS tab(col)" }, ) - + # VAR_SAMP tests from multiple dialects self.validate_all( "SELECT VAR_SAMP(col) FROM (VALUES (1), (2), (3), (4)) AS tab(col)", read={ "databricks": "SELECT var_samp(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", - "snowflake": "SELECT VAR_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)" + "snowflake": "SELECT VAR_SAMP(col) FROM VALUES (1), (2), (3), (4) AS tab(col)", }, ) @@ -2350,7 +2362,7 @@ def test_group_by_all(self): "SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL", read={ "databricks": "SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL" - } + }, ) # GROUP BY ALL with CTE @@ -2358,7 +2370,7 @@ def test_group_by_all(self): """WITH products AS (SELECT 'Electronics' AS category, 'BrandA' AS brand, 100 AS price UNION ALL SELECT 'Electronics' AS category, 'BrandA' AS brand, 150 AS price) SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL""", read={ "databricks": """WITH products AS (SELECT 'Electronics' AS category, 'BrandA' AS brand, 100 AS price UNION ALL SELECT 'Electronics' AS category, 'BrandA' AS brand, 150 AS price) SELECT category, brand, AVG(price) AS average_price FROM products GROUP BY ALL""" - } + }, ) # GROUP BY ALL with ORDER BY @@ -2366,7 +2378,7 @@ def test_group_by_all(self): "SELECT department, COUNT(*) AS employee_count FROM employees GROUP BY ALL ORDER BY employee_count DESC", read={ "databricks": "SELECT department, COUNT(*) AS employee_count FROM employees GROUP BY ALL ORDER BY employee_count DESC" - } + }, ) # GROUP BY ALL with HAVING clause @@ -2374,7 +2386,7 @@ def test_group_by_all(self): "SELECT region, SUM(sales) AS total_sales FROM sales_data GROUP BY ALL HAVING SUM(sales) > 1000", read={ "databricks": "SELECT region, SUM(sales) AS total_sales FROM sales_data GROUP BY ALL HAVING SUM(sales) > 1000" - } + }, ) # GROUP BY ALL with multiple aggregations @@ -2382,7 +2394,7 @@ def test_group_by_all(self): "SELECT product_category, COUNT(*) AS item_count, AVG(price) AS avg_price, MAX(price) AS max_price FROM inventory GROUP BY ALL", read={ "databricks": "SELECT product_category, COUNT(*) AS item_count, AVG(price) AS avg_price, MAX(price) AS max_price FROM inventory GROUP BY ALL" - } + }, ) def test_keywords(self): From 42733ec802df6a851cd84c448998c389626a6bb0 Mon Sep 17 00:00:00 2001 From: NiranjGaurav Date: Mon, 4 Aug 2025 17:25:22 +0530 Subject: [PATCH 2/2] Fixed the TYPEOF issue we were getting after the rebase merge. there was some merge conflicts when rebase was merged that included some changes in the transforms.py. While solving merge conflicts i added some things on my branch that lead to error thats also sorted now. Along with this ran make check too. --- sqlglot/dialects/e6.py | 4 ++-- sqlglot/dialects/spark.py | 2 +- sqlglot/expressions.py | 4 ---- sqlglot/transforms.py | 48 +-------------------------------------- 4 files changed, 4 insertions(+), 54 deletions(-) diff --git a/sqlglot/dialects/e6.py b/sqlglot/dialects/e6.py index b4910a7658..a77a47de59 100644 --- a/sqlglot/dialects/e6.py +++ b/sqlglot/dialects/e6.py @@ -1499,7 +1499,7 @@ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: ), "TRUNC": date_trunc_to_time, "TRIM": lambda self: self._parse_trim(), - "TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)), + "TYPEOF": lambda args: exp.Typeof(this=seq_get(args, 0)), "UNNEST": lambda args: exp.Explode(this=seq_get(args, 0)), # TODO:: I have removed the _parse_unnest_sql, was it really required # It was added due to some requirements before but those were asked to remove afterwards so it should not matter now @@ -2218,7 +2218,7 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit): exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), exp.Array: array_sql, - exp.TypeOf: rename_func("TYPEOF"), + exp.Typeof: rename_func("TYPEOF"), exp.ArrayAgg: rename_func("ARRAY_AGG"), exp.ArrayConcat: rename_func("ARRAY_CONCAT"), exp.ArrayIntersect: rename_func("ARRAY_INTERSECT"), diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 507ec28e90..ecabbc32fb 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -135,7 +135,7 @@ class Parser(Spark2.Parser): "DATEDIFF": _build_datediff, "DATE_DIFF": _build_datediff, "LISTAGG": exp.GroupConcat.from_arg_list, - "TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)), + "TYPEOF": lambda args: exp.Typeof(this=seq_get(args, 0)), "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime( diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a6c705ecd1..de986f3b14 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5546,10 +5546,6 @@ class ToArray(Func): pass -class TypeOf(Func): - arg_types = {"this": True} - - # https://materialize.com/docs/sql/types/list/ class List(Func): arg_types = {"expressions": False} diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 836b29d36f..79b244d61c 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -936,54 +936,13 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: if not left_join_table: continue - - predicate = column.find_ancestor(exp.Predicate, exp.Select) - assert isinstance(predicate, exp.Binary), ( - "Columns can only be marked with (+) when involved in a binary operation" - ) - - predicate_parent = predicate.parent - join_predicate = predicate.pop() - - left_columns = [ - c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") - ] - right_columns = [ - c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") - ] - - assert not (left_columns and right_columns), ( - "The (+) marker cannot appear in both sides of a binary predicate" - ) - - marked_column_tables = set() - for col in left_columns or right_columns: - table = col.table - assert table, f"Column {col} needs to be qualified with a table" - assert not ( len(left_join_table) > 1 ), "Cannot combine JOIN predicates from different tables" - for col in join_cols: col.set("join_mark", False) - - assert len(marked_column_tables) == 1, ( - "Columns of only a single table can be marked with (+) in a given binary predicate" - ) - - # Add predicate if join already copied, or add join if it is new - join_this = old_joins.get(col.table, query_from).this - existing_join = new_joins.get(join_this.alias_or_name) - if existing_join: - existing_join.set("on", exp.and_(existing_join.args["on"], join_predicate)) - else: - new_joins[join_this.alias_or_name] = exp.Join( - this=join_this.copy(), on=join_predicate.copy(), kind="LEFT" - ) - joins_ons[left_join_table.pop()].append(cond) @@ -1009,11 +968,6 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: parent.pop() if query_from.alias_or_name in new_joins: - - assert len(only_old_join_sources) >= 1, ( - "Cannot determine which table to use in the new FROM clause" - ) - only_old_joins = old_joins.keys() - new_joins.keys() assert ( len(only_old_joins) >= 1 @@ -1089,4 +1043,4 @@ def _inline_inherited_window(window: exp.Expression) -> None: for window in find_all_in_scope(expression, exp.Window): _inline_inherited_window(window) - return expression + return expression \ No newline at end of file