Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 57 additions & 80 deletions airflow/providers/amazon/aws/operators/eventbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.eventbridge import EventBridgeHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from airflow.utils.context import Context


class EventBridgePutEventsOperator(BaseOperator):
class EventBridgePutEventsOperator(AwsBaseOperator[EventBridgeHook]):
"""
Put Events onto Amazon EventBridge.

Expand All @@ -38,32 +38,25 @@ class EventBridgePutEventsOperator(BaseOperator):

:param entries: the list of events to be put onto EventBridge, each event is a dict (required)
:param endpoint_id: the URL subdomain of the endpoint
:param aws_conn_id: the AWS connection to use
:param region_name: the region where events are to be sent

:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = ("entries", "endpoint_id", "aws_conn_id", "region_name")
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields("entries", "endpoint_id")

def __init__(
self,
*,
entries: list[dict],
endpoint_id: str | None = None,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
**kwargs,
):
def __init__(self, *, entries: list[dict], endpoint_id: str | None = None, **kwargs):
super().__init__(**kwargs)
self.entries = entries
self.endpoint_id = endpoint_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
response = self.hook.conn.put_events(
Expand All @@ -90,7 +83,7 @@ def execute(self, context: Context):
return [e["EventId"] for e in response["Entries"]]


class EventBridgePutRuleOperator(BaseOperator):
class EventBridgePutRuleOperator(AwsBaseOperator[EventBridgeHook]):
"""
Create or update a specified EventBridge rule.

Expand All @@ -106,12 +99,20 @@ class EventBridgePutRuleOperator(BaseOperator):
:param schedule_expression: the scheduling expression (for example, a cron or rate expression)
:param state: indicates whether rule is set to be "ENABLED" or "DISABLED"
:param tags: list of key-value pairs to associate with the rule
:param region: the region where rule is to be created or updated

:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = (
"aws_conn_id",
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields(
"name",
"description",
"event_bus_name",
Expand All @@ -120,7 +121,6 @@ class EventBridgePutRuleOperator(BaseOperator):
"schedule_expression",
"state",
"tags",
"region_name",
)

def __init__(
Expand All @@ -134,8 +134,6 @@ def __init__(
schedule_expression: str | None = None,
state: str | None = None,
tags: list | None = None,
region_name: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -144,16 +142,9 @@ def __init__(
self.event_bus_name = event_bus_name
self.event_pattern = event_pattern
self.role_arn = role_arn
self.region_name = region_name
self.schedule_expression = schedule_expression
self.state = state
self.tags = tags
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
self.log.info('Sending rule "%s" to EventBridge.', self.name)
Expand All @@ -170,7 +161,7 @@ def execute(self, context: Context):
)


class EventBridgeEnableRuleOperator(BaseOperator):
class EventBridgeEnableRuleOperator(AwsBaseOperator[EventBridgeHook]):
"""
Enable an EventBridge Rule.

Expand All @@ -180,32 +171,25 @@ class EventBridgeEnableRuleOperator(BaseOperator):

:param name: the name of the rule to enable
:param event_bus_name: the name or ARN of the event bus associated with the rule (default if omitted)
:param aws_conn_id: the AWS connection to use
:param region_name: the region of the rule to be enabled

:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", "aws_conn_id")
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields("name", "event_bus_name")

def __init__(
self,
*,
name: str,
event_bus_name: str | None = None,
region_name: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, name: str, event_bus_name: str | None = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.event_bus_name = event_bus_name
self.region_name = region_name
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
self.hook.conn.enable_rule(
Expand All @@ -220,7 +204,7 @@ def execute(self, context: Context):
self.log.info('Enabled rule "%s"', self.name)


class EventBridgeDisableRuleOperator(BaseOperator):
class EventBridgeDisableRuleOperator(AwsBaseOperator[EventBridgeHook]):
"""
Disable an EventBridge Rule.

Expand All @@ -230,32 +214,25 @@ class EventBridgeDisableRuleOperator(BaseOperator):

:param name: the name of the rule to disable
:param event_bus_name: the name or ARN of the event bus associated with the rule (default if omitted)
:param aws_conn_id: the AWS connection to use
:param region_name: the region of the rule to be disabled

:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.htmlt
"""

template_fields: Sequence[str] = ("name", "event_bus_name", "region_name", "aws_conn_id")
aws_hook_class = EventBridgeHook
template_fields: Sequence[str] = aws_template_fields("name", "event_bus_name")

def __init__(
self,
*,
name: str,
event_bus_name: str | None = None,
region_name: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, name: str, event_bus_name: str | None = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.event_bus_name = event_bus_name
self.region_name = region_name
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> EventBridgeHook:
"""Create and return an EventBridgeHook."""
return EventBridgeHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
self.hook.conn.disable_rule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
92 changes: 80 additions & 12 deletions tests/providers/amazon/aws/operators/test_eventbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,28 @@

class TestEventBridgePutEventsOperator:
def test_init(self):
operator = EventBridgePutEventsOperator(
op = EventBridgePutEventsOperator(
task_id="put_events_job",
entries=ENTRIES,
aws_conn_id="fake-conn-id",
region_name="eu-central-1",
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
)

assert operator.entries == ENTRIES
assert op.entries == ENTRIES
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-central-1"
assert op.hook._verify == "/spam/egg.pem"
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgePutEventsOperator(task_id="put_events_job", entries=ENTRIES)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_execute(self, mock_conn: MagicMock):
Expand Down Expand Up @@ -83,11 +99,31 @@ def test_failed_to_send(self, mock_conn: MagicMock):

class TestEventBridgePutRuleOperator:
def test_init(self):
operator = EventBridgePutRuleOperator(
op = EventBridgePutRuleOperator(
task_id="events_put_rule_job",
name=RULE_NAME,
event_pattern=EVENT_PATTERN,
aws_conn_id="fake-conn-id",
region_name="eu-west-1",
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
)
assert op.event_pattern == EVENT_PATTERN
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify == "/spam/egg.pem"
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgePutRuleOperator(
task_id="events_put_rule_job", name=RULE_NAME, event_pattern=EVENT_PATTERN
)

assert operator.event_pattern == EVENT_PATTERN
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_execute(self, mock_conn: MagicMock):
Expand Down Expand Up @@ -117,12 +153,28 @@ def test_put_rule_with_bad_json_fails(self):

class TestEventBridgeEnableRuleOperator:
def test_init(self):
operator = EventBridgeDisableRuleOperator(
op = EventBridgeEnableRuleOperator(
task_id="enable_rule_task",
name=RULE_NAME,
aws_conn_id="fake-conn-id",
region_name="us-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)

assert operator.name == RULE_NAME
assert op.name == RULE_NAME
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "us-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgeEnableRuleOperator(task_id="enable_rule_task", name=RULE_NAME)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_enable_rule(self, mock_conn: MagicMock):
Expand All @@ -137,12 +189,28 @@ def test_enable_rule(self, mock_conn: MagicMock):

class TestEventBridgeDisableRuleOperator:
def test_init(self):
operator = EventBridgeDisableRuleOperator(
op = EventBridgeDisableRuleOperator(
task_id="disable_rule_task",
name=RULE_NAME,
aws_conn_id="fake-conn-id",
region_name="ca-west-1",
verify=True,
botocore_config={"read_timeout": 42},
)

assert operator.name == RULE_NAME
assert op.name == RULE_NAME
assert op.hook.client_type == "events"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "ca-west-1"
assert op.hook._verify is True
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = EventBridgeDisableRuleOperator(task_id="disable_rule_task", name=RULE_NAME)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(EventBridgeHook, "conn")
def test_disable_rule(self, mock_conn: MagicMock):
Expand Down