Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions airflow/task/task_runner/standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import logging
import os
import threading
import time
from typing import TYPE_CHECKING

import psutil
Expand All @@ -29,6 +31,7 @@
from airflow.api_internal.internal_api_call import InternalApiConfig
from airflow.models.taskinstance import TaskReturnCode
from airflow.settings import CAN_FORK
from airflow.stats import Stats
from airflow.task.task_runner.base_task_runner import BaseTaskRunner
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.process_utils import reap_process_group, set_new_process_group
Expand All @@ -53,6 +56,11 @@ def start(self):
else:
self.process = self._start_by_exec()

if self.process:
resource_monitor = threading.Thread(target=self._read_task_utilization)
resource_monitor.daemon = True
resource_monitor.start()

def _start_by_exec(self) -> psutil.Process:
subprocess = self.run_command()
self.process = psutil.Process(subprocess.pid)
Expand Down Expand Up @@ -186,3 +194,20 @@ def get_process_pid(self) -> int:
if self.process is None:
raise RuntimeError("Process is not started yet")
return self.process.pid

def _read_task_utilization(self):
dag_id = self._task_instance.dag_id
task_id = self._task_instance.task_id

try:
while True:
with self.process.oneshot():
mem_usage = self.process.memory_percent()
cpu_usage = self.process.cpu_percent()

Stats.gauge(f"task.mem_usage.{dag_id}.{task_id}", mem_usage)
Stats.gauge(f"task.cpu_usage.{dag_id}.{task_id}", cpu_usage)
time.sleep(5)
except (psutil.NoSuchProcess, psutil.AccessDenied, AttributeError):
self.log.info("Process not found (most likely exited), stop collecting metrics")
return
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ Name Description
``pool.scheduled_tasks`` Number of scheduled tasks in the pool. Metric with pool_name tagging.
``pool.starving_tasks.<pool_name>`` Number of starving tasks in the pool
``pool.starving_tasks`` Number of starving tasks in the pool. Metric with pool_name tagging.
``task.cpu_usage_percent.<dag_id>.<task_id>`` Percentage of CPU used by a task
``task.mem_usage_percent.<dag_id>.<task_id>`` Percentage of memory used by a task
``triggers.running.<hostname>`` Number of triggers currently running for a triggerer (described by hostname)
``triggers.running`` Number of triggers currently running for a triggerer (described by hostname).
Metric with hostname tagging.
Expand Down
37 changes: 35 additions & 2 deletions tests/task/task_runner/test_standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import psutil
import pytest

from airflow.exceptions import AirflowTaskTimeout
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.listeners.listener import get_listener_manager
Expand Down Expand Up @@ -96,8 +97,9 @@ def clean_listener_manager(self):
yield
get_listener_manager().clear()

@mock.patch.object(StandardTaskRunner, "_read_task_utilization")
@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
def test_start_and_terminate(self, mock_init):
def test_start_and_terminate(self, mock_init, mock_read_task_utilization):
mock_init.return_value = "/tmp/any"
Job = mock.Mock()
Job.job_type = None
Expand Down Expand Up @@ -131,6 +133,7 @@ def test_start_and_terminate(self, mock_init):
assert not psutil.pid_exists(process.pid), f"{process} is still alive"

assert task_runner.return_code() is not None
mock_read_task_utilization.assert_called()

@pytest.mark.db_test
def test_notifies_about_start_and_stop(self, tmp_path):
Expand Down Expand Up @@ -260,8 +263,9 @@ def test_ol_does_not_block_xcoms(self, tmp_path):
assert f.readline() == "on_task_instance_success\n"
assert f.readline() == "listener\n"

@mock.patch.object(StandardTaskRunner, "_read_task_utilization")
@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
def test_start_and_terminate_run_as_user(self, mock_init):
def test_start_and_terminate_run_as_user(self, mock_init, mock_read_task_utilization):
mock_init.return_value = "/tmp/any"
Job = mock.Mock()
Job.job_type = None
Expand Down Expand Up @@ -296,6 +300,7 @@ def test_start_and_terminate_run_as_user(self, mock_init):
assert not psutil.pid_exists(process.pid), f"{process} is still alive"

assert task_runner.return_code() is not None
mock_read_task_utilization.assert_called()

@propagate_task_logger()
@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
Expand Down Expand Up @@ -444,6 +449,34 @@ def test_parsing_context(self):
"_AIRFLOW_PARSING_CONTEXT_TASK_ID=task1\n"
)

@mock.patch("airflow.task.task_runner.standard_task_runner.Stats.gauge")
@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
def test_read_task_utilization(self, mock_init, mock_stats):
mock_init.return_value = "/tmp/any"
Job = mock.Mock()
Job.job_type = None
Job.task_instance = mock.MagicMock()
Job.task_instance.task_id = "task_id"
Job.task_instance.dag_id = "dag_id"
Job.task_instance.run_as_user = None
Job.task_instance.command_as_list.return_value = [
"airflow",
"tasks",
"run",
"test_on_kill",
"task1",
"2016-01-01",
]
job_runner = LocalTaskJobRunner(job=Job, task_instance=Job.task_instance)
task_runner = StandardTaskRunner(job_runner)
task_runner.start()
try:
with timeout(1):
task_runner._read_task_utilization()
except AirflowTaskTimeout:
pass
assert mock_stats.call_count == 2

@staticmethod
def _procs_in_pgroup(pgid):
for proc in psutil.process_iter(attrs=["pid", "name"]):
Expand Down