diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index 126c585e3fcfb..b454ad76ec440 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -17,11 +17,11 @@ # under the License. from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from mypy_boto3_redshift_data.type_defs import GetStatementResultResponseTypeDef @@ -29,7 +29,7 @@ from airflow.utils.context import Context -class RedshiftDataOperator(BaseOperator): +class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): """ Executes SQL Statements against an Amazon Redshift cluster using Redshift Data. @@ -49,22 +49,29 @@ class RedshiftDataOperator(BaseOperator): :param poll_interval: how often in seconds to check the query status :param return_sql_result: if True will return the result of an SQL statement, if False (default) will return statement ID - :param aws_conn_id: aws connection to use - :param region: aws region to use :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields = ( + aws_hook_class = RedshiftDataHook + template_fields = aws_template_fields( "cluster_identifier", "database", "sql", "db_user", "parameters", "statement_name", - "aws_conn_id", - "region", "workgroup_name", ) template_ext = (".sql",) @@ -84,8 +91,6 @@ def __init__( wait_for_completion: bool = True, poll_interval: int = 10, return_sql_result: bool = False, - aws_conn_id: str = "aws_default", - region: str | None = None, workgroup_name: str | None = None, **kwargs, ) -> None: @@ -108,15 +113,8 @@ def __init__( poll_interval, ) self.return_sql_result = return_sql_result - self.aws_conn_id = aws_conn_id - self.region = region self.statement_id: str | None = None - @cached_property - def hook(self) -> RedshiftDataHook: - """Create and return an RedshiftDataHook.""" - return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region) - def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: """Execute a statement against Amazon Redshift.""" self.log.info("Executing statement: %s", self.sql) diff --git a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst index 782d0cf6992fb..0b314d34f3193 100644 --- a/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst +++ b/docs/apache-airflow-providers-amazon/operators/redshift/redshift_data.rst @@ -29,6 +29,11 @@ Prerequisite Tasks .. include:: ../../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index e5a851fe737e5..4b921b71423b0 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -19,6 +19,9 @@ from unittest import mock +import pytest + +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator CONN_ID = "aws_conn_test" @@ -29,6 +32,33 @@ class TestRedshiftDataOperator: + def test_init(self): + op = RedshiftDataOperator( + task_id="fake_task_id", + database="fake-db", + sql="SELECT 1", + aws_conn_id="fake-conn-id", + region_name="eu-central-1", + verify="/spam/egg.pem", + botocore_config={"read_timeout": 42}, + ) + with pytest.warns(AirflowProviderDeprecationWarning): + # Check deprecated region argument + assert op.region == "eu-central-1" + assert op.hook.client_type == "redshift-data" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-central-1" + assert op.hook._verify == "/spam/egg.pem" + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = RedshiftDataOperator(task_id="fake_task_id", database="fake-db", sql="SELECT 1") + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") def test_execute(self, mock_exec_query): cluster_identifier = "cluster_identifier"