Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4926,6 +4926,10 @@ class AddConstraint(Expression):
arg_types = {"expressions": True}


class AddPartition(Expression):
arg_types = {"this": True, "exists": False}


class AttachOption(Expression):
arg_types = {"this": True, "expression": False}

Expand Down
4 changes: 4 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3498,6 +3498,10 @@ def droppartition_sql(self, expression: exp.DropPartition) -> str:
def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
return f"ADD {self.expressions(expression)}"

def addpartition_sql(self, expression: exp.AddPartition) -> str:
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
return f"ADD {exists}{self.sql(expression.this)}"

def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)

Expand Down
42 changes: 29 additions & 13 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7334,24 +7334,29 @@ def _parse_refresh(self) -> exp.Refresh:
self._match(TokenType.TABLE)
return self.expression(exp.Refresh, this=self._parse_string() or self._parse_table())

def _parse_add_column(self) -> t.Optional[exp.Expression]:
def _parse_add_column(self) -> t.Optional[exp.ColumnDef]:
if not self._prev.text.upper() == "ADD":
return None

start = self._index
self._match(TokenType.COLUMN)

exists_column = self._parse_exists(not_=True)
expression = self._parse_field_def()

if expression:
expression.set("exists", exists_column)
if not isinstance(expression, exp.ColumnDef):
self._retreat(start)
return None

# https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns
if self._match_texts(("FIRST", "AFTER")):
position = self._prev.text
column_position = self.expression(
exp.ColumnPosition, this=self._parse_column(), position=position
)
expression.set("position", column_position)
expression.set("exists", exists_column)

# https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns
if self._match_texts(("FIRST", "AFTER")):
position = self._prev.text
column_position = self.expression(
exp.ColumnPosition, this=self._parse_column(), position=position
)
expression.set("position", column_position)

return expression

Expand All @@ -7368,13 +7373,24 @@ def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.DropPart
)

def _parse_alter_table_add(self) -> t.List[exp.Expression]:
def _parse_add_column_or_constraint():
def _parse_add_alteration() -> t.Optional[exp.Expression]:
self._match_text_seq("ADD")
if self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False):
return self.expression(
exp.AddConstraint, expressions=self._parse_csv(self._parse_constraint)
)
return self._parse_add_column()

column_def = self._parse_add_column()
if isinstance(column_def, exp.ColumnDef):
return column_def

exists = self._parse_exists(not_=True)
if self._match_pair(TokenType.PARTITION, TokenType.L_PAREN, advance=False):
return self.expression(
exp.AddPartition, exists=exists, this=self._parse_field(any_token=True)
)

return None

if not self.dialect.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN or self._match_text_seq(
"COLUMNS"
Expand All @@ -7383,7 +7399,7 @@ def _parse_add_column_or_constraint():

return ensure_list(schema) if schema else self._parse_csv(self._parse_field_def)

return self._parse_csv(_parse_add_column_or_constraint)
return self._parse_csv(_parse_add_alteration)

def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]:
if self._match_texts(self.ALTER_ALTER_PARSERS):
Expand Down
2 changes: 2 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def test_spark(self):
"REFRESH TABLE t",
)

self.validate_identity("ALTER TABLE foo ADD PARTITION(event = 'click')")
self.validate_identity("ALTER TABLE foo ADD IF NOT EXISTS PARTITION(event = 'click')")
self.validate_identity("IF(cond, foo AS bar, bla AS baz)")
self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS")
self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS")
Expand Down