Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,21 +590,6 @@ def dagbag_report(self):
table=pprinttable(stats),
)

@provide_session
def deactivate_inactive_dags(self, session=None):
active_dag_ids = [dag.dag_id for dag in list(self.dags.values())]
for dag in session.query(
DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
dag.is_active = False
session.merge(dag)
session.commit()

@provide_session
def paused_dags(self, session=None):
dag_ids = [dp.dag_id for dp in session.query(DagModel).filter(
DagModel.is_paused.__eq__(True))]
return dag_ids


class User(Base):
__tablename__ = "users"
Expand Down Expand Up @@ -4202,16 +4187,6 @@ def add_tasks(self, tasks):
for task in tasks:
self.add_task(task)

@provide_session
def db_merge(self, session=None):
BO = BaseOperator
tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all()
for t in tasks:
session.delete(t)
session.commit()
session.merge(self)
session.commit()

def run(
self,
start_date=None,
Expand Down
181 changes: 180 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,20 @@
from airflow.models import clear_task_instances
from airflow.models import XCom
from airflow.models import Connection
from airflow.models import SkipMixin
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier
from airflow.jobs import LocalTaskJob
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
from airflow.operators.python_operator import ShortCircuitOperator
from airflow.operators.subdag_operator import SubDagOperator
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.weight_rule import WeightRule
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from mock import patch, ANY
from mock import patch, Mock, ANY
from parameterized import parameterized
from tempfile import mkdtemp, NamedTemporaryFile

Expand Down Expand Up @@ -575,6 +578,38 @@ def test_cycle(self):
with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()

@patch('airflow.models.timezone.utcnow')
def test_sync_to_db(self, mock_now):
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
DummyOperator(task_id='task', owner='owner1')
SubDagOperator(
task_id='subtask',
owner='owner2',
subdag=DAG(
'dag.subtask',
start_date=DEFAULT_DATE,
)
)
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
session = settings.Session()
dag.sync_to_db(session=session)

orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'})
self.assertEqual(orm_dag.last_scheduler_run, now)
self.assertTrue(orm_dag.is_active)

orm_subdag = session.query(DagModel).filter(
DagModel.dag_id == 'dag.subtask').one()
self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'})
self.assertEqual(orm_subdag.last_scheduler_run, now)
self.assertTrue(orm_subdag.is_active)


class DagStatTest(unittest.TestCase):
def test_dagstats_crud(self):
Expand Down Expand Up @@ -625,6 +660,25 @@ def test_dagstats_crud(self):
for stat in res:
self.assertFalse(stat.dirty)

def test_update_exception(self):
session = Mock()
(session.query.return_value
.filter.return_value
.with_for_update.return_value
.all.side_effect) = RuntimeError('it broke')
DagStat.update(session=session)
session.rollback.assert_called()

def test_set_dirty_exception(self):
session = Mock()
session.query.return_value.filter.return_value.all.return_value = []
(session.query.return_value
.filter.return_value
.with_for_update.return_value
.all.side_effect) = RuntimeError('it broke')
DagStat.set_dirty('dag', session)
session.rollback.assert_called()


class DagRunTest(unittest.TestCase):

Expand Down Expand Up @@ -2349,6 +2403,35 @@ def test_overwrite_params_with_dag_run_conf_none(self):

self.assertEqual(False, params["override"])

@patch('airflow.models.send_email')
def test_email_alert(self, mock_send_email):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.email_alert(RuntimeError('it broke'))

self.assertTrue(mock_send_email.called)
(email, title, body), _ = mock_send_email.call_args
self.assertEqual(email, 'test@test.test')
self.assertIn(repr(ti), title)
self.assertIn('it broke', body)

def test_set_duration(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(
task=task,
execution_date=datetime.datetime.now(),
)
ti.start_date = datetime.datetime(2018, 10, 1, 1)
ti.end_date = datetime.datetime(2018, 10, 1, 2)
ti.set_duration()
self.assertEqual(ti.duration, 3600)

def test_set_duration_empty_dates(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.set_duration()
self.assertIsNone(ti.duration)


class ClearTasksTest(unittest.TestCase):

Expand Down Expand Up @@ -2705,3 +2788,99 @@ def test_connection_from_uri_with_extras(self):
self.assertEqual(connection.port, 1234)
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
'extra2': '/path/'})


class TestSkipMixin(unittest.TestCase):

@patch('airflow.models.timezone.utcnow')
def test_skip(self, mock_now):
session = settings.Session()
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
tasks = [DummyOperator(task_id='task')]
dag_run = dag.create_dagrun(
run_id='manual__' + now.isoformat(),
state=State.FAILED,
)
SkipMixin().skip(
dag_run=dag_run,
execution_date=now,
tasks=tasks,
session=session)

session.query(TI).filter(
TI.dag_id == 'dag',
TI.task_id == 'task',
TI.state == State.SKIPPED,
TI.start_date == now,
TI.end_date == now,
).one()

@patch('airflow.models.timezone.utcnow')
def test_skip_none_dagrun(self, mock_now):
session = settings.Session()
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
tasks = [DummyOperator(task_id='task')]
SkipMixin().skip(
dag_run=None,
execution_date=now,
tasks=tasks,
session=session)

session.query(TI).filter(
TI.dag_id == 'dag',
TI.task_id == 'task',
TI.state == State.SKIPPED,
TI.start_date == now,
TI.end_date == now,
).one()

def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session)
self.assertFalse(session.query.called)
self.assertFalse(session.commit.called)


class TestKubeResourceVersion(unittest.TestCase):

def test_checkpoint_resource_version(self):
session = settings.Session()
KubeResourceVersion.checkpoint_resource_version('7', session)
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7')

def test_reset_resource_version(self):
session = settings.Session()
version = KubeResourceVersion.reset_resource_version(session)
self.assertEqual(version, '0')
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0')


class TestKubeWorkerIdentifier(unittest.TestCase):

@patch('airflow.models.uuid.uuid4')
def test_get_or_create_not_exist(self, mock_uuid):
session = settings.Session()
session.query(KubeWorkerIdentifier).update({
KubeWorkerIdentifier.worker_uuid: ''
})
mock_uuid.return_value = 'abcde'
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
self.assertEqual(worker_uuid, 'abcde')

def test_get_or_create_exist(self):
session = settings.Session()
KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session)
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
self.assertEqual(worker_uuid, 'fghij')