diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 46ddcbfb3453f..f846b3f3a650d 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -30,7 +30,6 @@ from typing import TYPE_CHECKING, Callable, Generator, Iterable from sqlalchemy import Table, and_, column, delete, exc, func, inspect, or_, select, table, text, tuple_ -from sqlalchemy.orm.session import Session import airflow from airflow import settings @@ -45,9 +44,10 @@ if TYPE_CHECKING: from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory - from sqlalchemy.orm import Query + from sqlalchemy.orm import Query, Session from airflow.models.base import Base + from airflow.models.connection import Connection log = logging.getLogger(__name__) @@ -90,9 +90,9 @@ def _format_airflow_moved_table_name(source_table, version, category): @provide_session -def merge_conn(conn, session: Session = NEW_SESSION): +def merge_conn(conn: Connection, session: Session = NEW_SESSION): """Add new Connection.""" - if not session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)): + if not session.scalar(select(1).where(conn.__class__.conn_id == conn.conn_id)): session.add(conn) session.commit() @@ -957,20 +957,20 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]: """ from airflow.models.connection import Connection - dups = [] try: - dups = session.execute( + dups = session.scalars( select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1) ).all() except (exc.OperationalError, exc.ProgrammingError): # fallback if tables hasn't been created yet session.rollback() + return if dups: yield ( "Seems you have non unique conn_id in connection table.\n" "You have to manage those duplicate connections " "before upgrading the database.\n" - f"Duplicated conn_id: {[dup.conn_id for dup in dups]}" + f"Duplicated conn_id: {dups}" ) @@ -1057,11 +1057,11 @@ def check_task_fail_for_duplicates(session): :param uniqueness: uniqueness constraint to evaluate against :param session: session of the sqlalchemy """ - minimal_table_obj = table(table_name, *[column(x) for x in uniqueness]) + minimal_table_obj = table(table_name, *(column(x) for x in uniqueness)) try: subquery = session.execute( select(minimal_table_obj, func.count().label("dupe_count")) - .group_by(*[text(x) for x in uniqueness]) + .group_by(*(text(x) for x in uniqueness)) .having(func.count() > text("1")) .subquery() ) @@ -1100,12 +1100,12 @@ def check_conn_type_null(session: Session) -> Iterable[str]: """ from airflow.models.connection import Connection - n_nulls = [] try: n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all() except (exc.OperationalError, exc.ProgrammingError, exc.InternalError): # fallback if tables hasn't been created yet session.rollback() + return if n_nulls: yield ( @@ -1113,7 +1113,7 @@ def check_conn_type_null(session: Session) -> Iterable[str]: "table must contain content.\n" "Make sure you don't have null " "in the conn_type column.\n" - f"Null conn_type conn_id: {list(n_nulls)}" + f"Null conn_type conn_id: {n_nulls}" ) @@ -1265,7 +1265,7 @@ def _dangling_against_dag_run(session, source_table, dag_run): ) return ( - select(*[c.label(c.name) for c in source_table.c]) + select(*(c.label(c.name) for c in source_table.c)) .join(dag_run, source_to_dag_run_join_cond, isouter=True) .where(dag_run.c.dag_id.is_(None)) ) @@ -1306,9 +1306,9 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc ) return ( - select(*[c.label(c.name) for c in source_table.c]) - .join(dag_run, dr_join_cond, isouter=True) - .join(task_instance, ti_join_cond, isouter=True) + select(*(c.label(c.name) for c in source_table.c)) + .outerjoin(dag_run, dr_join_cond) + .outerjoin(task_instance, ti_join_cond) .where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) ) @@ -1335,9 +1335,9 @@ def _move_duplicate_data_to_new_table( dialect_name = bind.dialect.name query = ( - select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns]) + select(*(source_table.c[x.name].label(str(x.name)) for x in source_table.columns)) .select_from(source_table) - .join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness])) + .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in uniqueness))) ) _create_table_as( @@ -1353,7 +1353,7 @@ def _move_duplicate_data_to_new_table( metadata = reflect_tables([target_table_name], session) target_table = metadata.tables[target_table_name] - where_clause = and_(*[getattr(source_table.c, x) == getattr(target_table.c, x) for x in uniqueness]) + where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in uniqueness)) if dialect_name == "sqlite": subq = query.selectable.with_only_columns([text(f"{source_table}.ROWID")]) @@ -1410,7 +1410,7 @@ class BadReferenceConfig: (TaskFail, "2.3", missing_ti_config), (XCom, "2.3", missing_ti_config), ] - metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session) + metadata = reflect_tables([*(x[0] for x in models_list), DagRun, TaskInstance], session) if ( not metadata.tables