From 3d96caa3b41dc0add0636e8bb4f80c6de3ef9d3b Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Tue, 8 Aug 2023 16:04:50 -0400 Subject: [PATCH 1/3] Improve fetching logs form AWS --- airflow/providers/amazon/aws/hooks/logs.py | 21 ++++++++++++----- .../amazon/aws/log/cloudwatch_task_handler.py | 23 +++++++++++++++---- .../providers/amazon/aws/utils/__init__.py | 7 +++++- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py index ba9ef09112110..27f9d03081b95 100644 --- a/airflow/providers/amazon/aws/hooks/logs.py +++ b/airflow/providers/amazon/aws/hooks/logs.py @@ -22,6 +22,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.utils.helpers import prune_dict # Guidance received from the AWS team regarding the correct way to check for the end of a stream is that the # value of the nextForwardToken is the same in subsequent calls. @@ -60,6 +61,7 @@ def get_log_events( log_group: str, log_stream_name: str, start_time: int = 0, + end_time: int | None = None, skip: int = 0, start_from_head: bool | None = None, continuation_token: ContinuationToken | None = None, @@ -72,7 +74,9 @@ def get_log_events( :param log_group: The name of the log group. :param log_stream_name: The name of the specific stream. - :param start_time: The time stamp value to start reading the logs from (default: 0). + :param start_time: The timestamp value to start reading the logs from (default: 0). + :param end_time: The timestamp value to stop reading the logs from (default: None). + If None is provided, reads it until the end of the log stream :param skip: The number of log entries to skip at the start (default: 0). This is for when there are multiple entries at the same timestamp. :param start_from_head: Deprecated. Do not use with False, logs would be retrieved out of order. @@ -110,11 +114,16 @@ def get_log_events( token_arg = {} response = self.conn.get_log_events( - logGroupName=log_group, - logStreamName=log_stream_name, - startTime=start_time, - startFromHead=start_from_head, - **token_arg, + **prune_dict( + { + "logGroupName": log_group, + "logStreamName": log_stream_name, + "startTime": start_time, + "endTime": end_time, + "startFromHead": start_from_head, + **token_arg, + } + ) ) events = response["events"] diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index 5d1074b8402a6..0f96e791c59b5 100644 --- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -17,13 +17,15 @@ # under the License. from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from functools import cached_property import watchtower from airflow.configuration import conf +from airflow.models import TaskInstance from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin @@ -90,7 +92,8 @@ def _read(self, task_instance, try_number, metadata=None): try: return ( f"*** Reading remote log from Cloudwatch log_group: {self.log_group} " - f"log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n", + f"log_stream: {stream_name}.\n" + f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n", {"end_of_log": True}, ) except Exception as e: @@ -103,17 +106,29 @@ def _read(self, task_instance, try_number, metadata=None): log += local_log return log, metadata - def get_cloudwatch_logs(self, stream_name: str) -> str: + def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str: """ Return all logs from the given log stream. :param stream_name: name of the Cloudwatch log stream to get all logs from + :param task_instance: the task instance to get logs about :return: string of all logs from the given log stream """ + start_time = ( + None if task_instance.start_date is None else datetime_to_epoch_utc_ms(task_instance.start_date) + ) + # If there is an end_date to the task instance, fetch logs until that date + 30 seconds + # 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted + end_time = ( + None + if task_instance.end_date is None + else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30)) + ) events = self.hook.get_log_events( log_group=self.log_group, log_stream_name=stream_name, - start_from_head=True, + start_time=start_time, + end_time=end_time, ) return "\n".join(self._event_to_str(event) for event in events) diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 312366df26ae7..0a26f3bba592b 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -18,7 +18,7 @@ import logging import re -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from airflow.utils.helpers import prune_dict @@ -55,6 +55,11 @@ def datetime_to_epoch_ms(date_time: datetime) -> int: return int(date_time.timestamp() * 1_000) +def datetime_to_epoch_utc_ms(date_time: datetime) -> int: + """Convert a datetime object to an epoch integer (milliseconds) in UTC timezone.""" + return int(date_time.replace(tzinfo=timezone.utc).timestamp() * 1_000) + + def datetime_to_epoch_us(date_time: datetime) -> int: """Convert a datetime object to an epoch integer (microseconds).""" return int(date_time.timestamp() * 1_000_000) From 017920792ba2588a0069176ed43459bf7cb1a71d Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Wed, 9 Aug 2023 11:38:00 -0400 Subject: [PATCH 2/3] Add unit tests --- .../aws/log/test_cloudwatch_task_handler.py | 37 ++++++++++++++++++- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py index 257aa144407af..ce9908cecf70d 100644 --- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py @@ -18,7 +18,7 @@ from __future__ import annotations import time -from datetime import datetime as dt +from datetime import datetime as dt, timedelta from unittest import mock from unittest.mock import call @@ -31,6 +31,7 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudwatchTaskHandler +from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timezone import datetime @@ -60,7 +61,6 @@ def setup_tests(self, create_log_template, tmp_path_factory): self.local_log_location, f"arn:aws:logs:{self.region_name}:11111111:log-group:{self.remote_log_group}", ) - self.cloudwatch_task_handler.hook date = datetime(2020, 1, 1) dag_id = "dag_for_testing_cloudwatch_task_handler" @@ -154,6 +154,39 @@ def test_read(self): [{"end_of_log": True}], ) + @pytest.mark.parametrize( + "start_date, end_date, expected_start_time, expected_end_time", + [ + (None, None, None, None), + (datetime(2020, 1, 1), None, datetime_to_epoch_utc_ms(datetime(2020, 1, 1)), None), + ( + None, + datetime(2020, 1, 2), + None, + datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)), + ), + ( + datetime(2020, 1, 1), + datetime(2020, 1, 2), + datetime_to_epoch_utc_ms(datetime(2020, 1, 1)), + datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)), + ), + ], + ) + @mock.patch.object(AwsLogsHook, "get_log_events") + def test_get_cloudwatch_logs( + self, mock_get_log_events, start_date, end_date, expected_start_time, expected_end_time + ): + self.ti.start_date = start_date + self.ti.end_date = end_date + self.cloudwatch_task_handler.get_cloudwatch_logs(self.remote_log_stream, self.ti) + mock_get_log_events.assert_called_once_with( + log_group=self.remote_log_group, + log_stream_name=self.remote_log_stream, + start_time=expected_start_time, + end_time=expected_end_time, + ) + def test_close_prevents_duplicate_calls(self): with mock.patch("watchtower.CloudWatchLogHandler.close") as mock_log_handler_close: with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.set_context"): From b3b784beb1258a291165f38e0a6a09fd3b2b36fe Mon Sep 17 00:00:00 2001 From: Vincent Beck Date: Wed, 9 Aug 2023 15:01:29 -0400 Subject: [PATCH 3/3] Add unit test --- airflow/providers/amazon/aws/hooks/logs.py | 4 +-- .../amazon/aws/log/cloudwatch_task_handler.py | 2 +- tests/providers/amazon/aws/hooks/test_logs.py | 29 +++++++++++++++---- .../aws/log/test_cloudwatch_task_handler.py | 4 +-- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py index 27f9d03081b95..db0b29db539d7 100644 --- a/airflow/providers/amazon/aws/hooks/logs.py +++ b/airflow/providers/amazon/aws/hooks/logs.py @@ -74,8 +74,8 @@ def get_log_events( :param log_group: The name of the log group. :param log_stream_name: The name of the specific stream. - :param start_time: The timestamp value to start reading the logs from (default: 0). - :param end_time: The timestamp value to stop reading the logs from (default: None). + :param start_time: The timestamp value in ms to start reading the logs from (default: 0). + :param end_time: The timestamp value in ms to stop reading the logs from (default: None). If None is provided, reads it until the end of the log stream :param skip: The number of log entries to skip at the start (default: 0). This is for when there are multiple entries at the same timestamp. diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index 0f96e791c59b5..5f74468d04ce4 100644 --- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -115,7 +115,7 @@ def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> :return: string of all logs from the given log stream """ start_time = ( - None if task_instance.start_date is None else datetime_to_epoch_utc_ms(task_instance.start_date) + 0 if task_instance.start_date is None else datetime_to_epoch_utc_ms(task_instance.start_date) ) # If there is an end_date to the task instance, fetch logs until that date + 30 seconds # 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted diff --git a/tests/providers/amazon/aws/hooks/test_logs.py b/tests/providers/amazon/aws/hooks/test_logs.py index 00cf38f1ab298..4e2ec3a5954d2 100644 --- a/tests/providers/amazon/aws/hooks/test_logs.py +++ b/tests/providers/amazon/aws/hooks/test_logs.py @@ -18,7 +18,7 @@ from __future__ import annotations from unittest import mock -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest from moto import mock_logs @@ -29,7 +29,7 @@ @mock_logs class TestAwsLogsHook: @pytest.mark.parametrize( - "get_log_events_response, num_skip_events, expected_num_events", + "get_log_events_response, num_skip_events, expected_num_events, end_time", [ # 3 empty responses with different tokens ( @@ -40,6 +40,7 @@ class TestAwsLogsHook: ], 0, 0, + None, ), # 2 events on the second response with same token ( @@ -49,6 +50,7 @@ class TestAwsLogsHook: ], 0, 2, + None, ), # Different tokens, 2 events on the second response then 3 empty responses ( @@ -63,6 +65,7 @@ class TestAwsLogsHook: ], 0, 2, + 10, ), # 2 events on the second response, then 2 empty responses, then 2 consecutive responses with # 2 events with the same token @@ -79,20 +82,36 @@ class TestAwsLogsHook: ], 0, 6, + 20, ), ], ) @patch("airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.conn", new_callable=mock.PropertyMock) - def test_get_log_events(self, mock_conn, get_log_events_response, num_skip_events, expected_num_events): + def test_get_log_events( + self, mock_conn, get_log_events_response, num_skip_events, expected_num_events, end_time + ): mock_conn().get_log_events.side_effect = get_log_events_response + log_group_name = "example-group" + log_stream_name = "example-log-stream" hook = AwsLogsHook(aws_conn_id="aws_default", region_name="us-east-1") events = hook.get_log_events( - log_group="example-group", - log_stream_name="example-log-stream", + log_group=log_group_name, + log_stream_name=log_stream_name, skip=num_skip_events, + end_time=end_time, ) events = list(events) assert len(events) == expected_num_events + kwargs = { + "logGroupName": log_group_name, + "logStreamName": log_stream_name, + "startFromHead": True, + "startTime": 0, + "nextToken": ANY, + } + if end_time: + kwargs["endTime"] = end_time + mock_conn().get_log_events.assert_called_with(**kwargs) diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py index ce9908cecf70d..219f594604a4a 100644 --- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py @@ -157,12 +157,12 @@ def test_read(self): @pytest.mark.parametrize( "start_date, end_date, expected_start_time, expected_end_time", [ - (None, None, None, None), + (None, None, 0, None), (datetime(2020, 1, 1), None, datetime_to_epoch_utc_ms(datetime(2020, 1, 1)), None), ( None, datetime(2020, 1, 2), - None, + 0, datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)), ), (