Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 66 additions & 17 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@

if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa
from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef

FINISHED_STATE = "FINISHED"
FAILED_STATE = "FAILED"
ABORTED_STATE = "ABORTED"
FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}


class RedshiftDataQueryFailedError(ValueError):
"""Raise an error that redshift data query failed."""


class RedshiftDataQueryAbortedError(ValueError):
"""Raise an error that redshift data query was aborted."""


class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
Expand Down Expand Up @@ -108,27 +123,40 @@ def execute_query(

return statement_id

def wait_for_results(self, statement_id, poll_interval):
def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
while True:
self.log.info("Polling statement %s", statement_id)
resp = self.conn.describe_statement(
Id=statement_id,
)
status = resp["Status"]
if status == "FINISHED":
num_rows = resp.get("ResultRows")
if num_rows is not None:
self.log.info("Processed %s rows", num_rows)
return status
elif status in ("FAILED", "ABORTED"):
raise ValueError(
f"Statement {statement_id!r} terminated with status {status}. "
f"Response details: {pformat(resp)}"
)
else:
self.log.info("Query %s", status)
is_finished = self.check_query_is_finished(statement_id)
if is_finished:
return FINISHED_STATE

time.sleep(poll_interval)

def check_query_is_finished(self, statement_id: str) -> bool:
"""Check whether query finished, raise exception is failed."""
resp = self.conn.describe_statement(Id=statement_id)
return self.parse_statement_resposne(resp)

def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool:
"""Parse the response of describe_statement."""
status = resp["Status"]
if status == FINISHED_STATE:
num_rows = resp.get("ResultRows")
if num_rows is not None:
self.log.info("Processed %s rows", num_rows)
return True
elif status in FAILURE_STATES:
exception_cls = (
RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
)
raise exception_cls(
f"Statement {resp['Id']} terminated with status {status}. "
f"Response details: {pformat(resp)}"
)

self.log.info("Query status: %s", status)
return False

def get_table_primary_key(
self,
table: str,
Expand Down Expand Up @@ -201,3 +229,24 @@ def get_table_primary_key(
break

return pk_columns or None

async def is_still_running(self, statement_id: str) -> bool:
"""Async function to check whether the query is still running.

:param statement_id: the UUID of the statement
"""
async with self.async_conn as client:
desc = await client.describe_statement(Id=statement_id)
return desc["Status"] in RUNNING_STATES

async def check_query_is_finished_async(self, statement_id: str) -> bool:
"""Async function to check statement is finished.

It takes statement_id, makes async connection to redshift data to get the query status
by statement_id and returns the query status.

:param statement_id: the UUID of the statement
"""
async with self.async_conn as client:
resp = await client.describe_statement(Id=statement_id)
return self.parse_statement_resposne(resp)
55 changes: 53 additions & 2 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
Expand Down Expand Up @@ -92,6 +95,7 @@ def __init__(
poll_interval: int = 10,
return_sql_result: bool = False,
workgroup_name: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -114,11 +118,17 @@ def __init__(
)
self.return_sql_result = return_sql_result
self.statement_id: str | None = None
self.deferrable = deferrable

def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
"""Execute a statement against Amazon Redshift."""
self.log.info("Executing statement: %s", self.sql)

# Set wait_for_completion to False so that it waits for the status in the deferred task.
wait_for_completion = self.wait_for_completion
if self.deferrable and self.wait_for_completion:
self.wait_for_completion = False

self.statement_id = self.hook.execute_query(
database=self.database,
sql=self.sql,
Expand All @@ -129,17 +139,58 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
secret_arn=self.secret_arn,
statement_name=self.statement_name,
with_event=self.with_event,
wait_for_completion=self.wait_for_completion,
wait_for_completion=wait_for_completion,
poll_interval=self.poll_interval,
)

if self.deferrable:
is_finished = self.hook.check_query_is_finished(self.statement_id)
if not is_finished:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftDataTrigger(
statement_id=self.statement_id,
task_id=self.task_id,
poll_interval=self.poll_interval,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)

if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=self.statement_id)
self.log.debug("Statement result: %s", result)
return result
else:
return self.statement_id

def execute_complete(
self, context: Context, event: dict[str, Any] | None = None
) -> GetStatementResultResponseTypeDef | str:
if event is None:
err_msg = "Trigger error: event is None"
self.log.info(err_msg)
raise AirflowException(err_msg)

if event["status"] == "error":
msg = f"context: {context}, error message: {event['message']}"
raise AirflowException(msg)

statement_id = event["statement_id"]
if not statement_id:
raise AirflowException("statement_id should not be empty.")

self.log.info("%s completed successfully.", self.task_id)
if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=statement_id)
self.log.debug("Statement result: %s", result)
return result

return statement_id

def on_kill(self) -> None:
"""Cancel the submitted redshift query."""
if self.statement_id:
Expand Down
113 changes: 113 additions & 0 deletions airflow/providers/amazon/aws/triggers/redshift_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import asyncio
from functools import cached_property
from typing import Any, AsyncIterator

from airflow.providers.amazon.aws.hooks.redshift_data import (
ABORTED_STATE,
FAILED_STATE,
RedshiftDataHook,
RedshiftDataQueryAbortedError,
RedshiftDataQueryFailedError,
)
from airflow.triggers.base import BaseTrigger, TriggerEvent


class RedshiftDataTrigger(BaseTrigger):
"""
RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer.

:param statement_id: the UUID of the statement
:param task_id: task ID of the Dag
:param poll_interval: polling period in seconds to check for the status
:param aws_conn_id: AWS connection ID for redshift
:param region_name: aws region to use
"""

def __init__(
self,
statement_id: str,
task_id: str,
poll_interval: int,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
):
super().__init__()
self.statement_id = statement_id
self.task_id = task_id
self.poll_interval = poll_interval

self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftDataTrigger arguments and classpath."""
return (
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
{
"statement_id": self.statement_id,
"task_id": self.task_id,
"aws_conn_id": self.aws_conn_id,
"poll_interval": self.poll_interval,
"region_name": self.region_name,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@cached_property
def hook(self) -> RedshiftDataHook:
return RedshiftDataHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
while await self.hook.is_still_running(self.statement_id):
await asyncio.sleep(self.poll_interval)

is_finished = await self.hook.check_query_is_finished_async(self.statement_id)
if is_finished:
response = {"status": "success", "statement_id": self.statement_id}
else:
response = {
"status": "error",
"statement_id": self.statement_id,
"message": f"{self.task_id} failed",
}
yield TriggerEvent(response)
except (RedshiftDataQueryFailedError, RedshiftDataQueryAbortedError) as error:
response = {
"status": "error",
"statement_id": self.statement_id,
"message": str(error),
"type": FAILED_STATE if isinstance(error, RedshiftDataQueryFailedError) else ABORTED_STATE,
}
yield TriggerEvent(response)
except Exception as error:
yield TriggerEvent({"status": "error", "statement_id": self.statement_id, "message": str(error)})
1 change: 1 addition & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ triggers:
- integration-name: Amazon Redshift
python-modules:
- airflow.providers.amazon.aws.triggers.redshift_cluster
- airflow.providers.amazon.aws.triggers.redshift_data
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.triggers.sagemaker
Expand Down
61 changes: 60 additions & 1 deletion tests/providers/amazon/aws/hooks/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

import pytest

from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from airflow.providers.amazon.aws.hooks.redshift_data import (
RedshiftDataHook,
RedshiftDataQueryAbortedError,
RedshiftDataQueryFailedError,
)

SQL = "sql"
DATABASE = "database"
Expand Down Expand Up @@ -292,3 +296,58 @@ def test_result_num_rows(self, mock_conn, caplog):
wait_for_completion=True,
)
assert "Processed " not in caplog.text

@pytest.mark.asyncio
@pytest.mark.parametrize(
"describe_statement_response, expected_result",
[
({"Status": "PICKED"}, True),
({"Status": "STARTED"}, True),
({"Status": "SUBMITTED"}, True),
({"Status": "FINISHED"}, False),
({"Status": "FAILED"}, False),
({"Status": "ABORTED"}, False),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
async def test_is_still_running(self, mock_conn, describe_statement_response, expected_result):
hook = RedshiftDataHook()
mock_conn.__aenter__.return_value.describe_statement.return_value = describe_statement_response
response = await hook.is_still_running("uuid")
assert response == expected_result

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
async def test_check_query_is_finished_async(self, mock_is_still_running, mock_conn):
hook = RedshiftDataHook()
mock_is_still_running.return_value = False
mock_conn.describe_statement = mock.AsyncMock()
mock_conn.__aenter__.return_value.describe_statement.return_value = {
"Id": "uuid",
"Status": "FINISHED",
}
is_finished = await hook.check_query_is_finished_async(statement_id="uuid")
assert is_finished is True

@pytest.mark.asyncio
@pytest.mark.parametrize(
"describe_statement_response, expected_exception",
(
(
{"Id": "uuid", "Status": "FAILED", "QueryString": "select 1", "Error": "Test error"},
RedshiftDataQueryFailedError,
),
({"Id": "uuid", "Status": "ABORTED"}, RedshiftDataQueryAbortedError),
),
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
async def test_check_query_is_finished_async_exception(
self, mock_is_still_running, mock_conn, describe_statement_response, expected_exception
):
hook = RedshiftDataHook()
mock_is_still_running.return_value = False
mock_conn.__aenter__.return_value.describe_statement.return_value = describe_statement_response
with pytest.raises(expected_exception):
await hook.check_query_is_finished_async(statement_id="uuid")
Loading