Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions airflow/providers/amazon/aws/hooks/glue_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""This module contains AWS Glue Catalog Hook."""
from __future__ import annotations

from typing import Any

from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
Expand All @@ -42,6 +44,45 @@ class GlueCatalogHook(AwsBaseHook):
def __init__(self, *args, **kwargs):
super().__init__(client_type="glue", *args, **kwargs)

async def async_get_partitions(
self,
client: Any,
database_name: str,
table_name: str,
expression: str = "",
page_size: int | None = None,
max_items: int | None = 1,
) -> set[tuple]:
"""
Asynchronously retrieves the partition values for a table.

:param database_name: The name of the catalog database where the partitions reside.
:param table_name: The name of the partitions' table.
:param expression: An expression filtering the partitions to be returned.
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions
:param page_size: pagination size
:param max_items: maximum items to return
:return: set of partition values where each value is a tuple since
a partition may be composed of multiple columns. For example:
``{('2018-01-01','1'), ('2018-01-01','2')}``
"""
config = {
"PageSize": page_size,
"MaxItems": max_items,
}

paginator = client.get_paginator("get_partitions")
partitions = set()

async for page in paginator.paginate(
DatabaseName=database_name, TableName=table_name, Expression=expression, PaginationConfig=config
):
for partition in page["Partitions"]:
partitions.add(tuple(partition["Values"]))

return partitions

def get_partitions(
self,
database_name: str,
Expand Down
32 changes: 31 additions & 1 deletion airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
# under the License.
from __future__ import annotations

from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from deprecated import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand All @@ -48,6 +52,9 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
:param database_name: The name of the catalog database where the partitions reside.
:param poke_interval: Time in seconds that the job should wait in
between each tries
:param deferrable: If true, then the sensor will wait asynchronously for the partition to
show up in the AWS Glue Catalog.
(default: False, but can be overridden in config file by setting default_deferrable to True)
"""

template_fields: Sequence[str] = (
Expand All @@ -66,6 +73,7 @@ def __init__(
region_name: str | None = None,
database_name: str = "default",
poke_interval: int = 60 * 3,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(poke_interval=poke_interval, **kwargs)
Expand All @@ -74,6 +82,23 @@ def __init__(
self.table_name = table_name
self.expression = expression
self.database_name = database_name
self.deferrable = deferrable

def execute(self, context: Context) -> Any:
if self.deferrable:
self.defer(
trigger=GlueCatalogPartitionTrigger(
database_name=self.database_name,
table_name=self.table_name,
expression=self.expression,
aws_conn_id=self.aws_conn_id,
waiter_delay=int(self.poke_interval),
),
method_name="execute_complete",
timeout=timedelta(seconds=self.timeout),
)
else:
super().execute(context=context)

def poke(self, context: Context):
"""Checks for existence of the partition in the AWS Glue Catalog table."""
Expand All @@ -85,6 +110,11 @@ def poke(self, context: Context):

return self.hook.check_for_partition(self.database_name, self.table_name, self.expression)

def execute_complete(self, context: Context, event: dict | None = None) -> None:
if event is None or event["status"] != "success":
raise AirflowException(f"Trigger error: event is {event}")
self.log.info("Partition exists in the Glue Catalog")

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> GlueCatalogHook:
"""Gets the GlueCatalogHook."""
Expand Down
83 changes: 83 additions & 0 deletions airflow/providers/amazon/aws/triggers/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

from __future__ import annotations

import asyncio
from functools import cached_property
from typing import Any, AsyncIterator

from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand Down Expand Up @@ -65,3 +68,83 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
hook = GlueJobHook(aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval)
await hook.async_job_completion(self.job_name, self.run_id, self.verbose)
yield TriggerEvent({"status": "success", "message": "Job done", "value": self.run_id})


class GlueCatalogPartitionTrigger(BaseTrigger):
"""
Asynchronously waits for a partition to show up in AWS Glue Catalog.

:param database_name: The name of the catalog database where the partitions reside.
:param table_name: The name of the table to wait for, supports the dot
notation (my_database.my_table)
:param expression: The partition clause to wait for. This is passed as
is to the AWS Glue Catalog API's get_partitions function,
and supports SQL like notation as in ``ds='2015-01-01'
AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``.
See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html
#aws-glue-api-catalog-partitions-GetPartitions
:param aws_conn_id: ID of the Airflow connection where
credentials and extra configuration are stored
:param region_name: Optional aws region name (example: us-east-1). Uses region from connection
if not specified.
:param waiter_delay: Number of seconds to wait between two checks. Default is 60 seconds.
"""

def __init__(
self,
database_name: str,
table_name: str,
expression: str = "",
aws_conn_id: str = "aws_default",
region_name: str | None = None,
waiter_delay: int = 60,
):
self.database_name = database_name
self.table_name = table_name
self.expression = expression
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.waiter_delay = waiter_delay

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
# dynamically generate the fully qualified name of the class
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"database_name": self.database_name,
"table_name": self.table_name,
"expression": self.expression,
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"waiter_delay": self.waiter_delay,
},
)

@cached_property
def hook(self) -> GlueCatalogHook:
return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

async def poke(self, client: Any) -> bool:
if "." in self.table_name:
self.database_name, self.table_name = self.table_name.split(".")
self.log.info(
"Poking for table %s. %s, expression %s", self.database_name, self.table_name, self.expression
)
partitions = await self.hook.async_get_partitions(
client=client,
database_name=self.database_name,
table_name=self.table_name,
expression=self.expression,
)

return bool(partitions)

async def run(self) -> AsyncIterator[TriggerEvent]:
async with self.hook.async_conn as client:
while True:
result = await self.poke(client=client)
if result:
yield TriggerEvent({"status": "success"})
break
else:
await asyncio.sleep(self.waiter_delay)
19 changes: 19 additions & 0 deletions tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

from unittest import mock

import pytest
from moto import mock_glue

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.providers.amazon.aws.sensors.glue_catalog_partition import GlueCatalogPartitionSensor

Expand Down Expand Up @@ -93,3 +95,20 @@ def test_dot_notation(self, mock_check_for_partition):
op.poke({})

mock_check_for_partition.assert_called_once_with("my_db", "my_tbl", "ds='{{ ds }}'")

def test_deferrable_mode_raises_task_deferred(self):
op = GlueCatalogPartitionSensor(task_id=self.task_id, table_name="tbl", deferrable=True)
with pytest.raises(TaskDeferred):
op.execute({})

def test_execute_complete_fails_if_status_is_not_success(self):
op = GlueCatalogPartitionSensor(task_id=self.task_id, table_name="tbl", deferrable=True)
event = {"status": "FAILED"}
with pytest.raises(AirflowException):
op.execute_complete(context={}, event=event)

def test_execute_complete_succeeds_if_status_is_success(self, caplog):
op = GlueCatalogPartitionSensor(task_id=self.task_id, table_name="tbl", deferrable=True)
event = {"status": "success"}
op.execute_complete(context={}, event=event)
assert "Partition exists in the Glue Catalog" in caplog.messages
40 changes: 39 additions & 1 deletion tests/providers/amazon/aws/triggers/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger, GlueJobCompleteTrigger
from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger


Expand Down Expand Up @@ -88,3 +89,40 @@ def test_serialize_recreate(self):

assert class_path == class_path2
assert args == args2


class TestGlueCatalogPartitionSensorTrigger:
def test_serialize_recreate(self):
trigger = GlueCatalogPartitionTrigger(
database_name="my_database",
table_name="my_table",
expression="my_expression",
aws_conn_id="my_conn_id",
)

class_path, args = trigger.serialize()

class_name = class_path.split(".")[-1]
clazz = globals()[class_name]
instance = clazz(**args)

class_path2, args2 = instance.serialize()

assert class_path == class_path2
assert args == args2

@pytest.mark.asyncio
@mock.patch.object(GlueCatalogHook, "async_get_partitions")
async def test_poke(self, mock_async_get_partitions):
a_mock = mock.AsyncMock()
a_mock.return_value = True
mock_async_get_partitions.return_value = a_mock
trigger = GlueCatalogPartitionTrigger(
database_name="my_database",
table_name="my_table",
expression="my_expression",
aws_conn_id="my_conn_id",
)
response = await trigger.poke(client=mock.MagicMock())

assert response is True