From 0f6c4931bff533a6b670c10b60397900227e4c05 Mon Sep 17 00:00:00 2001 From: Victor Jimenez Date: Mon, 2 Jul 2018 23:30:45 +0200 Subject: [PATCH 1/6] [AIRFLOW-2709] Improve error handling in Databricks hook --- airflow/contrib/hooks/databricks_hook.py | 54 +++++++++++---- .../contrib/operators/databricks_operator.py | 8 ++- tests/contrib/hooks/test_databricks_hook.py | 66 +++++++++++++------ .../operators/test_databricks_operator.py | 10 +-- 4 files changed, 101 insertions(+), 37 deletions(-) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 1443ff4740b94..12df5f40457ee 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -24,6 +24,7 @@ from airflow.hooks.base_hook import BaseHook from requests import exceptions as requests_exceptions from requests.auth import AuthBase +from time import sleep from airflow.utils.log.logging_mixin import LoggingMixin @@ -47,7 +48,8 @@ def __init__( self, databricks_conn_id='databricks_default', timeout_seconds=180, - retry_limit=3): + retry_limit=3, + retry_delay=1): """ :param databricks_conn_id: The name of the databricks connection to use. :type databricks_conn_id: string @@ -57,12 +59,16 @@ def __init__( :param retry_limit: The number of times to retry the connection in case of service outages. :type retry_limit: int + :param retry_delay: The number of seconds to wait between retries (it + might be a floating point number). + :type retry_delay: float """ self.databricks_conn_id = databricks_conn_id self.databricks_conn = self.get_connection(databricks_conn_id) self.timeout_seconds = timeout_seconds assert retry_limit >= 1, 'Retry limit must be greater than equal to 1' self.retry_limit = retry_limit + self.retry_delay = retry_delay def _parse_host(self, host): """ @@ -117,7 +123,8 @@ def _do_api_call(self, endpoint_info, json): else: raise AirflowException('Unexpected HTTP Method: ' + method) - for attempt_num in range(1, self.retry_limit + 1): + attempt_num = 1 + while True: try: response = request_func( url, @@ -125,21 +132,42 @@ def _do_api_call(self, endpoint_info, json): auth=auth, headers=USER_AGENT_HEADER, timeout=self.timeout_seconds) - if response.status_code == requests.codes.ok: - return response.json() - else: + response.raise_for_status() + return response.json() + except (requests_exceptions.ConnectionError, + requests_exceptions.Timeout) as e: + self._log_request_error(attempt_num, e) + except requests_exceptions.HTTPError as e: + response = e.response + if not self._retriable_error(response): # In this case, the user probably made a mistake. # Don't retry. raise AirflowException('Response: {0}, Status Code: {1}'.format( response.content, response.status_code)) - except (requests_exceptions.ConnectionError, - requests_exceptions.Timeout) as e: - self.log.error( - 'Attempt %s API Request to Databricks failed with reason: %s', - attempt_num, e - ) - raise AirflowException(('API requests to Databricks failed {} times. ' + - 'Giving up.').format(self.retry_limit)) + + self._log_request_error(attempt_num, e) + + if attempt_num == self.retry_limit: + raise AirflowException(('API requests to Databricks failed {} times. ' + + 'Giving up.').format(self.retry_limit)) + + attempt_num += 1 + sleep(self.retry_delay) + + def _log_request_error(self, attempt_num, error): + self.log.error( + 'Attempt %s API Request to Databricks failed with reason: %s', + attempt_num, error + ) + + @staticmethod + def _retriable_error(response): + try: + error_code = response.json().get('error_code') + return error_code == 'TEMPORARILY_UNAVAILABLE' + except ValueError: + # not a valid JSON + return False def submit_run(self, json): """ diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py index 7b8d522dba85b..3245a99256502 100644 --- a/airflow/contrib/operators/databricks_operator.py +++ b/airflow/contrib/operators/databricks_operator.py @@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator): :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :type databricks_retry_limit: int + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :type databricks_retry_delay: float :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. :type do_xcom_push: boolean """ @@ -168,6 +171,7 @@ def __init__( databricks_conn_id='databricks_default', polling_period_seconds=30, databricks_retry_limit=3, + databricks_retry_delay=1, do_xcom_push=False, **kwargs): """ @@ -178,6 +182,7 @@ def __init__( self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay if spark_jar_task is not None: self.json['spark_jar_task'] = spark_jar_task if notebook_task is not None: @@ -232,7 +237,8 @@ def _log_run_page_url(self, url): def get_hook(self): return DatabricksHook( self.databricks_conn_id, - retry_limit=self.databricks_retry_limit) + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay) def execute(self, context): hook = self.get_hook() diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index 6052a6d54f1f8..e199f72332d2d 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -114,31 +114,51 @@ def test_init_bad_retry_limit(self): DatabricksHook(retry_limit = 0) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_error_retry(self, mock_requests): - for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: + @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') + def test_do_api_call_with_error_retry(self, _, mock_requests): + for exception in [ + requests_exceptions.ConnectionError(), + requests_exceptions.Timeout(), + self._build_http_error('TEMPORARILY_UNAVAILABLE')]: with mock.patch.object(self.hook.log, 'error') as mock_errors: - mock_requests.reset_mock() - mock_requests.post.side_effect = exception() + self._setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit) + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') + def test_do_api_call_waits_between_retries(self, mock_sleep, mock_requests): + retry_delay = 5 + self.hook = DatabricksHook(retry_delay=retry_delay) + + for exception in [ + requests_exceptions.ConnectionError(), + requests_exceptions.Timeout(), + self._build_http_error('TEMPORARILY_UNAVAILABLE')]: + with mock.patch.object(self.hook.log, 'error'): + mock_sleep.reset_mock() + self._setup_mock_requests(mock_requests, exception) + + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1) + mock_sleep.assert_called_with(retry_delay) + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_do_api_call_with_bad_status_code(self, mock_requests): - mock_requests.codes.ok = 200 - status_code_mock = mock.PropertyMock(return_value=500) - type(mock_requests.post.return_value).status_code = status_code_mock + response = mock.MagicMock() + response.raise_for_status.side_effect = self._build_http_error('ERROR') + mock_requests.post.return_value = response with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_submit_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {'run_id': '1'} - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock json = { 'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER @@ -158,10 +178,7 @@ def test_submit_run(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_page_url(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_page_url = self.hook.get_run_page_url(RUN_ID) @@ -175,10 +192,7 @@ def test_get_run_page_url(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_state(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_state = self.hook.get_run_state(RUN_ID) @@ -195,10 +209,7 @@ def test_get_run_state(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_cancel_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock self.hook.cancel_run(RUN_ID) @@ -209,6 +220,23 @@ def test_cancel_run(self, mock_requests): headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) + @staticmethod + def _setup_mock_requests(mock_requests, exception): + mock_requests.reset_mock() + if type(exception) in [requests_exceptions.ConnectionError, + requests_exceptions.Timeout]: + mock_requests.post.side_effect = exception + elif type(exception) == requests_exceptions.HTTPError: + mock_requests.raise_for_status.side_effect = exception + + @staticmethod + def _build_http_error(error_code): + response = mock.MagicMock() + error_info = {'error_code': error_code, 'message': ''} + response.json.return_value = error_info + response.text = json.dumps(error_info) + return requests_exceptions.HTTPError(response=response) + class DatabricksHookTokenTest(unittest.TestCase): """ diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py index f77da2ec18eda..afe1a92f28d9e 100644 --- a/tests/contrib/operators/test_databricks_operator.py +++ b/tests/contrib/operators/test_databricks_operator.py @@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class): 'run_name': TASK_ID }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) @@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class): 'run_name': TASK_ID, }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) From 58f2ec007705bd086e21644f0eb57fc092e6054b Mon Sep 17 00:00:00 2001 From: Victor Jimenez Date: Sun, 19 Aug 2018 18:51:13 +0200 Subject: [PATCH 2/6] Use float for default value --- airflow/contrib/hooks/databricks_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 12df5f40457ee..f04c942f0d9bd 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -49,7 +49,7 @@ def __init__( databricks_conn_id='databricks_default', timeout_seconds=180, retry_limit=3, - retry_delay=1): + retry_delay=1.0): """ :param databricks_conn_id: The name of the databricks connection to use. :type databricks_conn_id: string From c283dbbdacbecff1f3aabf1a22c5240daad2ed68 Mon Sep 17 00:00:00 2001 From: Victor Jimenez Date: Fri, 24 Aug 2018 10:33:16 +0200 Subject: [PATCH 3/6] Use status code to determine whether an error is retryable --- airflow/contrib/hooks/databricks_hook.py | 25 ++-- tests/contrib/hooks/test_databricks_hook.py | 119 ++++++++++++-------- 2 files changed, 84 insertions(+), 60 deletions(-) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index f04c942f0d9bd..79b42866e9e6c 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -134,16 +134,12 @@ def _do_api_call(self, endpoint_info, json): timeout=self.timeout_seconds) response.raise_for_status() return response.json() - except (requests_exceptions.ConnectionError, - requests_exceptions.Timeout) as e: - self._log_request_error(attempt_num, e) - except requests_exceptions.HTTPError as e: - response = e.response - if not self._retriable_error(response): + except requests_exceptions.RequestException as e: + if not _retryable_error(e): # In this case, the user probably made a mistake. # Don't retry. raise AirflowException('Response: {0}, Status Code: {1}'.format( - response.content, response.status_code)) + e.response.content, e.response.status_code)) self._log_request_error(attempt_num, e) @@ -160,15 +156,6 @@ def _log_request_error(self, attempt_num, error): attempt_num, error ) - @staticmethod - def _retriable_error(response): - try: - error_code = response.json().get('error_code') - return error_code == 'TEMPORARILY_UNAVAILABLE' - except ValueError: - # not a valid JSON - return False - def submit_run(self, json): """ Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint. @@ -201,6 +188,12 @@ def cancel_run(self, run_id): self._do_api_call(CANCEL_RUN_ENDPOINT, json) +def _retryable_error(exception): + return type(exception) == requests_exceptions.ConnectionError \ + or type(exception) == requests_exceptions.Timeout \ + or exception.response is not None and exception.response.status_code >= 500 + + RUN_LIFE_CYCLE_STATES = [ 'PENDING', 'RUNNING', diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index e199f72332d2d..e408dc2ab4a4b 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -18,15 +18,17 @@ # under the License. # +import itertools import json import unittest +from requests import exceptions as requests_exceptions + from airflow import __version__ -from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth +from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils import db -from requests import exceptions as requests_exceptions try: from unittest import mock @@ -79,12 +81,43 @@ def get_run_endpoint(host): """ return 'https://{}/api/2.0/jobs/runs/get'.format(host) + def cancel_run_endpoint(host): """ Utility function to generate the get run endpoint given the host. """ return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) + +def create_valid_response_mock(content): + response = mock.MagicMock() + response.json.return_value = content + return response + + +def create_post_side_effect(exception, status_code=500): + if exception in [requests_exceptions.ConnectionError, + requests_exceptions.Timeout]: + return exception() + elif exception == requests_exceptions.HTTPError: + response = mock.MagicMock() + response.status_code = status_code + response.raise_for_status.side_effect = exception(response=response) + return response + + +def setup_mock_requests(mock_requests, exception, status_code=500, error_count=None, response_content=None): + side_effect = create_post_side_effect(exception, status_code) + + if error_count is None: + # POST requests will fail indefinitely + mock_requests.post.side_effect = itertools.repeat(side_effect) + else: + # POST requests will fail 'error_count' times, and then they will succeed (once) + mock_requests.post.side_effect = \ + [side_effect] * error_count + [create_valid_response_mock(response_content)] + + class DatabricksHookTest(unittest.TestCase): """ Tests for DatabricksHook. @@ -99,7 +132,7 @@ def setUp(self, session=None): conn.password = PASSWORD session.commit() - self.hook = DatabricksHook() + self.hook = DatabricksHook(retry_delay=0) def test_parse_host_with_proper_host(self): host = self.hook._parse_host(HOST) @@ -111,36 +144,59 @@ def test_parse_host_with_scheme(self): def test_init_bad_retry_limit(self): with self.assertRaises(AssertionError): - DatabricksHook(retry_limit = 0) + DatabricksHook(retry_limit=0) - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') - def test_do_api_call_with_error_retry(self, _, mock_requests): + def test_do_api_call_retries_with_retryable_error(self): for exception in [ - requests_exceptions.ConnectionError(), - requests_exceptions.Timeout(), - self._build_http_error('TEMPORARILY_UNAVAILABLE')]: - with mock.patch.object(self.hook.log, 'error') as mock_errors: - self._setup_mock_requests(mock_requests, exception) + requests_exceptions.ConnectionError, + requests_exceptions.Timeout, + requests_exceptions.HTTPError]: + with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error') as mock_errors: + setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit) + self.assertEquals(mock_errors.call_count, self.hook.retry_limit) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests): + setup_mock_requests(mock_requests, requests_exceptions.HTTPError, status_code=400) + + with mock.patch.object(self.hook.log, 'error') as mock_errors: + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + mock_errors.assert_not_called() + + def test_do_api_call_succeeds_after_retrying(self): + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.Timeout, + requests_exceptions.HTTPError]: + with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error') as mock_errors: + setup_mock_requests(mock_requests, exception, error_count=2, response_content={'run_id': '1'}) + + response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(mock_errors.call_count, 2) + self.assertEquals(response, {'run_id': '1'}) + @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') - def test_do_api_call_waits_between_retries(self, mock_sleep, mock_requests): + def test_do_api_call_waits_between_retries(self, mock_sleep): retry_delay = 5 self.hook = DatabricksHook(retry_delay=retry_delay) for exception in [ - requests_exceptions.ConnectionError(), - requests_exceptions.Timeout(), - self._build_http_error('TEMPORARILY_UNAVAILABLE')]: - with mock.patch.object(self.hook.log, 'error'): + requests_exceptions.ConnectionError, + requests_exceptions.Timeout, + requests_exceptions.HTTPError]: + with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error'): mock_sleep.reset_mock() - self._setup_mock_requests(mock_requests, exception) + setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) @@ -148,14 +204,6 @@ def test_do_api_call_waits_between_retries(self, mock_sleep, mock_requests): self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1) mock_sleep.assert_called_with(retry_delay) - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_bad_status_code(self, mock_requests): - response = mock.MagicMock() - response.raise_for_status.side_effect = self._build_http_error('ERROR') - mock_requests.post.return_value = response - with self.assertRaises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {'run_id': '1'} @@ -220,23 +268,6 @@ def test_cancel_run(self, mock_requests): headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) - @staticmethod - def _setup_mock_requests(mock_requests, exception): - mock_requests.reset_mock() - if type(exception) in [requests_exceptions.ConnectionError, - requests_exceptions.Timeout]: - mock_requests.post.side_effect = exception - elif type(exception) == requests_exceptions.HTTPError: - mock_requests.raise_for_status.side_effect = exception - - @staticmethod - def _build_http_error(error_code): - response = mock.MagicMock() - error_info = {'error_code': error_code, 'message': ''} - response.json.return_value = error_info - response.text = json.dumps(error_info) - return requests_exceptions.HTTPError(response=response) - class DatabricksHookTokenTest(unittest.TestCase): """ From e71d671b293c197b0be8c8b6b02c250f6571fe34 Mon Sep 17 00:00:00 2001 From: Victor Jimenez Date: Fri, 24 Aug 2018 14:16:53 +0200 Subject: [PATCH 4/6] Fix wrong type in assertion --- tests/contrib/hooks/test_databricks_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index e57e434460dc4..54944334b5fac 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -143,7 +143,7 @@ def test_parse_host_with_scheme(self): self.assertEquals(host, HOST) def test_init_bad_retry_limit(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): DatabricksHook(retry_limit=0) def test_do_api_call_retries_with_retryable_error(self): From b19de034f04072075a531d249d62adf7cb805cf3 Mon Sep 17 00:00:00 2001 From: Victor Jimenez Date: Fri, 24 Aug 2018 15:47:05 +0200 Subject: [PATCH 5/6] Fix style to prevent lines from exceeding 90 characters --- tests/contrib/hooks/test_databricks_hook.py | 34 ++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index 54944334b5fac..df9967839e372 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -25,7 +25,11 @@ from requests import exceptions as requests_exceptions from airflow import __version__ -from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT +from airflow.contrib.hooks.databricks_hook import ( + DatabricksHook, + RunState, + SUBMIT_RUN_ENDPOINT +) from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils import db @@ -106,7 +110,13 @@ def create_post_side_effect(exception, status_code=500): return response -def setup_mock_requests(mock_requests, exception, status_code=500, error_count=None, response_content=None): +def setup_mock_requests( + mock_requests, + exception, + status_code=500, + error_count=None, + response_content=None): + side_effect = create_post_side_effect(exception, status_code) if error_count is None: @@ -151,7 +161,8 @@ def test_do_api_call_retries_with_retryable_error(self): requests_exceptions.ConnectionError, requests_exceptions.Timeout, requests_exceptions.HTTPError]: - with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ mock.patch.object(self.hook.log, 'error') as mock_errors: setup_mock_requests(mock_requests, exception) @@ -162,7 +173,9 @@ def test_do_api_call_retries_with_retryable_error(self): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests): - setup_mock_requests(mock_requests, requests_exceptions.HTTPError, status_code=400) + setup_mock_requests( + mock_requests, requests_exceptions.HTTPError, status_code=400 + ) with mock.patch.object(self.hook.log, 'error') as mock_errors: with self.assertRaises(AirflowException): @@ -175,9 +188,15 @@ def test_do_api_call_succeeds_after_retrying(self): requests_exceptions.ConnectionError, requests_exceptions.Timeout, requests_exceptions.HTTPError]: - with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ mock.patch.object(self.hook.log, 'error') as mock_errors: - setup_mock_requests(mock_requests, exception, error_count=2, response_content={'run_id': '1'}) + setup_mock_requests( + mock_requests, + exception, + error_count=2, + response_content={'run_id': '1'} + ) response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) @@ -193,7 +212,8 @@ def test_do_api_call_waits_between_retries(self, mock_sleep): requests_exceptions.ConnectionError, requests_exceptions.Timeout, requests_exceptions.HTTPError]: - with mock.patch('airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ mock.patch.object(self.hook.log, 'error'): mock_sleep.reset_mock() setup_mock_requests(mock_requests, exception) From fdec2d00b54444ae0b4d2ebdf35aaff6db15c340 Mon Sep 17 00:00:00 2001 From: Victor Jimenez Date: Mon, 27 Aug 2018 21:16:59 +0200 Subject: [PATCH 6/6] Fix wrong way of checking exception type --- airflow/contrib/hooks/databricks_hook.py | 4 ++-- tests/contrib/hooks/test_databricks_hook.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index bb89113491426..5b97a0eba0391 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -191,8 +191,8 @@ def cancel_run(self, run_id): def _retryable_error(exception): - return type(exception) == requests_exceptions.ConnectionError \ - or type(exception) == requests_exceptions.Timeout \ + return isinstance(exception, requests_exceptions.ConnectionError) \ + or isinstance(exception, requests_exceptions.Timeout) \ or exception.response is not None and exception.response.status_code >= 500 diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index df9967839e372..a022431899b4d 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -100,10 +100,9 @@ def create_valid_response_mock(content): def create_post_side_effect(exception, status_code=500): - if exception in [requests_exceptions.ConnectionError, - requests_exceptions.Timeout]: + if exception != requests_exceptions.HTTPError: return exception() - elif exception == requests_exceptions.HTTPError: + else: response = mock.MagicMock() response.status_code = status_code response.raise_for_status.side_effect = exception(response=response) @@ -159,7 +158,9 @@ def test_init_bad_retry_limit(self): def test_do_api_call_retries_with_retryable_error(self): for exception in [ requests_exceptions.ConnectionError, + requests_exceptions.SSLError, requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError]: with mock.patch( 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ @@ -186,7 +187,9 @@ def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests def test_do_api_call_succeeds_after_retrying(self): for exception in [ requests_exceptions.ConnectionError, + requests_exceptions.SSLError, requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError]: with mock.patch( 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ @@ -210,7 +213,9 @@ def test_do_api_call_waits_between_retries(self, mock_sleep): for exception in [ requests_exceptions.ConnectionError, + requests_exceptions.SSLError, requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, requests_exceptions.HTTPError]: with mock.patch( 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \