diff --git a/airflow/models.py b/airflow/models.py index 2d5d04125909a..97a0fa92fcf87 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -590,21 +590,6 @@ def dagbag_report(self): table=pprinttable(stats), ) - @provide_session - def deactivate_inactive_dags(self, session=None): - active_dag_ids = [dag.dag_id for dag in list(self.dags.values())] - for dag in session.query( - DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all(): - dag.is_active = False - session.merge(dag) - session.commit() - - @provide_session - def paused_dags(self, session=None): - dag_ids = [dp.dag_id for dp in session.query(DagModel).filter( - DagModel.is_paused.__eq__(True))] - return dag_ids - class User(Base): __tablename__ = "users" @@ -4202,16 +4187,6 @@ def add_tasks(self, tasks): for task in tasks: self.add_task(task) - @provide_session - def db_merge(self, session=None): - BO = BaseOperator - tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all() - for t in tasks: - session.delete(t) - session.commit() - session.merge(self) - session.commit() - def run( self, start_date=None, diff --git a/tests/models.py b/tests/models.py index 838a47f938eaa..7fedfb2085192 100644 --- a/tests/models.py +++ b/tests/models.py @@ -43,17 +43,20 @@ from airflow.models import clear_task_instances from airflow.models import XCom from airflow.models import Connection +from airflow.models import SkipMixin +from airflow.models import KubeResourceVersion, KubeWorkerIdentifier from airflow.jobs import LocalTaskJob from airflow.operators.dummy_operator import DummyOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.python_operator import PythonOperator from airflow.operators.python_operator import ShortCircuitOperator +from airflow.operators.subdag_operator import SubDagOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone from airflow.utils.weight_rule import WeightRule from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule -from mock import patch, ANY +from mock import patch, Mock, ANY from parameterized import parameterized from tempfile import mkdtemp, NamedTemporaryFile @@ -575,6 +578,38 @@ def test_cycle(self): with self.assertRaises(AirflowDagCycleException): dag.test_cycle() + @patch('airflow.models.timezone.utcnow') + def test_sync_to_db(self, mock_now): + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + ) + with dag: + DummyOperator(task_id='task', owner='owner1') + SubDagOperator( + task_id='subtask', + owner='owner2', + subdag=DAG( + 'dag.subtask', + start_date=DEFAULT_DATE, + ) + ) + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) + mock_now.return_value = now + session = settings.Session() + dag.sync_to_db(session=session) + + orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one() + self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'}) + self.assertEqual(orm_dag.last_scheduler_run, now) + self.assertTrue(orm_dag.is_active) + + orm_subdag = session.query(DagModel).filter( + DagModel.dag_id == 'dag.subtask').one() + self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'}) + self.assertEqual(orm_subdag.last_scheduler_run, now) + self.assertTrue(orm_subdag.is_active) + class DagStatTest(unittest.TestCase): def test_dagstats_crud(self): @@ -625,6 +660,25 @@ def test_dagstats_crud(self): for stat in res: self.assertFalse(stat.dirty) + def test_update_exception(self): + session = Mock() + (session.query.return_value + .filter.return_value + .with_for_update.return_value + .all.side_effect) = RuntimeError('it broke') + DagStat.update(session=session) + session.rollback.assert_called() + + def test_set_dirty_exception(self): + session = Mock() + session.query.return_value.filter.return_value.all.return_value = [] + (session.query.return_value + .filter.return_value + .with_for_update.return_value + .all.side_effect) = RuntimeError('it broke') + DagStat.set_dirty('dag', session) + session.rollback.assert_called() + class DagRunTest(unittest.TestCase): @@ -2349,6 +2403,35 @@ def test_overwrite_params_with_dag_run_conf_none(self): self.assertEqual(False, params["override"]) + @patch('airflow.models.send_email') + def test_email_alert(self, mock_send_email): + task = DummyOperator(task_id='op', email='test@test.test') + ti = TI(task=task, execution_date=datetime.datetime.now()) + ti.email_alert(RuntimeError('it broke')) + + self.assertTrue(mock_send_email.called) + (email, title, body), _ = mock_send_email.call_args + self.assertEqual(email, 'test@test.test') + self.assertIn(repr(ti), title) + self.assertIn('it broke', body) + + def test_set_duration(self): + task = DummyOperator(task_id='op', email='test@test.test') + ti = TI( + task=task, + execution_date=datetime.datetime.now(), + ) + ti.start_date = datetime.datetime(2018, 10, 1, 1) + ti.end_date = datetime.datetime(2018, 10, 1, 2) + ti.set_duration() + self.assertEqual(ti.duration, 3600) + + def test_set_duration_empty_dates(self): + task = DummyOperator(task_id='op', email='test@test.test') + ti = TI(task=task, execution_date=datetime.datetime.now()) + ti.set_duration() + self.assertIsNone(ti.duration) + class ClearTasksTest(unittest.TestCase): @@ -2705,3 +2788,99 @@ def test_connection_from_uri_with_extras(self): self.assertEqual(connection.port, 1234) self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value', 'extra2': '/path/'}) + + +class TestSkipMixin(unittest.TestCase): + + @patch('airflow.models.timezone.utcnow') + def test_skip(self, mock_now): + session = settings.Session() + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) + mock_now.return_value = now + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + ) + with dag: + tasks = [DummyOperator(task_id='task')] + dag_run = dag.create_dagrun( + run_id='manual__' + now.isoformat(), + state=State.FAILED, + ) + SkipMixin().skip( + dag_run=dag_run, + execution_date=now, + tasks=tasks, + session=session) + + session.query(TI).filter( + TI.dag_id == 'dag', + TI.task_id == 'task', + TI.state == State.SKIPPED, + TI.start_date == now, + TI.end_date == now, + ).one() + + @patch('airflow.models.timezone.utcnow') + def test_skip_none_dagrun(self, mock_now): + session = settings.Session() + now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) + mock_now.return_value = now + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + ) + with dag: + tasks = [DummyOperator(task_id='task')] + SkipMixin().skip( + dag_run=None, + execution_date=now, + tasks=tasks, + session=session) + + session.query(TI).filter( + TI.dag_id == 'dag', + TI.task_id == 'task', + TI.state == State.SKIPPED, + TI.start_date == now, + TI.end_date == now, + ).one() + + def test_skip_none_tasks(self): + session = Mock() + SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session) + self.assertFalse(session.query.called) + self.assertFalse(session.commit.called) + + +class TestKubeResourceVersion(unittest.TestCase): + + def test_checkpoint_resource_version(self): + session = settings.Session() + KubeResourceVersion.checkpoint_resource_version('7', session) + self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7') + + def test_reset_resource_version(self): + session = settings.Session() + version = KubeResourceVersion.reset_resource_version(session) + self.assertEqual(version, '0') + self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0') + + +class TestKubeWorkerIdentifier(unittest.TestCase): + + @patch('airflow.models.uuid.uuid4') + def test_get_or_create_not_exist(self, mock_uuid): + session = settings.Session() + session.query(KubeWorkerIdentifier).update({ + KubeWorkerIdentifier.worker_uuid: '' + }) + mock_uuid.return_value = 'abcde' + worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) + self.assertEqual(worker_uuid, 'abcde') + + def test_get_or_create_exist(self): + session = settings.Session() + KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session) + worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) + self.assertEqual(worker_uuid, 'fghij')