From c9de188e2cb257f78ca8643a97953c46bd896a4c Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Sun, 13 Aug 2023 18:10:09 +0200 Subject: [PATCH] openlineage: don't run task instance listener in executor Signed-off-by: Maciej Obuchowski --- .../providers/openlineage/plugins/listener.py | 20 ++++--- tests/dags/test_dag_xcom_openlineage.py | 41 ++++++++++++++ tests/listeners/test_listeners.py | 4 ++ tests/listeners/xcom_listener.py | 46 +++++++++++++++ .../task_runner/test_standard_task_runner.py | 56 ++++++++++++++++++- 5 files changed, 157 insertions(+), 10 deletions(-) create mode 100644 tests/dags/test_dag_xcom_openlineage.py create mode 100644 tests/listeners/xcom_listener.py diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index d85a559f56a59..4a6b75f677c94 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from concurrent.futures import Executor, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import TYPE_CHECKING @@ -42,8 +42,8 @@ class OpenLineageListener: """OpenLineage listener sends events on task instance and dag run starts, completes and failures.""" def __init__(self): + self._executor = None self.log = logging.getLogger(__name__) - self.executor: Executor = None # type: ignore self.extractor_manager = ExtractorManager() self.adapter = OpenLineageAdapter() @@ -102,7 +102,7 @@ def on_running(): }, ) - self.executor.submit(on_running) + on_running() @hookimpl def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session): @@ -130,7 +130,7 @@ def on_success(): task=task_metadata, ) - self.executor.submit(on_success) + on_success() @hookimpl def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session): @@ -158,12 +158,17 @@ def on_failure(): task=task_metadata, ) - self.executor.submit(on_failure) + on_failure() + + @property + def executor(self): + if not self._executor: + self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") + return self._executor @hookimpl def on_starting(self, component): self.log.debug("on_starting: %s", component.__class__.__name__) - self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_") @hookimpl def before_stopping(self, component): @@ -174,9 +179,6 @@ def before_stopping(self, component): @hookimpl def on_dag_run_running(self, dag_run: DagRun, msg: str): - if not self.executor: - self.log.error("Executor have not started before `on_dag_run_running`") - return data_interval_start = dag_run.data_interval_start.isoformat() if dag_run.data_interval_start else None data_interval_end = dag_run.data_interval_end.isoformat() if dag_run.data_interval_end else None self.executor.submit( diff --git a/tests/dags/test_dag_xcom_openlineage.py b/tests/dags/test_dag_xcom_openlineage.py new file mode 100644 index 0000000000000..6236c8b4ec3b5 --- /dev/null +++ b/tests/dags/test_dag_xcom_openlineage.py @@ -0,0 +1,41 @@ +## +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime + +from airflow.models import DAG +from airflow.operators.python import PythonOperator + +dag = DAG( + dag_id="test_dag_xcom_openlineage", + default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)}, + schedule="0 0 * * *", + dagrun_timeout=datetime.timedelta(minutes=60), +) + + +def push_and_pull(ti, **kwargs): + ti.xcom_push(key="pushed_key", value="asdf") + ti.xcom_pull(key="pushed_key") + + +task = PythonOperator(task_id="push_and_pull", python_callable=push_and_pull, dag=dag) + +if __name__ == "__main__": + dag.cli() diff --git a/tests/listeners/test_listeners.py b/tests/listeners/test_listeners.py index 6369bd60dad65..f2958f2c01515 100644 --- a/tests/listeners/test_listeners.py +++ b/tests/listeners/test_listeners.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import os + import pytest as pytest from airflow import AirflowException @@ -46,6 +48,8 @@ TASK_ID = "test_listener_task" EXECUTION_DATE = timezone.utcnow() +TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"] + @pytest.fixture(autouse=True) def clean_listener_manager(): diff --git a/tests/listeners/xcom_listener.py b/tests/listeners/xcom_listener.py new file mode 100644 index 0000000000000..a7ffc19178589 --- /dev/null +++ b/tests/listeners/xcom_listener.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.listeners import hookimpl + + +class XComListener: + def __init__(self, path: str, task_id: str): + self.path = path + self.task_id = task_id + + def write(self, line: str): + with open(self.path, "a") as f: + f.write(line + "\n") + + @hookimpl + def on_task_instance_running(self, previous_state, task_instance, session): + task_instance.xcom_push(key="listener", value="listener") + task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener") + self.write("on_task_instance_running") + + @hookimpl + def on_task_instance_success(self, previous_state, task_instance, session): + read = task_instance.xcom_pull(task_ids=self.task_id, key="listener") + self.write("on_task_instance_success") + self.write(read) + + +def clear(): + pass diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index 7d81025b4d6f9..52ac864a41abf 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -39,6 +39,7 @@ from airflow.utils.platform import getuser from airflow.utils.state import State from airflow.utils.timeout import timeout +from tests.listeners import xcom_listener from tests.listeners.file_write_listener import FileWriteListener from tests.test_utils.db import clear_db_runs @@ -85,10 +86,14 @@ def setup_class(self): (as the test environment does not have enough context for the normal way to run) and ensures they reset back to normal on the way out. """ - get_listener_manager().clear() clear_db_runs() yield clear_db_runs() + + @pytest.fixture(autouse=True) + def clean_listener_manager(self): + get_listener_manager().clear() + yield get_listener_manager().clear() @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") @@ -215,6 +220,55 @@ def test_notifies_about_fail(self): assert f.readline() == "on_task_instance_failed\n" assert f.readline() == "before_stopping\n" + def test_ol_does_not_block_xcoms(self): + """ + Test that ensures that pushing and pulling xcoms both in listener and task does not collide + """ + + path_listener_writer = "/tmp/test_ol_does_not_block_xcoms" + try: + os.unlink(path_listener_writer) + except OSError: + pass + + listener = xcom_listener.XComListener(path_listener_writer, "push_and_pull") + get_listener_manager().add_listener(listener) + + dagbag = DagBag( + dag_folder=TEST_DAG_FOLDER, + include_examples=False, + ) + dag = dagbag.dags.get("test_dag_xcom_openlineage") + task = dag.get_task("push_and_pull") + dag.create_dagrun( + run_id="test", + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + state=State.RUNNING, + start_date=DEFAULT_DATE, + ) + + ti = TaskInstance(task=task, run_id="test") + job = Job(dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + task_runner = StandardTaskRunner(job_runner) + task_runner.start() + + # Wait until process makes itself the leader of its own process group + with timeout(seconds=1): + while True: + runner_pgid = os.getpgid(task_runner.process.pid) + if runner_pgid == task_runner.process.pid: + break + time.sleep(0.01) + + # Wait till process finishes + assert task_runner.return_code(timeout=10) is not None + + with open(path_listener_writer) as f: + assert f.readline() == "on_task_instance_running\n" + assert f.readline() == "on_task_instance_success\n" + assert f.readline() == "listener\n" + @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") def test_start_and_terminate_run_as_user(self, mock_init): mock_init.return_value = "/tmp/any"