diff --git a/airflow/providers/amazon/aws/example_dags/example_athena.py b/airflow/providers/amazon/aws/example_dags/example_athena.py index 3ae6e91e44534..80d30c2edd1bd 100644 --- a/airflow/providers/amazon/aws/example_dags/example_athena.py +++ b/airflow/providers/amazon/aws/example_dags/example_athena.py @@ -21,7 +21,7 @@ from airflow import DAG from airflow.decorators import task from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator +from airflow.providers.amazon.aws.operators.athena import AthenaOperator from airflow.providers.amazon.aws.sensors.athena import AthenaSensor # [START howto_operator_athena_env_variables] @@ -91,7 +91,7 @@ def read_results_from_s3(query_execution_id): # Using a task-decorated function to create a CSV file in S3 add_sample_data_to_s3 = add_sample_data_to_s3() - create_table = AWSAthenaOperator( + create_table = AthenaOperator( task_id='setup__create_table', query=QUERY_CREATE_TABLE, database=ATHENA_DATABASE, @@ -100,7 +100,7 @@ def read_results_from_s3(query_execution_id): max_tries=None, ) - read_table = AWSAthenaOperator( + read_table = AthenaOperator( task_id='query__read_table', query=QUERY_READ_TABLE, database=ATHENA_DATABASE, @@ -119,7 +119,7 @@ def read_results_from_s3(query_execution_id): # Using a task-decorated function to read the results from S3 read_results_from_s3 = read_results_from_s3(read_table.output) - drop_table = AWSAthenaOperator( + drop_table = AthenaOperator( task_id='teardown__drop_table', query=QUERY_DROP_TABLE, database=ATHENA_DATABASE, diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py index 9bb58fd7345c5..c7a91a5870c4c 100644 --- a/airflow/providers/amazon/aws/hooks/athena.py +++ b/airflow/providers/amazon/aws/hooks/athena.py @@ -17,6 +17,7 @@ # under the License. """This module contains AWS Athena hook""" +import warnings from time import sleep from typing import Any, Dict, Optional @@ -25,7 +26,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -class AWSAthenaHook(AwsBaseHook): +class AthenaHook(AwsBaseHook): """ Interact with AWS Athena to run, poll queries and return query results @@ -260,3 +261,18 @@ def stop_query(self, query_execution_id: str) -> Dict: :return: dict """ return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id) + + +class AWSAthenaHook(AthenaHook): + """ + This hook is deprecated. + Please use :class:`airflow.providers.amazon.aws.hooks.athena.AthenaHook`. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.athena.AthenaHook`.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 19b96398889c8..dabf5e6c0dfb5 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -17,6 +17,7 @@ # under the License. # import sys +import warnings from typing import Any, Dict, Optional from uuid import uuid4 @@ -26,16 +27,16 @@ from cached_property import cached_property from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook +from airflow.providers.amazon.aws.hooks.athena import AthenaHook -class AWSAthenaOperator(BaseOperator): +class AthenaOperator(BaseOperator): """ An operator that submits a presto query to athena. .. seealso:: For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AWSAthenaOperator` + :ref:`howto/operator:AthenaOperator` :param query: Presto to be run on athena. (templated) :type query: str @@ -93,9 +94,9 @@ def __init__( self.query_execution_id = None # type: Optional[str] @cached_property - def hook(self) -> AWSAthenaHook: - """Create and return an AWSAthenaHook.""" - return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) + def hook(self) -> AthenaHook: + """Create and return an AthenaHook.""" + return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) def execute(self, context: dict) -> Optional[str]: """Run Presto Query on Athena""" @@ -110,13 +111,13 @@ def execute(self, context: dict) -> Optional[str]: ) query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries) - if query_status in AWSAthenaHook.FAILURE_STATES: + if query_status in AthenaHook.FAILURE_STATES: error_message = self.hook.get_state_change_reason(self.query_execution_id) raise Exception( f'Final state of Athena job is {query_status}, query_execution_id is ' f'{self.query_execution_id}. Error: {error_message}' ) - elif not query_status or query_status in AWSAthenaHook.INTERMEDIATE_STATES: + elif not query_status or query_status in AthenaHook.INTERMEDIATE_STATES: raise Exception( f'Final state of Athena job is {query_status}. Max tries of poll status exceeded, ' f'query_execution_id is {self.query_execution_id}.' @@ -143,3 +144,19 @@ def on_kill(self) -> None: 'Polling Athena for query with id %s to reach final state', self.query_execution_id ) self.hook.poll_query_status(self.query_execution_id) + + +class AWSAthenaOperator(AthenaOperator): + """ + This operator is deprecated. + Please use :class:`airflow.providers.amazon.aws.operators.athena.AthenaOperator`. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + "This operator is deprecated. Please use " + "`airflow.providers.amazon.aws.operators.athena.AthenaOperator`.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index 232bfdb5859f0..449688a996285 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -24,7 +24,7 @@ from cached_property import cached_property from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook +from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.sensors.base import BaseSensorOperator @@ -85,6 +85,6 @@ def poke(self, context: dict) -> bool: return True @cached_property - def hook(self) -> AWSAthenaHook: - """Create and return an AWSAthenaHook""" - return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) + def hook(self) -> AthenaHook: + """Create and return an AthenaHook""" + return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) diff --git a/docs/apache-airflow-providers-amazon/operators/athena.rst b/docs/apache-airflow-providers-amazon/operators/athena.rst index 8f150a9065aaf..47e62d160e380 100644 --- a/docs/apache-airflow-providers-amazon/operators/athena.rst +++ b/docs/apache-airflow-providers-amazon/operators/athena.rst @@ -16,7 +16,7 @@ under the License. -.. _howto/operator:AWSAthenaOperator: +.. _howto/operator:AthenaOperator: Amazon Athena Operator ====================== @@ -33,7 +33,7 @@ Prerequisite Tasks Using Operator -------------- Use the -:class:`~airflow.providers.amazon.aws.operators.athena.AWSAthenaOperator` +:class:`~airflow.providers.amazon.aws.operators.athena.AthenaOperator` to run a query in Amazon Athena. To get started with Amazon Athena please visit `aws.amazon.com/athena `_ diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py index a3cf521383f2c..5b73b0f244936 100644 --- a/tests/providers/amazon/aws/hooks/test_athena.py +++ b/tests/providers/amazon/aws/hooks/test_athena.py @@ -18,7 +18,7 @@ import unittest from unittest import mock -from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook +from airflow.providers.amazon.aws.hooks.athena import AthenaHook MOCK_DATA = { 'query': 'SELECT * FROM TEST_TABLE', @@ -47,15 +47,15 @@ } -class TestAWSAthenaHook(unittest.TestCase): +class TestAthenaHook(unittest.TestCase): def setUp(self): - self.athena = AWSAthenaHook(sleep_time=0) + self.athena = AthenaHook(sleep_time=0) def test_init(self): assert self.athena.aws_conn_id == 'aws_default' assert self.athena.sleep_time == 0 - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_query_without_token(self, mock_conn): mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION result = self.athena.run_query( @@ -72,7 +72,7 @@ def test_hook_run_query_without_token(self, mock_conn): mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params) assert result == MOCK_DATA['query_execution_id'] - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_query_with_token(self, mock_conn): mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION result = self.athena.run_query( @@ -91,20 +91,20 @@ def test_hook_run_query_with_token(self, mock_conn): mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params) assert result == MOCK_DATA['query_execution_id'] - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION result = self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id']) assert result is None - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_get_query_results_with_default_params(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id']) expected_call_params = {'QueryExecutionId': MOCK_DATA['query_execution_id'], 'MaxResults': 1000} mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_get_query_results_with_next_token(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION self.athena.get_query_results( @@ -117,13 +117,13 @@ def test_hook_get_query_results_with_next_token(self, mock_conn): } mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_get_paginator_with_non_succeeded_query(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION result = self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id']) assert result is None - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_get_paginator_with_default_params(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id']) @@ -133,7 +133,7 @@ def test_hook_get_paginator_with_default_params(self, mock_conn): } mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_get_paginator_with_pagination_config(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION self.athena.get_query_results_paginator( @@ -152,14 +152,14 @@ def test_hook_get_paginator_with_pagination_config(self, mock_conn): } mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_poll_query_when_final(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION result = self.athena.poll_query_status(query_execution_id=MOCK_DATA['query_execution_id']) mock_conn.return_value.get_query_execution.assert_called_once() assert result == 'SUCCEEDED' - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_poll_query_with_timeout(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION result = self.athena.poll_query_status( @@ -168,7 +168,7 @@ def test_hook_poll_query_with_timeout(self, mock_conn): mock_conn.return_value.get_query_execution.assert_called_once() assert result == 'RUNNING' - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_get_output_location(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT result = self.athena.get_output_location(query_execution_id=MOCK_DATA['query_execution_id']) diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 97ec4d205065d..060bfb40c182a 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -21,8 +21,8 @@ import pytest from airflow.models import DAG, DagRun, TaskInstance -from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook -from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator +from airflow.providers.amazon.aws.hooks.athena import AthenaHook +from airflow.providers.amazon.aws.operators.athena import AthenaOperator from airflow.utils import timezone from airflow.utils.timezone import datetime @@ -31,7 +31,7 @@ ATHENA_QUERY_ID = 'eac29bf8-daa1-4ffc-b19a-0db31dc3b784' MOCK_DATA = { - 'task_id': 'test_aws_athena_operator', + 'task_id': 'test_athena_operator', 'query': 'SELECT * FROM TEST_TABLE', 'database': 'TEST_DATABASE', 'outputLocation': 's3://test_s3_bucket/', @@ -43,7 +43,7 @@ result_configuration = {'OutputLocation': MOCK_DATA['outputLocation']} -class TestAWSAthenaOperator(unittest.TestCase): +class TestAthenaOperator(unittest.TestCase): def setUp(self): args = { 'owner': 'airflow', @@ -51,8 +51,8 @@ def setUp(self): } self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once') - self.athena = AWSAthenaOperator( - task_id='test_aws_athena_operator', + self.athena = AthenaOperator( + task_id='test_athena_operator', query='SELECT * FROM TEST_TABLE', database='TEST_DATABASE', output_location='s3://test_s3_bucket/', @@ -72,11 +72,11 @@ def test_init(self): assert self.athena.hook.sleep_time == 0 - @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) - @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'check_query_status', side_effect=("SUCCESS",)) + @mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status): - self.athena.execute(None) + self.athena.execute({}) mock_run_query.assert_called_once_with( MOCK_DATA['query'], query_context, @@ -87,7 +87,7 @@ def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_chec assert mock_check_query_status.call_count == 1 @mock.patch.object( - AWSAthenaHook, + AthenaHook, 'check_query_status', side_effect=( "RUNNING", @@ -95,10 +95,10 @@ def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_chec "SUCCESS", ), ) - @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_query_status): - self.athena.execute(None) + self.athena.execute({}) mock_run_query.assert_called_once_with( MOCK_DATA['query'], query_context, @@ -109,18 +109,18 @@ def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_ assert mock_check_query_status.call_count == 3 @mock.patch.object( - AWSAthenaHook, + AthenaHook, 'check_query_status', side_effect=( None, None, ), ) - @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_check_query_status): with pytest.raises(Exception): - self.athena.execute(None) + self.athena.execute({}) mock_run_query.assert_called_once_with( MOCK_DATA['query'], query_context, @@ -130,22 +130,22 @@ def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_c ) assert mock_check_query_status.call_count == 3 - @mock.patch.object(AWSAthenaHook, 'get_state_change_reason') + @mock.patch.object(AthenaHook, 'get_state_change_reason') @mock.patch.object( - AWSAthenaHook, + AthenaHook, 'check_query_status', side_effect=( "RUNNING", "FAILED", ), ) - @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_failure_query( self, mock_conn, mock_run_query, mock_check_query_status, mock_get_state_change_reason ): with pytest.raises(Exception): - self.athena.execute(None) + self.athena.execute({}) mock_run_query.assert_called_once_with( MOCK_DATA['query'], query_context, @@ -157,7 +157,7 @@ def test_hook_run_failure_query( assert mock_get_state_change_reason.call_count == 1 @mock.patch.object( - AWSAthenaHook, + AthenaHook, 'check_query_status', side_effect=( "RUNNING", @@ -165,11 +165,11 @@ def test_hook_run_failure_query( "CANCELLED", ), ) - @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status): with pytest.raises(Exception): - self.athena.execute(None) + self.athena.execute({}) mock_run_query.assert_called_once_with( MOCK_DATA['query'], query_context, @@ -180,7 +180,7 @@ def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_qu assert mock_check_query_status.call_count == 3 @mock.patch.object( - AWSAthenaHook, + AthenaHook, 'check_query_status', side_effect=( "RUNNING", @@ -188,11 +188,11 @@ def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_qu "RUNNING", ), ) - @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, 'get_conn') def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, mock_check_query_status): with pytest.raises(Exception): - self.athena.execute(None) + self.athena.execute({}) mock_run_query.assert_called_once_with( MOCK_DATA['query'], query_context, @@ -202,9 +202,9 @@ def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, m ) assert mock_check_query_status.call_count == 3 - @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) - @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) - @mock.patch.object(AWSAthenaHook, 'get_conn') + @mock.patch.object(AthenaHook, 'check_query_status', side_effect=("SUCCESS",)) + @mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) + @mock.patch.object(AthenaHook, 'get_conn') def test_return_value(self, mock_conn, mock_run_query, mock_check_query_status): """Test we return the right value -- that will get put in to XCom by the execution engine""" dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=timezone.utcnow(), run_id="test") diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index 781f94cdb190a..1ef99753700d0 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -22,7 +22,7 @@ import pytest from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook +from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.sensors.athena import AthenaSensor @@ -36,26 +36,26 @@ def setUp(self): aws_conn_id='aws_default', ) - @mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("SUCCEEDED",)) + @mock.patch.object(AthenaHook, 'poll_query_status', side_effect=("SUCCEEDED",)) def test_poke_success(self, mock_poll_query_status): - assert self.sensor.poke(None) + assert self.sensor.poke({}) - @mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("RUNNING",)) + @mock.patch.object(AthenaHook, 'poll_query_status', side_effect=("RUNNING",)) def test_poke_running(self, mock_poll_query_status): - assert not self.sensor.poke(None) + assert not self.sensor.poke({}) - @mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("QUEUED",)) + @mock.patch.object(AthenaHook, 'poll_query_status', side_effect=("QUEUED",)) def test_poke_queued(self, mock_poll_query_status): - assert not self.sensor.poke(None) + assert not self.sensor.poke({}) - @mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("FAILED",)) + @mock.patch.object(AthenaHook, 'poll_query_status', side_effect=("FAILED",)) def test_poke_failed(self, mock_poll_query_status): with pytest.raises(AirflowException) as ctx: - self.sensor.poke(None) + self.sensor.poke({}) assert 'Athena sensor failed' in str(ctx.value) - @mock.patch.object(AWSAthenaHook, 'poll_query_status', side_effect=("CANCELLED",)) + @mock.patch.object(AthenaHook, 'poll_query_status', side_effect=("CANCELLED",)) def test_poke_cancelled(self, mock_poll_query_status): with pytest.raises(AirflowException) as ctx: - self.sensor.poke(None) + self.sensor.poke({}) assert 'Athena sensor failed' in str(ctx.value)