Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/example_dags/example_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from airflow import DAG
from airflow.decorators import task
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator
from airflow.providers.amazon.aws.operators.athena import AthenaOperator
from airflow.providers.amazon.aws.sensors.athena import AthenaSensor

# [START howto_operator_athena_env_variables]
Expand Down Expand Up @@ -91,7 +91,7 @@ def read_results_from_s3(query_execution_id):
# Using a task-decorated function to create a CSV file in S3
add_sample_data_to_s3 = add_sample_data_to_s3()

create_table = AWSAthenaOperator(
create_table = AthenaOperator(
task_id='setup__create_table',
query=QUERY_CREATE_TABLE,
database=ATHENA_DATABASE,
Expand All @@ -100,7 +100,7 @@ def read_results_from_s3(query_execution_id):
max_tries=None,
)

read_table = AWSAthenaOperator(
read_table = AthenaOperator(
task_id='query__read_table',
query=QUERY_READ_TABLE,
database=ATHENA_DATABASE,
Expand All @@ -119,7 +119,7 @@ def read_results_from_s3(query_execution_id):
# Using a task-decorated function to read the results from S3
read_results_from_s3 = read_results_from_s3(read_table.output)

drop_table = AWSAthenaOperator(
drop_table = AthenaOperator(
task_id='teardown__drop_table',
query=QUERY_DROP_TABLE,
database=ATHENA_DATABASE,
Expand Down
18 changes: 17 additions & 1 deletion airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

"""This module contains AWS Athena hook"""
import warnings
from time import sleep
from typing import Any, Dict, Optional

Expand All @@ -25,7 +26,7 @@
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


class AWSAthenaHook(AwsBaseHook):
class AthenaHook(AwsBaseHook):
"""
Interact with AWS Athena to run, poll queries and return query results

Expand Down Expand Up @@ -260,3 +261,18 @@ def stop_query(self, query_execution_id: str) -> Dict:
:return: dict
"""
return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)


class AWSAthenaHook(AthenaHook):
"""
This hook is deprecated.
Please use :class:`airflow.providers.amazon.aws.hooks.athena.AthenaHook`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.athena.AthenaHook`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
33 changes: 25 additions & 8 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
#
import sys
import warnings
from typing import Any, Dict, Optional
from uuid import uuid4

Expand All @@ -26,16 +27,16 @@
from cached_property import cached_property

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook
from airflow.providers.amazon.aws.hooks.athena import AthenaHook


class AWSAthenaOperator(BaseOperator):
class AthenaOperator(BaseOperator):
"""
An operator that submits a presto query to athena.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AWSAthenaOperator`
:ref:`howto/operator:AthenaOperator`

:param query: Presto to be run on athena. (templated)
:type query: str
Expand Down Expand Up @@ -93,9 +94,9 @@ def __init__(
self.query_execution_id = None # type: Optional[str]

@cached_property
def hook(self) -> AWSAthenaHook:
"""Create and return an AWSAthenaHook."""
return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)

def execute(self, context: dict) -> Optional[str]:
"""Run Presto Query on Athena"""
Expand All @@ -110,13 +111,13 @@ def execute(self, context: dict) -> Optional[str]:
)
query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries)

if query_status in AWSAthenaHook.FAILURE_STATES:
if query_status in AthenaHook.FAILURE_STATES:
error_message = self.hook.get_state_change_reason(self.query_execution_id)
raise Exception(
f'Final state of Athena job is {query_status}, query_execution_id is '
f'{self.query_execution_id}. Error: {error_message}'
)
elif not query_status or query_status in AWSAthenaHook.INTERMEDIATE_STATES:
elif not query_status or query_status in AthenaHook.INTERMEDIATE_STATES:
raise Exception(
f'Final state of Athena job is {query_status}. Max tries of poll status exceeded, '
f'query_execution_id is {self.query_execution_id}.'
Expand All @@ -143,3 +144,19 @@ def on_kill(self) -> None:
'Polling Athena for query with id %s to reach final state', self.query_execution_id
)
self.hook.poll_query_status(self.query_execution_id)


class AWSAthenaOperator(AthenaOperator):
"""
This operator is deprecated.
Please use :class:`airflow.providers.amazon.aws.operators.athena.AthenaOperator`.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"This operator is deprecated. Please use "
"`airflow.providers.amazon.aws.operators.athena.AthenaOperator`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/sensors/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from cached_property import cached_property

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.sensors.base import BaseSensorOperator


Expand Down Expand Up @@ -85,6 +85,6 @@ def poke(self, context: dict) -> bool:
return True

@cached_property
def hook(self) -> AWSAthenaHook:
"""Create and return an AWSAthenaHook"""
return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook"""
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
4 changes: 2 additions & 2 deletions docs/apache-airflow-providers-amazon/operators/athena.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
under the License.


.. _howto/operator:AWSAthenaOperator:
.. _howto/operator:AthenaOperator:

Amazon Athena Operator
======================
Expand All @@ -33,7 +33,7 @@ Prerequisite Tasks
Using Operator
--------------
Use the
:class:`~airflow.providers.amazon.aws.operators.athena.AWSAthenaOperator`
:class:`~airflow.providers.amazon.aws.operators.athena.AthenaOperator`
to run a query in Amazon Athena. To get started with Amazon Athena please visit
`aws.amazon.com/athena <https://aws.amazon.com/athena>`_

Expand Down
28 changes: 14 additions & 14 deletions tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest
from unittest import mock

from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook
from airflow.providers.amazon.aws.hooks.athena import AthenaHook

MOCK_DATA = {
'query': 'SELECT * FROM TEST_TABLE',
Expand Down Expand Up @@ -47,15 +47,15 @@
}


class TestAWSAthenaHook(unittest.TestCase):
class TestAthenaHook(unittest.TestCase):
def setUp(self):
self.athena = AWSAthenaHook(sleep_time=0)
self.athena = AthenaHook(sleep_time=0)

def test_init(self):
assert self.athena.aws_conn_id == 'aws_default'
assert self.athena.sleep_time == 0

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_run_query_without_token(self, mock_conn):
mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION
result = self.athena.run_query(
Expand All @@ -72,7 +72,7 @@ def test_hook_run_query_without_token(self, mock_conn):
mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
assert result == MOCK_DATA['query_execution_id']

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_run_query_with_token(self, mock_conn):
mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION
result = self.athena.run_query(
Expand All @@ -91,20 +91,20 @@ def test_hook_run_query_with_token(self, mock_conn):
mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
assert result == MOCK_DATA['query_execution_id']

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'])
assert result is None

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_get_query_results_with_default_params(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'])
expected_call_params = {'QueryExecutionId': MOCK_DATA['query_execution_id'], 'MaxResults': 1000}
mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_get_query_results_with_next_token(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results(
Expand All @@ -117,13 +117,13 @@ def test_hook_get_query_results_with_next_token(self, mock_conn):
}
mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_get_paginator_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'])
assert result is None

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_get_paginator_with_default_params(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'])
Expand All @@ -133,7 +133,7 @@ def test_hook_get_paginator_with_default_params(self, mock_conn):
}
mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_get_paginator_with_pagination_config(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results_paginator(
Expand All @@ -152,14 +152,14 @@ def test_hook_get_paginator_with_pagination_config(self, mock_conn):
}
mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_poll_query_when_final(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
result = self.athena.poll_query_status(query_execution_id=MOCK_DATA['query_execution_id'])
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == 'SUCCEEDED'

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.poll_query_status(
Expand All @@ -168,7 +168,7 @@ def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == 'RUNNING'

@mock.patch.object(AWSAthenaHook, 'get_conn')
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_get_output_location(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
result = self.athena.get_output_location(query_execution_id=MOCK_DATA['query_execution_id'])
Expand Down
Loading