From 989e07a58f06fec171a73ac37b7ea3d08b1ab161 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Tue, 28 May 2024 15:18:18 +0200 Subject: [PATCH] Add error stacktrace to OpenLineage task event Signed-off-by: Kacper Muda --- .../providers/openlineage/plugins/adapter.py | 12 +- .../providers/openlineage/plugins/listener.py | 66 ++++++++--- .../openlineage/plugins/test_listener.py | 109 +++++++----------- 3 files changed, 101 insertions(+), 86 deletions(-) diff --git a/airflow/providers/openlineage/plugins/adapter.py b/airflow/providers/openlineage/plugins/adapter.py index 5a5b8ed34bd80..608bd568e11c9 100644 --- a/airflow/providers/openlineage/plugins/adapter.py +++ b/airflow/providers/openlineage/plugins/adapter.py @@ -276,6 +276,7 @@ def fail_task( parent_run_id: str | None, end_time: str, task: OperatorLineage, + error: str | BaseException | None = None, ) -> RunEvent: """ Emit openlineage event of type FAIL. @@ -287,7 +288,16 @@ def fail_task( :param parent_run_id: identifier of job spawning this task :param end_time: time of task completion :param task: metadata container with information extracted from operator + :param error: error """ + error_facet = {} + if error: + if isinstance(error, BaseException): + import traceback + + error = "\\n".join(traceback.format_exception(type(error), error, error.__traceback__)) + error_facet = {"errorMessage": ErrorMessageRunFacet(message=error, programmingLanguage="python")} + event = RunEvent( eventType=RunState.FAIL, eventTime=end_time, @@ -296,7 +306,7 @@ def fail_task( job_name=job_name, parent_job_name=parent_job_name, parent_run_id=parent_run_id, - run_facets=task.run_facets, + run_facets={**task.run_facets, **error_facet}, ), job=self._build_job(job_name, job_type=_JOB_TYPE_TASK, job_facets=task.job_facets), inputs=task.inputs, diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index cd2901a721e5b..4a1085168aa8e 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -22,8 +22,9 @@ from typing import TYPE_CHECKING from openlineage.client.serde import Serde +from packaging.version import Version -from airflow import __version__ as airflow_version, settings +from airflow import __version__ as AIRFLOW_VERSION, settings from airflow.listeners import hookimpl from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors import ExtractorManager @@ -43,18 +44,17 @@ from sqlalchemy.orm import Session from airflow.models import DagRun, TaskInstance + from airflow.utils.state import TaskInstanceState _openlineage_listener: OpenLineageListener | None = None +_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") def _get_try_number_success(val): # todo: remove when min airflow version >= 2.10.0 - from packaging.version import parse - - if parse(parse(airflow_version).base_version) < parse("2.10.0"): - return val.try_number - 1 - else: + if _IS_AIRFLOW_2_10_OR_HIGHER: return val.try_number + return val.try_number - 1 class OpenLineageListener: @@ -69,10 +69,10 @@ def __init__(self): @hookimpl def on_task_instance_running( self, - previous_state, + previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session, # This will always be QUEUED - ): + ) -> None: if not getattr(task_instance, "task", None) is not None: self.log.warning( "No task set for TI object task_id: %s - dag_id: %s - run_id %s", @@ -159,7 +159,9 @@ def on_running(): on_running() @hookimpl - def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session): + def on_task_instance_success( + self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session + ) -> None: self.log.debug("OpenLineage listener got notification about task instance success") dagrun = task_instance.dag_run @@ -223,8 +225,37 @@ def on_success(): on_success() - @hookimpl - def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session): + if _IS_AIRFLOW_2_10_OR_HIGHER: + + @hookimpl + def on_task_instance_failed( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + error: None | str | BaseException, + session: Session, + ) -> None: + self._on_task_instance_failed( + previous_state=previous_state, task_instance=task_instance, error=error, session=session + ) + + else: + + @hookimpl + def on_task_instance_failed( + self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session + ) -> None: + self._on_task_instance_failed( + previous_state=previous_state, task_instance=task_instance, error=None, session=session + ) + + def _on_task_instance_failed( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + session: Session, + error: None | str | BaseException = None, + ) -> None: self.log.debug("OpenLineage listener got notification about task instance failure") dagrun = task_instance.dag_run @@ -280,6 +311,7 @@ def on_failure(): parent_run_id=parent_run_id, end_time=end_date.isoformat(), task=task_metadata, + error=error, ) Stats.gauge( f"ol.event.size.{event_type}.{operator_name}", @@ -289,7 +321,7 @@ def on_failure(): on_failure() @property - def executor(self): + def executor(self) -> ProcessPoolExecutor: def initializer(): # Re-configure the ORM engine as there are issues with multiple processes # if process calls Airflow DB. @@ -303,17 +335,17 @@ def initializer(): return self._executor @hookimpl - def on_starting(self, component): + def on_starting(self, component) -> None: self.log.debug("on_starting: %s", component.__class__.__name__) @hookimpl - def before_stopping(self, component): + def before_stopping(self, component) -> None: self.log.debug("before_stopping: %s", component.__class__.__name__) with timeout(30): self.executor.shutdown(wait=True) @hookimpl - def on_dag_run_running(self, dag_run: DagRun, msg: str): + def on_dag_run_running(self, dag_run: DagRun, msg: str) -> None: if dag_run.dag and not is_selective_lineage_enabled(dag_run.dag): self.log.debug( "Skipping OpenLineage event emission for DAG `%s` " @@ -338,7 +370,7 @@ def on_dag_run_running(self, dag_run: DagRun, msg: str): ) @hookimpl - def on_dag_run_success(self, dag_run: DagRun, msg: str): + def on_dag_run_success(self, dag_run: DagRun, msg: str) -> None: if dag_run.dag and not is_selective_lineage_enabled(dag_run.dag): self.log.debug( "Skipping OpenLineage event emission for DAG `%s` " @@ -355,7 +387,7 @@ def on_dag_run_success(self, dag_run: DagRun, msg: str): self.executor.submit(self.adapter.dag_success, dag_run=dag_run, msg=msg) @hookimpl - def on_dag_run_failed(self, dag_run: DagRun, msg: str): + def on_dag_run_failed(self, dag_run: DagRun, msg: str) -> None: if dag_run.dag and not is_selective_lineage_enabled(dag_run.dag): self.log.debug( "Skipping OpenLineage event emission for DAG `%s` " diff --git a/tests/providers/openlineage/plugins/test_listener.py b/tests/providers/openlineage/plugins/test_listener.py index ca5708bbba690..15d928ec03b58 100644 --- a/tests/providers/openlineage/plugins/test_listener.py +++ b/tests/providers/openlineage/plugins/test_listener.py @@ -39,7 +39,6 @@ pytestmark = pytest.mark.db_test EXPECTED_TRY_NUMBER_1 = 1 if AIRFLOW_V_2_10_PLUS else 0 -EXPECTED_TRY_NUMBER_2 = 2 if AIRFLOW_V_2_10_PLUS else 1 TRY_NUMBER_BEFORE_EXECUTION = 0 if AIRFLOW_V_2_10_PLUS else 1 TRY_NUMBER_RUNNING = 0 if AIRFLOW_V_2_10_PLUS else 1 @@ -276,7 +275,13 @@ def mock_task_id(dag_id, task_id, try_number, execution_date): mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id mock_disabled.return_value = False - listener.on_task_instance_failed(None, task_instance, None) + err = ValueError("test") + on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS else {} + expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None} + + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, session=None, **on_task_failed_listener_kwargs + ) listener.adapter.fail_task.assert_called_once_with( end_time="2023-01-03T13:01:01", job_name="job_name", @@ -284,6 +289,7 @@ def mock_task_id(dag_id, task_id, try_number, execution_date): parent_run_id="execution_date.dag_id", run_id="execution_date.dag_id.task_id.1", task=listener.extractor_manager.extract_metadata(), + **expected_err_kwargs, ) @@ -316,7 +322,7 @@ def mock_task_id(dag_id, task_id, try_number, execution_date): listener.on_task_instance_success(None, task_instance, None) # This run_id will be different as we did NOT simulate increase of the try_number attribute, - # which happens in Airflow. + # which happens in Airflow < 2.10. calls = listener.adapter.complete_task.call_args_list assert len(calls) == 1 assert calls[0][1] == dict( @@ -328,65 +334,8 @@ def mock_task_id(dag_id, task_id, try_number, execution_date): task=listener.extractor_manager.extract_metadata(), ) - # Now we simulate the increase of try_number, and the run_id should reflect that change. - listener.adapter.complete_task.reset_mock() - task_instance.try_number += 1 - listener.on_task_instance_success(None, task_instance, None) - calls = listener.adapter.complete_task.call_args_list - assert len(calls) == 1 - assert calls[0][1] == dict( - end_time="2023-01-03T13:01:01", - job_name="job_name", - parent_job_name="dag_id", - parent_run_id="execution_date.dag_id", - run_id=f"execution_date.dag_id.task_id.{EXPECTED_TRY_NUMBER_2}", - task=listener.extractor_manager.extract_metadata(), - ) - - -@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") -def test_run_id_is_constant_across_all_methods(mocked_adapter): - """Tests that the run_id remains constant across different methods of the listener. - - It ensures that the run_id generated for starting, failing, and completing a task is consistent, - reflecting the task's identity and execution context. The test also simulates the change in the - try_number attribute, as it would occur in Airflow, to verify that the run_id updates accordingly. - """ - - def mock_task_id(dag_id, task_id, try_number, execution_date): - returned_try_number = try_number if AIRFLOW_V_2_10_PLUS else max(try_number - 1, 1) - return f"{execution_date}.{dag_id}.{task_id}.{returned_try_number}" - listener, task_instance = _create_listener_and_task_instance() - mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id - expected_run_id_1 = "execution_date.dag_id.task_id.1" - expected_run_id_2 = "execution_date.dag_id.task_id.2" - listener.on_task_instance_running(None, task_instance, None) - assert listener.adapter.start_task.call_args.kwargs["run_id"] == expected_run_id_1 - - listener.on_task_instance_failed(None, task_instance, None) - assert ( - listener.adapter.fail_task.call_args.kwargs["run_id"] == expected_run_id_1 - if AIRFLOW_V_2_10_PLUS - else expected_run_id_2 - ) - - # This run_id will not be different as we did NOT simulate increase of the try_number attribute, - listener.on_task_instance_success(None, task_instance, None) - assert listener.adapter.complete_task.call_args.kwargs["run_id"] == expected_run_id_1 - - # Now we simulate the increase of try_number, and the run_id should reflect that change. - # This is how airflow works, and that's why we expect the run_id to remain constant across all methods. - task_instance.try_number += 1 - listener.on_task_instance_success(None, task_instance, None) - assert ( - listener.adapter.complete_task.call_args.kwargs["run_id"] == expected_run_id_2 - if AIRFLOW_V_2_10_PLUS - else expected_run_id_1 - ) - - -def test_running_task_correctly_calls_openlineage_adapter_run_id_method(): +def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method(): """Tests the OpenLineageListener's response when a task instance is in the running state. This test ensures that when an Airflow task instance transitions to the running state, @@ -404,7 +353,7 @@ def test_running_task_correctly_calls_openlineage_adapter_run_id_method(): @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") -def test_failed_task_correctly_calls_openlineage_adapter_run_id_method(mock_adapter): +def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(mock_adapter): """Tests the OpenLineageListener's response when a task instance is in the failed state. This test ensures that when an Airflow task instance transitions to the failed state, @@ -412,7 +361,11 @@ def test_failed_task_correctly_calls_openlineage_adapter_run_id_method(mock_adap parameters derived from the task instance. """ listener, task_instance = _create_listener_and_task_instance() - listener.on_task_instance_failed(None, task_instance, None) + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} + + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs + ) mock_adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", @@ -422,7 +375,7 @@ def test_failed_task_correctly_calls_openlineage_adapter_run_id_method(mock_adap @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") -def test_successful_task_correctly_calls_openlineage_adapter_run_id_method(mock_adapter): +def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(mock_adapter): """Tests the OpenLineageListener's response when a task instance is in the success state. This test ensures that when an Airflow task instance transitions to the success state, @@ -530,7 +483,11 @@ def test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_oper listener, task_instance = _create_listener_and_task_instance() mock_disabled.return_value = True - listener.on_task_instance_failed(None, task_instance, None) + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} + + listener.on_task_instance_failed( + previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs + ) mock_disabled.assert_called_once_with(task_instance.task) mocked_adapter.build_dag_run_id.assert_not_called() mocked_adapter.build_task_instance_run_id.assert_not_called() @@ -645,6 +602,8 @@ def test_listener_with_task_enabled( if enable_task: enable_lineage(self.task_1) + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} + conf.selective_enable.cache_clear() with conf_vars({("openlineage", "selective_enable"): selective_enable}): listener = OpenLineageListener() @@ -662,14 +621,24 @@ def test_listener_with_task_enabled( # run TaskInstance-related hooks for lineage enabled task listener.on_task_instance_running(None, self.task_instance_1, None) listener.on_task_instance_success(None, self.task_instance_1, None) - listener.on_task_instance_failed(None, self.task_instance_1, None) + listener.on_task_instance_failed( + previous_state=None, + task_instance=self.task_instance_1, + session=None, + **on_task_failed_kwargs, + ) assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count # run TaskInstance-related hooks for lineage disabled task listener.on_task_instance_running(None, self.task_instance_2, None) listener.on_task_instance_success(None, self.task_instance_2, None) - listener.on_task_instance_failed(None, self.task_instance_2, None) + listener.on_task_instance_failed( + previous_state=None, + task_instance=self.task_instance_2, + session=None, + **on_task_failed_kwargs, + ) # with selective-enable disabled both task_1 and task_2 should trigger metadata extraction if selective_enable == "False": @@ -697,6 +666,8 @@ def test_listener_with_dag_disabled_task_enabled( if enable_task: enable_lineage(self.task_1) + on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} + conf.selective_enable.cache_clear() with conf_vars({("openlineage", "selective_enable"): selective_enable}): listener = OpenLineageListener() @@ -712,7 +683,9 @@ def test_listener_with_dag_disabled_task_enabled( # run TaskInstance-related hooks for lineage enabled task listener.on_task_instance_running(None, self.task_instance_1, None) listener.on_task_instance_success(None, self.task_instance_1, None) - listener.on_task_instance_failed(None, self.task_instance_1, None) + listener.on_task_instance_failed( + previous_state=None, task_instance=self.task_instance_1, session=None, **on_task_failed_kwargs + ) try: assert expected_call_count == listener._executor.submit.call_count