Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""add pool_slots field to task_instance

Revision ID: a4c2fd67d16b
Revises: 7939bcff74ba
Create Date: 2020-01-14 03:35:01.161519

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = 'a4c2fd67d16b'
down_revision = '7939bcff74ba'
branch_labels = None
depends_on = None


def upgrade():
op.add_column('task_instance', sa.Column('pool_slots', sa.Integer, default=1))


def downgrade():
op.drop_column('task_instance', 'pool_slots')
8 changes: 8 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class derived from this one results in the creation of a task object,
:param pool: the slot pool this task should run in, slot pools are a
way to limit concurrency for certain tasks
:type pool: str
:param pool_slots: the number of pool slots this task should use (>= 1)
Values less than 1 are not allowed.
:type pool_slots: int
:param sla: time by which the job is expected to succeed. Note that
this represents the ``timedelta`` after the period is closed. For
example if you set an SLA of 1 hour, the scheduler would send an email
Expand Down Expand Up @@ -313,6 +316,7 @@ def __init__(
weight_rule: str = WeightRule.DOWNSTREAM,
queue: str = conf.get('celery', 'default_queue'),
pool: str = Pool.DEFAULT_POOL_NAME,
pool_slots: int = 1,
sla: Optional[timedelta] = None,
execution_timeout: Optional[timedelta] = None,
on_execute_callback: Optional[Callable] = None,
Expand Down Expand Up @@ -381,6 +385,10 @@ def __init__(
self.retries = retries
self.queue = queue
self.pool = pool
self.pool_slots = pool_slots
if self.pool_slots < 1:
raise AirflowException("pool slots for %s in dag %s cannot be less than 1"
% (self.task_id, dag.dag_id))
self.sla = sla
self.execution_timeout = execution_timeout
self.on_execute_callback = on_execute_callback
Expand Down
12 changes: 6 additions & 6 deletions airflow/models/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def occupied_slots(self, session):
from airflow.models.taskinstance import TaskInstance # Avoid circular import
return (
session
.query(func.count())
.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state.in_(STATES_TO_COUNT_AS_RUNNING))
.scalar()
)
) or 0

@provide_session
def used_slots(self, session):
Expand All @@ -80,11 +80,11 @@ def used_slots(self, session):

running = (
session
.query(func.count())
.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state == State.RUNNING)
.scalar()
)
) or 0
return running

@provide_session
Expand All @@ -96,11 +96,11 @@ def queued_slots(self, session):

return (
session
.query(func.count())
.query(func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.pool == self.pool)
.filter(TaskInstance.state == State.QUEUED)
.scalar()
)
) or 0

@provide_session
def open_slots(self, session):
Expand Down
5 changes: 5 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class TaskInstance(Base, LoggingMixin):
unixname = Column(String(1000))
job_id = Column(Integer)
pool = Column(String(50), nullable=False)
pool_slots = Column(Integer, default=1)
queue = Column(String(256))
priority_weight = Column(Integer)
operator = Column(String(1000))
Expand Down Expand Up @@ -194,6 +195,7 @@ def __init__(self, task, execution_date, state=None):

self.queue = task.queue
self.pool = task.pool
self.pool_slots = task.pool_slots
self.priority_weight = task.priority_weight_total
self.try_number = 0
self.max_tries = self.task.retries
Expand Down Expand Up @@ -458,6 +460,7 @@ def refresh_from_db(self, session=None, lock_for_update=False, refresh_executor_
self.unixname = ti.unixname
self.job_id = ti.job_id
self.pool = ti.pool
self.pool_slots = ti.pool_slots
self.queue = ti.queue
self.priority_weight = ti.priority_weight
self.operator = ti.operator
Expand Down Expand Up @@ -770,6 +773,7 @@ def _check_and_change_state_before_execution(
"""
task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.test_mode = test_mode
self.refresh_from_db(session=session, lock_for_update=True)
self.job_id = job_id
Expand Down Expand Up @@ -885,6 +889,7 @@ def _run_raw_task(

task = self.task
self.pool = pool or task.pool
self.pool_slots = task.pool_slots
self.test_mode = test_mode
self.refresh_from_db(session=session)
self.job_id = job_id
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
"retries": { "type": "number" },
"queue": { "type": "string" },
"pool": { "type": "string" },
"pool_slots": { "type": "number" },
"execution_timeout": { "$ref": "#/definitions/timedelta" },
"retry_delay": { "$ref": "#/definitions/timedelta" },
"retry_exponential_backoff": { "type": "boolean" },
Expand Down
9 changes: 5 additions & 4 deletions airflow/ti_deps/deps/pool_slots_available_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ def _get_dep_statuses(self, ti, session, dep_context=None):
open_slots = pools[0].open_slots()

if ti.state in STATES_TO_COUNT_AS_RUNNING:
open_slots += 1
open_slots += ti.pool_slots

if open_slots <= 0:
if open_slots <= (ti.pool_slots - 1):
yield self._failing_status(
reason=("Not scheduling since there are %s open slots in pool %s",
open_slots, pool_name)
reason=("Not scheduling since there are %s open slots in pool %s "
"and require %s pool slots",
open_slots, pool_name, ti.pool_slots)
)
else:
yield self._passing_status(
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_default_pool_open_slots(self):
dag_id='test_default_pool_open_slots',
start_date=DEFAULT_DATE, )
op1 = DummyOperator(task_id='dummy1', dag=dag)
op2 = DummyOperator(task_id='dummy2', dag=dag)
op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
ti1.state = State.RUNNING
Expand All @@ -109,4 +109,4 @@ def test_default_pool_open_slots(self):
session.commit()
session.close()

self.assertEqual(3, Pool.get_default_pool().open_slots())
self.assertEqual(2, Pool.get_default_pool().open_slots())
13 changes: 13 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,19 @@ def test_run_pooling_task(self):
db.clear_db_pools()
self.assertEqual(ti.state, State.SUCCESS)

def test_pool_slots_property(self):
"""
test that try to create a task with pool_slots less than 1
"""
def create_task_instance():
dag = models.DAG(dag_id='test_run_pooling_task')
task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag,
pool='test_pool', pool_slots=0, owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
return TI(task=task, execution_date=timezone.utcnow())

self.assertRaises(AirflowException, create_task_instance)

@provide_session
def test_ti_updates_with_task(self, session=None):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def test_no_new_fields_added_to_base_operator(self):
'owner': 'airflow',
'params': {},
'pool': 'default_pool',
'pool_slots': 1,
'priority_weight': 1,
'queue': 'default',
'resources': None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def setUp(self):
self.dag.task_ids = [TASK_ID]

# Mock the task
self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[])
self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[], pool_slots=1)
self.task.__class__.__name__ = OPERATOR

self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE)
Expand Down
4 changes: 2 additions & 2 deletions tests/ti_deps/deps/test_dag_ti_slots_available_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_concurrency_reached(self):
Test concurrency reached should fail dep
"""
dag = Mock(concurrency=1, concurrency_reached=True)
task = Mock(dag=dag)
task = Mock(dag=dag, pool_slots=1)
ti = TaskInstance(task, execution_date=None)

self.assertFalse(DagTISlotsAvailableDep().is_met(ti=ti))
Expand All @@ -41,7 +41,7 @@ def test_all_conditions_met(self):
Test all conditions met should pass dep
"""
dag = Mock(concurrency=1, concurrency_reached=False)
task = Mock(dag=dag)
task = Mock(dag=dag, pool_slots=1)
ti = TaskInstance(task, execution_date=None)

self.assertTrue(DagTISlotsAvailableDep().is_met(ti=ti))
8 changes: 4 additions & 4 deletions tests/ti_deps/deps/test_pool_slots_available_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def tearDown(self):
@patch('airflow.models.Pool.open_slots', return_value=0)
# pylint: disable=unused-argument
def test_pooled_task_reached_concurrency(self, mock_open_slots):
ti = Mock(pool='test_pool')
ti = Mock(pool='test_pool', pool_slots=1)
self.assertFalse(PoolSlotsAvailableDep().is_met(ti=ti))

@patch('airflow.models.Pool.open_slots', return_value=1)
# pylint: disable=unused-argument
def test_pooled_task_pass(self, mock_open_slots):
ti = Mock(pool='test_pool')
ti = Mock(pool='test_pool', pool_slots=1)
self.assertTrue(PoolSlotsAvailableDep().is_met(ti=ti))

@patch('airflow.models.Pool.open_slots', return_value=0)
# pylint: disable=unused-argument
def test_running_pooled_task_pass(self, mock_open_slots):
for state in STATES_TO_COUNT_AS_RUNNING:
ti = Mock(pool='test_pool', state=state)
ti = Mock(pool='test_pool', state=state, pool_slots=1)
self.assertTrue(PoolSlotsAvailableDep().is_met(ti=ti))

def test_task_with_nonexistent_pool(self):
ti = Mock(pool='nonexistent_pool')
ti = Mock(pool='nonexistent_pool', pool_slots=1)
self.assertFalse(PoolSlotsAvailableDep().is_met(ti=ti))