Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,7 @@ class BadReferenceConfig:

dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, "dangling")
if dangling_table_name in existing_table_names:
invalid_row_count = bad_rows_query.count()
invalid_row_count = get_query_count(bad_rows_query, session=session)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test to protect future changes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will add it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ephraimbuddy I've finally found the time to add the test, although it required more work than expected :)

if invalid_row_count:
yield _format_dangling_error(
source_table=source_table.name,
Expand Down
1 change: 1 addition & 0 deletions newsfragments/34348.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed ``AttributeError: 'Select' object has no attribute 'count'`` during the ``airflow db migrate`` command
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to have but not required

86 changes: 84 additions & 2 deletions tests/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
from alembic.migration import MigrationContext
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy import MetaData
from sqlalchemy import MetaData, Table
from sqlalchemy.sql import Select

from airflow.exceptions import AirflowException
from airflow.models import Base as airflow_base
from airflow.settings import engine
from airflow.utils.db import (
_get_alembic_config,
check_bad_references,
check_migrations,
compare_server_default,
compare_type,
Expand All @@ -49,6 +51,7 @@
resetdb,
upgradedb,
)
from airflow.utils.session import NEW_SESSION


class TestDb:
Expand All @@ -57,7 +60,7 @@ def test_database_schema_and_sqlalchemy_model_are_in_sync(self):

airflow.models.import_all_models()
all_meta_data = MetaData()
for (table_name, table) in airflow_base.metadata.tables.items():
for table_name, table in airflow_base.metadata.tables.items():
all_meta_data._add_table(table_name, table.schema, table)

# create diff between database schema and SQLAlchemy model
Expand Down Expand Up @@ -251,3 +254,82 @@ def test_alembic_configuration(self):
import airflow

assert config.config_file_name == os.path.join(os.path.dirname(airflow.__file__), "alembic.ini")

@mock.patch("airflow.utils.db._move_dangling_data_to_new_table")
@mock.patch("airflow.utils.db.get_query_count")
@mock.patch("airflow.utils.db._dangling_against_task_instance")
@mock.patch("airflow.utils.db._dangling_against_dag_run")
@mock.patch("airflow.utils.db.reflect_tables")
@mock.patch("airflow.utils.db.inspect")
def test_check_bad_references(
self,
mock_inspect: MagicMock,
mock_reflect_tables: MagicMock,
mock_dangling_against_dag_run: MagicMock,
mock_dangling_against_task_instance: MagicMock,
mock_get_query_count: MagicMock,
mock_move_dangling_data_to_new_table: MagicMock,
):
from airflow.models.dagrun import DagRun
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import XCom

mock_session = MagicMock(spec=NEW_SESSION)
mock_bind = MagicMock()
mock_session.get_bind.return_value = mock_bind
task_instance_table = MagicMock(spec=Table)
task_instance_table.name = TaskInstance.__tablename__
dag_run_table = MagicMock(spec=Table)
task_fail_table = MagicMock(spec=Table)
task_fail_table.name = TaskFail.__tablename__

mock_reflect_tables.return_value = MagicMock(
tables={
DagRun.__tablename__: dag_run_table,
TaskInstance.__tablename__: task_instance_table,
TaskFail.__tablename__: task_fail_table,
}
)

# Simulate that there is a moved `task_instance` table from the
# previous run, but no moved `task_fail` table
dangling_task_instance_table_name = f"_airflow_moved__2_2__dangling__{task_instance_table.name}"
dangling_task_fail_table_name = f"_airflow_moved__2_3__dangling__{task_fail_table.name}"
mock_get_table_names = MagicMock(
return_value=[
TaskInstance.__tablename__,
DagRun.__tablename__,
TaskFail.__tablename__,
dangling_task_instance_table_name,
]
)
mock_inspect.return_value = MagicMock(
get_table_names=mock_get_table_names,
)
mock_select = MagicMock(spec=Select)
mock_dangling_against_dag_run.return_value = mock_select
mock_dangling_against_task_instance.return_value = mock_select
mock_get_query_count.return_value = 1

# Should return a single error related to the dangling `task_instance` table
errs = list(check_bad_references(session=mock_session))
assert len(errs) == 1
assert dangling_task_instance_table_name in errs[0]

mock_reflect_tables.assert_called_once_with(
[TaskInstance, TaskReschedule, RenderedTaskInstanceFields, TaskFail, XCom, DagRun, TaskInstance],
mock_session,
)
mock_inspect.assert_called_once_with(mock_bind)
mock_get_table_names.assert_called_once()
mock_dangling_against_dag_run.assert_called_once_with(
mock_session, task_instance_table, dag_run=dag_run_table
)
mock_get_query_count.assert_called_once_with(mock_select, session=mock_session)
mock_move_dangling_data_to_new_table.assert_called_once_with(
mock_session, task_fail_table, mock_select, dangling_task_fail_table_name
)
mock_session.rollback.assert_called_once()