From ad5b29bd3ded1f005f61a4e37e1d2f2ed32b85e9 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Tue, 8 Aug 2023 15:56:51 -0700 Subject: [PATCH 1/2] Add Deferrable mode to GlueCatalogPartitionSensor --- .../amazon/aws/hooks/glue_catalog.py | 41 +++++++++ .../aws/sensors/glue_catalog_partition.py | 32 ++++++- airflow/providers/amazon/aws/triggers/glue.py | 83 +++++++++++++++++++ .../sensors/test_glue_catalog_partition.py | 19 +++++ .../amazon/aws/triggers/test_glue.py | 40 ++++++++- 5 files changed, 213 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py index a5d434879e872..9a2d47b19e045 100644 --- a/airflow/providers/amazon/aws/hooks/glue_catalog.py +++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py @@ -18,6 +18,8 @@ """This module contains AWS Glue Catalog Hook.""" from __future__ import annotations +from typing import Any + from botocore.exceptions import ClientError from airflow.exceptions import AirflowException @@ -42,6 +44,45 @@ class GlueCatalogHook(AwsBaseHook): def __init__(self, *args, **kwargs): super().__init__(client_type="glue", *args, **kwargs) + async def async_get_partitions( + self, + client: Any, + database_name: str, + table_name: str, + expression: str = "", + page_size: int | None = None, + max_items: int | None = 1, + ) -> set[tuple]: + """ + Asynschronously retrieves the partition values for a table. + + :param database_name: The name of the catalog database where the partitions reside. + :param table_name: The name of the partitions' table. + :param expression: An expression filtering the partitions to be returned. + Please see official AWS documentation for further information. + https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartitions + :param page_size: pagination size + :param max_items: maximum items to return + :return: set of partition values where each value is a tuple since + a partition may be composed of multiple columns. For example: + ``{('2018-01-01','1'), ('2018-01-01','2')}`` + """ + config = { + "PageSize": page_size, + "MaxItems": max_items, + } + + paginator = client.get_paginator("get_partitions") + partitions = set() + + async for page in paginator.paginate( + DatabaseName=database_name, TableName=table_name, Expression=expression, PaginationConfig=config + ): + for partition in page["Partitions"]: + partitions.add(tuple(partition["Values"])) + + return partitions + def get_partitions( self, database_name: str, diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py index 2f5a00208422c..ca7a1436a6bf1 100644 --- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py @@ -17,12 +17,16 @@ # under the License. from __future__ import annotations +from datetime import timedelta from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from deprecated import deprecated +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook +from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -48,6 +52,9 @@ class GlueCatalogPartitionSensor(BaseSensorOperator): :param database_name: The name of the catalog database where the partitions reside. :param poke_interval: Time in seconds that the job should wait in between each tries + :param deferrable: If true, then the sensor will wait asynchronously for the partition to + show up in the AWS Glue Catalog. + (default: False, but can be overridden in config file by setting default_deferrable to True) """ template_fields: Sequence[str] = ( @@ -66,6 +73,7 @@ def __init__( region_name: str | None = None, database_name: str = "default", poke_interval: int = 60 * 3, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(poke_interval=poke_interval, **kwargs) @@ -74,6 +82,23 @@ def __init__( self.table_name = table_name self.expression = expression self.database_name = database_name + self.deferrable = deferrable + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=GlueCatalogPartitionTrigger( + database_name=self.database_name, + table_name=self.table_name, + expression=self.expression, + aws_conn_id=self.aws_conn_id, + waiter_delay=int(self.poke_interval), + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.timeout), + ) + else: + super().execute(context=context) def poke(self, context: Context): """Checks for existence of the partition in the AWS Glue Catalog table.""" @@ -85,6 +110,11 @@ def poke(self, context: Context): return self.hook.check_for_partition(self.database_name, self.table_name, self.expression) + def execute_complete(self, context: Context, event: dict | None = None) -> None: + if event is None or event["status"] != "success": + raise AirflowException(f"Trigger error: event is {event}") + self.log.info("Partition exists in the Glue Catalog") + @deprecated(reason="use `hook` property instead.") def get_hook(self) -> GlueCatalogHook: """Gets the GlueCatalogHook.""" diff --git a/airflow/providers/amazon/aws/triggers/glue.py b/airflow/providers/amazon/aws/triggers/glue.py index 00529330a7743..b5a0a673d3e18 100644 --- a/airflow/providers/amazon/aws/triggers/glue.py +++ b/airflow/providers/amazon/aws/triggers/glue.py @@ -17,9 +17,12 @@ from __future__ import annotations +import asyncio +from functools import cached_property from typing import Any, AsyncIterator from airflow.providers.amazon.aws.hooks.glue import GlueJobHook +from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -65,3 +68,83 @@ async def run(self) -> AsyncIterator[TriggerEvent]: hook = GlueJobHook(aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval) await hook.async_job_completion(self.job_name, self.run_id, self.verbose) yield TriggerEvent({"status": "success", "message": "Job done", "value": self.run_id}) + + +class GlueCatalogPartitionTrigger(BaseTrigger): + """ + Asynchronously waits for a partition to show up in AWS Glue Catalog. + + :param database_name: The name of the catalog database where the partitions reside. + :param table_name: The name of the table to wait for, supports the dot + notation (my_database.my_table) + :param expression: The partition clause to wait for. This is passed as + is to the AWS Glue Catalog API's get_partitions function, + and supports SQL like notation as in ``ds='2015-01-01' + AND type='value'`` and comparison operators as in ``"ds>=2015-01-01"``. + See https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html + #aws-glue-api-catalog-partitions-GetPartitions + :param aws_conn_id: ID of the Airflow connection where + credentials and extra configuration are stored + :param region_name: Optional aws region name (example: us-east-1). Uses region from connection + if not specified. + :param waiter_delay: Number of seconds to wait between two checks. Default is 60 seconds. + """ + + def __init__( + self, + database_name: str, + table_name: str, + expression: str = "", + aws_conn_id: str = "aws_default", + region_name: str | None = None, + waiter_delay: int = 60, + ): + self.database_name = database_name + self.table_name = table_name + self.expression = expression + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.waiter_delay = waiter_delay + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + # dynamically generate the fully qualified name of the class + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "database_name": self.database_name, + "table_name": self.table_name, + "expression": self.expression, + "aws_conn_id": self.aws_conn_id, + "region_name": self.region_name, + "waiter_delay": self.waiter_delay, + }, + ) + + @cached_property + def hook(self) -> GlueCatalogHook: + return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + async def poke(self, client: Any) -> bool: + if "." in self.table_name: + self.database_name, self.table_name = self.table_name.split(".") + self.log.info( + "Poking for table %s. %s, expression %s", self.database_name, self.table_name, self.expression + ) + partitions = await self.hook.async_get_partitions( + client=client, + database_name=self.database_name, + table_name=self.table_name, + expression=self.expression, + ) + + return bool(partitions) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with self.hook.async_conn as client: + while True: + result = await self.poke(client=client) + if result: + yield TriggerEvent({"status": "success"}) + break + else: + await asyncio.sleep(self.waiter_delay) diff --git a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py index e5726c58a25f7..6903dd09ebc25 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py +++ b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py @@ -19,8 +19,10 @@ from unittest import mock +import pytest from moto import mock_glue +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook from airflow.providers.amazon.aws.sensors.glue_catalog_partition import GlueCatalogPartitionSensor @@ -93,3 +95,20 @@ def test_dot_notation(self, mock_check_for_partition): op.poke({}) mock_check_for_partition.assert_called_once_with("my_db", "my_tbl", "ds='{{ ds }}'") + + def test_deferrable_mode_raises_task_deferred(self): + op = GlueCatalogPartitionSensor(task_id=self.task_id, table_name="tbl", deferrable=True) + with pytest.raises(TaskDeferred): + op.execute({}) + + def test_execute_complete_fails_if_status_is_not_success(self): + op = GlueCatalogPartitionSensor(task_id=self.task_id, table_name="tbl", deferrable=True) + event = {"status": "FAILED"} + with pytest.raises(AirflowException): + op.execute_complete(context={}, event=event) + + def test_execute_complete_succeeds_if_status_is_success(self, caplog): + op = GlueCatalogPartitionSensor(task_id=self.task_id, table_name="tbl", deferrable=True) + event = {"status": "success"} + op.execute_complete(context={}, event=event) + assert "Partition exists in the Glue Catalog" in caplog.messages diff --git a/tests/providers/amazon/aws/triggers/test_glue.py b/tests/providers/amazon/aws/triggers/test_glue.py index 9c7cd61a712ee..9e6c6652b2177 100644 --- a/tests/providers/amazon/aws/triggers/test_glue.py +++ b/tests/providers/amazon/aws/triggers/test_glue.py @@ -23,7 +23,8 @@ from airflow import AirflowException from airflow.providers.amazon.aws.hooks.glue import GlueJobHook -from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger +from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook +from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger, GlueJobCompleteTrigger from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger @@ -88,3 +89,40 @@ def test_serialize_recreate(self): assert class_path == class_path2 assert args == args2 + + +class TestGlueCatalogPartitionSensorTrigger: + def test_serialize_recreate(self): + trigger = GlueCatalogPartitionTrigger( + database_name="my_database", + table_name="my_table", + expression="my_expression", + aws_conn_id="my_conn_id", + ) + + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 + + @pytest.mark.asyncio + @mock.patch.object(GlueCatalogHook, "async_get_partitions") + async def test_poke(self, mock_async_get_partitions): + a_mock = mock.AsyncMock() + a_mock.return_value = True + mock_async_get_partitions.return_value = a_mock + trigger = GlueCatalogPartitionTrigger( + database_name="my_database", + table_name="my_table", + expression="my_expression", + aws_conn_id="my_conn_id", + ) + response = await trigger.poke(client=mock.MagicMock()) + + assert response is True From c04597191a47fa87bf455d6390b2ef27559ab4d4 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Tue, 8 Aug 2023 19:14:49 -0700 Subject: [PATCH 2/2] Fix spelling mistake --- airflow/providers/amazon/aws/hooks/glue_catalog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py index 9a2d47b19e045..ba40954ec6a7a 100644 --- a/airflow/providers/amazon/aws/hooks/glue_catalog.py +++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py @@ -54,7 +54,7 @@ async def async_get_partitions( max_items: int | None = 1, ) -> set[tuple]: """ - Asynschronously retrieves the partition values for a table. + Asynchronously retrieves the partition values for a table. :param database_name: The name of the catalog database where the partitions reside. :param table_name: The name of the partitions' table.