Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions sqlglot/dialects/e6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,6 +2033,60 @@ def string_agg_sql(self: E6.Generator, expression: exp.GroupConcat) -> str:
# Generate SQL using STRING_AGG/LISTAGG, with separator or default ''
return self.func("LISTAGG", expr_1, separator or exp.Literal.string(""))

def concat_ws_sql(self: E6.Generator, expression: exp.ConcatWs) -> str:
"""
Generate the SQL for the CONCAT_WS function in E6.

Implements Databricks CONCAT_WS behavior:
- If sep is NULL the result is NULL (handled by e6 engine)
- exprN that are NULL are ignored
- If only separator provided or all exprN are NULL, returns empty string
- Each exprN can be STRING or ARRAY of STRING
- NULLs within arrays are filtered out
- Arrays are flattened and individual elements joined with separator
"""
if not expression.expressions:
return "''"

# Extract separator and arguments
separator = expression.expressions[0]
args = expression.expressions[1:] if len(expression.expressions) > 1 else []

# If no arguments provided (only separator), return empty string
if not args:
return "''"

# Collect all non-NULL expression nodes (flattening arrays)
array_expressions = []

for arg in args:
if isinstance(arg, exp.Array):
# For array arguments: add non-NULL elements
for element in arg.expressions:
if not isinstance(element, exp.Null):
array_expressions.append(element)
else:
# For string arguments: add if not NULL
if not isinstance(arg, exp.Null):
array_expressions.append(arg)

# If no elements after filtering, return empty string
if not array_expressions:
return "''"

# Single element case - just return the element
if len(array_expressions) == 1:
return self.sql(array_expressions[0])

# Multiple elements: create array and join with separator
# Build: ARRAY_TO_STRING(ARRAY[element1, element2, ...], separator)
# Create Array expression with the actual expression nodes
array_expr = exp.Array(expressions=array_expressions)

# Use ARRAY_TO_STRING function directly instead of exp.ArrayToString
# to avoid the ARRAY_JOIN mapping in TRANSFORMS
return self.func("ARRAY_TO_STRING", array_expr, separator)

# def struct_sql(self, expression: exp.Struct) -> str:
# struct_expr = expression.expressions
# return f"{struct_expr}"
Expand Down Expand Up @@ -2247,6 +2301,7 @@ def split_sql(self, expression: exp.Split | exp.RegexpSplit):
# We mapped this believing that for most of the cases,
# CONCAT function in other dialects would mostly use for ARRAY concatenation
exp.Concat: rename_func("CONCAT"),
exp.ConcatWs: concat_ws_sql,
exp.Contains: rename_func("CONTAINS_SUBSTR"),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
Expand Down
43 changes: 43 additions & 0 deletions tests/dialects/test_e6.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,49 @@ def test_string(self):
read={"databricks": "SELECT to_varchar(x'537061726b2053514c', 'hex')"},
)

# CONCAT_WS tests - based on Databricks documentation
# Basic string concatenation: concat_ws(' ', 'Spark', 'SQL') -> 'Spark SQL'
self.validate_all(
"SELECT ARRAY_TO_STRING(ARRAY['Spark', 'SQL'], ' ')",
read={"databricks": "SELECT concat_ws(' ', 'Spark', 'SQL')"},
)

# Only separator provided: concat_ws('s') -> ''
self.validate_all(
"SELECT ''",
read={"databricks": "SELECT concat_ws('s')"},
)

# Mixed strings, arrays and NULLs: concat_ws(',', 'Spark', array('S', 'Q', NULL, 'L'), NULL) -> 'Spark,S,Q,L'
self.validate_all(
"SELECT ARRAY_TO_STRING(ARRAY['Spark', 'S', 'Q', 'L'], ',')",
read={"databricks": "SELECT concat_ws(',', 'Spark', array('S', 'Q', NULL, 'L'), NULL)"},
)

# Single string argument with separator
self.validate_all(
"SELECT 'test'",
read={"databricks": "SELECT concat_ws('-', 'test')"},
)

# Multiple string arguments
self.validate_all(
"SELECT ARRAY_TO_STRING(ARRAY['a', 'b', 'c'], '-')",
read={"databricks": "SELECT concat_ws('-', 'a', 'b', 'c')"},
)

# Empty separator
self.validate_all(
"SELECT ARRAY_TO_STRING(ARRAY['hello', 'world'], '')",
read={"databricks": "SELECT concat_ws('', 'hello', 'world')"},
)

# Array with all valid elements (no NULLs)
self.validate_all(
"SELECT ARRAY_TO_STRING(ARRAY['x', 'y', 'z'], '|')",
read={"databricks": "SELECT concat_ws('|', array('x', 'y', 'z'))"},
)

def test_to_utf(self):
self.validate_all(
"TO_UTF8(x)",
Expand Down