diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index e0de3122b0ecf..95340d6f8aaae 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -43,6 +43,7 @@ class AthenaOperator(BaseOperator): :param query: Presto to be run on athena. (templated) :param database: Database to select. (templated) + :param catalog: Catalog to select. (templated) :param output_location: s3 path to write the query results into. (templated) :param aws_conn_id: aws connection to use :param client_request_token: Unique token created by user to avoid multiple executions of same query @@ -57,7 +58,7 @@ class AthenaOperator(BaseOperator): """ ui_color = "#44b5e2" - template_fields: Sequence[str] = ("query", "database", "output_location", "workgroup") + template_fields: Sequence[str] = ("query", "database", "output_location", "workgroup", "catalog") template_ext: Sequence[str] = (".sql",) template_fields_renderers = {"query": "sql"} @@ -76,6 +77,7 @@ def __init__( max_polling_attempts: int | None = None, log_query: bool = True, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + catalog: str = "AwsDataCatalog", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -92,6 +94,7 @@ def __init__( self.query_execution_id: str | None = None self.log_query: bool = log_query self.deferrable = deferrable + self.catalog: str = catalog @cached_property def hook(self) -> AthenaHook: @@ -101,6 +104,7 @@ def hook(self) -> AthenaHook: def execute(self, context: Context) -> str | None: """Run Presto Query on Athena.""" self.query_execution_context["Database"] = self.database + self.query_execution_context["Catalog"] = self.catalog self.result_configuration["OutputLocation"] = self.output_location self.query_execution_id = self.hook.run_query( self.query, diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 3eb6349850799..5208c7c6832d6 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -36,12 +36,13 @@ "task_id": "test_athena_operator", "query": "SELECT * FROM TEST_TABLE", "database": "TEST_DATABASE", + "catalog": "AwsDataCatalog", "outputLocation": "s3://test_s3_bucket/", "client_request_token": "eac427d0-1c6d-4dfb-96aa-2835d3ac6595", "workgroup": "primary", } -query_context = {"Database": MOCK_DATA["database"]} +query_context = {"Database": MOCK_DATA["database"], "Catalog": MOCK_DATA["catalog"]} result_configuration = {"OutputLocation": MOCK_DATA["outputLocation"]} @@ -69,10 +70,27 @@ def test_init(self): assert self.athena.task_id == MOCK_DATA["task_id"] assert self.athena.query == MOCK_DATA["query"] assert self.athena.database == MOCK_DATA["database"] + assert self.athena.catalog == MOCK_DATA["catalog"] assert self.athena.aws_conn_id == "aws_default" assert self.athena.client_request_token == MOCK_DATA["client_request_token"] assert self.athena.sleep_time == 0 + @mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",)) + @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, "get_conn") + def test_hook_run_override_catalog(self, mock_conn, mock_run_query, mock_check_query_status): + query_context_catalog = {"Database": MOCK_DATA["database"], "Catalog": "MyCatalog"} + self.athena.catalog = "MyCatalog" + self.athena.execute({}) + mock_run_query.assert_called_once_with( + MOCK_DATA["query"], + query_context_catalog, + result_configuration, + MOCK_DATA["client_request_token"], + MOCK_DATA["workgroup"], + ) + assert mock_check_query_status.call_count == 1 + @mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",)) @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) @mock.patch.object(AthenaHook, "get_conn")