diff --git a/apis/utils/supported_functions_in_all_dialects.json b/apis/utils/supported_functions_in_all_dialects.json index 2ceddc4d01..412e8f2ba3 100644 --- a/apis/utils/supported_functions_in_all_dialects.json +++ b/apis/utils/supported_functions_in_all_dialects.json @@ -740,6 +740,7 @@ "DISTINCT", "STDDEV", "FILTER_ARRAY", + "FIND_IN_SET", "TIMESTAMP", "REGEXP_CONTAINS", "CASE", @@ -785,13 +786,15 @@ "LAST_DAY_OF_MONTH", "FORMAT_DATETIME", "COUNT_IF", + "ARRAY_INTERSECT" "WIDTH_BUCKET", "RAND", "CORR", "COVAR_POP", - "URL_DECODE" - "TRANSFORM", - "ARRAY_INTERSECT" + "URL_DECODE", + "TYPEOF", + "TIMEDIFF", + "INTERVAL" ], "databricks": [ "ABS", diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 7ea21400f3..ebefaa1664 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -105,6 +105,7 @@ class Parser(Spark.Parser): "DATE_ADD": build_date_delta(exp.DateAdd), "DATEDIFF": build_date_delta(exp.DateDiff), "DATE_DIFF": build_date_delta(exp.DateDiff), + "FIND_IN_SET": exp.FindInSet.from_arg_list, "GETDATE": exp.CurrentTimestamp.from_arg_list, "GET_JSON_OBJECT": _build_json_extract, "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "databricks"), @@ -115,6 +116,9 @@ class Parser(Spark.Parser): "TIMEDIFF": lambda args: exp.TimestampDiff( unit=seq_get(args, 0), this=seq_get(args, 1), expression=seq_get(args, 2) ), + "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime( + this=seq_get(args, 0), scale=exp.Literal.string("seconds") + ), } FACTOR = { diff --git a/sqlglot/dialects/e6.py b/sqlglot/dialects/e6.py index 69036d25be..03ee9a49a7 100644 --- a/sqlglot/dialects/e6.py +++ b/sqlglot/dialects/e6.py @@ -1494,6 +1494,7 @@ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: ), "TRUNC": date_trunc_to_time, "TRIM": lambda self: self._parse_trim(), + "TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)), "UNNEST": lambda args: exp.Explode(this=seq_get(args, 0)), # TODO:: I have removed the _parse_unnest_sql, was it really required # It was added due to some requirements before but those were asked to remove afterwards so it should not matter now @@ -2188,12 +2189,16 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.Anonymous: anonymous_sql, + exp.FindInSet: lambda self, e: self.func( + "ARRAY_POSITION", e.this, self.func("SPLIT", e.expression, exp.Literal.string(",")) + ), exp.AnyValue: rename_func("ARBITRARY"), exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), exp.Array: array_sql, + exp.TypeOf: rename_func("TYPEOF"), exp.ArrayAgg: rename_func("ARRAY_AGG"), exp.ArrayConcat: rename_func("ARRAY_CONCAT"), exp.ArrayIntersect: rename_func("ARRAY_INTERSECT"), diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index f8237aa64a..0bddcfddaf 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -120,6 +120,7 @@ class Parser(Spark2.Parser): "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff), "DATEDIFF": _build_datediff, "DATE_DIFF": _build_datediff, + "TYPEOF": lambda args: exp.TypeOf(this=seq_get(args, 0)), "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime( diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 91c5ad8cf4..eb9d0f3bcf 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5495,6 +5495,10 @@ class ToArray(Func): pass +class TypeOf(Func): + arg_types = {"this": True} + + # https://materialize.com/docs/sql/types/list/ class List(Func): arg_types = {"expressions": False} @@ -6858,6 +6862,21 @@ class StrPosition(Func): } +class FindInSet(Func): + """ + FIND_IN_SET function that returns the position of a string within a comma-separated list of strings. + + Returns: + The position (1-based) of searchExpr in sourceExpr, or 0 if not found or if searchExpr contains a comma. + + Args: + this: The string to search for (searchExpr) + expression: The comma-separated list of strings to search in (sourceExpr) + """ + + arg_types = {"this": True, "expression": True} + + class StrToDate(Func): arg_types = {"this": True, "format": False, "safe": False} diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 6d7c5aebe8..e4385e6750 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -888,9 +888,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: continue predicate = column.find_ancestor(exp.Predicate, exp.Select) - assert isinstance(predicate, exp.Binary), ( - "Columns can only be marked with (+) when involved in a binary operation" - ) + assert isinstance( + predicate, exp.Binary + ), "Columns can only be marked with (+) when involved in a binary operation" predicate_parent = predicate.parent join_predicate = predicate.pop() @@ -902,9 +902,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") ] - assert not (left_columns and right_columns), ( - "The (+) marker cannot appear in both sides of a binary predicate" - ) + assert not ( + left_columns and right_columns + ), "The (+) marker cannot appear in both sides of a binary predicate" marked_column_tables = set() for col in left_columns or right_columns: @@ -914,9 +914,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: col.set("join_mark", False) marked_column_tables.add(table) - assert len(marked_column_tables) == 1, ( - "Columns of only a single table can be marked with (+) in a given binary predicate" - ) + assert ( + len(marked_column_tables) == 1 + ), "Columns of only a single table can be marked with (+) in a given binary predicate" # Add predicate if join already copied, or add join if it is new join_this = old_joins.get(col.table, query_from).this @@ -938,9 +938,9 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: only_old_join_sources = old_joins.keys() - new_joins.keys() if query_from.alias_or_name in new_joins: - assert len(only_old_join_sources) >= 1, ( - "Cannot determine which table to use in the new FROM clause" - ) + assert ( + len(only_old_join_sources) >= 1 + ), "Cannot determine which table to use in the new FROM clause" new_from_name = list(only_old_join_sources)[0] query.set("from", exp.From(this=old_joins.pop(new_from_name).this)) diff --git a/tests/dialects/test_e6.py b/tests/dialects/test_e6.py index 0608a968c4..470b699604 100644 --- a/tests/dialects/test_e6.py +++ b/tests/dialects/test_e6.py @@ -49,6 +49,15 @@ def test_E6(self): }, ) + self.validate_all( + "SELECT TYPEOF('hello')", + read={ + "databricks": "SELECT TYPEOF('hello');", + "spark": "SELECT TYPEOF('hello');", + "spark2": "SELECT TYPEOF('hello');", + "snowflake": "SELECT TYPEOF('hello');", + }, + ) self.validate_all( "SELECT ARRAY_INTERSECT(ARRAY[1, 2, 3], ARRAY[1, 3, 3, 5])", read={ @@ -694,6 +703,29 @@ def test_E6(self): }, ) + # FIND_IN_SET function tests - Databricks to E6 transpilation + self.validate_all( + "SELECT ARRAY_POSITION('ab', SPLIT('abc,b,ab,c,def', ','))", + read={ + "databricks": "SELECT FIND_IN_SET('ab', 'abc,b,ab,c,def')", + }, + ) + + self.validate_all( + "SELECT ARRAY_POSITION('test', SPLIT('hello,world,test', ','))", + read={ + "databricks": "SELECT FIND_IN_SET('test', 'hello,world,test')", + }, + ) + + # Test FIND_IN_SET with column references + self.validate_all( + "SELECT ARRAY_POSITION(search_col, SPLIT(list_col, ',')) FROM table1", + read={ + "databricks": "SELECT FIND_IN_SET(search_col, list_col) FROM table1", + }, + ) + def test_regex(self): self.validate_all( "REGEXP_REPLACE('abcd', 'ab', '')",