Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from typing import TYPE_CHECKING, Callable, Generator, Iterable

from sqlalchemy import Table, and_, column, delete, exc, func, inspect, or_, select, table, text, tuple_
from sqlalchemy.orm.session import Session

import airflow
from airflow import settings
Expand All @@ -45,9 +44,10 @@
if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from airflow.models.base import Base
from airflow.models.connection import Connection

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,9 +90,9 @@ def _format_airflow_moved_table_name(source_table, version, category):


@provide_session
def merge_conn(conn, session: Session = NEW_SESSION):
def merge_conn(conn: Connection, session: Session = NEW_SESSION):
"""Add new Connection."""
if not session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)):
if not session.scalar(select(1).where(conn.__class__.conn_id == conn.conn_id)):
session.add(conn)
session.commit()

Expand Down Expand Up @@ -957,20 +957,20 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
from airflow.models.connection import Connection

dups = []
try:
dups = session.execute(
dups = session.scalars(
select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
).all()
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
session.rollback()
return
if dups:
yield (
"Seems you have non unique conn_id in connection table.\n"
"You have to manage those duplicate connections "
"before upgrading the database.\n"
f"Duplicated conn_id: {[dup.conn_id for dup in dups]}"
f"Duplicated conn_id: {dups}"
)


Expand Down Expand Up @@ -1057,11 +1057,11 @@ def check_task_fail_for_duplicates(session):
:param uniqueness: uniqueness constraint to evaluate against
:param session: session of the sqlalchemy
"""
minimal_table_obj = table(table_name, *[column(x) for x in uniqueness])
minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
try:
subquery = session.execute(
select(minimal_table_obj, func.count().label("dupe_count"))
.group_by(*[text(x) for x in uniqueness])
.group_by(*(text(x) for x in uniqueness))
.having(func.count() > text("1"))
.subquery()
)
Expand Down Expand Up @@ -1100,20 +1100,20 @@ def check_conn_type_null(session: Session) -> Iterable[str]:
"""
from airflow.models.connection import Connection

n_nulls = []
try:
n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
# fallback if tables hasn't been created yet
session.rollback()
return

if n_nulls:
yield (
"The conn_type column in the connection "
"table must contain content.\n"
"Make sure you don't have null "
"in the conn_type column.\n"
f"Null conn_type conn_id: {list(n_nulls)}"
f"Null conn_type conn_id: {n_nulls}"
)


Expand Down Expand Up @@ -1265,7 +1265,7 @@ def _dangling_against_dag_run(session, source_table, dag_run):
)

return (
select(*[c.label(c.name) for c in source_table.c])
select(*(c.label(c.name) for c in source_table.c))
.join(dag_run, source_to_dag_run_join_cond, isouter=True)
.where(dag_run.c.dag_id.is_(None))
)
Expand Down Expand Up @@ -1306,9 +1306,9 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc
)

return (
select(*[c.label(c.name) for c in source_table.c])
.join(dag_run, dr_join_cond, isouter=True)
.join(task_instance, ti_join_cond, isouter=True)
select(*(c.label(c.name) for c in source_table.c))
.outerjoin(dag_run, dr_join_cond)
.outerjoin(task_instance, ti_join_cond)
.where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
)

Expand All @@ -1335,9 +1335,9 @@ def _move_duplicate_data_to_new_table(
dialect_name = bind.dialect.name

query = (
select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns])
select(*(source_table.c[x.name].label(str(x.name)) for x in source_table.columns))
.select_from(source_table)
.join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness]))
.join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in uniqueness)))
)

_create_table_as(
Expand All @@ -1353,7 +1353,7 @@ def _move_duplicate_data_to_new_table(

metadata = reflect_tables([target_table_name], session)
target_table = metadata.tables[target_table_name]
where_clause = and_(*[getattr(source_table.c, x) == getattr(target_table.c, x) for x in uniqueness])
where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in uniqueness))

if dialect_name == "sqlite":
subq = query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
Expand Down Expand Up @@ -1410,7 +1410,7 @@ class BadReferenceConfig:
(TaskFail, "2.3", missing_ti_config),
(XCom, "2.3", missing_ti_config),
]
metadata = reflect_tables([*[x[0] for x in models_list], DagRun, TaskInstance], session)
metadata = reflect_tables([*(x[0] for x in models_list), DagRun, TaskInstance], session)

if (
not metadata.tables
Expand Down