diff --git a/sqlglot/dialects/e6.py b/sqlglot/dialects/e6.py index a77a47de59..55f4ad6283 100644 --- a/sqlglot/dialects/e6.py +++ b/sqlglot/dialects/e6.py @@ -2033,6 +2033,60 @@ def string_agg_sql(self: E6.Generator, expression: exp.GroupConcat) -> str: # Generate SQL using STRING_AGG/LISTAGG, with separator or default '' return self.func("LISTAGG", expr_1, separator or exp.Literal.string("")) + def concat_ws_sql(self: E6.Generator, expression: exp.ConcatWs) -> str: + """ + Generate the SQL for the CONCAT_WS function in E6. + + Implements Databricks CONCAT_WS behavior: + - If sep is NULL the result is NULL (handled by e6 engine) + - exprN that are NULL are ignored + - If only separator provided or all exprN are NULL, returns empty string + - Each exprN can be STRING or ARRAY of STRING + - NULLs within arrays are filtered out + - Arrays are flattened and individual elements joined with separator + """ + if not expression.expressions: + return "''" + + # Extract separator and arguments + separator = expression.expressions[0] + args = expression.expressions[1:] if len(expression.expressions) > 1 else [] + + # If no arguments provided (only separator), return empty string + if not args: + return "''" + + # Collect all non-NULL expression nodes (flattening arrays) + array_expressions = [] + + for arg in args: + if isinstance(arg, exp.Array): + # For array arguments: add non-NULL elements + for element in arg.expressions: + if not isinstance(element, exp.Null): + array_expressions.append(element) + else: + # For string arguments: add if not NULL + if not isinstance(arg, exp.Null): + array_expressions.append(arg) + + # If no elements after filtering, return empty string + if not array_expressions: + return "''" + + # Single element case - just return the element + if len(array_expressions) == 1: + return self.sql(array_expressions[0]) + + # Multiple elements: create array and join with separator + # Build: ARRAY_TO_STRING(ARRAY[element1, element2, ...], separator) + # Create Array expression with the actual expression nodes + array_expr = exp.Array(expressions=array_expressions) + + # Use ARRAY_TO_STRING function directly instead of exp.ArrayToString + # to avoid the ARRAY_JOIN mapping in TRANSFORMS + return self.func("ARRAY_TO_STRING", array_expr, separator) + # def struct_sql(self, expression: exp.Struct) -> str: # struct_expr = expression.expressions # return f"{struct_expr}" @@ -2247,6 +2301,7 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit): # We mapped this believing that for most of the cases, # CONCAT function in other dialects would mostly use for ARRAY concatenation exp.Concat: rename_func("CONCAT"), + exp.ConcatWs: concat_ws_sql, exp.Contains: rename_func("CONTAINS_SUBSTR"), exp.CurrentDate: lambda *_: "CURRENT_DATE", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", diff --git a/tests/dialects/test_e6.py b/tests/dialects/test_e6.py index 3cd485d60f..536e8212c2 100644 --- a/tests/dialects/test_e6.py +++ b/tests/dialects/test_e6.py @@ -1604,6 +1604,49 @@ def test_string(self): read={"databricks": "SELECT to_varchar(x'537061726b2053514c', 'hex')"}, ) + # CONCAT_WS tests - based on Databricks documentation + # Basic string concatenation: concat_ws(' ', 'Spark', 'SQL') -> 'Spark SQL' + self.validate_all( + "SELECT ARRAY_TO_STRING(ARRAY['Spark', 'SQL'], ' ')", + read={"databricks": "SELECT concat_ws(' ', 'Spark', 'SQL')"}, + ) + + # Only separator provided: concat_ws('s') -> '' + self.validate_all( + "SELECT ''", + read={"databricks": "SELECT concat_ws('s')"}, + ) + + # Mixed strings, arrays and NULLs: concat_ws(',', 'Spark', array('S', 'Q', NULL, 'L'), NULL) -> 'Spark,S,Q,L' + self.validate_all( + "SELECT ARRAY_TO_STRING(ARRAY['Spark', 'S', 'Q', 'L'], ',')", + read={"databricks": "SELECT concat_ws(',', 'Spark', array('S', 'Q', NULL, 'L'), NULL)"}, + ) + + # Single string argument with separator + self.validate_all( + "SELECT 'test'", + read={"databricks": "SELECT concat_ws('-', 'test')"}, + ) + + # Multiple string arguments + self.validate_all( + "SELECT ARRAY_TO_STRING(ARRAY['a', 'b', 'c'], '-')", + read={"databricks": "SELECT concat_ws('-', 'a', 'b', 'c')"}, + ) + + # Empty separator + self.validate_all( + "SELECT ARRAY_TO_STRING(ARRAY['hello', 'world'], '')", + read={"databricks": "SELECT concat_ws('', 'hello', 'world')"}, + ) + + # Array with all valid elements (no NULLs) + self.validate_all( + "SELECT ARRAY_TO_STRING(ARRAY['x', 'y', 'z'], '|')", + read={"databricks": "SELECT concat_ws('|', array('x', 'y', 'z'))"}, + ) + def test_to_utf(self): self.validate_all( "TO_UTF8(x)",