From 2a0388015919973164590d3a89098a6ccc6b9894 Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 14 Nov 2025 21:02:38 -0800 Subject: [PATCH 1/4] fix: arg cleanup --- sqlglot/dialects/presto.py | 8 +++++++- sqlglot/dialects/tsql.py | 1 - sqlglot/expressions.py | 4 ++-- sqlglot/optimizer/qualify_columns.py | 3 ++- sqlglot/transforms.py | 7 ++----- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 0b73260f4c..10f25359af 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -417,7 +417,13 @@ class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: rename_func("ARBITRARY"), - exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), + exp.ApproxQuantile: lambda self, e: self.func( + "APPROX_PERCENTILE", + e.this, + e.args.get("weight"), + e.args.get("quantile"), + e.args.get("accuracy"), + ), exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), exp.Array: transforms.preprocess( diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 251f689219..a417e5f229 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -825,7 +825,6 @@ def _parse_convert( args = [this, *self._parse_csv(self._parse_assignment)] convert = exp.Convert.from_arg_list(args) convert.set("safe", safe) - convert.set("strict", strict) return convert def _parse_column_def( diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 372501b576..7b3d3907cc 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2890,7 +2890,7 @@ class DataBlocksizeProperty(Property): class DataDeletionProperty(Property): - arg_types = {"on": True, "filter_col": False, "retention_period": False} + arg_types = {"on": True, "filter_column": False, "retention_period": False} class DefinerProperty(Property): @@ -5863,7 +5863,7 @@ class Columns(Func): # https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax class Convert(Func): - arg_types = {"this": True, "expression": True, "style": False} + arg_types = {"this": True, "expression": True, "style": False, "safe": False} # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/CONVERT.html diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e95c63321f..b390e8e0ab 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -960,7 +960,8 @@ def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: else: projection = alias(projection, alias=_alias) new_expressions.append(projection) - cte.this.set("expressions", new_expressions) + if isinstance(cte.this, exp.Select): + cte.this.set("expressions", new_expressions) return expression diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index c321ff680c..7fdbcaa8bd 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -321,7 +321,7 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: alias = unnest.args.get("alias") exprs = unnest.expressions has_multi_expr = len(exprs) > 1 - this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) + this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr) columns = alias.columns if alias else [] offset = unnest.args.get("offset") @@ -332,10 +332,7 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: unnest.replace( exp.Table( - this=_udtf_type(unnest, has_multi_expr)( - this=this, - expressions=expressions, - ), + this=_udtf_type(unnest, has_multi_expr)(this=this), alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None, ) ) From ddaf432fc343345148ee8339edc990a43d82db88 Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 14 Nov 2025 21:18:17 -0800 Subject: [PATCH 2/4] refactor!!: rename reserved python kwargs --- sqlglot/dialects/clickhouse.py | 2 +- sqlglot/dialects/duckdb.py | 10 ++- sqlglot/dialects/hive.py | 2 +- sqlglot/dialects/mysql.py | 4 +- sqlglot/dialects/singlestore.py | 2 +- sqlglot/dialects/snowflake.py | 26 ++++---- sqlglot/dialects/teradata.py | 13 ++-- sqlglot/dialects/tsql.py | 10 +-- sqlglot/expressions.py | 60 +++++++++--------- sqlglot/generator.py | 26 ++++---- sqlglot/lineage.py | 2 +- sqlglot/optimizer/eliminate_joins.py | 2 +- sqlglot/optimizer/eliminate_subqueries.py | 4 +- sqlglot/optimizer/merge_subqueries.py | 10 +-- sqlglot/optimizer/pushdown_predicates.py | 2 +- sqlglot/optimizer/qualify_columns.py | 4 +- sqlglot/optimizer/qualify_tables.py | 2 +- sqlglot/optimizer/scope.py | 6 +- sqlglot/optimizer/unnest_subqueries.py | 2 +- sqlglot/parser.py | 77 +++++++++++------------ sqlglot/planner.py | 4 +- sqlglot/transforms.py | 26 ++++---- tests/dialects/test_bigquery.py | 18 +++--- tests/dialects/test_clickhouse.py | 4 +- tests/dialects/test_redshift.py | 6 +- tests/dialects/test_snowflake.py | 2 +- tests/dialects/test_tsql.py | 4 +- tests/test_expressions.py | 8 +-- tests/test_optimizer.py | 6 +- tests/test_parser.py | 32 +++++----- 30 files changed, 182 insertions(+), 194 deletions(-) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index a9b3ef6cae..d70a919517 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -821,7 +821,7 @@ def _parse_join( if join: method = join.args.get("method") join.set("method", None) - join.set("global", method) + join.set("global_", method) # tbl ARRAY JOIN arr <-- this should be a `Column` reference, not a `Table` # https://clickhouse.com/docs/en/sql-reference/statements/select/array-join diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6a5d4fce37..1b7ceea4ef 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -662,11 +662,9 @@ def _parse_force(self) -> exp.Install | exp.Command: def _parse_install(self, force: bool = False) -> exp.Install: return self.expression( exp.Install, - **{ # type: ignore - "this": self._parse_id_var(), - "from": self._parse_var_or_string() if self._match(TokenType.FROM) else None, - "force": force, - }, + this=self._parse_id_var(), + from_=self._parse_var_or_string() if self._match(TokenType.FROM) else None, + force=force, ) def _parse_primary(self) -> t.Optional[exp.Expression]: @@ -1008,7 +1006,7 @@ def show_sql(self, expression: exp.Show) -> str: def install_sql(self, expression: exp.Install) -> str: force = "FORCE " if expression.args.get("force") else "" this = self.sql(expression, "this") - from_clause = expression.args.get("from") + from_clause = expression.args.get("from_") from_clause = f" FROM {from_clause}" if from_clause else "" return f"{force}INSTALL {this}{from_clause}" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 34a5ba10b4..581a92f51c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -834,7 +834,7 @@ def alterset_sql(self, expression: exp.AlterSet) -> str: return f"SET{serde}{exprs}{location}{file_format}{tags}" def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str: - prefix = "WITH " if expression.args.get("with") else "" + prefix = "WITH " if expression.args.get("with_") else "" exprs = self.expressions(expression, flat=True) return f"{prefix}SERDEPROPERTIES ({exprs})" diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 0bf249618a..39f3069af7 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -671,7 +671,7 @@ def _parse_show_mysql( for_role=for_role, into_outfile=into_outfile, json=json, - **{"global": global_}, # type: ignore + global_=global_, ) def _parse_oldstyle_limit( @@ -1229,7 +1229,7 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> def show_sql(self, expression: exp.Show) -> str: this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" - global_ = " GLOBAL" if expression.args.get("global") else "" + global_ = " GLOBAL" if expression.args.get("global_") else "" target = self.sql(expression, "target") target = f" {target}" if target else "" diff --git a/sqlglot/dialects/singlestore.py b/sqlglot/dialects/singlestore.py index 5475fe3d09..58fbf98801 100644 --- a/sqlglot/dialects/singlestore.py +++ b/sqlglot/dialects/singlestore.py @@ -542,7 +542,7 @@ def _unicode_substitute(m: re.Match[str]) -> str: "offset", "starts_with", "limit", - "from", + "from_", "scope", "scope_kind", "mutex", diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a4fb5e375c..4b247cad73 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1114,19 +1114,17 @@ def _parse_show_snowflake(self, this: str) -> exp.Show: return self.expression( exp.Show, - **{ - "terse": terse, - "this": this, - "history": history, - "like": like, - "scope": scope, - "scope_kind": scope_kind, - "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(), - "limit": self._parse_limit(), - "from": self._parse_string() if self._match(TokenType.FROM) else None, - "privileges": self._match_text_seq("WITH", "PRIVILEGES") - and self._parse_csv(lambda: self._parse_var(any_token=True, upper=True)), - }, + terse=terse, + this=this, + history=history, + like=like, + scope=scope, + scope_kind=scope_kind, + starts_with=self._match_text_seq("STARTS", "WITH") and self._parse_string(), + limit=self._parse_limit(), + from_=self._parse_string() if self._match(TokenType.FROM) else None, + privileges=self._match_text_seq("WITH", "PRIVILEGES") + and self._parse_csv(lambda: self._parse_var(any_token=True, upper=True)), ) def _parse_put(self) -> exp.Put | exp.Command: @@ -1652,7 +1650,7 @@ def show_sql(self, expression: exp.Show) -> str: limit = self.sql(expression, "limit") - from_ = self.sql(expression, "from") + from_ = self.sql(expression, "from_") if from_: from_ = f" FROM {from_}" diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 2f6c22a276..eb435a1ae8 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -213,13 +213,10 @@ def _parse_translate(self) -> exp.TranslateCharacters: def _parse_update(self) -> exp.Update: return self.expression( exp.Update, - **{ # type: ignore - "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), - "from": self._parse_from(joins=True), - "expressions": self._match(TokenType.SET) - and self._parse_csv(self._parse_equality), - "where": self._parse_where(), - }, + this=self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), + from_=self._parse_from(joins=True), + expressions=self._match(TokenType.SET) and self._parse_csv(self._parse_equality), + where=self._parse_where(), ) def _parse_rangen(self): @@ -387,7 +384,7 @@ def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> st # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause def update_sql(self, expression: exp.Update) -> str: this = self.sql(expression, "this") - from_sql = self.sql(expression, "from") + from_sql = self.sql(expression, "from_") set_sql = self.expressions(expression, flat=True) where_sql = self.sql(expression, "where") sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a417e5f229..ce801f71e3 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -575,7 +575,7 @@ class Parser(parser.Parser): QUERY_MODIFIER_PARSERS = { **parser.Parser.QUERY_MODIFIER_PARSERS, TokenType.OPTION: lambda self: ("options", self._parse_options()), - TokenType.FOR: lambda self: ("for", self._parse_for()), + TokenType.FOR: lambda self: ("for_", self._parse_for()), } # T-SQL does not allow BEGIN to be used as an identifier @@ -881,7 +881,7 @@ def _parse_id_var( this = super()._parse_id_var(any_token=any_token, tokens=tokens) if this: if is_global: - this.set("global", True) + this.set("global_", True) elif is_temporary: this.set("temporary", True) @@ -1240,12 +1240,12 @@ def create_sql(self, expression: exp.Create) -> str: if kind == "VIEW": expression.this.set("catalog", None) - with_ = expression.args.get("with") + with_ = expression.args.get("with_") if ctas_expression and with_: # We've already preprocessed the Create expression to bubble up any nested CTEs, # but CREATE VIEW actually requires the WITH clause to come after it so we need # to amend the AST by moving the CTEs to the CREATE VIEW statement's query. - ctas_expression.set("with", with_.pop()) + ctas_expression.set("with_", with_.pop()) table = expression.find(exp.Table) @@ -1361,7 +1361,7 @@ def rollback_sql(self, expression: exp.Rollback) -> str: def identifier_sql(self, expression: exp.Identifier) -> str: identifier = super().identifier_sql(expression) - if expression.args.get("global"): + if expression.args.get("global_"): identifier = f"##{identifier}" elif expression.args.get("temporary"): identifier = f"#{identifier}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7b3d3907cc..62fcd8ee2d 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1247,7 +1247,7 @@ def order_by( @property def ctes(self) -> t.List[CTE]: """Returns a list of all the CTEs attached to this query.""" - with_ = self.args.get("with") + with_ = self.args.get("with_") return with_.expressions if with_ else [] @property @@ -1475,7 +1475,7 @@ class DDL(Expression): @property def ctes(self) -> t.List[CTE]: """Returns a list of all the CTEs attached to this statement.""" - with_ = self.args.get("with") + with_ = self.args.get("with_") return with_.expressions if with_ else [] @property @@ -1536,7 +1536,7 @@ def returning( class Create(DDL): arg_types = { - "with": False, + "with_": False, "this": True, "kind": True, "expression": False, @@ -1615,7 +1615,7 @@ class Detach(Expression): # https://duckdb.org/docs/sql/statements/load_and_install.html class Install(Expression): - arg_types = {"this": True, "from": False, "force": False} + arg_types = {"this": True, "from_": False, "force": False} # https://duckdb.org/docs/guides/meta/summarize.html @@ -1653,7 +1653,7 @@ class SetItem(Expression): "expressions": False, "kind": False, "collate": False, # MySQL SET NAMES statement - "global": False, + "global_": False, } @@ -1670,7 +1670,7 @@ class Show(Expression): "offset": False, "starts_with": False, "limit": False, - "from": False, + "from_": False, "like": False, "where": False, "db": False, @@ -1680,7 +1680,7 @@ class Show(Expression): "mutex": False, "query": False, "channel": False, - "global": False, + "global_": False, "log": False, "position": False, "types": False, @@ -2120,7 +2120,7 @@ class Constraint(Expression): class Delete(DML): arg_types = { - "with": False, + "with_": False, "this": False, "using": False, "where": False, @@ -2337,7 +2337,7 @@ class JoinHint(Expression): class Identifier(Expression): - arg_types = {"this": True, "quoted": False, "global": False, "temporary": False} + arg_types = {"this": True, "quoted": False, "global_": False, "temporary": False} @property def quoted(self) -> bool: @@ -2380,7 +2380,7 @@ class IndexParameters(Expression): class Insert(DDL, DML): arg_types = { "hint": False, - "with": False, + "with_": False, "is_function": False, "this": False, "expression": False, @@ -2607,7 +2607,7 @@ class Join(Expression): "kind": False, "using": False, "method": False, - "global": False, + "global_": False, "hint": False, "match_condition": False, # Snowflake "expressions": False, @@ -2786,7 +2786,7 @@ class Order(Expression): # https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier class WithFill(Expression): arg_types = { - "from": False, + "from_": False, "to": False, "step": False, "interpolate": False, @@ -3210,7 +3210,7 @@ class SemanticView(Expression): class SerdeProperties(Property): - arg_types = {"expressions": True, "with": False} + arg_types = {"expressions": True, "with_": False} class SetProperty(Property): @@ -3306,7 +3306,7 @@ class WithSystemVersioningProperty(Property): "this": False, "data_consistency": False, "retention_period": False, - "with": True, + "with_": True, } @@ -3574,7 +3574,7 @@ def to_column(self, copy: bool = True) -> Expression: class SetOperation(Query): arg_types = { - "with": False, + "with_": False, "this": True, "expression": True, "distinct": False, @@ -3643,10 +3643,10 @@ class Intersect(SetOperation): class Update(DML): arg_types = { - "with": False, + "with_": False, "this": False, "expressions": True, - "from": False, + "from_": False, "where": False, "returning": False, "order": False, @@ -3794,7 +3794,7 @@ def from_( return _apply_builder( expression=expression, instance=self, - arg="from", + arg="from_", into=From, prefix="FROM", dialect=dialect, @@ -3890,13 +3890,13 @@ class Lock(Expression): class Select(Query): arg_types = { - "with": False, + "with_": False, "kind": False, "expressions": False, "hint": False, "distinct": False, "into": False, - "from": False, + "from_": False, "operation_modifiers": False, **QUERY_MODIFIERS, } @@ -3925,7 +3925,7 @@ def from_( return _apply_builder( expression=expression, instance=self, - arg="from", + arg="from_", into=From, prefix="FROM", dialect=dialect, @@ -4427,7 +4427,7 @@ class Subquery(DerivedTable, Query): arg_types = { "this": True, "alias": False, - "with": False, + "with_": False, **QUERY_MODIFIERS, } @@ -4564,7 +4564,7 @@ class Where(Expression): class Star(Expression): - arg_types = {"except": False, "replace": False, "rename": False} + arg_types = {"except_": False, "replace": False, "rename": False} @property def name(self) -> str: @@ -5737,7 +5737,7 @@ class Transform(Func): class Translate(Func): - arg_types = {"this": True, "from": True, "to": True} + arg_types = {"this": True, "from_": True, "to": True} class Grouping(AggFunc): @@ -6771,7 +6771,7 @@ class IsNullValue(Func): # https://www.postgresql.org/docs/current/functions-json.html class JSON(Expression): - arg_types = {"this": False, "with": False, "unique": False} + arg_types = {"this": False, "with_": False, "unique": False} class JSONPath(Expression): @@ -7296,7 +7296,7 @@ class Normalize(Func): class Overlay(Func): - arg_types = {"this": True, "expression": True, "from": True, "for": False} + arg_types = {"this": True, "expression": True, "from_": True, "for_": False} # https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function @@ -7996,7 +7996,7 @@ class Merge(DML): "on": False, "using_cond": False, "whens": True, - "with": False, + "with_": False, "returning": False, } @@ -8317,7 +8317,7 @@ def _apply_cte_builder( return _apply_child_list_builder( cte, instance=instance, - arg="with", + arg="with_", append=append, copy=copy, into=With, @@ -8550,7 +8550,7 @@ def update( ) if from_: update_expr.set( - "from", + "from_", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), ) if isinstance(where, Condition): @@ -8566,7 +8566,7 @@ def update( for alias, qry in with_.items() ] update_expr.set( - "with", + "with_", With(expressions=cte_list), ) return update_expr diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d08b261232..e026cd1d01 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1306,7 +1306,7 @@ def heredoc_sql(self, expression: exp.Heredoc) -> str: return f"${tag}${self.sql(expression, 'this')}${tag}$" def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: - with_ = self.sql(expression, "with") + with_ = self.sql(expression, "with_") if with_: sql = f"{with_}{self.sep()}{sql}" return sql @@ -1944,7 +1944,7 @@ def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningP sql = f"SYSTEM_VERSIONING={on_sql}" - return f"WITH({sql})" if expression.args.get("with") else sql + return f"WITH({sql})" if expression.args.get("with_") else sql def insert_sql(self, expression: exp.Insert) -> str: hint = self.sql(expression, "hint") @@ -2212,7 +2212,7 @@ def tuple_sql(self, expression: exp.Tuple) -> str: def update_sql(self, expression: exp.Update) -> str: this = self.sql(expression, "this") set_sql = self.expressions(expression, flat=True) - from_sql = self.sql(expression, "from") + from_sql = self.sql(expression, "from_") where_sql = self.sql(expression, "where") returning = self.sql(expression, "returning") order = self.sql(expression, "order") @@ -2351,7 +2351,7 @@ def join_sql(self, expression: exp.Join) -> str: op for op in ( expression.method, - "GLOBAL" if expression.args.get("global") else None, + "GLOBAL" if expression.args.get("global_") else None, side, expression.kind, expression.hint if self.JOIN_HINTS else None, @@ -2466,7 +2466,7 @@ def setitem_sql(self, expression: exp.SetItem) -> str: expressions = self.expressions(expression) collate = self.sql(expression, "collate") collate = f" COLLATE {collate}" if collate else "" - global_ = "GLOBAL " if expression.args.get("global") else "" + global_ = "GLOBAL " if expression.args.get("global_") else "" return f"{global_}{kind}{this}{expressions}{collate}" def set_sql(self, expression: exp.Set) -> str: @@ -2564,7 +2564,7 @@ def order_sql(self, expression: exp.Order, flat: bool = False) -> str: return self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore def withfill_sql(self, expression: exp.WithFill) -> str: - from_sql = self.sql(expression, "from") + from_sql = self.sql(expression, "from_") from_sql = f" FROM {from_sql}" if from_sql else "" to_sql = self.sql(expression, "to") to_sql = f" TO {to_sql}" if to_sql else "" @@ -2725,7 +2725,7 @@ def options_modifier(self, expression: exp.Expression) -> str: return f" {options}" if options else "" def for_modifiers(self, expression: exp.Expression) -> str: - for_modifiers = self.expressions(expression, key="for") + for_modifiers = self.expressions(expression, key="for_") return f"{self.sep()}FOR XML{self.seg(for_modifiers)}" if for_modifiers else "" def queryoption_sql(self, expression: exp.QueryOption) -> str: @@ -2796,11 +2796,11 @@ def select_sql(self, expression: exp.Select) -> str: expression, f"SELECT{top_distinct}{operation_modifiers}{kind}{expressions}", self.sql(expression, "into", comment=False), - self.sql(expression, "from", comment=False), + self.sql(expression, "from_", comment=False), ) # If both the CTE and SELECT clauses have comments, generate the latter earlier - if expression.args.get("with"): + if expression.args.get("with_"): sql = self.maybe_comment(sql, expression) expression.pop_comments() @@ -2828,7 +2828,7 @@ def schema_columns_sql(self, expression: exp.Schema) -> str: return "" def star_sql(self, expression: exp.Star) -> str: - except_ = self.expressions(expression, key="except", flat=True) + except_ = self.expressions(expression, key="except_", flat=True) except_ = f"{self.seg(self.STAR_EXCEPT)} ({except_})" if except_ else "" replace = self.expressions(expression, key="replace", flat=True) replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" @@ -4826,7 +4826,7 @@ def json_sql(self, expression: exp.JSON) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" - _with = expression.args.get("with") + _with = expression.args.get("with_") if _with is None: with_sql = "" @@ -5010,8 +5010,8 @@ def columns_sql(self, expression: exp.Columns): def overlay_sql(self, expression: exp.Overlay): this = self.sql(expression, "this") expr = self.sql(expression, "expression") - from_sql = self.sql(expression, "from") - for_sql = self.sql(expression, "for") + from_sql = self.sql(expression, "from_") + for_sql = self.sql(expression, "for_") for_sql = f" FOR {for_sql}" if for_sql else "" return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})" diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index cc1aedd5b6..5a80a45fa3 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -61,7 +61,7 @@ def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: } for d in node.downstream: - edges.append({"from": node_id, "to": id(d)}) + edges.append({"from_": node_id, "to": id(d)}) return GraphHTML(nodes, edges, **opts) diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index da1fb7a217..2578fdd4f7 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -110,7 +110,7 @@ def _has_single_output_row(scope): return isinstance(scope.expression, exp.Select) and ( all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) or _is_limit_1(scope) - or not scope.expression.args.get("from") + or not scope.expression.args.get("from_") ) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index b661003690..58d71477ea 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -65,7 +65,7 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: # Existing CTES in the root expression. We'll use this for deduplication. existing_ctes: ExistingCTEsMapping = {} - with_ = root.expression.args.get("with") + with_ = root.expression.args.get("with_") recursive = False if with_: recursive = with_.args.get("recursive") @@ -97,7 +97,7 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: if new_ctes: query = expression.expression if isinstance(expression, exp.DDL) else expression - query.set("with", exp.With(expressions=new_ctes, recursive=recursive)) + query.set("with_", exp.With(expressions=new_ctes, recursive=recursive)) return expression diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 9b68c7da8e..b514785dff 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -48,7 +48,7 @@ def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E: # If a derived table has these Select args, it can't be merged UNMERGABLE_ARGS = set(exp.Select.arg_types) - { "expressions", - "from", + "from_", "joins", "where", "order", @@ -165,7 +165,7 @@ def _outer_select_joins_on_inner_select_join(): if not on: return False selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] - inner_from = inner_scope.expression.args.get("from") + inner_from = inner_scope.expression.args.get("from_") if not inner_from: return False inner_from_table = inner_from.alias_or_name @@ -197,7 +197,7 @@ def _is_recursive(): and not outer_scope.expression.is_star and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) - and inner_select.args.get("from") is not None + and inner_select.args.get("from_") is not None and not outer_scope.pivots and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) @@ -261,7 +261,7 @@ def _merge_from( """ Merge FROM clause of inner query into outer query. """ - new_subquery = inner_scope.expression.args["from"].this + new_subquery = inner_scope.expression.args["from_"].this new_subquery.set("joins", node_to_replace.args.get("joins")) node_to_replace.replace(new_subquery) for join_hint in outer_scope.join_hints: @@ -357,7 +357,7 @@ def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoi if isinstance(from_or_join, exp.Join): # Merge predicates from an outer join to the ON clause # if it only has columns that are already joined - from_ = expression.args.get("from") + from_ = expression.args.get("from_") sources = {from_.alias_or_name} if from_ else set() for join in expression.args["joins"]: diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 4e65f24fe1..9ac198645c 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -181,7 +181,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): - with_ = source.parent.expression.args.get("with") + with_ = source.parent.expression.args.get("with_") if with_ and with_.recursive: return {} node = source.expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index b390e8e0ab..d8f6e405b7 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -853,7 +853,7 @@ def _expand_stars( def _add_except_columns( expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] ) -> None: - except_ = expression.args.get("except") + except_ = expression.args.get("except_") if not except_: return @@ -1206,7 +1206,7 @@ def _get_available_source_columns( args = self.scope.expression.args # Collect tables in order: FROM clause tables + joined tables up to current join - from_name = args["from"].alias_or_name + from_name = args["from_"].alias_or_name available_sources = {from_name: self.get_source_columns(from_name)} for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index baff507965..cc50dfa6d7 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -68,7 +68,7 @@ def _qualify(table: exp.Table) -> None: table.set("catalog", catalog.copy()) if (db or catalog) and not isinstance(expression, exp.Query): - with_ = expression.args.get("with") or exp.With() + with_ = expression.args.get("with_") or exp.With() cte_names = {cte.alias_or_name for cte in with_.expressions} for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 85041fdb1f..96d3a60a1c 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -308,7 +308,7 @@ def columns(self): or column.name not in named_selects ) ) - or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") + or (isinstance(ancestor, exp.Star) and not column.arg_key == "except_") ): self._columns.append(column) @@ -663,7 +663,7 @@ def _traverse_ctes(scope): # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. # thus the recursive scope is the first section of the union. - with_ = scope.expression.args.get("with") + with_ = scope.expression.args.get("with_") if with_ and with_.recursive: union = cte.this @@ -720,7 +720,7 @@ def _traverse_tables(scope): # Traverse FROMs, JOINs, and LATERALs in the order they are defined expressions = [] - from_ = scope.expression.args.get("from") + from_ = scope.expression.args.get("from_") if from_: expressions.append(from_.this) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 163a5a8ec0..b723b2fc63 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -44,7 +44,7 @@ def unnest(select, parent_select, next_alias_name): if ( not predicate or parent_select is not predicate.parent_select - or not parent_select.args.get("from") + or not parent_select.args.get("from_") ): return diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c1c7bd48fb..f476bbc01f 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2382,10 +2382,8 @@ def _parse_system_versioning_property( self._match(TokenType.EQ) prop = self.expression( exp.WithSystemVersioningProperty, - **{ # type: ignore - "on": True, - "with": with_, - }, + on=True, + with_=with_, ) if self._match_text_seq("OFF"): @@ -3056,10 +3054,8 @@ def _parse_serde_properties(self, with_: bool = False) -> t.Optional[exp.SerdePr return None return self.expression( exp.SerdeProperties, - **{ # type: ignore - "expressions": self._parse_wrapped_properties(), - "with": with_, - }, + expressions=self._parse_wrapped_properties(), + with_=with_, ) def _parse_row_format( @@ -3147,7 +3143,7 @@ def _parse_update(self) -> exp.Update: elif self._match(TokenType.RETURNING, advance=False): kwargs["returning"] = self._parse_returning() elif self._match(TokenType.FROM, advance=False): - kwargs["from"] = self._parse_from(joins=True) + kwargs["from_"] = self._parse_from(joins=True) elif self._match(TokenType.WHERE, advance=False): kwargs["where"] = self._parse_where() elif self._match(TokenType.ORDER_BY, advance=False): @@ -3237,8 +3233,8 @@ def _parse_wrapped_select(self, table: bool = False) -> t.Optional[exp.Expressio # Support parentheses for duckdb FROM-first syntax select = self._parse_select(from_=from_) if select: - if not select.args.get("from"): - select.set("from", from_) + if not select.args.get("from_"): + select.set("from_", from_) this = select else: this = exp.select("*").from_(t.cast(exp.From, from_)) @@ -3304,8 +3300,8 @@ def _parse_select_query( while isinstance(this, exp.Subquery) and this.is_wrapper: this = this.this - if "with" in this.arg_types: - this.set("with", cte) + if "with_" in this.arg_types: + this.set("with_", cte) else: self.raise_error(f"{this.key} does not support CTE") this = cte @@ -3371,7 +3367,7 @@ def _parse_select_query( from_ = self._parse_from() if from_: - this.set("from", from_) + this.set("from_", from_) this = self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): @@ -3537,7 +3533,7 @@ def _parse_subquery( def _implicit_unnests_to_explicit(self, this: E) -> E: from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm - refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name} + refs = {_norm(this.args["from_"].this.copy(), dialect=self.dialect).alias_or_name} for i, join in enumerate(this.args.get("joins") or []): table = join.this normalized_table = table.copy() @@ -3602,7 +3598,7 @@ def _parse_query_modifiers(self, this): continue break - if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from"): + if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from_"): this = self._implicit_unnests_to_explicit(this) return this @@ -4769,12 +4765,10 @@ def _parse_ordered( if self._match_text_seq("WITH", "FILL"): with_fill = self.expression( exp.WithFill, - **{ # type: ignore - "from": self._match(TokenType.FROM) and self._parse_bitwise(), - "to": self._match_text_seq("TO") and self._parse_bitwise(), - "step": self._match_text_seq("STEP") and self._parse_bitwise(), - "interpolate": self._parse_interpolate(), - }, + from_=self._match(TokenType.FROM) and self._parse_bitwise(), + to=self._match_text_seq("TO") and self._parse_bitwise(), + step=self._match_text_seq("STEP") and self._parse_bitwise(), + interpolate=self._parse_interpolate(), ) else: with_fill = None @@ -5087,7 +5081,10 @@ def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expressi unique = self._match(TokenType.UNIQUE) self._match_text_seq("KEYS") expression: t.Optional[exp.Expression] = self.expression( - exp.JSON, **{"this": kind, "with": _with, "unique": unique} + exp.JSON, + this=kind, + with_=_with, + unique=unique, ) else: expression = self._parse_null() or self._parse_bitwise() @@ -8098,7 +8095,7 @@ def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: exp.SetItem, expressions=characteristics, kind="TRANSACTION", - **{"global": global_}, # type: ignore + global_=global_, ) def _parse_set_item(self) -> t.Optional[exp.Expression]: @@ -8581,11 +8578,9 @@ def _parse_star_ops(self) -> t.Optional[exp.Expression]: return self.expression( exp.Star, - **{ # type: ignore - "except": self._parse_star_op("EXCEPT", "EXCLUDE"), - "replace": self._parse_star_op("REPLACE"), - "rename": self._parse_star_op("RENAME"), - }, + except_=self._parse_star_op("EXCEPT", "EXCLUDE"), + replace=self._parse_star_op("REPLACE"), + rename=self._parse_star_op("RENAME"), ).update_positions(star_token) def _parse_grant_privilege(self) -> t.Optional[exp.GrantPrivilege]: @@ -8685,12 +8680,10 @@ def _parse_revoke(self) -> exp.Revoke | exp.Command: def _parse_overlay(self) -> exp.Overlay: return self.expression( exp.Overlay, - **{ # type: ignore - "this": self._parse_bitwise(), - "expression": self._match_text_seq("PLACING") and self._parse_bitwise(), - "from": self._match_text_seq("FROM") and self._parse_bitwise(), - "for": self._match_text_seq("FOR") and self._parse_bitwise(), - }, + this=self._parse_bitwise(), + expression=self._match_text_seq("PLACING") and self._parse_bitwise(), + from_=self._match_text_seq("FROM") and self._parse_bitwise(), + for_=self._match_text_seq("FOR") and self._parse_bitwise(), ) def _parse_format_name(self) -> exp.Property: @@ -8733,12 +8726,12 @@ def _build_pipe_cte( self._pipe_cte_counter += 1 new_cte = f"__tmp{self._pipe_cte_counter}" - with_ = query.args.get("with") + with_ = query.args.get("with_") ctes = with_.pop() if with_ else None new_select = exp.select(*expressions, copy=False).from_(new_cte, copy=False) if ctes: - new_select.set("with", ctes) + new_select.set("with_", ctes) return new_select.with_(new_cte, as_=query, copy=False) @@ -8833,7 +8826,7 @@ def _parse_and_unwrap_query() -> t.Optional[exp.Select]: ] query = self._build_pipe_cte(query=query, expressions=[exp.Star()]) - with_ = query.args.get("with") + with_ = query.args.get("with_") ctes = with_.pop() if with_ else None if isinstance(first_setop, exp.Union): @@ -8843,7 +8836,7 @@ def _parse_and_unwrap_query() -> t.Optional[exp.Select]: else: query = query.intersect(*setops, copy=False, **first_setop.args) - query.set("with", ctes) + query.set("with_", ctes) return self._build_pipe_cte(query=query, expressions=[exp.Star()]) @@ -8862,7 +8855,7 @@ def _parse_pipe_syntax_pivot(self, query: exp.Select) -> exp.Select: if not pivots: return query - from_ = query.args.get("from") + from_ = query.args.get("from_") if from_: from_.this.set("pivots", pivots) @@ -8876,7 +8869,7 @@ def _parse_pipe_syntax_extend(self, query: exp.Select) -> exp.Select: def _parse_pipe_syntax_tablesample(self, query: exp.Select) -> exp.Select: sample = self._parse_table_sample() - with_ = query.args.get("with") + with_ = query.args.get("with_") if with_: with_.expressions[-1].this.set("sample", sample) else: @@ -8888,7 +8881,7 @@ def _parse_pipe_syntax_query(self, query: exp.Query) -> t.Optional[exp.Query]: if isinstance(query, exp.Subquery): query = exp.select("*").from_(query, copy=False) - if not query.args.get("from"): + if not query.args.get("from_"): query = exp.select("*").from_(query.subquery(copy=False), copy=False) while self._match(TokenType.PIPE_GT): diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 687bffb9fa..ebd34fe1e1 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -94,7 +94,7 @@ def from_expression( """ ctes = ctes or {} expression = expression.unnest() - with_ = expression.args.get("with") + with_ = expression.args.get("with_") # CTEs break the mold of scope and introduce themselves to all in the context. if with_: @@ -104,7 +104,7 @@ def from_expression( step.name = cte.alias ctes[step.name] = step # type: ignore - from_ = expression.args.get("from") + from_ = expression.args.get("from_") if isinstance(expression, exp.Select) and from_: step = Scan.from_expression(from_.this, ctes) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 7fdbcaa8bd..fbccdaf728 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -114,10 +114,10 @@ def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) - count += 1 if recursive_ctes: - with_expression = expression.args.get("with") or exp.With() + with_expression = expression.args.get("with_") or exp.With() with_expression.set("recursive", True) with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) - expression.set("with", with_expression) + expression.set("with_", with_expression) return expression @@ -314,7 +314,7 @@ def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: return exp.Inline if has_multi_expr else exp.Explode if isinstance(expression, exp.Select): - from_ = expression.args.get("from") + from_ = expression.args.get("from_") if from_ and isinstance(from_.this, exp.Unnest): unnest = from_.this @@ -495,7 +495,7 @@ def new_name(names: t.Set[str], name: str) -> str: expression.set("expressions", expressions) if not arrays: - if expression.args.get("from"): + if expression.args.get("from_"): expression.join(series, copy=False, join_type="CROSS") else: expression.from_(series, copy=False) @@ -639,7 +639,7 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: expression.set("limit", None) index, full_outer_join = full_outer_joins[0] - tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) + tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name) join_conditions = full_outer_join.args.get("on") or exp.and_( *[ exp.column(col, tables[0]).eq(exp.column(col, tables[1])) @@ -648,10 +648,12 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: ) full_outer_join.set("side", "left") - anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) + anti_join_clause = ( + exp.select("1").from_(expression.args["from_"]).where(join_conditions) + ) expression_copy.args["joins"][index].set("side", "right") expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) - expression_copy.set("with", None) # remove CTEs from RIGHT side + expression_copy.set("with_", None) # remove CTEs from RIGHT side expression.set("order", None) # remove order by from LEFT side return exp.union(expression, expression_copy, copy=False, distinct=False) @@ -671,14 +673,14 @@ def move_ctes_to_top_level(expression: E) -> E: TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). """ - top_level_with = expression.args.get("with") + top_level_with = expression.args.get("with_") for inner_with in expression.find_all(exp.With): if inner_with.parent is expression: continue if not top_level_with: top_level_with = inner_with.pop() - expression.set("with", top_level_with) + expression.set("with_", top_level_with) else: if inner_with.recursive: top_level_with.set("recursive", True) @@ -905,7 +907,7 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: old_joins = {join.alias_or_name: join for join in joins} new_joins = {} - query_from = query.args["from"] + query_from = query.args["from_"] for table, predicates in joins_ons.items(): join_what = old_joins.get(table, query_from).this.copy() @@ -931,11 +933,11 @@ def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: ), "Cannot determine which table to use in the new FROM clause" new_from_name = list(only_old_joins)[0] - query.set("from", exp.From(this=old_joins[new_from_name].this)) + query.set("from_", exp.From(this=old_joins[new_from_name].this)) if new_joins: for n, j in old_joins.items(): # preserve any other joins - if n not in new_joins and n != query.args["from"].name: + if n not in new_joins and n != query.args["from_"].name: if not j.kind: j.set("kind", "CROSS") new_joins[n] = j diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index fb6bdc673a..02ffbb7abf 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -2986,29 +2986,29 @@ def test_identifier_meta(self): self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"}) self.assertEqual( - ast.this.args["from"].this.args["this"].meta, + ast.this.args["from_"].this.args["this"].meta, {"line": 1, "col": 41, "start": 29, "end": 40}, ) self.assertEqual( - ast.this.args["from"].this.args["db"].meta, + ast.this.args["from_"].this.args["db"].meta, {"line": 1, "col": 28, "start": 17, "end": 27}, ) self.assertEqual( - ast.expression.args["from"].this.args["this"].meta, + ast.expression.args["from_"].this.args["this"].meta, {"line": 1, "col": 106, "start": 94, "end": 105}, ) self.assertEqual( - ast.expression.args["from"].this.args["db"].meta, + ast.expression.args["from_"].this.args["db"].meta, {"line": 1, "col": 93, "start": 82, "end": 92}, ) self.assertEqual( - ast.expression.args["from"].this.args["catalog"].meta, + ast.expression.args["from_"].this.args["catalog"].meta, {"line": 1, "col": 81, "start": 69, "end": 80}, ) information_schema_sql = "SELECT a, b FROM region.INFORMATION_SCHEMA.COLUMNS" ast = parse_one(information_schema_sql, dialect="bigquery") - meta = ast.args["from"].this.this.meta + meta = ast.args["from_"].this.this.meta self.assertEqual(meta, {"line": 1, "col": 50, "start": 24, "end": 49}) assert ( information_schema_sql[meta["start"] : meta["end"] + 1] == "INFORMATION_SCHEMA.COLUMNS" @@ -3017,14 +3017,14 @@ def test_identifier_meta(self): def test_quoted_identifier_meta(self): sql = "SELECT `a` FROM `test_schema`.`test_table_a`" ast = parse_one(sql, dialect="bigquery") - db_meta = ast.args["from"].this.args["db"].meta + db_meta = ast.args["from_"].this.args["db"].meta self.assertEqual(sql[db_meta["start"] : db_meta["end"] + 1], "`test_schema`") - table_meta = ast.args["from"].this.this.meta + table_meta = ast.args["from_"].this.this.meta self.assertEqual(sql[table_meta["start"] : table_meta["end"] + 1], "`test_table_a`") information_schema_sql = "SELECT a, b FROM `region.INFORMATION_SCHEMA.COLUMNS`" ast = parse_one(information_schema_sql, dialect="bigquery") - table_meta = ast.args["from"].this.this.meta + table_meta = ast.args["from_"].this.this.meta assert ( information_schema_sql[table_meta["start"] : table_meta["end"] + 1] == "`region.INFORMATION_SCHEMA.COLUMNS`" diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 4ec7d7449e..93a474b839 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -716,8 +716,8 @@ def test_cte(self): self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") query = parse_one("""WITH (SELECT 1) AS y SELECT * FROM y""", read="clickhouse") - self.assertIsInstance(query.args["with"].expressions[0].this, exp.Subquery) - self.assertEqual(query.args["with"].expressions[0].alias, "y") + self.assertIsInstance(query.args["with_"].expressions[0].this, exp.Subquery) + self.assertEqual(query.args["with_"].expressions[0].alias, "y") query = "WITH 1 AS var SELECT var" for error_level in [ErrorLevel.IGNORE, ErrorLevel.RAISE, ErrorLevel.IMMEDIATE]: diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index de855283f6..fbbc6683b5 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -610,12 +610,12 @@ def test_column_unnesting(self): ) ast = parse_one("SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3", read="redshift") - ast.args["from"].this.assert_is(exp.Table) + ast.args["from_"].this.assert_is(exp.Table) ast.args["joins"][0].this.assert_is(exp.Table) self.assertEqual(ast.sql("redshift"), "SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3") ast = parse_one("SELECT * FROM t AS t CROSS JOIN t.c1", read="redshift") - ast.args["from"].this.assert_is(exp.Table) + ast.args["from_"].this.assert_is(exp.Table) ast.args["joins"][0].this.assert_is(exp.Unnest) self.assertEqual(ast.sql("redshift"), "SELECT * FROM t AS t CROSS JOIN t.c1") @@ -623,7 +623,7 @@ def test_column_unnesting(self): "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l", read="redshift" ) joins = ast.args["joins"] - ast.args["from"].this.assert_is(exp.Table) + ast.args["from_"].this.assert_is(exp.Table) joins[0].this.assert_is(exp.Unnest) joins[1].this.assert_is(exp.Unnest) joins[2].this.assert_is(exp.Unnest).expressions[0].assert_is(exp.Dot) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index cb98cfe8e8..59042d392a 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -243,7 +243,7 @@ def test_snowflake(self): "SELECT STRTOK('hello world', ' ', 2)", "SELECT SPLIT_PART('hello world', ' ', 2)" ) self.validate_identity("SELECT FILE_URL FROM DIRECTORY(@mystage) WHERE SIZE > 100000").args[ - "from" + "from_" ].this.this.assert_is(exp.DirectoryStage).this.assert_is(exp.Var) self.validate_identity( "SELECT AI_CLASSIFY('text', ['travel', 'cooking'], OBJECT_CONSTRUCT('output_mode', 'multi'))" diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index a21156b9c6..c67ae2231b 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -2025,11 +2025,11 @@ def test_identifier_prefixes(self): self.validate_identity("##x") .assert_is(exp.Column) .this.assert_is(exp.Identifier) - .args.get("global") + .args.get("global_") ) self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var) - self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is( + self.validate_identity("SELECT * FROM @x").args["from_"].this.assert_is( exp.Table ).this.assert_is(exp.Parameter).this.assert_is(exp.Var) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 30a92bf706..fb6a4320ce 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -211,11 +211,11 @@ def test_alias_or_name(self): ) self.assertEqual( - [e.alias_or_name for e in expression.args["with"].expressions], + [e.alias_or_name for e in expression.args["with_"].expressions], ["first", "second"], ) - self.assertEqual("first", expression.args["from"].alias_or_name) + self.assertEqual("first", expression.args["from_"].alias_or_name) self.assertEqual( [e.alias_or_name for e in expression.args["joins"]], ["second", "third"], @@ -1177,10 +1177,10 @@ def test_unnest(self): self.assertIs(ast.selects[0].unnest(), ast.find(exp.Literal)) ast = parse_one("SELECT * FROM (((SELECT * FROM t)))") - self.assertIs(ast.args["from"].this.unnest(), list(ast.find_all(exp.Select))[1]) + self.assertIs(ast.args["from_"].this.unnest(), list(ast.find_all(exp.Select))[1]) ast = parse_one("SELECT * FROM ((((SELECT * FROM t))) AS foo)") - second_subquery = ast.args["from"].this.this + second_subquery = ast.args["from_"].this.this innermost_subquery = list(ast.find_all(exp.Select))[1].parent self.assertIs(second_subquery, innermost_subquery.unwrap()) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 63a3dc8f46..1f9dc12459 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -626,7 +626,7 @@ def test_simplify(self): self.assertEqual( optimizer.simplify.gen(sql), """ -SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True +SELECT :with_,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from_,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from_,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True """.strip(), ) self.assertEqual( @@ -1131,7 +1131,7 @@ def test_derived_tables_column_annotation(self): expression.expressions[0].type.this, exp.DataType.Type.FLOAT ) # a.cola AS cola - addition_alias = expression.args["from"].this.this.expressions[0] + addition_alias = expression.args["from_"].this.this.expressions[0] self.assertEqual( addition_alias.type.this, exp.DataType.Type.FLOAT ) # x.cola + y.cola AS cola @@ -1177,7 +1177,7 @@ def test_cte_column_annotation(self): # WHERE tbl.colc = True self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN) - cte_select = expression.args["with"].expressions[0].this + cte_select = expression.args["with_"].expressions[0].this self.assertEqual( cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola + 'bla' AS cola diff --git a/tests/test_parser.py b/tests/test_parser.py index 11fbf41e89..d0cbdc1d2e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -231,7 +231,7 @@ def test_identify(self): assert expression.expressions[2].alias == "c" assert expression.expressions[3].alias == "D" assert expression.expressions[4].alias == "y|z'" - table = expression.args["from"].this + table = expression.args["from_"].this assert table.name == "z" assert table.args["db"].name == "y" @@ -243,8 +243,8 @@ def test_multi(self): ) assert len(expressions) == 2 - assert expressions[0].args["from"].name == "a" - assert expressions[1].args["from"].name == "b" + assert expressions[0].args["from_"].name == "a" + assert expressions[1].args["from_"].name == "b" expressions = parse("SELECT 1; ; SELECT 2") @@ -374,8 +374,8 @@ def test_comments_select_cte(self): ) self.assertEqual(expression.comments, ["comment2"]) - self.assertEqual(expression.args.get("from").comments, ["comment3"]) - self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args.get("from_").comments, ["comment3"]) + self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"]) def test_comments_insert(self): expression = parse_one( @@ -406,7 +406,7 @@ def test_comments_insert_cte(self): self.assertEqual(expression.comments, ["comment2"]) self.assertEqual(expression.this.comments, ["comment3"]) - self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"]) def test_comments_update(self): expression = parse_one( @@ -440,7 +440,7 @@ def test_comments_update_cte(self): self.assertEqual(expression.comments, ["comment2"]) self.assertEqual(expression.this.comments, ["comment3"]) - self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"]) def test_comments_delete(self): expression = parse_one( @@ -472,7 +472,7 @@ def test_comments_delete_cte(self): self.assertEqual(expression.comments, ["comment2"]) self.assertEqual(expression.this.comments, ["comment3"]) - self.assertEqual(expression.args["with"].comments, ["comment1.1", "comment1.2"]) + self.assertEqual(expression.args["with_"].comments, ["comment1.1", "comment1.2"]) def test_type_literals(self): self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) @@ -768,7 +768,7 @@ def test_pivot_columns(self): for dialect, expected_columns in dialect_columns.items(): with self.subTest(f"Testing query '{query}' for dialect {dialect}"): expr = parse_one(query, read=dialect) - columns = expr.args["from"].this.args["pivots"][0].args["columns"] + columns = expr.args["from_"].this.args["pivots"][0].args["columns"] self.assertEqual( expected_columns, [col.sql(dialect=dialect) for col in columns] ) @@ -957,23 +957,23 @@ def test_token_position_meta(self): self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"}) self.assertEqual( - ast.this.args["from"].this.args["this"].meta, + ast.this.args["from_"].this.args["this"].meta, {"line": 1, "col": 41, "start": 29, "end": 40}, ) self.assertEqual( - ast.this.args["from"].this.args["db"].meta, + ast.this.args["from_"].this.args["db"].meta, {"line": 1, "col": 28, "start": 17, "end": 27}, ) self.assertEqual( - ast.expression.args["from"].this.args["this"].meta, + ast.expression.args["from_"].this.args["this"].meta, {"line": 1, "col": 106, "start": 94, "end": 105}, ) self.assertEqual( - ast.expression.args["from"].this.args["db"].meta, + ast.expression.args["from_"].this.args["db"].meta, {"line": 1, "col": 93, "start": 82, "end": 92}, ) self.assertEqual( - ast.expression.args["from"].this.args["catalog"].meta, + ast.expression.args["from_"].this.args["catalog"].meta, {"line": 1, "col": 81, "start": 69, "end": 80}, ) @@ -993,10 +993,10 @@ def test_quoted_identifier_meta(self): sql = 'SELECT "a" FROM "test_schema"."test_table_a"' ast = parse_one(sql) - db_meta = ast.args["from"].this.args["db"].meta + db_meta = ast.args["from_"].this.args["db"].meta self.assertEqual(sql[db_meta["start"] : db_meta["end"] + 1], '"test_schema"') - table_meta = ast.args["from"].this.this.meta + table_meta = ast.args["from_"].this.this.meta self.assertEqual(sql[table_meta["start"] : table_meta["end"] + 1], '"test_table_a"') def test_qualified_function(self): From 3954773b463fdd4bbe265e011331de073e4a9830 Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 14 Nov 2025 21:41:32 -0800 Subject: [PATCH 3/4] feat: improve performance of parser by making checks optional/faster the check for unexpected args only runs at parse time. so for 3rd party libraries or other code like the optimizer setting args or creating expressions themselves, it's never checked it or worked. moving this to unit testing only. --- sqlglot/expressions.py | 16 +++++++++++----- tests/test_parser.py | 13 +++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 62fcd8ee2d..a1c0070899 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -16,6 +16,7 @@ import math import numbers import re +import sys import textwrap import typing as t from collections import deque @@ -54,6 +55,7 @@ def __new__(cls, clsname, bases, attrs): # When an Expression class is created, its key is automatically set # to be the lowercase version of the class' name. klass.key = clsname.lower() + klass.required_args = {k for k, v in klass.arg_types.items() if v} # This is so that docstrings are not inherited in pdoc klass.__doc__ = klass.__doc__ or "" @@ -66,6 +68,7 @@ def __new__(cls, clsname, bases, attrs): TABLE_PARTS = ("this", "db", "catalog") COLUMN_PARTS = ("this", "table", "db", "catalog") POSITION_META_KEYS = ("line", "col", "start", "end") +UNITTEST = "unittest" in sys.modules or "pytest" in sys.modules class Expression(metaclass=_Expression): @@ -102,6 +105,7 @@ class Expression(metaclass=_Expression): key = "expression" arg_types = {"this": True} + required_args = {"this"} __slots__ = ("args", "parent", "arg_key", "index", "comments", "_type", "_meta", "_hash") def __init__(self, **args: t.Any): @@ -768,12 +772,14 @@ def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: """ errors: t.List[str] = [] - for k in self.args: - if k not in self.arg_types: - errors.append(f"Unexpected keyword: '{k}' for {self.__class__}") - for k, mandatory in self.arg_types.items(): + if UNITTEST: + for k in self.args: + if k not in self.arg_types: + raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}") + + for k in self.required_args: v = self.args.get(k) - if mandatory and (v is None or (isinstance(v, list) and not v)): + if v is None or (type(v) is list and not v): errors.append(f"Required keyword: '{k}' missing for {self.__class__}") if ( diff --git a/tests/test_parser.py b/tests/test_parser.py index d0cbdc1d2e..f10c03111b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -253,19 +253,20 @@ def test_multi(self): def test_expression(self): ignore = Parser(error_level=ErrorLevel.IGNORE) - self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint) + self.assertIsInstance(ignore.expression(exp.Hint, expressions=[]), exp.Hint) self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint) self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint) default = Parser(error_level=ErrorLevel.RAISE) - self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint) - default.expression(exp.Hint, y="") + with self.assertRaises(TypeError): + default.expression(exp.Hint, y="") + self.assertIsInstance(default.expression(exp.Hint, expressions=[]), exp.Hint) default.expression(exp.Hint) - self.assertEqual(len(default.errors), 3) + self.assertEqual(len(default.errors), 2) warn = Parser(error_level=ErrorLevel.WARN) - warn.expression(exp.Hint, y="") - self.assertEqual(len(warn.errors), 2) + warn.expression(exp.Hint) + self.assertEqual(len(warn.errors), 1) def test_parse_errors(self): with self.assertRaises(ParseError): From 7e3ecee9b848ce3be102bccbe1ab35c5fc16e51d Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 14 Nov 2025 22:12:08 -0800 Subject: [PATCH 4/4] feat: improve update_positions performance --- sqlglot/expressions.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a1c0070899..1ce740c5df 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -889,29 +889,39 @@ def not_(self, copy: bool = True): return not_(self, copy=copy) def update_positions( - self: E, other: t.Optional[Token | Expression] = None, **kwargs: t.Any + self: E, + other: t.Optional[Token | Expression] = None, + line: t.Optional[int] = None, + col: t.Optional[int] = None, + start: t.Optional[int] = None, + end: t.Optional[int] = None, ) -> E: """ Update this expression with positions from a token or other expression. Args: other: a token or expression to update this expression with. + line: the line number to use if other is None + col: column number + start: start char index + end: end char index Returns: The updated expression. """ - if isinstance(other, Expression): - self.meta.update({k: v for k, v in other.meta.items() if k in POSITION_META_KEYS}) - elif other is not None: - self.meta.update( - { - "line": other.line, - "col": other.col, - "start": other.start, - "end": other.end, - } - ) - self.meta.update({k: v for k, v in kwargs.items() if k in POSITION_META_KEYS}) + if other is None: + self.meta["line"] = line + self.meta["col"] = col + self.meta["start"] = start + self.meta["end"] = end + elif hasattr(other, "meta"): + for k in POSITION_META_KEYS: + self.meta[k] = other.meta[k] + else: + self.meta["line"] = other.line + self.meta["col"] = other.col + self.meta["start"] = other.start + self.meta["end"] = other.end return self def as_(