diff --git a/airflow/providers/microsoft/azure/hooks/base_azure.py b/airflow/providers/microsoft/azure/hooks/base_azure.py index 214cbd5f20b29..54130b3b2e69c 100644 --- a/airflow/providers/microsoft/azure/hooks/base_azure.py +++ b/airflow/providers/microsoft/azure/hooks/base_azure.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import AzureIdentityCredentialAdapter class AzureBaseHook(BaseHook): @@ -124,10 +125,17 @@ def get_conn(self) -> Any: self.log.info("Getting connection using a JSON config.") return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json) - self.log.info("Getting connection using specific credentials and subscription_id.") - return self.sdk_client( - credentials=ServicePrincipalCredentials( + credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter + if all([conn.login, conn.password, tenant]): + self.log.info("Getting connection using specific credentials and subscription_id.") + credentials = ServicePrincipalCredentials( client_id=conn.login, secret=conn.password, tenant=tenant - ), + ) + else: + self.log.info("Using DefaultAzureCredential as credential") + credentials = AzureIdentityCredentialAdapter() + + return self.sdk_client( + credentials=credentials, subscription_id=subscription_id, ) diff --git a/airflow/providers/microsoft/azure/utils.py b/airflow/providers/microsoft/azure/utils.py index 0a8edcf7c75e0..5afc2a48ca1b6 100644 --- a/airflow/providers/microsoft/azure/utils.py +++ b/airflow/providers/microsoft/azure/utils.py @@ -19,6 +19,12 @@ import warnings +from azure.core.pipeline import PipelineContext, PipelineRequest +from azure.core.pipeline.policies import BearerTokenCredentialPolicy +from azure.core.pipeline.transport import HttpRequest +from azure.identity import DefaultAzureCredential +from msrest.authentication import BasicTokenAuthentication + def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str): """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" @@ -43,3 +49,47 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str): if ret == "": return None return ret + + +class AzureIdentityCredentialAdapter(BasicTokenAuthentication): + """Adapt azure-identity credentials for backward compatibility. + + Adapt credentials from azure-identity to be compatible with SD + that needs msrestazure or azure.common.credentials + + Check https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig + """ + + def __init__(self, credential=None, resource_id="https://management.azure.com/.default", **kwargs): + """Adapt azure-identity credentials for backward compatibility. + + :param credential: Any azure-identity credential (DefaultAzureCredential by default) + :param str resource_id: The scope to use to get the token (default ARM) + """ + super().__init__(None) + if credential is None: + credential = DefaultAzureCredential() + self._policy = BearerTokenCredentialPolicy(credential, resource_id, **kwargs) + + def _make_request(self): + return PipelineRequest( + HttpRequest("AzureIdentityCredentialAdapter", "https://fakeurl"), PipelineContext(None) + ) + + def set_token(self): + """Ask the azure-core BearerTokenCredentialPolicy policy to get a token. + + Using the policy gives us for free the caching system of azure-core. + We could make this code simpler by using private method, but by definition + I can't assure they will be there forever, so mocking a fake call to the policy + to extract the token, using 100% public API. + """ + request = self._make_request() + self._policy.on_request(request) + # Read Authorization, and get the second part after Bearer + token = request.http_request.headers["Authorization"].split(" ", 1)[1] + self.token = {"access_token": token} + + def signed_session(self, azure_session=None): + self.set_token() + return super().signed_session(azure_session)