Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions airflow/providers/microsoft/azure/hooks/base_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import AzureIdentityCredentialAdapter


class AzureBaseHook(BaseHook):
Expand Down Expand Up @@ -124,10 +125,17 @@ def get_conn(self) -> Any:
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)

self.log.info("Getting connection using specific credentials and subscription_id.")
return self.sdk_client(
credentials=ServicePrincipalCredentials(
credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter
if all([conn.login, conn.password, tenant]):
self.log.info("Getting connection using specific credentials and subscription_id.")
credentials = ServicePrincipalCredentials(
client_id=conn.login, secret=conn.password, tenant=tenant
),
)
else:
self.log.info("Using DefaultAzureCredential as credential")
credentials = AzureIdentityCredentialAdapter()

return self.sdk_client(
credentials=credentials,
subscription_id=subscription_id,
)
50 changes: 50 additions & 0 deletions airflow/providers/microsoft/azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

import warnings

from azure.core.pipeline import PipelineContext, PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import HttpRequest
from azure.identity import DefaultAzureCredential
from msrest.authentication import BasicTokenAuthentication


def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
Expand All @@ -43,3 +49,47 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
if ret == "":
return None
return ret


class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
"""Adapt azure-identity credentials for backward compatibility.

Adapt credentials from azure-identity to be compatible with SD
that needs msrestazure or azure.common.credentials

Check https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
"""

def __init__(self, credential=None, resource_id="https://management.azure.com/.default", **kwargs):
"""Adapt azure-identity credentials for backward compatibility.

:param credential: Any azure-identity credential (DefaultAzureCredential by default)
:param str resource_id: The scope to use to get the token (default ARM)
"""
super().__init__(None)
if credential is None:
credential = DefaultAzureCredential()
self._policy = BearerTokenCredentialPolicy(credential, resource_id, **kwargs)

def _make_request(self):
return PipelineRequest(
HttpRequest("AzureIdentityCredentialAdapter", "https://fakeurl"), PipelineContext(None)
)

def set_token(self):
"""Ask the azure-core BearerTokenCredentialPolicy policy to get a token.

Using the policy gives us for free the caching system of azure-core.
We could make this code simpler by using private method, but by definition
I can't assure they will be there forever, so mocking a fake call to the policy
to extract the token, using 100% public API.
"""
request = self._make_request()
self._policy.on_request(request)
# Read Authorization, and get the second part after Bearer
token = request.http_request.headers["Authorization"].split(" ", 1)[1]
self.token = {"access_token": token}

def signed_session(self, azure_session=None):
self.set_token()
return super().signed_session(azure_session)