diff --git a/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py b/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py new file mode 100644 index 0000000000000..4dd825e8b7de4 --- /dev/null +++ b/airflow/migrations/versions/a4c2fd67d16b_add_pool_slots_field_to_task_instance.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""add pool_slots field to task_instance + +Revision ID: a4c2fd67d16b +Revises: 7939bcff74ba +Create Date: 2020-01-14 03:35:01.161519 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'a4c2fd67d16b' +down_revision = '7939bcff74ba' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('task_instance', sa.Column('pool_slots', sa.Integer, default=1)) + + +def downgrade(): + op.drop_column('task_instance', 'pool_slots') diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index e6cbcd78ea298..390b493f04600 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -178,6 +178,9 @@ class derived from this one results in the creation of a task object, :param pool: the slot pool this task should run in, slot pools are a way to limit concurrency for certain tasks :type pool: str + :param pool_slots: the number of pool slots this task should use (>= 1) + Values less than 1 are not allowed. + :type pool_slots: int :param sla: time by which the job is expected to succeed. Note that this represents the ``timedelta`` after the period is closed. For example if you set an SLA of 1 hour, the scheduler would send an email @@ -313,6 +316,7 @@ def __init__( weight_rule: str = WeightRule.DOWNSTREAM, queue: str = conf.get('celery', 'default_queue'), pool: str = Pool.DEFAULT_POOL_NAME, + pool_slots: int = 1, sla: Optional[timedelta] = None, execution_timeout: Optional[timedelta] = None, on_execute_callback: Optional[Callable] = None, @@ -381,6 +385,10 @@ def __init__( self.retries = retries self.queue = queue self.pool = pool + self.pool_slots = pool_slots + if self.pool_slots < 1: + raise AirflowException("pool slots for %s in dag %s cannot be less than 1" + % (self.task_id, dag.dag_id)) self.sla = sla self.execution_timeout = execution_timeout self.on_execute_callback = on_execute_callback diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 6ea4d4aa4d486..dde63e81f403e 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -65,11 +65,11 @@ def occupied_slots(self, session): from airflow.models.taskinstance import TaskInstance # Avoid circular import return ( session - .query(func.count()) + .query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state.in_(STATES_TO_COUNT_AS_RUNNING)) .scalar() - ) + ) or 0 @provide_session def used_slots(self, session): @@ -80,11 +80,11 @@ def used_slots(self, session): running = ( session - .query(func.count()) + .query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.RUNNING) .scalar() - ) + ) or 0 return running @provide_session @@ -96,11 +96,11 @@ def queued_slots(self, session): return ( session - .query(func.count()) + .query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) .filter(TaskInstance.state == State.QUEUED) .scalar() - ) + ) or 0 @provide_session def open_slots(self, session): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d8d47d1694157..501bdc4a469a4 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -154,6 +154,7 @@ class TaskInstance(Base, LoggingMixin): unixname = Column(String(1000)) job_id = Column(Integer) pool = Column(String(50), nullable=False) + pool_slots = Column(Integer, default=1) queue = Column(String(256)) priority_weight = Column(Integer) operator = Column(String(1000)) @@ -194,6 +195,7 @@ def __init__(self, task, execution_date, state=None): self.queue = task.queue self.pool = task.pool + self.pool_slots = task.pool_slots self.priority_weight = task.priority_weight_total self.try_number = 0 self.max_tries = self.task.retries @@ -458,6 +460,7 @@ def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_ self.unixname = ti.unixname self.job_id = ti.job_id self.pool = ti.pool + self.pool_slots = ti.pool_slots self.queue = ti.queue self.priority_weight = ti.priority_weight self.operator = ti.operator @@ -770,6 +773,7 @@ def _check_and_change_state_before_execution( """ task = self.task self.pool = pool or task.pool + self.pool_slots = task.pool_slots self.test_mode = test_mode self.refresh_from_db(session=session, lock_for_update=True) self.job_id = job_id @@ -885,6 +889,7 @@ def _run_raw_task( task = self.task self.pool = pool or task.pool + self.pool_slots = task.pool_slots self.test_mode = test_mode self.refresh_from_db(session=session) self.job_id = job_id diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 3c7677f0b8e5d..49de949677e8d 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -134,6 +134,7 @@ "retries": { "type": "number" }, "queue": { "type": "string" }, "pool": { "type": "string" }, + "pool_slots": { "type": "number" }, "execution_timeout": { "$ref": "#/definitions/timedelta" }, "retry_delay": { "$ref": "#/definitions/timedelta" }, "retry_exponential_backoff": { "type": "boolean" }, diff --git a/airflow/ti_deps/deps/pool_slots_available_dep.py b/airflow/ti_deps/deps/pool_slots_available_dep.py index e88f04f83fc65..3385881b4ad2c 100644 --- a/airflow/ti_deps/deps/pool_slots_available_dep.py +++ b/airflow/ti_deps/deps/pool_slots_available_dep.py @@ -62,12 +62,13 @@ def _get_dep_statuses(self, ti, session, dep_context=None): open_slots = pools[0].open_slots() if ti.state in STATES_TO_COUNT_AS_RUNNING: - open_slots += 1 + open_slots += ti.pool_slots - if open_slots <= 0: + if open_slots <= (ti.pool_slots - 1): yield self._failing_status( - reason=("Not scheduling since there are %s open slots in pool %s", - open_slots, pool_name) + reason=("Not scheduling since there are %s open slots in pool %s " + "and require %s pool slots", + open_slots, pool_name, ti.pool_slots) ) else: yield self._passing_status( diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py index 21ef645d059d6..083f61e16a13d 100644 --- a/tests/models/test_pool.py +++ b/tests/models/test_pool.py @@ -97,7 +97,7 @@ def test_default_pool_open_slots(self): dag_id='test_default_pool_open_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag) - op2 = DummyOperator(task_id='dummy2', dag=dag) + op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING @@ -109,4 +109,4 @@ def test_default_pool_open_slots(self): session.commit() session.close() - self.assertEqual(3, Pool.get_default_pool().open_slots()) + self.assertEqual(2, Pool.get_default_pool().open_slots()) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index cab4715610212..a5bdbe0a763b0 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -370,6 +370,19 @@ def test_run_pooling_task(self): db.clear_db_pools() self.assertEqual(ti.state, State.SUCCESS) + def test_pool_slots_property(self): + """ + test that try to create a task with pool_slots less than 1 + """ + def create_task_instance(): + dag = models.DAG(dag_id='test_run_pooling_task') + task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, + pool='test_pool', pool_slots=0, owner='airflow', + start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) + return TI(task=task, execution_date=timezone.utcnow()) + + self.assertRaises(AirflowException, create_task_instance) + @provide_session def test_ti_updates_with_task(self, session=None): """ diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index b245a1128751e..424238ed23b61 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -574,6 +574,7 @@ def test_no_new_fields_added_to_base_operator(self): 'owner': 'airflow', 'params': {}, 'pool': 'default_pool', + 'pool_slots': 1, 'priority_weight': 1, 'queue': 'default', 'resources': None, diff --git a/tests/test_sentry.py b/tests/test_sentry.py index 9c0cd076e31e4..5e341b58e27c7 100644 --- a/tests/test_sentry.py +++ b/tests/test_sentry.py @@ -70,7 +70,7 @@ def setUp(self): self.dag.task_ids = [TASK_ID] # Mock the task - self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[]) + self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[], pool_slots=1) self.task.__class__.__name__ = OPERATOR self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE) diff --git a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py index 0c03baa603430..5ff628a1b0d98 100644 --- a/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py +++ b/tests/ti_deps/deps/test_dag_ti_slots_available_dep.py @@ -31,7 +31,7 @@ def test_concurrency_reached(self): Test concurrency reached should fail dep """ dag = Mock(concurrency=1, concurrency_reached=True) - task = Mock(dag=dag) + task = Mock(dag=dag, pool_slots=1) ti = TaskInstance(task, execution_date=None) self.assertFalse(DagTISlotsAvailableDep().is_met(ti=ti)) @@ -41,7 +41,7 @@ def test_all_conditions_met(self): Test all conditions met should pass dep """ dag = Mock(concurrency=1, concurrency_reached=False) - task = Mock(dag=dag) + task = Mock(dag=dag, pool_slots=1) ti = TaskInstance(task, execution_date=None) self.assertTrue(DagTISlotsAvailableDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/test_pool_slots_available_dep.py b/tests/ti_deps/deps/test_pool_slots_available_dep.py index 74f277ecb9abc..74c854858e6f9 100644 --- a/tests/ti_deps/deps/test_pool_slots_available_dep.py +++ b/tests/ti_deps/deps/test_pool_slots_available_dep.py @@ -40,22 +40,22 @@ def tearDown(self): @patch('airflow.models.Pool.open_slots', return_value=0) # pylint: disable=unused-argument def test_pooled_task_reached_concurrency(self, mock_open_slots): - ti = Mock(pool='test_pool') + ti = Mock(pool='test_pool', pool_slots=1) self.assertFalse(PoolSlotsAvailableDep().is_met(ti=ti)) @patch('airflow.models.Pool.open_slots', return_value=1) # pylint: disable=unused-argument def test_pooled_task_pass(self, mock_open_slots): - ti = Mock(pool='test_pool') + ti = Mock(pool='test_pool', pool_slots=1) self.assertTrue(PoolSlotsAvailableDep().is_met(ti=ti)) @patch('airflow.models.Pool.open_slots', return_value=0) # pylint: disable=unused-argument def test_running_pooled_task_pass(self, mock_open_slots): for state in STATES_TO_COUNT_AS_RUNNING: - ti = Mock(pool='test_pool', state=state) + ti = Mock(pool='test_pool', state=state, pool_slots=1) self.assertTrue(PoolSlotsAvailableDep().is_met(ti=ti)) def test_task_with_nonexistent_pool(self): - ti = Mock(pool='nonexistent_pool') + ti = Mock(pool='nonexistent_pool', pool_slots=1) self.assertFalse(PoolSlotsAvailableDep().is_met(ti=ti))