diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py index ba9ef09112110..db0b29db539d7 100644 --- a/airflow/providers/amazon/aws/hooks/logs.py +++ b/airflow/providers/amazon/aws/hooks/logs.py @@ -22,6 +22,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.utils.helpers import prune_dict # Guidance received from the AWS team regarding the correct way to check for the end of a stream is that the # value of the nextForwardToken is the same in subsequent calls. @@ -60,6 +61,7 @@ def get_log_events( log_group: str, log_stream_name: str, start_time: int = 0, + end_time: int | None = None, skip: int = 0, start_from_head: bool | None = None, continuation_token: ContinuationToken | None = None, @@ -72,7 +74,9 @@ def get_log_events( :param log_group: The name of the log group. :param log_stream_name: The name of the specific stream. - :param start_time: The time stamp value to start reading the logs from (default: 0). + :param start_time: The timestamp value in ms to start reading the logs from (default: 0). + :param end_time: The timestamp value in ms to stop reading the logs from (default: None). + If None is provided, reads it until the end of the log stream :param skip: The number of log entries to skip at the start (default: 0). This is for when there are multiple entries at the same timestamp. :param start_from_head: Deprecated. Do not use with False, logs would be retrieved out of order. @@ -110,11 +114,16 @@ def get_log_events( token_arg = {} response = self.conn.get_log_events( - logGroupName=log_group, - logStreamName=log_stream_name, - startTime=start_time, - startFromHead=start_from_head, - **token_arg, + **prune_dict( + { + "logGroupName": log_group, + "logStreamName": log_stream_name, + "startTime": start_time, + "endTime": end_time, + "startFromHead": start_from_head, + **token_arg, + } + ) ) events = response["events"] diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index 5d1074b8402a6..5f74468d04ce4 100644 --- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -17,13 +17,15 @@ # under the License. from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from functools import cached_property import watchtower from airflow.configuration import conf +from airflow.models import TaskInstance from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin @@ -90,7 +92,8 @@ def _read(self, task_instance, try_number, metadata=None): try: return ( f"*** Reading remote log from Cloudwatch log_group: {self.log_group} " - f"log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n", + f"log_stream: {stream_name}.\n" + f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n", {"end_of_log": True}, ) except Exception as e: @@ -103,17 +106,29 @@ def _read(self, task_instance, try_number, metadata=None): log += local_log return log, metadata - def get_cloudwatch_logs(self, stream_name: str) -> str: + def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str: """ Return all logs from the given log stream. :param stream_name: name of the Cloudwatch log stream to get all logs from + :param task_instance: the task instance to get logs about :return: string of all logs from the given log stream """ + start_time = ( + 0 if task_instance.start_date is None else datetime_to_epoch_utc_ms(task_instance.start_date) + ) + # If there is an end_date to the task instance, fetch logs until that date + 30 seconds + # 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted + end_time = ( + None + if task_instance.end_date is None + else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30)) + ) events = self.hook.get_log_events( log_group=self.log_group, log_stream_name=stream_name, - start_from_head=True, + start_time=start_time, + end_time=end_time, ) return "\n".join(self._event_to_str(event) for event in events) diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py index 312366df26ae7..0a26f3bba592b 100644 --- a/airflow/providers/amazon/aws/utils/__init__.py +++ b/airflow/providers/amazon/aws/utils/__init__.py @@ -18,7 +18,7 @@ import logging import re -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from airflow.utils.helpers import prune_dict @@ -55,6 +55,11 @@ def datetime_to_epoch_ms(date_time: datetime) -> int: return int(date_time.timestamp() * 1_000) +def datetime_to_epoch_utc_ms(date_time: datetime) -> int: + """Convert a datetime object to an epoch integer (milliseconds) in UTC timezone.""" + return int(date_time.replace(tzinfo=timezone.utc).timestamp() * 1_000) + + def datetime_to_epoch_us(date_time: datetime) -> int: """Convert a datetime object to an epoch integer (microseconds).""" return int(date_time.timestamp() * 1_000_000) diff --git a/tests/providers/amazon/aws/hooks/test_logs.py b/tests/providers/amazon/aws/hooks/test_logs.py index 00cf38f1ab298..4e2ec3a5954d2 100644 --- a/tests/providers/amazon/aws/hooks/test_logs.py +++ b/tests/providers/amazon/aws/hooks/test_logs.py @@ -18,7 +18,7 @@ from __future__ import annotations from unittest import mock -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest from moto import mock_logs @@ -29,7 +29,7 @@ @mock_logs class TestAwsLogsHook: @pytest.mark.parametrize( - "get_log_events_response, num_skip_events, expected_num_events", + "get_log_events_response, num_skip_events, expected_num_events, end_time", [ # 3 empty responses with different tokens ( @@ -40,6 +40,7 @@ class TestAwsLogsHook: ], 0, 0, + None, ), # 2 events on the second response with same token ( @@ -49,6 +50,7 @@ class TestAwsLogsHook: ], 0, 2, + None, ), # Different tokens, 2 events on the second response then 3 empty responses ( @@ -63,6 +65,7 @@ class TestAwsLogsHook: ], 0, 2, + 10, ), # 2 events on the second response, then 2 empty responses, then 2 consecutive responses with # 2 events with the same token @@ -79,20 +82,36 @@ class TestAwsLogsHook: ], 0, 6, + 20, ), ], ) @patch("airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.conn", new_callable=mock.PropertyMock) - def test_get_log_events(self, mock_conn, get_log_events_response, num_skip_events, expected_num_events): + def test_get_log_events( + self, mock_conn, get_log_events_response, num_skip_events, expected_num_events, end_time + ): mock_conn().get_log_events.side_effect = get_log_events_response + log_group_name = "example-group" + log_stream_name = "example-log-stream" hook = AwsLogsHook(aws_conn_id="aws_default", region_name="us-east-1") events = hook.get_log_events( - log_group="example-group", - log_stream_name="example-log-stream", + log_group=log_group_name, + log_stream_name=log_stream_name, skip=num_skip_events, + end_time=end_time, ) events = list(events) assert len(events) == expected_num_events + kwargs = { + "logGroupName": log_group_name, + "logStreamName": log_stream_name, + "startFromHead": True, + "startTime": 0, + "nextToken": ANY, + } + if end_time: + kwargs["endTime"] = end_time + mock_conn().get_log_events.assert_called_with(**kwargs) diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py index 257aa144407af..219f594604a4a 100644 --- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py @@ -18,7 +18,7 @@ from __future__ import annotations import time -from datetime import datetime as dt +from datetime import datetime as dt, timedelta from unittest import mock from unittest.mock import call @@ -31,6 +31,7 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudwatchTaskHandler +from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timezone import datetime @@ -60,7 +61,6 @@ def setup_tests(self, create_log_template, tmp_path_factory): self.local_log_location, f"arn:aws:logs:{self.region_name}:11111111:log-group:{self.remote_log_group}", ) - self.cloudwatch_task_handler.hook date = datetime(2020, 1, 1) dag_id = "dag_for_testing_cloudwatch_task_handler" @@ -154,6 +154,39 @@ def test_read(self): [{"end_of_log": True}], ) + @pytest.mark.parametrize( + "start_date, end_date, expected_start_time, expected_end_time", + [ + (None, None, 0, None), + (datetime(2020, 1, 1), None, datetime_to_epoch_utc_ms(datetime(2020, 1, 1)), None), + ( + None, + datetime(2020, 1, 2), + 0, + datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)), + ), + ( + datetime(2020, 1, 1), + datetime(2020, 1, 2), + datetime_to_epoch_utc_ms(datetime(2020, 1, 1)), + datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)), + ), + ], + ) + @mock.patch.object(AwsLogsHook, "get_log_events") + def test_get_cloudwatch_logs( + self, mock_get_log_events, start_date, end_date, expected_start_time, expected_end_time + ): + self.ti.start_date = start_date + self.ti.end_date = end_date + self.cloudwatch_task_handler.get_cloudwatch_logs(self.remote_log_stream, self.ti) + mock_get_log_events.assert_called_once_with( + log_group=self.remote_log_group, + log_stream_name=self.remote_log_stream, + start_time=expected_start_time, + end_time=expected_end_time, + ) + def test_close_prevents_duplicate_calls(self): with mock.patch("watchtower.CloudWatchLogHandler.close") as mock_log_handler_close: with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.set_context"):