Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class DatabricksHook(BaseDatabricksHook):
service outages.
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
"""

hook_name = 'Databricks'
Expand All @@ -110,8 +111,9 @@ def __init__(
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
retry_args: Optional[Dict[Any, Any]] = None,
) -> None:
super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay)
super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args)

def run_now(self, json: dict) -> int:
"""
Expand Down
196 changes: 104 additions & 92 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@
operators talk to the ``api/2.0/jobs/runs/submit``
`endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_.
"""
import copy
import sys
import time
from time import sleep
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse

import requests
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from tenacity import RetryError, Retrying, retry_if_exception, stop_after_attempt, wait_exponential

from airflow import __version__
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -68,6 +69,7 @@ class BaseDatabricksHook(BaseHook):
service outages.
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
"""

conn_name_attr = 'databricks_conn_id'
Expand All @@ -89,17 +91,33 @@ def __init__(
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
retry_args: Optional[Dict[Any, Any]] = None,
) -> None:
super().__init__()
self.databricks_conn_id = databricks_conn_id
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError('Retry limit must be greater than equal to 1')
raise ValueError('Retry limit must be greater than or equal to 1')
self.retry_limit = retry_limit
self.retry_delay = retry_delay
self.aad_tokens: Dict[str, dict] = {}
self.aad_timeout_seconds = 10

def my_after_func(retry_state):
self._log_request_error(retry_state.attempt_number, retry_state.outcome)

if retry_args:
self.retry_args = copy.copy(retry_args)
self.retry_args['retry'] = retry_if_exception(self._retryable_error)
self.retry_args['after'] = my_after_func
else:
self.retry_args = dict(
stop=stop_after_attempt(self.retry_limit),
wait=wait_exponential(min=self.retry_delay, max=(2 ** retry_limit)),
retry=retry_if_exception(self._retryable_error),
after=my_after_func,
)

@cached_property
def databricks_conn(self) -> Connection:
return self.get_connection(self.databricks_conn_id)
Expand Down Expand Up @@ -143,6 +161,13 @@ def _parse_host(host: str) -> str:
# In this case, host = xx.cloud.databricks.com
return host

def _get_retry_object(self) -> Retrying:
"""
Instantiates a retry object
:return: instance of Retrying class
"""
return Retrying(**self.retry_args)

def _get_aad_token(self, resource: str) -> str:
"""
Function to get AAD token for given resource. Supports managed identity or service principal auth
Expand All @@ -154,60 +179,59 @@ def _get_aad_token(self, resource: str) -> str:
return aad_token['token']

self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
attempt_num = 1
while True:
try:
if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
resp = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**USER_AGENT_HEADER, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
)
else:
tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id']
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
)
resp = requests.post(
AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={**USER_AGENT_HEADER, 'Content-Type': 'application/x-www-form-urlencoded'},
timeout=self.aad_timeout_seconds,
)

resp.raise_for_status()
jsn = resp.json()
if 'access_token' not in jsn or jsn.get('token_type') != 'Bearer' or 'expires_on' not in jsn:
raise AirflowException(f"Can't get necessary data from AAD token: {jsn}")

token = jsn['access_token']
self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])}

return token
except requests_exceptions.RequestException as e:
if not self._retryable_error(e):
raise AirflowException(
f'Response: {e.response.content}, Status Code: {e.response.status_code}'
)

self._log_request_error(attempt_num, e.strerror)

if attempt_num == self.retry_limit:
raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')

attempt_num += 1
sleep(self.retry_delay)
try:
for attempt in self._get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
resp = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**USER_AGENT_HEADER, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
)
else:
tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id']
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
)
resp = requests.post(
AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**USER_AGENT_HEADER,
'Content-Type': 'application/x-www-form-urlencoded',
},
timeout=self.aad_timeout_seconds,
)

resp.raise_for_status()
jsn = resp.json()
if (
'access_token' not in jsn
or jsn.get('token_type') != 'Bearer'
or 'expires_on' not in jsn
):
raise AirflowException(f"Can't get necessary data from AAD token: {jsn}")

token = jsn['access_token']
self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])}
break
except RetryError:
raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')
except requests_exceptions.HTTPError as e:
raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}')

return token

def _get_aad_headers(self) -> dict:
"""
Expand Down Expand Up @@ -279,14 +303,6 @@ def _get_token(self, raise_error: bool = False) -> Optional[str]:

return None

@staticmethod
def _retryable_error(exception) -> bool:
return (
isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout))
or exception.response is not None
and exception.response.status_code >= 500
)

def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error)

Expand Down Expand Up @@ -327,36 +343,32 @@ def _do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str,
else:
raise AirflowException('Unexpected HTTP Method: ' + method)

attempt_num = 1
while True:
try:
response = request_func(
url,
json=json if method in ('POST', 'PATCH') else None,
params=json if method == 'GET' else None,
auth=auth,
headers=headers,
timeout=self.timeout_seconds,
)
response.raise_for_status()
return response.json()
except requests_exceptions.RequestException as e:
if not self._retryable_error(e):
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(
f'Response: {e.response.content}, Status Code: {e.response.status_code}'
try:
for attempt in self._get_retry_object():
with attempt:
response = request_func(
url,
json=json if method in ('POST', 'PATCH') else None,
params=json if method == 'GET' else None,
auth=auth,
headers=headers,
timeout=self.timeout_seconds,
)
response.raise_for_status()
return response.json()
except RetryError:
raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.')
except requests_exceptions.HTTPError as e:
raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}')

self._log_request_error(attempt_num, str(e))

if attempt_num == self.retry_limit:
raise AirflowException(
f'API requests to Databricks failed {self.retry_limit} times. Giving up.'
)

attempt_num += 1
sleep(self.retry_delay)
@staticmethod
def _retryable_error(exception: BaseException) -> bool:
if not isinstance(exception, requests_exceptions.RequestException):
return False
return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or (
exception.response is not None
and (exception.response.status_code >= 500 or exception.response.status_code == 429)
)


class _TokenAuth(AuthBase):
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
"""

Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: Optional[Dict[Any, Any]] = None,
do_xcom_push: bool = False,
idempotency_token: Optional[str] = None,
access_control_list: Optional[List[Dict[str, str]]] = None,
Expand All @@ -287,6 +289,7 @@ def __init__(
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
if tasks is not None:
self.json['tasks'] = tasks
Expand Down Expand Up @@ -327,6 +330,7 @@ def _get_hook(self) -> DatabricksHook:
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
)

def execute(self, context: 'Context'):
Expand Down Expand Up @@ -484,6 +488,7 @@ class DatabricksRunNowOperator(BaseOperator):
this run. By default the operator will poll every 30 seconds.
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
"""

Expand All @@ -508,6 +513,7 @@ def __init__(
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: Optional[Dict[Any, Any]] = None,
do_xcom_push: bool = False,
wait_for_termination: bool = True,
**kwargs,
Expand All @@ -519,6 +525,7 @@ def __init__(
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination

if job_id is not None:
Expand Down Expand Up @@ -546,6 +553,7 @@ def _get_hook(self) -> DatabricksHook:
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
)

def execute(self, context: 'Context'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,7 @@ Note that there is exactly one named parameter for each top level parameter in t
- amount of times retry if the Databricks backend is unreachable
* - databricks_retry_delay: decimal
- number of seconds to wait between retries
* - databricks_retry_args: dict
- An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
* - do_xcom_push: boolean
- whether we should push run_id and run_page_url to xcom
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ one named parameter for each top level parameter in the ``runs/submit`` endpoint
- amount of times retry if the Databricks backend is unreachable
* - databricks_retry_delay: decimal
- number of seconds to wait between retries
* - databricks_retry_args: dict
- An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
* - do_xcom_push: boolean
- whether we should push run_id and run_page_url to xcom

Expand Down
Loading