From abdb6806081e3fd55774ba8769323b81095cc0aa Mon Sep 17 00:00:00 2001 From: qian Date: Sat, 1 Feb 2020 18:25:06 +0800 Subject: [PATCH] [AIRFLOW-6657] Deprecate BaseBranchOperator --- airflow/operators/branch_operator.py | 3 + airflow/operators/latest_only_operator.py | 52 +++---- docs/concepts.rst | 25 +-- tests/operators/test_branch_operator.py | 177 ---------------------- 4 files changed, 42 insertions(+), 215 deletions(-) delete mode 100644 tests/operators/test_branch_operator.py diff --git a/airflow/operators/branch_operator.py b/airflow/operators/branch_operator.py index 247d4cc23375b..1cc3441a9f21a 100644 --- a/airflow/operators/branch_operator.py +++ b/airflow/operators/branch_operator.py @@ -17,10 +17,13 @@ # under the License. """Branching operators""" +import warnings from typing import Dict, Iterable, Union from airflow.models import BaseOperator, SkipMixin +warnings.warn("This module is deprecated. Please use `airflow.operators.python.BranchPythonOperator`.") + class BaseBranchOperator(BaseOperator, SkipMixin): """ diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py index 6b6c82c1482af..ee03cb2696d19 100644 --- a/airflow/operators/latest_only_operator.py +++ b/airflow/operators/latest_only_operator.py @@ -19,14 +19,12 @@ This module contains an operator to run downstream tasks only for the latest scheduled DagRun """ -from typing import Dict, Iterable, Union - import pendulum -from airflow.operators.branch_operator import BaseBranchOperator +from airflow.operators.python import BranchPythonOperator -class LatestOnlyOperator(BaseBranchOperator): +class LatestOnlyOperator(BranchPythonOperator): """ Allows a workflow to skip tasks that are not running during the most recent schedule interval. @@ -40,28 +38,30 @@ class LatestOnlyOperator(BaseBranchOperator): ui_color = '#e9ffdb' # nyanza - def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]: - # If the DAG Run is externally triggered, then return without - # skipping downstream tasks - if context['dag_run'] and context['dag_run'].external_trigger: + def __init__(self, *args, **kwargs): + def python_callable(dag_run, task, dag, execution_date, **_): + # If the DAG Run is externally triggered, then return without + # skipping downstream tasks + if dag_run and dag_run.external_trigger: + self.log.info( + "Externally triggered DAG_Run: allowing execution to proceed.") + return list(task.get_direct_relative_ids(upstream=False)) + + now = pendulum.utcnow() + left_window = dag.following_schedule(execution_date) + right_window = dag.following_schedule(left_window) self.log.info( - "Externally triggered DAG_Run: allowing execution to proceed.") - return context['task'].get_direct_relative_ids(upstream=False) + 'Checking latest only with left_window: %s right_window: %s now: %s', + left_window, right_window, now + ) - now = pendulum.utcnow() - left_window = context['dag'].following_schedule( - context['execution_date']) - right_window = context['dag'].following_schedule(left_window) - self.log.info( - 'Checking latest only with left_window: %s right_window: %s now: %s', - left_window, right_window, now - ) + if not left_window < now <= right_window: + self.log.info('Not latest execution, skipping downstream.') + # we return an empty list, thus the parent BranchPythonOperator + # won't exclude any downstream tasks from skipping. + return [] + else: + self.log.info('Latest, allowing execution to proceed.') + return list(task.get_direct_relative_ids(upstream=False)) - if not left_window < now <= right_window: - self.log.info('Not latest execution, skipping downstream.') - # we return an empty list, thus the parent BaseBranchOperator - # won't exclude any downstream tasks from skipping. - return [] - else: - self.log.info('Latest, allowing execution to proceed.') - return context['task'].get_direct_relative_ids(upstream=False) + super().__init__(python_callable=python_callable, *args, **kwargs) diff --git a/docs/concepts.rst b/docs/concepts.rst index 31054b00d6ecf..322e1d9c0f5cb 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -763,25 +763,26 @@ For example: start_op >> branch_op >> [continue_op, stop_op] -If you wish to implement your own operators with branching functionality, you -can inherit from :class:`~airflow.operators.branch_operator.BaseBranchOperator`, -which behaves similarly to ``BranchPythonOperator`` but expects you to provide -an implementation of the method ``choose_branch``. As with the callable for -``BranchPythonOperator``, this method should return the ID of a downstream task, -or a list of task IDs, which will be run, and all others will be skipped. +Most of the times you can get away with passing a ``python_callable`` to ``BranchPythonOperator`` +to create branching logic. However, if you do wish to implement your own operators with branching +functionality, you can inherit from ``BranchPythonOperator`` too. If jinja templating is needed for +the arguments of ``python_callable``, pass them as ``op_args`` or ``op_kwargs`` to +``BranchPythonOperator``. .. code:: python - class MyBranchOperator(BaseBranchOperator): - def choose_branch(self, context): + class MyBranchOperator(BranchPythonOperator): + def __init__(self, *args, **kwargs): """ Run an extra branch on the first day of the month """ - if context['execution_date'].day == 1: - return ['daily_task_id', 'monthly_task_id'] - else: - return 'daily_task_id' + def python_callable(execution_date, **_): + if execution_date.day == 1: + return ['daily_task_id', 'monthly_task_id'] + else: + return 'daily_task_id' + super().__init__(python_callable=python_callable, *args, **kwargs) SubDAGs ======= diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py deleted file mode 100644 index 0c303f90f7ae8..0000000000000 --- a/tests/operators/test_branch_operator.py +++ /dev/null @@ -1,177 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import datetime -import unittest - -from airflow.models import DAG, DagRun, TaskInstance as TI -from airflow.operators.branch_operator import BaseBranchOperator -from airflow.operators.dummy_operator import DummyOperator -from airflow.utils import timezone -from airflow.utils.session import create_session -from airflow.utils.state import State - -DEFAULT_DATE = timezone.datetime(2016, 1, 1) -INTERVAL = datetime.timedelta(hours=12) - - -class ChooseBranchOne(BaseBranchOperator): - def choose_branch(self, context): - return 'branch_1' - - -class ChooseBranchOneTwo(BaseBranchOperator): - def choose_branch(self, context): - return ['branch_1', 'branch_2'] - - -class TestBranchOperator(unittest.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() - - def setUp(self): - self.dag = DAG('branch_operator_test', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL) - - self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) - self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) - self.branch_3 = None - self.branch_op = None - - def tearDown(self): - super().tearDown() - - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() - - def test_without_dag_run(self): - """This checks the defensive against non existent tasks in a dag run""" - self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) - self.branch_1.set_upstream(self.branch_op) - self.branch_2.set_upstream(self.branch_op) - self.dag.clear() - - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE - ) - - for ti in tis: - if ti.task_id == 'make_choice': - self.assertEqual(ti.state, State.SUCCESS) - elif ti.task_id == 'branch_1': - # should exist with state None - self.assertEqual(ti.state, State.NONE) - elif ti.task_id == 'branch_2': - self.assertEqual(ti.state, State.SKIPPED) - else: - raise Exception - - def test_branch_list_without_dag_run(self): - """This checks if the BranchOperator supports branching off to a list of tasks.""" - self.branch_op = ChooseBranchOneTwo(task_id='make_choice', dag=self.dag) - self.branch_1.set_upstream(self.branch_op) - self.branch_2.set_upstream(self.branch_op) - self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) - self.branch_3.set_upstream(self.branch_op) - self.dag.clear() - - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - with create_session() as session: - tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, - TI.execution_date == DEFAULT_DATE - ) - - expected = { - "make_choice": State.SUCCESS, - "branch_1": State.NONE, - "branch_2": State.NONE, - "branch_3": State.SKIPPED, - } - - for ti in tis: - if ti.task_id in expected: - self.assertEqual(ti.state, expected[ti.task_id]) - else: - raise Exception - - def test_with_dag_run(self): - self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) - self.branch_1.set_upstream(self.branch_op) - self.branch_2.set_upstream(self.branch_op) - self.dag.clear() - - dagrun = self.dag.create_dagrun( - run_id="manual__", - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING - ) - - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - tis = dagrun.get_task_instances() - for ti in tis: - if ti.task_id == 'make_choice': - self.assertEqual(ti.state, State.SUCCESS) - elif ti.task_id == 'branch_1': - self.assertEqual(ti.state, State.NONE) - elif ti.task_id == 'branch_2': - self.assertEqual(ti.state, State.SKIPPED) - else: - raise Exception - - def test_with_skip_in_branch_downstream_dependencies(self): - self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) - self.branch_op >> self.branch_1 >> self.branch_2 - self.branch_op >> self.branch_2 - self.dag.clear() - - dagrun = self.dag.create_dagrun( - run_id="manual__", - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING - ) - - self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - tis = dagrun.get_task_instances() - for ti in tis: - if ti.task_id == 'make_choice': - self.assertEqual(ti.state, State.SUCCESS) - elif ti.task_id == 'branch_1': - self.assertEqual(ti.state, State.NONE) - elif ti.task_id == 'branch_2': - self.assertEqual(ti.state, State.NONE) - else: - raise Exception