Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import prune_dict

# Guidance received from the AWS team regarding the correct way to check for the end of a stream is that the
# value of the nextForwardToken is the same in subsequent calls.
Expand Down Expand Up @@ -60,6 +61,7 @@ def get_log_events(
log_group: str,
log_stream_name: str,
start_time: int = 0,
end_time: int | None = None,
skip: int = 0,
start_from_head: bool | None = None,
continuation_token: ContinuationToken | None = None,
Expand All @@ -72,7 +74,9 @@ def get_log_events(

:param log_group: The name of the log group.
:param log_stream_name: The name of the specific stream.
:param start_time: The time stamp value to start reading the logs from (default: 0).
:param start_time: The timestamp value in ms to start reading the logs from (default: 0).
:param end_time: The timestamp value in ms to stop reading the logs from (default: None).
If None is provided, reads it until the end of the log stream
:param skip: The number of log entries to skip at the start (default: 0).
This is for when there are multiple entries at the same timestamp.
:param start_from_head: Deprecated. Do not use with False, logs would be retrieved out of order.
Expand Down Expand Up @@ -110,11 +114,16 @@ def get_log_events(
token_arg = {}

response = self.conn.get_log_events(
logGroupName=log_group,
logStreamName=log_stream_name,
startTime=start_time,
startFromHead=start_from_head,
**token_arg,
**prune_dict(
{
"logGroupName": log_group,
"logStreamName": log_stream_name,
"startTime": start_time,
"endTime": end_time,
"startFromHead": start_from_head,
**token_arg,
}
)
)

events = response["events"]
Expand Down
23 changes: 19 additions & 4 deletions airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta
from functools import cached_property

import watchtower

from airflow.configuration import conf
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -90,7 +92,8 @@ def _read(self, task_instance, try_number, metadata=None):
try:
return (
f"*** Reading remote log from Cloudwatch log_group: {self.log_group} "
f"log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n",
f"log_stream: {stream_name}.\n"
f"{self.get_cloudwatch_logs(stream_name=stream_name, task_instance=task_instance)}\n",
{"end_of_log": True},
)
except Exception as e:
Expand All @@ -103,17 +106,29 @@ def _read(self, task_instance, try_number, metadata=None):
log += local_log
return log, metadata

def get_cloudwatch_logs(self, stream_name: str) -> str:
def get_cloudwatch_logs(self, stream_name: str, task_instance: TaskInstance) -> str:
"""
Return all logs from the given log stream.

:param stream_name: name of the Cloudwatch log stream to get all logs from
:param task_instance: the task instance to get logs about
:return: string of all logs from the given log stream
"""
start_time = (
0 if task_instance.start_date is None else datetime_to_epoch_utc_ms(task_instance.start_date)
)
# If there is an end_date to the task instance, fetch logs until that date + 30 seconds
# 30 seconds is an arbitrary buffer so that we don't miss any logs that were emitted
end_time = (
None
if task_instance.end_date is None
else datetime_to_epoch_utc_ms(task_instance.end_date + timedelta(seconds=30))
)
events = self.hook.get_log_events(
log_group=self.log_group,
log_stream_name=stream_name,
start_from_head=True,
start_time=start_time,
end_time=end_time,
)
return "\n".join(self._event_to_str(event) for event in events)

Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/amazon/aws/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging
import re
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum

from airflow.utils.helpers import prune_dict
Expand Down Expand Up @@ -55,6 +55,11 @@ def datetime_to_epoch_ms(date_time: datetime) -> int:
return int(date_time.timestamp() * 1_000)


def datetime_to_epoch_utc_ms(date_time: datetime) -> int:
"""Convert a datetime object to an epoch integer (milliseconds) in UTC timezone."""
return int(date_time.replace(tzinfo=timezone.utc).timestamp() * 1_000)


def datetime_to_epoch_us(date_time: datetime) -> int:
"""Convert a datetime object to an epoch integer (microseconds)."""
return int(date_time.timestamp() * 1_000_000)
Expand Down
29 changes: 24 additions & 5 deletions tests/providers/amazon/aws/hooks/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import patch
from unittest.mock import ANY, patch

import pytest
from moto import mock_logs
Expand All @@ -29,7 +29,7 @@
@mock_logs
class TestAwsLogsHook:
@pytest.mark.parametrize(
"get_log_events_response, num_skip_events, expected_num_events",
"get_log_events_response, num_skip_events, expected_num_events, end_time",
[
# 3 empty responses with different tokens
(
Expand All @@ -40,6 +40,7 @@ class TestAwsLogsHook:
],
0,
0,
None,
),
# 2 events on the second response with same token
(
Expand All @@ -49,6 +50,7 @@ class TestAwsLogsHook:
],
0,
2,
None,
),
# Different tokens, 2 events on the second response then 3 empty responses
(
Expand All @@ -63,6 +65,7 @@ class TestAwsLogsHook:
],
0,
2,
10,
),
# 2 events on the second response, then 2 empty responses, then 2 consecutive responses with
# 2 events with the same token
Expand All @@ -79,20 +82,36 @@ class TestAwsLogsHook:
],
0,
6,
20,
),
],
)
@patch("airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.conn", new_callable=mock.PropertyMock)
def test_get_log_events(self, mock_conn, get_log_events_response, num_skip_events, expected_num_events):
def test_get_log_events(
self, mock_conn, get_log_events_response, num_skip_events, expected_num_events, end_time
):
mock_conn().get_log_events.side_effect = get_log_events_response

log_group_name = "example-group"
log_stream_name = "example-log-stream"
hook = AwsLogsHook(aws_conn_id="aws_default", region_name="us-east-1")
events = hook.get_log_events(
log_group="example-group",
log_stream_name="example-log-stream",
log_group=log_group_name,
log_stream_name=log_stream_name,
skip=num_skip_events,
end_time=end_time,
)

events = list(events)

assert len(events) == expected_num_events
kwargs = {
"logGroupName": log_group_name,
"logStreamName": log_stream_name,
"startFromHead": True,
"startTime": 0,
"nextToken": ANY,
}
if end_time:
kwargs["endTime"] = end_time
mock_conn().get_log_events.assert_called_with(**kwargs)
37 changes: 35 additions & 2 deletions tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import time
from datetime import datetime as dt
from datetime import datetime as dt, timedelta
from unittest import mock
from unittest.mock import call

Expand All @@ -31,6 +31,7 @@
from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudwatchTaskHandler
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
Expand Down Expand Up @@ -60,7 +61,6 @@ def setup_tests(self, create_log_template, tmp_path_factory):
self.local_log_location,
f"arn:aws:logs:{self.region_name}:11111111:log-group:{self.remote_log_group}",
)
self.cloudwatch_task_handler.hook

date = datetime(2020, 1, 1)
dag_id = "dag_for_testing_cloudwatch_task_handler"
Expand Down Expand Up @@ -154,6 +154,39 @@ def test_read(self):
[{"end_of_log": True}],
)

@pytest.mark.parametrize(
"start_date, end_date, expected_start_time, expected_end_time",
[
(None, None, 0, None),
(datetime(2020, 1, 1), None, datetime_to_epoch_utc_ms(datetime(2020, 1, 1)), None),
(
None,
datetime(2020, 1, 2),
0,
datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)),
),
(
datetime(2020, 1, 1),
datetime(2020, 1, 2),
datetime_to_epoch_utc_ms(datetime(2020, 1, 1)),
datetime_to_epoch_utc_ms(datetime(2020, 1, 2) + timedelta(seconds=30)),
),
],
)
@mock.patch.object(AwsLogsHook, "get_log_events")
def test_get_cloudwatch_logs(
self, mock_get_log_events, start_date, end_date, expected_start_time, expected_end_time
):
self.ti.start_date = start_date
self.ti.end_date = end_date
self.cloudwatch_task_handler.get_cloudwatch_logs(self.remote_log_stream, self.ti)
mock_get_log_events.assert_called_once_with(
log_group=self.remote_log_group,
log_stream_name=self.remote_log_stream,
start_time=expected_start_time,
end_time=expected_end_time,
)

def test_close_prevents_duplicate_calls(self):
with mock.patch("watchtower.CloudWatchLogHandler.close") as mock_log_handler_close:
with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.set_context"):
Expand Down