diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index a9b3ef6cae..d70a919517 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -821,7 +821,7 @@ def _parse_join( if join: method = join.args.get("method") join.set("method", None) - join.set("global", method) + join.set("global_", method) # tbl ARRAY JOIN arr <-- this should be a `Column` reference, not a `Table` # https://clickhouse.com/docs/en/sql-reference/statements/select/array-join diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6a5d4fce37..1b7ceea4ef 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -662,11 +662,9 @@ def _parse_force(self) -> exp.Install | exp.Command: def _parse_install(self, force: bool = False) -> exp.Install: return self.expression( exp.Install, - **{ # type: ignore - "this": self._parse_id_var(), - "from": self._parse_var_or_string() if self._match(TokenType.FROM) else None, - "force": force, - }, + this=self._parse_id_var(), + from_=self._parse_var_or_string() if self._match(TokenType.FROM) else None, + force=force, ) def _parse_primary(self) -> t.Optional[exp.Expression]: @@ -1008,7 +1006,7 @@ def show_sql(self, expression: exp.Show) -> str: def install_sql(self, expression: exp.Install) -> str: force = "FORCE " if expression.args.get("force") else "" this = self.sql(expression, "this") - from_clause = expression.args.get("from") + from_clause = expression.args.get("from_") from_clause = f" FROM {from_clause}" if from_clause else "" return f"{force}INSTALL {this}{from_clause}" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 34a5ba10b4..581a92f51c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -834,7 +834,7 @@ def alterset_sql(self, expression: exp.AlterSet) -> str: return f"SET{serde}{exprs}{location}{file_format}{tags}" def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str: - prefix = "WITH " if expression.args.get("with") else "" + prefix = "WITH " if expression.args.get("with_") else "" exprs = self.expressions(expression, flat=True) return f"{prefix}SERDEPROPERTIES ({exprs})" diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 0bf249618a..39f3069af7 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -671,7 +671,7 @@ def _parse_show_mysql( for_role=for_role, into_outfile=into_outfile, json=json, - **{"global": global_}, # type: ignore + global_=global_, ) def _parse_oldstyle_limit( @@ -1229,7 +1229,7 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> def show_sql(self, expression: exp.Show) -> str: this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" - global_ = " GLOBAL" if expression.args.get("global") else "" + global_ = " GLOBAL" if expression.args.get("global_") else "" target = self.sql(expression, "target") target = f" {target}" if target else "" diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 0b73260f4c..10f25359af 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -417,7 +417,13 @@ class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: rename_func("ARBITRARY"), - exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), + exp.ApproxQuantile: lambda self, e: self.func( + "APPROX_PERCENTILE", + e.this, + e.args.get("weight"), + e.args.get("quantile"), + e.args.get("accuracy"), + ), exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), exp.Array: transforms.preprocess( diff --git a/sqlglot/dialects/singlestore.py b/sqlglot/dialects/singlestore.py index 5475fe3d09..58fbf98801 100644 --- a/sqlglot/dialects/singlestore.py +++ b/sqlglot/dialects/singlestore.py @@ -542,7 +542,7 @@ def _unicode_substitute(m: re.Match[str]) -> str: "offset", "starts_with", "limit", - "from", + "from_", "scope", "scope_kind", "mutex", diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a4fb5e375c..4b247cad73 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1114,19 +1114,17 @@ def _parse_show_snowflake(self, this: str) -> exp.Show: return self.expression( exp.Show, - **{ - "terse": terse, - "this": this, - "history": history, - "like": like, - "scope": scope, - "scope_kind": scope_kind, - "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(), - "limit": self._parse_limit(), - "from": self._parse_string() if self._match(TokenType.FROM) else None, - "privileges": self._match_text_seq("WITH", "PRIVILEGES") - and self._parse_csv(lambda: self._parse_var(any_token=True, upper=True)), - }, + terse=terse, + this=this, + history=history, + like=like, + scope=scope, + scope_kind=scope_kind, + starts_with=self._match_text_seq("STARTS", "WITH") and self._parse_string(), + limit=self._parse_limit(), + from_=self._parse_string() if self._match(TokenType.FROM) else None, + privileges=self._match_text_seq("WITH", "PRIVILEGES") + and self._parse_csv(lambda: self._parse_var(any_token=True, upper=True)), ) def _parse_put(self) -> exp.Put | exp.Command: @@ -1652,7 +1650,7 @@ def show_sql(self, expression: exp.Show) -> str: limit = self.sql(expression, "limit") - from_ = self.sql(expression, "from") + from_ = self.sql(expression, "from_") if from_: from_ = f" FROM {from_}" diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 2f6c22a276..eb435a1ae8 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -213,13 +213,10 @@ def _parse_translate(self) -> exp.TranslateCharacters: def _parse_update(self) -> exp.Update: return self.expression( exp.Update, - **{ # type: ignore - "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), - "from": self._parse_from(joins=True), - "expressions": self._match(TokenType.SET) - and self._parse_csv(self._parse_equality), - "where": self._parse_where(), - }, + this=self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), + from_=self._parse_from(joins=True), + expressions=self._match(TokenType.SET) and self._parse_csv(self._parse_equality), + where=self._parse_where(), ) def _parse_rangen(self): @@ -387,7 +384,7 @@ def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> st # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause def update_sql(self, expression: exp.Update) -> str: this = self.sql(expression, "this") - from_sql = self.sql(expression, "from") + from_sql = self.sql(expression, "from_") set_sql = self.expressions(expression, flat=True) where_sql = self.sql(expression, "where") sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 251f689219..ce801f71e3 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -575,7 +575,7 @@ class Parser(parser.Parser): QUERY_MODIFIER_PARSERS = { **parser.Parser.QUERY_MODIFIER_PARSERS, TokenType.OPTION: lambda self: ("options", self._parse_options()), - TokenType.FOR: lambda self: ("for", self._parse_for()), + TokenType.FOR: lambda self: ("for_", self._parse_for()), } # T-SQL does not allow BEGIN to be used as an identifier @@ -825,7 +825,6 @@ def _parse_convert( args = [this, *self._parse_csv(self._parse_assignment)] convert = exp.Convert.from_arg_list(args) convert.set("safe", safe) - convert.set("strict", strict) return convert def _parse_column_def( @@ -882,7 +881,7 @@ def _parse_id_var( this = super()._parse_id_var(any_token=any_token, tokens=tokens) if this: if is_global: - this.set("global", True) + this.set("global_", True) elif is_temporary: this.set("temporary", True) @@ -1241,12 +1240,12 @@ def create_sql(self, expression: exp.Create) -> str: if kind == "VIEW": expression.this.set("catalog", None) - with_ = expression.args.get("with") + with_ = expression.args.get("with_") if ctas_expression and with_: # We've already preprocessed the Create expression to bubble up any nested CTEs, # but CREATE VIEW actually requires the WITH clause to come after it so we need # to amend the AST by moving the CTEs to the CREATE VIEW statement's query. - ctas_expression.set("with", with_.pop()) + ctas_expression.set("with_", with_.pop()) table = expression.find(exp.Table) @@ -1362,7 +1361,7 @@ def rollback_sql(self, expression: exp.Rollback) -> str: def identifier_sql(self, expression: exp.Identifier) -> str: identifier = super().identifier_sql(expression) - if expression.args.get("global"): + if expression.args.get("global_"): identifier = f"##{identifier}" elif expression.args.get("temporary"): identifier = f"#{identifier}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 372501b576..1ce740c5df 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -16,6 +16,7 @@ import math import numbers import re +import sys import textwrap import typing as t from collections import deque @@ -54,6 +55,7 @@ def __new__(cls, clsname, bases, attrs): # When an Expression class is created, its key is automatically set # to be the lowercase version of the class' name. klass.key = clsname.lower() + klass.required_args = {k for k, v in klass.arg_types.items() if v} # This is so that docstrings are not inherited in pdoc klass.__doc__ = klass.__doc__ or "" @@ -66,6 +68,7 @@ def __new__(cls, clsname, bases, attrs): TABLE_PARTS = ("this", "db", "catalog") COLUMN_PARTS = ("this", "table", "db", "catalog") POSITION_META_KEYS = ("line", "col", "start", "end") +UNITTEST = "unittest" in sys.modules or "pytest" in sys.modules class Expression(metaclass=_Expression): @@ -102,6 +105,7 @@ class Expression(metaclass=_Expression): key = "expression" arg_types = {"this": True} + required_args = {"this"} __slots__ = ("args", "parent", "arg_key", "index", "comments", "_type", "_meta", "_hash") def __init__(self, **args: t.Any): @@ -768,12 +772,14 @@ def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: """ errors: t.List[str] = [] - for k in self.args: - if k not in self.arg_types: - errors.append(f"Unexpected keyword: '{k}' for {self.__class__}") - for k, mandatory in self.arg_types.items(): + if UNITTEST: + for k in self.args: + if k not in self.arg_types: + raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}") + + for k in self.required_args: v = self.args.get(k) - if mandatory and (v is None or (isinstance(v, list) and not v)): + if v is None or (type(v) is list and not v): errors.append(f"Required keyword: '{k}' missing for {self.__class__}") if ( @@ -883,29 +889,39 @@ def not_(self, copy: bool = True): return not_(self, copy=copy) def update_positions( - self: E, other: t.Optional[Token | Expression] = None, **kwargs: t.Any + self: E, + other: t.Optional[Token | Expression] = None, + line: t.Optional[int] = None, + col: t.Optional[int] = None, + start: t.Optional[int] = None, + end: t.Optional[int] = None, ) -> E: """ Update this expression with positions from a token or other expression. Args: other: a token or expression to update this expression with. + line: the line number to use if other is None + col: column number + start: start char index + end: end char index Returns: The updated expression. """ - if isinstance(other, Expression): - self.meta.update({k: v for k, v in other.meta.items() if k in POSITION_META_KEYS}) - elif other is not None: - self.meta.update( - { - "line": other.line, - "col": other.col, - "start": other.start, - "end": other.end, - } - ) - self.meta.update({k: v for k, v in kwargs.items() if k in POSITION_META_KEYS}) + if other is None: + self.meta["line"] = line + self.meta["col"] = col + self.meta["start"] = start + self.meta["end"] = end + elif hasattr(other, "meta"): + for k in POSITION_META_KEYS: + self.meta[k] = other.meta[k] + else: + self.meta["line"] = other.line + self.meta["col"] = other.col + self.meta["start"] = other.start + self.meta["end"] = other.end return self def as_( @@ -1247,7 +1263,7 @@ def order_by( @property def ctes(self) -> t.List[CTE]: """Returns a list of all the CTEs attached to this query.""" - with_ = self.args.get("with") + with_ = self.args.get("with_") return with_.expressions if with_ else [] @property @@ -1475,7 +1491,7 @@ class DDL(Expression): @property def ctes(self) -> t.List[CTE]: """Returns a list of all the CTEs attached to this statement.""" - with_ = self.args.get("with") + with_ = self.args.get("with_") return with_.expressions if with_ else [] @property @@ -1536,7 +1552,7 @@ def returning( class Create(DDL): arg_types = { - "with": False, + "with_": False, "this": True, "kind": True, "expression": False, @@ -1615,7 +1631,7 @@ class Detach(Expression): # https://duckdb.org/docs/sql/statements/load_and_install.html class Install(Expression): - arg_types = {"this": True, "from": False, "force": False} + arg_types = {"this": True, "from_": False, "force": False} # https://duckdb.org/docs/guides/meta/summarize.html @@ -1653,7 +1669,7 @@ class SetItem(Expression): "expressions": False, "kind": False, "collate": False, # MySQL SET NAMES statement - "global": False, + "global_": False, } @@ -1670,7 +1686,7 @@ class Show(Expression): "offset": False, "starts_with": False, "limit": False, - "from": False, + "from_": False, "like": False, "where": False, "db": False, @@ -1680,7 +1696,7 @@ class Show(Expression): "mutex": False, "query": False, "channel": False, - "global": False, + "global_": False, "log": False, "position": False, "types": False, @@ -2120,7 +2136,7 @@ class Constraint(Expression): class Delete(DML): arg_types = { - "with": False, + "with_": False, "this": False, "using": False, "where": False, @@ -2337,7 +2353,7 @@ class JoinHint(Expression): class Identifier(Expression): - arg_types = {"this": True, "quoted": False, "global": False, "temporary": False} + arg_types = {"this": True, "quoted": False, "global_": False, "temporary": False} @property def quoted(self) -> bool: @@ -2380,7 +2396,7 @@ class IndexParameters(Expression): class Insert(DDL, DML): arg_types = { "hint": False, - "with": False, + "with_": False, "is_function": False, "this": False, "expression": False, @@ -2607,7 +2623,7 @@ class Join(Expression): "kind": False, "using": False, "method": False, - "global": False, + "global_": False, "hint": False, "match_condition": False, # Snowflake "expressions": False, @@ -2786,7 +2802,7 @@ class Order(Expression): # https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier class WithFill(Expression): arg_types = { - "from": False, + "from_": False, "to": False, "step": False, "interpolate": False, @@ -2890,7 +2906,7 @@ class DataBlocksizeProperty(Property): class DataDeletionProperty(Property): - arg_types = {"on": True, "filter_col": False, "retention_period": False} + arg_types = {"on": True, "filter_column": False, "retention_period": False} class DefinerProperty(Property): @@ -3210,7 +3226,7 @@ class SemanticView(Expression): class SerdeProperties(Property): - arg_types = {"expressions": True, "with": False} + arg_types = {"expressions": True, "with_": False} class SetProperty(Property): @@ -3306,7 +3322,7 @@ class WithSystemVersioningProperty(Property): "this": False, "data_consistency": False, "retention_period": False, - "with": True, + "with_": True, } @@ -3574,7 +3590,7 @@ def to_column(self, copy: bool = True) -> Expression: class SetOperation(Query): arg_types = { - "with": False, + "with_": False, "this": True, "expression": True, "distinct": False, @@ -3643,10 +3659,10 @@ class Intersect(SetOperation): class Update(DML): arg_types = { - "with": False, + "with_": False, "this": False, "expressions": True, - "from": False, + "from_": False, "where": False, "returning": False, "order": False, @@ -3794,7 +3810,7 @@ def from_( return _apply_builder( expression=expression, instance=self, - arg="from", + arg="from_", into=From, prefix="FROM", dialect=dialect, @@ -3890,13 +3906,13 @@ class Lock(Expression): class Select(Query): arg_types = { - "with": False, + "with_": False, "kind": False, "expressions": False, "hint": False, "distinct": False, "into": False, - "from": False, + "from_": False, "operation_modifiers": False, **QUERY_MODIFIERS, } @@ -3925,7 +3941,7 @@ def from_( return _apply_builder( expression=expression, instance=self, - arg="from", + arg="from_", into=From, prefix="FROM", dialect=dialect, @@ -4427,7 +4443,7 @@ class Subquery(DerivedTable, Query): arg_types = { "this": True, "alias": False, - "with": False, + "with_": False, **QUERY_MODIFIERS, } @@ -4564,7 +4580,7 @@ class Where(Expression): class Star(Expression): - arg_types = {"except": False, "replace": False, "rename": False} + arg_types = {"except_": False, "replace": False, "rename": False} @property def name(self) -> str: @@ -5737,7 +5753,7 @@ class Transform(Func): class Translate(Func): - arg_types = {"this": True, "from": True, "to": True} + arg_types = {"this": True, "from_": True, "to": True} class Grouping(AggFunc): @@ -5863,7 +5879,7 @@ class Columns(Func): # https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax class Convert(Func): - arg_types = {"this": True, "expression": True, "style": False} + arg_types = {"this": True, "expression": True, "style": False, "safe": False} # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/CONVERT.html @@ -6771,7 +6787,7 @@ class IsNullValue(Func): # https://www.postgresql.org/docs/current/functions-json.html class JSON(Expression): - arg_types = {"this": False, "with": False, "unique": False} + arg_types = {"this": False, "with_": False, "unique": False} class JSONPath(Expression): @@ -7296,7 +7312,7 @@ class Normalize(Func): class Overlay(Func): - arg_types = {"this": True, "expression": True, "from": True, "for": False} + arg_types = {"this": True, "expression": True, "from_": True, "for_": False} # https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function @@ -7996,7 +8012,7 @@ class Merge(DML): "on": False, "using_cond": False, "whens": True, - "with": False, + "with_": False, "returning": False, } @@ -8317,7 +8333,7 @@ def _apply_cte_builder( return _apply_child_list_builder( cte, instance=instance, - arg="with", + arg="with_", append=append, copy=copy, into=With, @@ -8550,7 +8566,7 @@ def update( ) if from_: update_expr.set( - "from", + "from_", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), ) if isinstance(where, Condition): @@ -8566,7 +8582,7 @@ def update( for alias, qry in with_.items() ] update_expr.set( - "with", + "with_", With(expressions=cte_list), ) return update_expr diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d08b261232..e026cd1d01 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1306,7 +1306,7 @@ def heredoc_sql(self, expression: exp.Heredoc) -> str: return f"${tag}${self.sql(expression, 'this')}${tag}$" def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: - with_ = self.sql(expression, "with") + with_ = self.sql(expression, "with_") if with_: sql = f"{with_}{self.sep()}{sql}" return sql @@ -1944,7 +1944,7 @@ def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningP sql = f"SYSTEM_VERSIONING={on_sql}" - return f"WITH({sql})" if expression.args.get("with") else sql + return f"WITH({sql})" if expression.args.get("with_") else sql def insert_sql(self, expression: exp.Insert) -> str: hint = self.sql(expression, "hint") @@ -2212,7 +2212,7 @@ def tuple_sql(self, expression: exp.Tuple) -> str: def update_sql(self, expression: exp.Update) -> str: this = self.sql(expression, "this") set_sql = self.expressions(expression, flat=True) - from_sql = self.sql(expression, "from") + from_sql = self.sql(expression, "from_") where_sql = self.sql(expression, "where") returning = self.sql(expression, "returning") order = self.sql(expression, "order") @@ -2351,7 +2351,7 @@ def join_sql(self, expression: exp.Join) -> str: op for op in ( expression.method, - "GLOBAL" if expression.args.get("global") else None, + "GLOBAL" if expression.args.get("global_") else None, side, expression.kind, expression.hint if self.JOIN_HINTS else None, @@ -2466,7 +2466,7 @@ def setitem_sql(self, expression: exp.SetItem) -> str: expressions = self.expressions(expression) collate = self.sql(expression, "collate") collate = f" COLLATE {collate}" if collate else "" - global_ = "GLOBAL " if expression.args.get("global") else "" + global_ = "GLOBAL " if expression.args.get("global_") else "" return f"{global_}{kind}{this}{expressions}{collate}" def set_sql(self, expression: exp.Set) -> str: @@ -2564,7 +2564,7 @@ def order_sql(self, expression: exp.Order, flat: bool = False) -> str: return self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore def withfill_sql(self, expression: exp.WithFill) -> str: - from_sql = self.sql(expression, "from") + from_sql = self.sql(expression, "from_") from_sql = f" FROM {from_sql}" if from_sql else "" to_sql = self.sql(expression, "to") to_sql = f" TO {to_sql}" if to_sql else "" @@ -2725,7 +2725,7 @@ def options_modifier(self, expression: exp.Expression) -> str: return f" {options}" if options else "" def for_modifiers(self, expression: exp.Expression) -> str: - for_modifiers = self.expressions(expression, key="for") + for_modifiers = self.expressions(expression, key="for_") return f"{self.sep()}FOR XML{self.seg(for_modifiers)}" if for_modifiers else "" def queryoption_sql(self, expression: exp.QueryOption) -> str: @@ -2796,11 +2796,11 @@ def select_sql(self, expression: exp.Select) -> str: expression, f"SELECT{top_distinct}{operation_modifiers}{kind}{expressions}", self.sql(expression, "into", comment=False), - self.sql(expression, "from", comment=False), + self.sql(expression, "from_", comment=False), ) # If both the CTE and SELECT clauses have comments, generate the latter earlier - if expression.args.get("with"): + if expression.args.get("with_"): sql = self.maybe_comment(sql, expression) expression.pop_comments() @@ -2828,7 +2828,7 @@ def schema_columns_sql(self, expression: exp.Schema) -> str: return "" def star_sql(self, expression: exp.Star) -> str: - except_ = self.expressions(expression, key="except", flat=True) + except_ = self.expressions(expression, key="except_", flat=True) except_ = f"{self.seg(self.STAR_EXCEPT)} ({except_})" if except_ else "" replace = self.expressions(expression, key="replace", flat=True) replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" @@ -4826,7 +4826,7 @@ def json_sql(self, expression: exp.JSON) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" - _with = expression.args.get("with") + _with = expression.args.get("with_") if _with is None: with_sql = "" @@ -5010,8 +5010,8 @@ def columns_sql(self, expression: exp.Columns): def overlay_sql(self, expression: exp.Overlay): this = self.sql(expression, "this") expr = self.sql(expression, "expression") - from_sql = self.sql(expression, "from") - for_sql = self.sql(expression, "for") + from_sql = self.sql(expression, "from_") + for_sql = self.sql(expression, "for_") for_sql = f" FOR {for_sql}" if for_sql else "" return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})" diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index cc1aedd5b6..5a80a45fa3 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -61,7 +61,7 @@ def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: } for d in node.downstream: - edges.append({"from": node_id, "to": id(d)}) + edges.append({"from_": node_id, "to": id(d)}) return GraphHTML(nodes, edges, **opts) diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index da1fb7a217..2578fdd4f7 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -110,7 +110,7 @@ def _has_single_output_row(scope): return isinstance(scope.expression, exp.Select) and ( all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) or _is_limit_1(scope) - or not scope.expression.args.get("from") + or not scope.expression.args.get("from_") ) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index b661003690..58d71477ea 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -65,7 +65,7 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: # Existing CTES in the root expression. We'll use this for deduplication. existing_ctes: ExistingCTEsMapping = {} - with_ = root.expression.args.get("with") + with_ = root.expression.args.get("with_") recursive = False if with_: recursive = with_.args.get("recursive") @@ -97,7 +97,7 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: if new_ctes: query = expression.expression if isinstance(expression, exp.DDL) else expression - query.set("with", exp.With(expressions=new_ctes, recursive=recursive)) + query.set("with_", exp.With(expressions=new_ctes, recursive=recursive)) return expression diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 9b68c7da8e..b514785dff 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -48,7 +48,7 @@ def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E: # If a derived table has these Select args, it can't be merged UNMERGABLE_ARGS = set(exp.Select.arg_types) - { "expressions", - "from", + "from_", "joins", "where", "order", @@ -165,7 +165,7 @@ def _outer_select_joins_on_inner_select_join(): if not on: return False selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] - inner_from = inner_scope.expression.args.get("from") + inner_from = inner_scope.expression.args.get("from_") if not inner_from: return False inner_from_table = inner_from.alias_or_name @@ -197,7 +197,7 @@ def _is_recursive(): and not outer_scope.expression.is_star and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) - and inner_select.args.get("from") is not None + and inner_select.args.get("from_") is not None and not outer_scope.pivots and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) @@ -261,7 +261,7 @@ def _merge_from( """ Merge FROM clause of inner query into outer query. """ - new_subquery = inner_scope.expression.args["from"].this + new_subquery = inner_scope.expression.args["from_"].this new_subquery.set("joins", node_to_replace.args.get("joins")) node_to_replace.replace(new_subquery) for join_hint in outer_scope.join_hints: @@ -357,7 +357,7 @@ def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoi if isinstance(from_or_join, exp.Join): # Merge predicates from an outer join to the ON clause # if it only has columns that are already joined - from_ = expression.args.get("from") + from_ = expression.args.get("from_") sources = {from_.alias_or_name} if from_ else set() for join in expression.args["joins"]: diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 4e65f24fe1..9ac198645c 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -181,7 +181,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): - with_ = source.parent.expression.args.get("with") + with_ = source.parent.expression.args.get("with_") if with_ and with_.recursive: return {} node = source.expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e95c63321f..d8f6e405b7 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -853,7 +853,7 @@ def _expand_stars( def _add_except_columns( expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] ) -> None: - except_ = expression.args.get("except") + except_ = expression.args.get("except_") if not except_: return @@ -960,7 +960,8 @@ def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: else: projection = alias(projection, alias=_alias) new_expressions.append(projection) - cte.this.set("expressions", new_expressions) + if isinstance(cte.this, exp.Select): + cte.this.set("expressions", new_expressions) return expression @@ -1205,7 +1206,7 @@ def _get_available_source_columns( args = self.scope.expression.args # Collect tables in order: FROM clause tables + joined tables up to current join - from_name = args["from"].alias_or_name + from_name = args["from_"].alias_or_name available_sources = {from_name: self.get_source_columns(from_name)} for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index baff507965..cc50dfa6d7 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -68,7 +68,7 @@ def _qualify(table: exp.Table) -> None: table.set("catalog", catalog.copy()) if (db or catalog) and not isinstance(expression, exp.Query): - with_ = expression.args.get("with") or exp.With() + with_ = expression.args.get("with_") or exp.With() cte_names = {cte.alias_or_name for cte in with_.expressions} for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 85041fdb1f..96d3a60a1c 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -308,7 +308,7 @@ def columns(self): or column.name not in named_selects ) ) - or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") + or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") ): self._columns.append(column) @@ -663,7 +663,7 @@ def _traverse_ctes(scope): # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. # thus the recursive scope is the first section of the union. - with_ = scope.expression.args.get("with") + with_ = scope.expression.args.get("with_") if with_ and with_.recursive: union = cte.this @@ -720,7 +720,7 @@ def _traverse_tables(scope): # Traverse FROMs, JOINs, and LATERALs in the order they are defined expressions = [] - from_ = scope.expression.args.get("from") + from_ = scope.expression.args.get("from_") if from_: expressions.append(from_.this) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 163a5a8ec0..b723b2fc63 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -44,7 +44,7 @@ def unnest(select, parent_select, next_alias_name): if ( not predicate or parent_select is not predicate.parent_select - or not parent_select.args.get("from") + or not parent_select.args.get("from_") ): return diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c1c7bd48fb..f476bbc01f 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2382,10 +2382,8 @@ def _parse_system_versioning_property( self._match(TokenType.EQ) prop = self.expression( exp.WithSystemVersioningProperty, - **{ # type: ignore - "on": True, - "with": with_, - }, + on=True, + with_=with_, ) if self._match_text_seq("OFF"): @@ -3056,10 +3054,8 @@ def _parse_serde_properties(self, with_: bool = False) -> t.Optional[exp.SerdePr return None return self.expression( exp.SerdeProperties, - **{ # type: ignore - "expressions": self._parse_wrapped_properties(), - "with": with_, - }, + expressions=self._parse_wrapped_properties(), + with_=with_, ) def _parse_row_format( @@ -3147,7 +3143,7 @@ def _parse_update(self) -> exp.Update: elif self._match(TokenType.RETURNING, advance=False): kwargs["returning"] = self._parse_returning() elif self._match(TokenType.FROM, advance=False): - kwargs["from"] = self._parse_from(joins=True) + kwargs["from_"] = self._parse_from(joins=True) elif self._match(TokenType.WHERE, advance=False): kwargs["where"] = self._parse_where() elif self._match(TokenType.ORDER_BY, advance=False): @@ -3237,8 +3233,8 @@ def _parse_wrapped_select(self, table: bool = False) -> t.Optional[exp.Expressio # Support parentheses for duckdb FROM-first syntax select = self._parse_select(from_=from_) if select: - if not select.args.get("from"): - select.set("from", from_) + if not select.args.get("from_"): + select.set("from_", from_) this = select else: this = exp.select("*").from_(t.cast(exp.From, from_)) @@ -3304,8 +3300,8 @@ def _parse_select_query( while isinstance(this, exp.Subquery) and this.is_wrapper: this = this.this - if "with" in this.arg_types: - this.set("with", cte) + if "with_" in this.arg_types: + this.set("with_", cte) else: self.raise_error(f"{this.key} does not support CTE") this = cte @@ -3371,7 +3367,7 @@ def _parse_select_query( from_ = self._parse_from() if from_: - this.set("from", from_) + this.set("from_", from_) this = self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): @@ -3537,7 +3533,7 @@ def _parse_subquery( def _implicit_unnests_to_explicit(self, this: E) -> E: from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm - refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name} + refs = {_norm(this.args["from_"].this.copy(), dialect=self.dialect).alias_or_name} for i, join in enumerate(this.args.get("joins") or []): table = join.this normalized_table = table.copy() @@ -3602,7 +3598,7 @@ def _parse_query_modifiers(self, this): continue break - if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from"): + if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from_"): this = self._implicit_unnests_to_explicit(this) return this @@ -4769,12 +4765,10 @@ def _parse_ordered( if self._match_text_seq("WITH", "FILL"): with_fill = self.expression( exp.WithFill, - **{ # type: ignore - "from": self._match(TokenType.FROM) and self._parse_bitwise(), - "to": self._match_text_seq("TO") and self._parse_bitwise(), - "step": self._match_text_seq("STEP") and self._parse_bitwise(), - "interpolate": self._parse_interpolate(), - }, + from_=self._match(TokenType.FROM) and self._parse_bitwise(), + to=self._match_text_seq("TO") and self._parse_bitwise(), + step=self._match_text_seq("STEP") and self._parse_bitwise(), + interpolate=self._parse_interpolate(), ) else: with_fill = None @@ -5087,7 +5081,10 @@ def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expressi unique = self._match(TokenType.UNIQUE) self._match_text_seq("KEYS") expression: t.Optional[exp.Expression] = self.expression( - exp.JSON, **{"this": kind, "with": _with, "unique": unique} + exp.JSON, + this=kind, + with_=_with, + unique=unique, ) else: expression = self._parse_null() or self._parse_bitwise() @@ -8098,7 +8095,7 @@ def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: exp.SetItem, expressions=characteristics, kind="TRANSACTION", - **{"global": global_}, # type: ignore + global_=global_, ) def _parse_set_item(self) -> t.Optional[exp.Expression]: @@ -8581,11 +8578,9 @@ def _parse_star_ops(self) -> t.Optional[exp.Expression]: return self.expression( exp.Star, - **{ # type: ignore - "except": self._parse_star_op("EXCEPT", "EXCLUDE"), - "replace": self._parse_star_op("REPLACE"), - "rename": self._parse_star_op("RENAME"), - }, + except_=self._parse_star_op("EXCEPT", "EXCLUDE"), + replace=self._parse_star_op("REPLACE"), + rename=self._parse_star_op("RENAME"), ).update_positions(star_token) def _parse_grant_privilege(self) -> t.Optional[exp.GrantPrivilege]: @@ -8685,12 +8680,10 @@ def _parse_revoke(self) -> exp.Revoke | exp.Command: def _parse_overlay(self) -> exp.Overlay: return self.expression( exp.Overlay, - **{ # type: ignore - "this": self._parse_bitwise(), - "expression": self._match_text_seq("PLACING") and self._parse_bitwise(), - "from": self._match_text_seq("FROM") and self._parse_bitwise(), - "for": self._match_text_seq("FOR") and self._parse_bitwise(), - }, + this=self._parse_bitwise(), + expression=self._match_text_seq("PLACING") and self._parse_bitwise(), + from_=self._match_text_seq("FROM") and self._parse_bitwise(), + for_=self._match_text_seq("FOR") and self._parse_bitwise(), ) def _parse_format_name(self) -> exp.Property: @@ -8733,12 +8726,12 @@ def _build_pipe_cte( self._pipe_cte_counter += 1 new_cte = f"__tmp{self._pipe_cte_counter}" - with_ = query.args.get("with") + with_ = query.args.get("with_") ctes = with_.pop() if with_ else None new_select = exp.select(*expressions, copy=False).from_(new_cte, copy=False) if ctes: - new_select.set("with", ctes) + new_select.set("with_", ctes) return new_select.with_(new_cte, as_=query, copy=False) @@ -8833,7 +8826,7 @@ def _parse_and_unwrap_query() -> t.Optional[exp.Select]: ] query = self._build_pipe_cte(query=query, expressions=[exp.Star()]) - with_ = query.args.get("with") + with_ = query.args.get("with_") ctes = with_.pop() if with_ else None if isinstance(first_setop, exp.Union): @@ -8843,7 +8836,7 @@ def _parse_and_unwrap_query() -> t.Optional[exp.Select]: else: query = query.intersect(*setops, copy=False, **first_setop.args) - query.set("with", ctes) + query.set("with_", ctes) return self._build_pipe_cte(query=query, expressions=[exp.Star()]) @@ -8862,7 +8855,7 @@ def _parse_pipe_syntax_pivot(self, query: exp.Select) -> exp.Select: if not pivots: return query - from_ = query.args.get("from") + from_ = query.args.get("from_") if from_: from_.this.set("pivots", pivots) @@ -8876,7 +8869,7 @@ def _parse_pipe_syntax_extend(self, query: exp.Select) -> exp.Select: def _parse_pipe_syntax_tablesample(self, query: exp.Select) -> exp.Select: sample = self._parse_table_sample() - with_ = query.args.get("with") + with_ = query.args.get("with_") if with_: with_.expressions[-1].this.set("sample", sample) else: @@ -8888,7 +8881,7 @@ def _parse_pipe_syntax_query(self, query: exp.Query) -> t.Optional[exp.Query]: if isinstance(query, exp.Subquery): query = exp.select("*").from_(query, copy=False) - if not query.args.get("from"): + if not query.args.get("from_"): query = exp.select("*").from_(query.subquery(copy=False), copy=False) while self._match(TokenType.PIPE_GT): diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 687bffb9fa..ebd34fe1e1 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -94,7 +94,7 @@ def from_expression( """ ctes = ctes or {} expression = expression.unnest() - with_ = expression.args.get("with") + with_ = expression.args.get("with_") # CTEs break the mold of scope and introduce themselves to all in the context. if with_: @@ -104,7 +104,7 @@ def from_expression( step.name = cte.alias ctes[step.name] = step # type: ignore - from_ = expression.args.get("from") + from_ = expression.args.get("from_") if isinstance(expression, exp.Select) and from_: step = Scan.from_expression(from_.this, ctes) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index c321ff680c..fbccdaf728 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -114,10 +114,10 @@ def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) - count += 1 if recursive_ctes: - with_expression = expression.args.get("with") or exp.With() + with_expression = expression.args.get("with_") or exp.With() with_expression.set("recursive", True) with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) - expression.set("with", with_expression) + expression.set("with_", with_expression) return expression @@ -314,14 +314,14 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: return exp.Inline if has_multi_expr else exp.Explode if isinstance(expression, exp.Select): - from_ = expression.args.get("from") + from_ = expression.args.get("from_") if from_ and isinstance(from_.this, exp.Unnest): unnest = from_.this alias = unnest.args.get("alias") exprs = unnest.expressions has_multi_expr = len(exprs) > 1 - this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) + this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr) columns = alias.columns if alias else [] offset = unnest.args.get("offset") @@ -332,10 +332,7 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: unnest.replace( exp.Table( - this=_udtf_type(unnest, has_multi_expr)( - this=this, - expressions=expressions, - ), + this=_udtf_type(unnest, has_multi_expr)(this=this), alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None, ) ) @@ -498,7 +495,7 @@ def new_name(names: t.Set[str], name: str) -> str: expression.set("expressions", expressions) if not arrays: - if expression.args.get("from"): + if expression.args.get("from_"): expression.join(series, copy=False, join_type="CROSS") else: expression.from_(series, copy=False) @@ -642,7 +639,7 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: expression.set("limit", None) index, full_outer_join = full_outer_joins[0] - tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) + tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name) join_conditions = full_outer_join.args.get("on") or exp.and_( *[ exp.column(col, tables[0]).eq(exp.column(col, tables[1])) @@ -651,10 +648,12 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: ) full_outer_join.set("side", "left") - anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) + anti_join_clause = ( + exp.select("1").from_(expression.args["from_"]).where(join_conditions) + ) expression_copy.args["joins"][index].set("side", "right") expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) - expression_copy.set("with", None) # remove CTEs from RIGHT side + expression_copy.set("with_", None) # remove CTEs from RIGHT side expression.set("order", None) # remove order by from LEFT side return exp.union(expression, expression_copy, copy=False, distinct=False) @@ -674,14 +673,14 @@ def move_ctes_to_top_level(expression: E) -> E: TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). """ - top_level_with = expression.args.get("with") + top_level_with = expression.args.get("with_") for inner_with in expression.find_all(exp.With): if inner_with.parent is expression: continue if not top_level_with: top_level_with = inner_with.pop() - expression.set("with", top_level_with) + expression.set("with_", top_level_with) else: if inner_with.recursive: top_level_with.set("recursive", True) @@ -908,7 +907,7 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: old_joins = {join.alias_or_name: join for join in joins} new_joins = {} - query_from = query.args["from"] + query_from = query.args["from_"] for table, predicates in joins_ons.items(): join_what = old_joins.get(table, query_from).this.copy() @@ -934,11 +933,11 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: ), "Cannot determine which table to use in the new FROM clause" new_from_name = list(only_old_joins)[0] - query.set("from", exp.From(this=old_joins[new_from_name].this)) + query.set("from_", exp.From(this=old_joins[new_from_name].this)) if new_joins: for n, j in old_joins.items(): # preserve any other joins - if n not in new_joins and n != query.args["from"].name: + if n not in new_joins and n != query.args["from_"].name: if not j.kind: j.set("kind", "CROSS") new_joins[n] = j diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index fb6bdc673a..02ffbb7abf 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -2986,29 +2986,29 @@ def test_identifier_meta(self): self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"}) self.assertEqual( - ast.this.args["from"].this.args["this"].meta, + ast.this.args["from_"].this.args["this"].meta, {"line": 1, "col": 41, "start": 29, "end": 40}, ) self.assertEqual( - ast.this.args["from"].this.args["db"].meta, + ast.this.args["from_"].this.args["db"].meta, {"line": 1, "col": 28, "start": 17, "end": 27}, ) self.assertEqual( - ast.expression.args["from"].this.args["this"].meta, + ast.expression.args["from_"].this.args["this"].meta, {"line": 1, "col": 106, "start": 94, "end": 105}, ) self.assertEqual( - ast.expression.args["from"].this.args["db"].meta, + ast.expression.args["from_"].this.args["db"].meta, {"line": 1, "col": 93, "start": 82, "end": 92}, ) self.assertEqual( - ast.expression.args["from"].this.args["catalog"].meta, + ast.expression.args["from_"].this.args["catalog"].meta, {"line": 1, "col": 81, "start": 69, "end": 80}, ) information_schema_sql = "SELECT a, b FROM region.INFORMATION_SCHEMA.COLUMNS" ast = parse_one(information_schema_sql, dialect="bigquery") - meta = ast.args["from"].this.this.meta + meta = ast.args["from_"].this.this.meta self.assertEqual(meta, {"line": 1, "col": 50, "start": 24, "end": 49}) assert ( information_schema_sql[meta["start"] : meta["end"] + 1] == "INFORMATION_SCHEMA.COLUMNS" @@ -3017,14 +3017,14 @@ def test_identifier_meta(self): def test_quoted_identifier_meta(self): sql = "SELECT `a` FROM `test_schema`.`test_table_a`" ast = parse_one(sql, dialect="bigquery") - db_meta = ast.args["from"].this.args["db"].meta + db_meta = ast.args["from_"].this.args["db"].meta self.assertEqual(sql[db_meta["start"] : db_meta["end"] + 1], "`test_schema`") - table_meta = ast.args["from"].this.this.meta + table_meta = ast.args["from_"].this.this.meta self.assertEqual(sql[table_meta["start"] : table_meta["end"] + 1], "`test_table_a`") information_schema_sql = "SELECT a, b FROM `region.INFORMATION_SCHEMA.COLUMNS`" ast = parse_one(information_schema_sql, dialect="bigquery") - table_meta = ast.args["from"].this.this.meta + table_meta = ast.args["from_"].this.this.meta assert ( information_schema_sql[table_meta["start"] : table_meta["end"] + 1] == "`region.INFORMATION_SCHEMA.COLUMNS`" diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 4ec7d7449e..93a474b839 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -716,8 +716,8 @@ def test_cte(self): self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") query = parse_one("""WITH (SELECT 1) AS y SELECT * FROM y""", read="clickhouse") - self.assertIsInstance(query.args["with"].expressions[0].this, exp.Subquery) - self.assertEqual(query.args["with"].expressions[0].alias, "y") + self.assertIsInstance(query.args["with_"].expressions[0].this, exp.Subquery) + self.assertEqual(query.args["with_"].expressions[0].alias, "y") query = "WITH 1 AS var SELECT var" for error_level in [ErrorLevel.IGNORE, ErrorLevel.RAISE, ErrorLevel.IMMEDIATE]: diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index de855283f6..fbbc6683b5 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -610,12 +610,12 @@ def test_column_unnesting(self): ) ast = parse_one("SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3", read="redshift") - ast.args["from"].this.assert_is(exp.Table) + ast.args["from_"].this.assert_is(exp.Table) ast.args["joins"][0].this.assert_is(exp.Table) self.assertEqual(ast.sql("redshift"), "SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3") ast = parse_one("SELECT * FROM t AS t CROSS JOIN t.c1", read="redshift") - ast.args["from"].this.assert_is(exp.Table) + ast.args["from_"].this.assert_is(exp.Table) ast.args["joins"][0].this.assert_is(exp.Unnest) self.assertEqual(ast.sql("redshift"), "SELECT * FROM t AS t CROSS JOIN t.c1") @@ -623,7 +623,7 @@ def test_column_unnesting(self): "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l", read="redshift" ) joins = ast.args["joins"] - ast.args["from"].this.assert_is(exp.Table) + ast.args["from_"].this.assert_is(exp.Table) joins[0].this.assert_is(exp.Unnest) joins[1].this.assert_is(exp.Unnest) joins[2].this.assert_is(exp.Unnest).expressions[0].assert_is(exp.Dot) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index cb98cfe8e8..59042d392a 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -243,7 +243,7 @@ def test_snowflake(self): "SELECT STRTOK('hello world', ' ', 2)", "SELECT SPLIT_PART('hello world', ' ', 2)" ) self.validate_identity("SELECT FILE_URL FROM DIRECTORY(@mystage) WHERE SIZE > 100000").args[ - "from" + "from_" ].this.this.assert_is(exp.DirectoryStage).this.assert_is(exp.Var) self.validate_identity( "SELECT AI_CLASSIFY('text', ['travel', 'cooking'], OBJECT_CONSTRUCT('output_mode', 'multi'))" diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index a21156b9c6..c67ae2231b 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -2025,11 +2025,11 @@ def test_identifier_prefixes(self): self.validate_identity("##x") .assert_is(exp.Column) .this.assert_is(exp.Identifier) - .args.get("global") + .args.get("global_") ) self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var) - self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is( + self.validate_identity("SELECT * FROM @x").args["from_"].this.assert_is( exp.Table ).this.assert_is(exp.Parameter).this.assert_is(exp.Var) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 30a92bf706..fb6a4320ce 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -211,11 +211,11 @@ def test_alias_or_name(self): ) self.assertEqual( - [e.alias_or_name for e in expression.args["with"].expressions], + [e.alias_or_name for e in expression.args["with_"].expressions], ["first", "second"], ) - self.assertEqual("first", expression.args["from"].alias_or_name) + self.assertEqual("first", expression.args["from_"].alias_or_name) self.assertEqual( [e.alias_or_name for e in expression.args["joins"]], ["second", "third"], @@ -1177,10 +1177,10 @@ def test_unnest(self): self.assertIs(ast.selects[0].unnest(), ast.find(exp.Literal)) ast = parse_one("SELECT * FROM (((SELECT * FROM t)))") - self.assertIs(ast.args["from"].this.unnest(), list(ast.find_all(exp.Select))[1]) + self.assertIs(ast.args["from_"].this.unnest(), list(ast.find_all(exp.Select))[1]) ast = parse_one("SELECT * FROM ((((SELECT * FROM t))) AS foo)") - second_subquery = ast.args["from"].this.this + second_subquery = ast.args["from_"].this.this innermost_subquery = list(ast.find_all(exp.Select))[1].parent self.assertIs(second_subquery, innermost_subquery.unwrap()) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 63a3dc8f46..1f9dc12459 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -626,7 +626,7 @@ def test_simplify(self): self.assertEqual( optimizer.simplify.gen(sql), """ -SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True +SELECT :with_,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from_,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from_,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True """.strip(), ) self.assertEqual( @@ -1131,7 +1131,7 @@ def test_derived_tables_column_annotation(self): expression.expressions[0].type.this, exp.DataType.Type.FLOAT ) # a.cola AS cola - addition_alias = expression.args["from"].this.this.expressions[0] + addition_alias = expression.args["from_"].this.this.expressions[0] self.assertEqual( addition_alias.type.this, exp.DataType.Type.FLOAT ) # x.cola + y.cola AS cola @@ -1177,7 +1177,7 @@ def test_cte_column_annotation(self): # WHERE tbl.colc = True self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN) - cte_select = expression.args["with"].expressions[0].this + cte_select = expression.args["with_"].expressions[0].this self.assertEqual( cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola + 'bla' AS cola diff --git a/tests/test_parser.py b/tests/test_parser.py index 11fbf41e89..f10c03111b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -231,7 +231,7 @@ def test_identify(self): assert expression.expressions[2].alias == "c" assert expression.expressions[3].alias == "D" assert expression.expressions[4].alias == "y|z'" - table = expression.args["from"].this + table = expression.args["from_"].this assert table.name == "z" assert table.args["db"].name == "y" @@ -243,8 +243,8 @@ def test_multi(self): ) assert len(expressions) == 2 - assert expressions[0].args["from"].name == "a" - assert expressions[1].args["from"].name == "b" + assert expressions[0].args["from_"].name == "a" + assert expressions[1].args["from_"].name == "b" expressions = parse("SELECT 1; ; SELECT 2") @@ -253,19 +253,20 @@ def test_multi(self): def test_expression(self): ignore = Parser(error_level=ErrorLevel.IGNORE) - self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint) + self.assertIsInstance(ignore.expression(exp.Hint, expressions=[]), exp.Hint) self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint) self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint) default = Parser(error_level=ErrorLevel.RAISE) - self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint) - default.expression(exp.Hint, y="") + with self.assertRaises(TypeError): + default.expression(exp.Hint, y="") + self.assertIsInstance(default.expression(exp.Hint, expressions=[]), exp.Hint) default.expression(exp.Hint) - self.assertEqual(len(default.errors), 3) + self.assertEqual(len(default.errors), 2) warn = Parser(error_level=ErrorLevel.WARN) - warn.expression(exp.Hint, y="") - self.assertEqual(len(warn.errors), 2) + warn.expression(exp.Hint) + self.assertEqual(len(warn.errors), 1) def test_parse_errors(self): with self.assertRaises(ParseError): @@ -374,8 +375,8 @@ def test_comments_select_cte(self): ) self.assertEqual(expression.comments, ["comment2"]) - self.assertEqual(expression.args.get("from").comments, ["comment3"]) - self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args.get("from_").comments, ["comment3"]) + self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"]) def test_comments_insert(self): expression = parse_one( @@ -406,7 +407,7 @@ def test_comments_insert_cte(self): self.assertEqual(expression.comments, ["comment2"]) self.assertEqual(expression.this.comments, ["comment3"]) - self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"]) def test_comments_update(self): expression = parse_one( @@ -440,7 +441,7 @@ def test_comments_update_cte(self): self.assertEqual(expression.comments, ["comment2"]) self.assertEqual(expression.this.comments, ["comment3"]) - self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"]) def test_comments_delete(self): expression = parse_one( @@ -472,7 +473,7 @@ def test_comments_delete_cte(self): self.assertEqual(expression.comments, ["comment2"]) self.assertEqual(expression.this.comments, ["comment3"]) - self.assertEqual(expression.args["with"].comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args["with_"].comments, ["comment1.1", "comment1.2"]) def test_type_literals(self): self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) @@ -768,7 +769,7 @@ def test_pivot_columns(self): for dialect, expected_columns in dialect_columns.items(): with self.subTest(f"Testing query '{query}' for dialect {dialect}"): expr = parse_one(query, read=dialect) - columns = expr.args["from"].this.args["pivots"][0].args["columns"] + columns = expr.args["from_"].this.args["pivots"][0].args["columns"] self.assertEqual( expected_columns, [col.sql(dialect=dialect) for col in columns] ) @@ -957,23 +958,23 @@ def test_token_position_meta(self): self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"}) self.assertEqual( - ast.this.args["from"].this.args["this"].meta, + ast.this.args["from_"].this.args["this"].meta, {"line": 1, "col": 41, "start": 29, "end": 40}, ) self.assertEqual( - ast.this.args["from"].this.args["db"].meta, + ast.this.args["from_"].this.args["db"].meta, {"line": 1, "col": 28, "start": 17, "end": 27}, ) self.assertEqual( - ast.expression.args["from"].this.args["this"].meta, + ast.expression.args["from_"].this.args["this"].meta, {"line": 1, "col": 106, "start": 94, "end": 105}, ) self.assertEqual( - ast.expression.args["from"].this.args["db"].meta, + ast.expression.args["from_"].this.args["db"].meta, {"line": 1, "col": 93, "start": 82, "end": 92}, ) self.assertEqual( - ast.expression.args["from"].this.args["catalog"].meta, + ast.expression.args["from_"].this.args["catalog"].meta, {"line": 1, "col": 81, "start": 69, "end": 80}, ) @@ -993,10 +994,10 @@ def test_quoted_identifier_meta(self): sql = 'SELECT "a" FROM "test_schema"."test_table_a"' ast = parse_one(sql) - db_meta = ast.args["from"].this.args["db"].meta + db_meta = ast.args["from_"].this.args["db"].meta self.assertEqual(sql[db_meta["start"] : db_meta["end"] + 1], '"test_schema"') - table_meta = ast.args["from"].this.this.meta + table_meta = ast.args["from_"].this.this.meta self.assertEqual(sql[table_meta["start"] : table_meta["end"] + 1], '"test_table_a"') def test_qualified_function(self):