From 4c58dccf9df6137e3a51609d14149770cea0cd48 Mon Sep 17 00:00:00 2001 From: John Brandborg Date: Tue, 1 Aug 2023 14:40:15 +0200 Subject: [PATCH 1/4] Add Service Principal Oauth for Databricks. (apache/airflow#32969) --- .../databricks/hooks/databricks_base.py | 117 +++++++++++++++++- .../connections/databricks.rst | 4 + .../databricks/hooks/test_databricks.py | 85 +++++++++++++ 3 files changed, 201 insertions(+), 5 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index 9885e9a998058..74add417ca7b7 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -62,6 +62,7 @@ TOKEN_REFRESH_LEAD_TIME = 120 AZURE_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/" DEFAULT_DATABRICKS_SCOPE = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" +OIDC_TOKEN_SERVICE_URL = "{}/oidc/v1/token" class BaseDatabricksHook(BaseHook): @@ -89,6 +90,7 @@ class BaseDatabricksHook(BaseHook): "azure_ad_endpoint", "azure_resource_id", "azure_tenant_id", + "service_principal_oauth", ] def __init__( @@ -108,7 +110,8 @@ def __init__( self.retry_limit = retry_limit self.retry_delay = retry_delay self.aad_tokens: dict[str, dict] = {} - self.aad_timeout_seconds = 10 + self.sp_token: dict[str, Any] = {} + self.token_timeout_seconds = 10 self.caller = caller def my_after_func(retry_state): @@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying: """ return AsyncRetrying(**self.retry_args) + def _get_sp_token(self) -> str: + """Function to get Service Principal token.""" + if self.sp_token and self._is_sp_token_valid(self.sp_token): + return self.sp_token["token"] + + self.log.info("Existing Service Principal token is expired, or going to expire soon. Refreshing...") + try: + for attempt in self._get_retry_object(): + with attempt: + resp = requests.post( + OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host), + auth=HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password), + data="grant_type=client_credentials&scope=all-apis", + headers={ + **self.user_agent_header, + "Content-Type": "application/x-www-form-urlencoded", + }, + timeout=self.token_timeout_seconds, + ) + + resp.raise_for_status() + jsn = resp.json() + if ( + "access_token" not in jsn + or jsn.get("token_type") != "Bearer" + or "expires_in" not in jsn + ): + raise AirflowException( + f"Can't get necessary data from Service Principal token: {jsn}" + ) + + token = jsn["access_token"] + self.sp_token = {"token": token, "expires_on": int(time.time() + jsn["expires_in"])} + break + except RetryError: + raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.") + except requests_exceptions.HTTPError as e: + raise AirflowException(f"Response: {e.response.content}, Status Code: {e.response.status_code}") + + return token + + async def _a_get_sp_token(self) -> str: + """Async version of `_get_sp_token()`.""" + if self.sp_token and self._is_sp_token_valid(self.sp_token): + return self.sp_token["token"] + + self.log.info("Existing Service Principal token is expired, or going to expire soon. Refreshing...") + try: + async for attempt in self._a_get_retry_object(): + with attempt: + async with self._session.post( + OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host), + auth=HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password), + data="grant_type=client_credentials&scope=all-apis", + headers={ + **self.user_agent_header, + "Content-Type": "application/x-www-form-urlencoded", + }, + timeout=self.token_timeout_seconds, + ) as resp: + resp.raise_for_status() + jsn = await resp.json() + if ( + "access_token" not in jsn + or jsn.get("token_type") != "Bearer" + or "expires_in" not in jsn + ): + raise AirflowException( + f"Can't get necessary data from Service Principal token: {jsn}" + ) + + token = jsn["access_token"] + self.sp_token = {"token": token, "expires_on": int(time.time() + jsn["expires_in"])} + break + except RetryError: + raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.") + except requests_exceptions.HTTPError as e: + raise AirflowException(f"Response: {e.response.content}, Status Code: {e.response.status_code}") + + return token + + @staticmethod + def _is_sp_token_valid(sp_token: dict) -> bool: + """ + Utility function to check Service Principal token hasn't expired yet. + + :param aad_token: dict with properties of AAD token + :return: true if token is valid, false otherwise + """ + now = int(time.time()) + if sp_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME): + return True + return False + def _get_aad_token(self, resource: str) -> str: """ Function to get AAD token for given resource. @@ -235,7 +332,7 @@ def _get_aad_token(self, resource: str) -> str: AZURE_METADATA_SERVICE_TOKEN_URL, params=params, headers={**self.user_agent_header, "Metadata": "true"}, - timeout=self.aad_timeout_seconds, + timeout=self.token_timeout_seconds, ) else: tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"] @@ -255,7 +352,7 @@ def _get_aad_token(self, resource: str) -> str: **self.user_agent_header, "Content-Type": "application/x-www-form-urlencoded", }, - timeout=self.aad_timeout_seconds, + timeout=self.token_timeout_seconds, ) resp.raise_for_status() @@ -301,7 +398,7 @@ async def _a_get_aad_token(self, resource: str) -> str: url=AZURE_METADATA_SERVICE_TOKEN_URL, params=params, headers={**self.user_agent_header, "Metadata": "true"}, - timeout=self.aad_timeout_seconds, + timeout=self.token_timeout_seconds, ) as resp: resp.raise_for_status() jsn = await resp.json() @@ -323,7 +420,7 @@ async def _a_get_aad_token(self, resource: str) -> str: **self.user_agent_header, "Content-Type": "application/x-www-form-urlencoded", }, - timeout=self.aad_timeout_seconds, + timeout=self.token_timeout_seconds, ) as resp: resp.raise_for_status() jsn = await resp.json() @@ -443,6 +540,11 @@ def _get_token(self, raise_error: bool = False) -> str | None: self.log.info("Using AAD Token for managed identity.") self._check_azure_metadata_service() return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) + elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False): + if self.databricks_conn.login == "" or self.databricks_conn.password == "": + raise AirflowException("Service Principal credentials aren't provided") + self.log.info("Using Service Principal Token.") + return self._get_sp_token() elif raise_error: raise AirflowException("Token authentication isn't configured") @@ -466,6 +568,11 @@ async def _a_get_token(self, raise_error: bool = False) -> str | None: self.log.info("Using AAD Token for managed identity.") await self._a_check_azure_metadata_service() return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE) + elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False): + if self.databricks_conn.login == "" or self.databricks_conn.password == "": + raise AirflowException("Service Principal credentials aren't provided") + self.log.info("Using Service Principal Token.") + return await self._a_get_sp_token() elif raise_error: raise AirflowException("Token authentication isn't configured") diff --git a/docs/apache-airflow-providers-databricks/connections/databricks.rst b/docs/apache-airflow-providers-databricks/connections/databricks.rst index 6303702b7ec1a..86c0b7d6b4c68 100644 --- a/docs/apache-airflow-providers-databricks/connections/databricks.rst +++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst @@ -70,6 +70,10 @@ Extra (optional) * ``token``: Specify PAT to use. Consider to switch to specification of PAT in the Password field as it's more secure. + Following parameters are necessary if using Service Principal with Oauth token: + + * ``service_principal_oauth``: Specify as 'true', and use Client ID and Client Secret as the Username and Password. + Following parameters are necessary if using authentication with AAD token: * ``azure_tenant_id``: ID of the Azure Active Directory tenant diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index f55c55dfc364c..641b822d3637c 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -697,6 +697,14 @@ def test_is_aad_token_valid_returns_false(self): aad_token = {"token": "my_token", "expires_on": int(time.time())} assert not self.hook._is_aad_token_valid(aad_token) + def test_is_sp_token_valid_returns_true(self): + sp_token = {"token": "my_token", "expires_on": int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10} + assert self.hook._is_sp_token_valid(sp_token) + + def test_is_sp_token_valid_returns_false(self): + sp_token = {"token": "my_token", "expires_on": int(time.time())} + assert not self.hook._is_sp_token_valid(sp_token) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_list_jobs_success_single_page(self, mock_requests): mock_requests.codes.ok = 200 @@ -1448,3 +1456,80 @@ async def test_get_run_state(self, mock_get): assert ad_call_args[1]["url"] == AZURE_METADATA_SERVICE_INSTANCE_URL assert ad_call_args[1]["params"]["api-version"] > "2018-02-01" assert ad_call_args[1]["headers"]["Metadata"] == "true" + + +def create_sp_token_for_resource() -> dict: + return { + "token_type": "Bearer", + "expires_in": "3600", + "access_token": TOKEN, + } + + +class TestDatabricksHookSpToken: + """ + Tests for DatabricksHook when auth is done with Service Principal Oauth token. + """ + + @provide_session + def setup_method(self, method, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.login = "c64f6d12-f6e4-45a4-846e-032b42b27758" + conn.password = "secret" + conn.extra = json.dumps({"service_principal_oauth": True}) + session.commit() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_submit_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.side_effect = [ + create_successful_response_mock(create_sp_token_for_resource()), + create_successful_response_mock({"run_id": "1"}), + ] + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + data = {"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER} + run_id = self.hook.submit_run(data) + + assert run_id == "1" + args = mock_requests.post.call_args + kwargs = args[1] + assert kwargs["auth"].token == TOKEN + + +class TestDatabricksHookAsyncSpToken: + """ + Tests for DatabricksHook using async methods when auth is done with Service + Principal Oauth token. + """ + + @provide_session + def setup_method(self, method, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.login = "c64f6d12-f6e4-45a4-846e-032b42b27758" + conn.password = "secret" + conn.extra = json.dumps({"service_principal_oauth": True}) + session.commit() + self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") + @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post") + async def test_get_run_state(self, mock_post, mock_get): + mock_post.return_value.__aenter__.return_value.json = AsyncMock( + return_value=create_sp_token_for_resource(DEFAULT_DATABRICKS_SCOPE) + ) + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + + async with self.hook: + run_state = await self.hook.a_get_run_state(RUN_ID) + + assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + mock_get.assert_called_once_with( + get_run_endpoint(HOST), + json={"run_id": RUN_ID}, + auth=BearerAuth(TOKEN), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) From cad3fc68adf6aeeb0c95895b0ed3bc48a024c221 Mon Sep 17 00:00:00 2001 From: John Brandborg Date: Mon, 7 Aug 2023 14:22:00 +0200 Subject: [PATCH 2/4] Consolidate OAuth validation and storage (#32969) --- .../databricks/hooks/databricks_base.py | 123 ++++++------------ .../databricks/hooks/test_databricks.py | 53 ++++++-- 2 files changed, 82 insertions(+), 94 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index 74add417ca7b7..6d0d929b5dc4e 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -109,8 +109,7 @@ def __init__( raise ValueError("Retry limit must be greater than or equal to 1") self.retry_limit = retry_limit self.retry_delay = retry_delay - self.aad_tokens: dict[str, dict] = {} - self.sp_token: dict[str, Any] = {} + self.oauth_tokens: dict[str, dict] = {} self.token_timeout_seconds = 10 self.caller = caller @@ -213,17 +212,18 @@ def _a_get_retry_object(self) -> AsyncRetrying: """ return AsyncRetrying(**self.retry_args) - def _get_sp_token(self) -> str: + def _get_sp_token(self, resource: str) -> str: """Function to get Service Principal token.""" - if self.sp_token and self._is_sp_token_valid(self.sp_token): - return self.sp_token["token"] + sp_token = self.oauth_tokens.get(resource) + if sp_token and self._is_oauth_token_valid(sp_token): + return sp_token["access_token"] self.log.info("Existing Service Principal token is expired, or going to expire soon. Refreshing...") try: for attempt in self._get_retry_object(): with attempt: resp = requests.post( - OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host), + resource, auth=HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password), data="grant_type=client_credentials&scope=all-apis", headers={ @@ -235,36 +235,30 @@ def _get_sp_token(self) -> str: resp.raise_for_status() jsn = resp.json() - if ( - "access_token" not in jsn - or jsn.get("token_type") != "Bearer" - or "expires_in" not in jsn - ): - raise AirflowException( - f"Can't get necessary data from Service Principal token: {jsn}" - ) + jsn["expires_on"] = int(time.time() + jsn["expires_in"]) - token = jsn["access_token"] - self.sp_token = {"token": token, "expires_on": int(time.time() + jsn["expires_in"])} + self._is_oauth_token_valid(jsn) + self.oauth_tokens[resource] = jsn break except RetryError: raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.") except requests_exceptions.HTTPError as e: raise AirflowException(f"Response: {e.response.content}, Status Code: {e.response.status_code}") - return token + return jsn["access_token"] - async def _a_get_sp_token(self) -> str: + async def _a_get_sp_token(self, resource: str) -> str: """Async version of `_get_sp_token()`.""" - if self.sp_token and self._is_sp_token_valid(self.sp_token): - return self.sp_token["token"] + sp_token = self.oauth_tokens.get(resource) + if sp_token and self._is_oauth_token_valid(sp_token): + return sp_token["access_token"] self.log.info("Existing Service Principal token is expired, or going to expire soon. Refreshing...") try: async for attempt in self._a_get_retry_object(): with attempt: async with self._session.post( - OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host), + resource, auth=HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password), data="grant_type=client_credentials&scope=all-apis", headers={ @@ -275,37 +269,17 @@ async def _a_get_sp_token(self) -> str: ) as resp: resp.raise_for_status() jsn = await resp.json() - if ( - "access_token" not in jsn - or jsn.get("token_type") != "Bearer" - or "expires_in" not in jsn - ): - raise AirflowException( - f"Can't get necessary data from Service Principal token: {jsn}" - ) + jsn["expires_on"] = int(time.time() + jsn["expires_in"]) - token = jsn["access_token"] - self.sp_token = {"token": token, "expires_on": int(time.time() + jsn["expires_in"])} + self._is_oauth_token_valid(jsn) + self.oauth_tokens[resource] = jsn break except RetryError: raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.") except requests_exceptions.HTTPError as e: raise AirflowException(f"Response: {e.response.content}, Status Code: {e.response.status_code}") - return token - - @staticmethod - def _is_sp_token_valid(sp_token: dict) -> bool: - """ - Utility function to check Service Principal token hasn't expired yet. - - :param aad_token: dict with properties of AAD token - :return: true if token is valid, false otherwise - """ - now = int(time.time()) - if sp_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME): - return True - return False + return jsn["access_token"] def _get_aad_token(self, resource: str) -> str: """ @@ -315,9 +289,9 @@ def _get_aad_token(self, resource: str) -> str: :param resource: resource to issue token to :return: AAD token, or raise an exception """ - aad_token = self.aad_tokens.get(resource) - if aad_token and self._is_aad_token_valid(aad_token): - return aad_token["token"] + aad_token = self.oauth_tokens.get(resource) + if aad_token and self._is_oauth_token_valid(aad_token): + return aad_token["access_token"] self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") try: @@ -357,22 +331,16 @@ def _get_aad_token(self, resource: str) -> str: resp.raise_for_status() jsn = resp.json() - if ( - "access_token" not in jsn - or jsn.get("token_type") != "Bearer" - or "expires_on" not in jsn - ): - raise AirflowException(f"Can't get necessary data from AAD token: {jsn}") - - token = jsn["access_token"] - self.aad_tokens[resource] = {"token": token, "expires_on": int(jsn["expires_on"])} + + self._is_oauth_token_valid(jsn) + self.oauth_tokens[resource] = jsn break except RetryError: raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.") except requests_exceptions.HTTPError as e: raise AirflowException(f"Response: {e.response.content}, Status Code: {e.response.status_code}") - return token + return jsn["access_token"] async def _a_get_aad_token(self, resource: str) -> str: """ @@ -381,9 +349,9 @@ async def _a_get_aad_token(self, resource: str) -> str: :param resource: resource to issue token to :return: AAD token, or raise an exception """ - aad_token = self.aad_tokens.get(resource) - if aad_token and self._is_aad_token_valid(aad_token): - return aad_token["token"] + aad_token = self.oauth_tokens.get(resource) + if aad_token and self._is_oauth_token_valid(aad_token): + return aad_token["access_token"] self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") try: @@ -424,22 +392,16 @@ async def _a_get_aad_token(self, resource: str) -> str: ) as resp: resp.raise_for_status() jsn = await resp.json() - if ( - "access_token" not in jsn - or jsn.get("token_type") != "Bearer" - or "expires_on" not in jsn - ): - raise AirflowException(f"Can't get necessary data from AAD token: {jsn}") - - token = jsn["access_token"] - self.aad_tokens[resource] = {"token": token, "expires_on": int(jsn["expires_on"])} + + self._is_oauth_token_valid(jsn) + self.oauth_tokens[resource] = jsn break except RetryError: raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.") except aiohttp.ClientResponseError as err: raise AirflowException(f"Response: {err.message}, Status Code: {err.status}") - return token + return jsn["access_token"] def _get_aad_headers(self) -> dict: """ @@ -472,17 +434,18 @@ async def _a_get_aad_headers(self) -> dict: return headers @staticmethod - def _is_aad_token_valid(aad_token: dict) -> bool: + def _is_oauth_token_valid(token: dict, time_key="expires_on") -> bool: """ - Utility function to check AAD token hasn't expired yet. + Utility function to check if an OAuth token is valid and hasn't expired yet. - :param aad_token: dict with properties of AAD token + :param sp_token: dict with properties of OAuth token + :param time_key: name of the key that holds the time of expiration :return: true if token is valid, false otherwise """ - now = int(time.time()) - if aad_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME): - return True - return False + if "access_token" not in token or token.get("token_type", "") != "Bearer" or time_key not in token: + raise AirflowException(f"Can't get necessary data from OAuth token: {token}") + + return int(token[time_key]) > (int(time.time()) + TOKEN_REFRESH_LEAD_TIME) @staticmethod def _check_azure_metadata_service() -> None: @@ -544,7 +507,7 @@ def _get_token(self, raise_error: bool = False) -> str | None: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Service Principal credentials aren't provided") self.log.info("Using Service Principal Token.") - return self._get_sp_token() + return self._get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host)) elif raise_error: raise AirflowException("Token authentication isn't configured") @@ -572,7 +535,7 @@ async def _a_get_token(self, raise_error: bool = False) -> str | None: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Service Principal credentials aren't provided") self.log.info("Using Service Principal Token.") - return await self._a_get_sp_token() + return await self._a_get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host)) elif raise_error: raise AirflowException("Token authentication isn't configured") diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 641b822d3637c..f768df389eda0 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -689,21 +689,42 @@ def test_uninstall_libs_on_cluster(self, mock_requests): timeout=self.hook.timeout_seconds, ) - def test_is_aad_token_valid_returns_true(self): - aad_token = {"token": "my_token", "expires_on": int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10} - assert self.hook._is_aad_token_valid(aad_token) + def test_is_oauth_token_valid_returns_true(self): + token = { + "access_token": "my_token", + "expires_on": int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10, + "token_type": "Bearer", + } + assert self.hook._is_oauth_token_valid(token) - def test_is_aad_token_valid_returns_false(self): - aad_token = {"token": "my_token", "expires_on": int(time.time())} - assert not self.hook._is_aad_token_valid(aad_token) + def test_is_oauth_token_valid_returns_false(self): + token = { + "access_token": "my_token", + "expires_on": int(time.time()), + "token_type": "Bearer", + } + assert not self.hook._is_oauth_token_valid(token) - def test_is_sp_token_valid_returns_true(self): - sp_token = {"token": "my_token", "expires_on": int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10} - assert self.hook._is_sp_token_valid(sp_token) + def test_is_oauth_token_valid_raises_missing_token(self): + with pytest.raises(AirflowException): + self.hook._is_oauth_token_valid({}) - def test_is_sp_token_valid_returns_false(self): - sp_token = {"token": "my_token", "expires_on": int(time.time())} - assert not self.hook._is_sp_token_valid(sp_token) + def test_is_oauth_token_valid_raises_invalid_type(self): + token_missing_type = {"access_token": "my_token"} + token_wrong_type = {"access_token": "my_token", "token_type": "not bearer"} + + with pytest.raises(AirflowException): + self.hook._is_oauth_token_valid(token_missing_type) + self.hook._is_oauth_token_valid(token_wrong_type) + + def test_is_oauth_token_valid_raises_wrong_time_key(self): + token = { + "access_token": "my_token", + "expires_on": 0, + "token_type": "Bearer", + } + with pytest.raises(AirflowException): + self.hook._is_oauth_token_valid(token, time_key="expiration") @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_list_jobs_success_single_page(self, mock_requests): @@ -1461,7 +1482,7 @@ async def test_get_run_state(self, mock_get): def create_sp_token_for_resource() -> dict: return { "token_type": "Bearer", - "expires_in": "3600", + "expires_in": 3600, "access_token": TOKEN, } @@ -1492,6 +1513,10 @@ def test_submit_run(self, mock_requests): data = {"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER} run_id = self.hook.submit_run(data) + ad_call_args = mock_requests.method_calls[0] + assert ad_call_args[1][0] == "xx.cloud.databricks.com/oidc/v1/token" + assert ad_call_args[2]["data"] == "grant_type=client_credentials&scope=all-apis" + assert run_id == "1" args = mock_requests.post.call_args kwargs = args[1] @@ -1518,7 +1543,7 @@ def setup_method(self, method, session=None): @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.post") async def test_get_run_state(self, mock_post, mock_get): mock_post.return_value.__aenter__.return_value.json = AsyncMock( - return_value=create_sp_token_for_resource(DEFAULT_DATABRICKS_SCOPE) + return_value=create_sp_token_for_resource() ) mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) From ab06a253dceb7abfdc3ccbb0a67112195af9cdcd Mon Sep 17 00:00:00 2001 From: John Brandborg Date: Mon, 7 Aug 2023 14:41:35 +0200 Subject: [PATCH 3/4] Formatting update and added API and Spark to Init Import --- tests/providers/databricks/hooks/test_databricks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index f768df389eda0..8644d95cd9c4c 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -43,6 +43,7 @@ AZURE_METADATA_SERVICE_INSTANCE_URL, AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, + OIDC_TOKEN_SERVICE_URL, TOKEN_REFRESH_LEAD_TIME, BearerAuth, ) @@ -1514,7 +1515,7 @@ def test_submit_run(self, mock_requests): run_id = self.hook.submit_run(data) ad_call_args = mock_requests.method_calls[0] - assert ad_call_args[1][0] == "xx.cloud.databricks.com/oidc/v1/token" + assert ad_call_args[1][0] == OIDC_TOKEN_SERVICE_URL.format(HOST) assert ad_call_args[2]["data"] == "grant_type=client_credentials&scope=all-apis" assert run_id == "1" From 7a44fa409cf3c00da2de2bb596d04eba17272cca Mon Sep 17 00:00:00 2001 From: John Brandborg Date: Sun, 13 Aug 2023 20:17:55 +0200 Subject: [PATCH 4/4] Update documentation with doc ref and note this is for AWS deployments --- .../connections/databricks.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/apache-airflow-providers-databricks/connections/databricks.rst b/docs/apache-airflow-providers-databricks/connections/databricks.rst index 86c0b7d6b4c68..908b12eac7f03 100644 --- a/docs/apache-airflow-providers-databricks/connections/databricks.rst +++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst @@ -55,13 +55,15 @@ Host (required) Login (optional) * If authentication with *Databricks login credentials* is used then specify the ``username`` used to login to Databricks. - * If *authentication with Azure Service Principal* is used then specify the ID of the Azure Service Principal + * If authentication with *Azure Service Principal* is used then specify the ID of the Azure Service Principal * If authentication with *PAT* is used then either leave this field empty or use 'token' as login (both work, the only difference is that if login is empty then token will be sent in request header as Bearer token, if login is 'token' then it will be sent using Basic Auth which is allowed by Databricks API, this may be useful if you plan to reuse this connection with e.g. SimpleHttpOperator) + * If authentication with *Databricks Service Principal OAuth* is used then specify the ID of the Service Principal (Databricks on AWS) Password (optional) - * If authentication with *Databricks login credentials* is used then specify the ``password`` used to login to Databricks. + * If authentication with *Databricks login credentials* is used then specify the ``password`` used to login to Databricks. * If authentication with *Azure Service Principal* is used then specify the secret of the Azure Service Principal * If authentication with *PAT* is used, then specify PAT (recommended) + * If authentication with *Databricks Service Principal OAuth* is used then specify the secret of the Service Principal (Databricks on AWS) Extra (optional) Specify the extra parameter (as json dictionary) that can be used in the Databricks connection. @@ -70,9 +72,9 @@ Extra (optional) * ``token``: Specify PAT to use. Consider to switch to specification of PAT in the Password field as it's more secure. - Following parameters are necessary if using Service Principal with Oauth token: + Following parameters are necessary if using authentication with OAuth token for AWS Databricks Service Principal: - * ``service_principal_oauth``: Specify as 'true', and use Client ID and Client Secret as the Username and Password. + * ``service_principal_oauth``: required boolean flag. If specified as ``true``, use the Client ID and Client Secret as the Username and Password. See `Authentication using OAuth for service principals `_. Following parameters are necessary if using authentication with AAD token: