Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions airflow/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_connection_form_widgets() -> dict[str, Any]:
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
"hidden_fields": ["schema", "port"],
"relabeling": {
"login": "Blob Storage Login (optional)",
"password": "Blob Storage Key (optional)",
Expand All @@ -140,6 +140,7 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"tenant_id": "tenant",
"shared_access_key": "shared access key",
"sas_token": "account url or token",
"extra": "additional options for use with ClientSecretCredential or DefaultAzureCredential",
},
}

Expand Down Expand Up @@ -176,22 +177,11 @@ def get_conn(self) -> BlobServiceClient:
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
return BlobServiceClient(account_url=conn.host, **extra)

connection_string = self._get_field(extra, "connection_string")
if connection_string:
# connection_string auth takes priority
return BlobServiceClient.from_connection_string(connection_string, **extra)

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra)

tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
Expand All @@ -200,22 +190,33 @@ def get_conn(self) -> BlobServiceClient:
token_credential = ClientSecretCredential(tenant, app_id, app_secret, **client_secret_auth_config)
return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra)

account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
return BlobServiceClient(account_url=account_url, **extra)

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
return BlobServiceClient(account_url=account_url, credential=shared_access_key, **extra)

sas_token = self._get_field(extra, "sas_token")
if sas_token:
if sas_token.startswith("https"):
return BlobServiceClient(account_url=sas_token, **extra)
else:
return BlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra
)
return BlobServiceClient(account_url=f"{account_url}/{sas_token}", **extra)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will give incorrect url if account_url ends with a slash e.g:
account_url = https://{conn.login}.blob.core.windows.net/
then here we will end up with:
account_url = https://{conn.login}.blob.core.windows.net//sas_token

Copy link
Contributor Author

@Adaverse Adaverse Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will correct it with patch


# Fall back to old auth (password) or use managed identity if not provided.
credential = conn.password
if not credential:
credential = DefaultAzureCredential()
self.log.info("Using DefaultAzureCredential as credential")
return BlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/",
account_url=account_url,
credential=credential,
**extra,
)
Expand Down Expand Up @@ -545,13 +546,6 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
extra = conn.extra_dejson or {}
client_secret_auth_config = extra.pop("client_secret_auth_config", {})

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
self.blob_service_client = AsyncBlobServiceClient(account_url=conn.host, **extra)
return self.blob_service_client

connection_string = self._get_field(extra, "connection_string")
if connection_string:
# connection_string auth takes priority
Expand All @@ -560,14 +554,6 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
)
return self.blob_service_client

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
self.blob_service_client = AsyncBlobServiceClient(
account_url=conn.host, credential=shared_access_key, **extra
)
return self.blob_service_client

tenant = self._get_field(extra, "tenant_id")
if tenant:
# use Active Directory auth
Expand All @@ -581,13 +567,30 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
)
return self.blob_service_client

account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"

if self.public_read:
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
self.blob_service_client = AsyncBlobServiceClient(account_url=account_url, **extra)
return self.blob_service_client

shared_access_key = self._get_field(extra, "shared_access_key")
if shared_access_key:
# using shared access key
self.blob_service_client = AsyncBlobServiceClient(
account_url=account_url, credential=shared_access_key, **extra
)
return self.blob_service_client

sas_token = self._get_field(extra, "sas_token")
if sas_token:
if sas_token.startswith("https"):
self.blob_service_client = AsyncBlobServiceClient(account_url=sas_token, **extra)
else:
self.blob_service_client = AsyncBlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/{sas_token}", **extra
account_url=f"{account_url}/{sas_token}", **extra
)
return self.blob_service_client

Expand All @@ -597,7 +600,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
credential = AsyncDefaultAzureCredential()
self.log.info("Using DefaultAzureCredential as credential")
self.blob_service_client = AsyncBlobServiceClient(
account_url=f"https://{conn.login}.blob.core.windows.net/",
account_url=account_url,
credential=credential,
**extra,
)
Expand Down
22 changes: 15 additions & 7 deletions docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,31 @@ Configuring the Connection
--------------------------

Login (optional)
Specify the login used for azure blob storage. For use with Shared Key Credential and SAS Token authentication.
Specify the login used for Azure Blob Storage. Strictly needed for Active Directory (token) authentication as Service principle credential. Optional for the rest if host (account url) is specified.

Password (optional)
Specify the password used for azure blob storage. For use with
Specify the password used for Azure Blob Storage. For use with
Active Directory (token credential) and shared key authentication.

Host (optional)
Specify the account url for anonymous public read, Active Directory, shared access key authentication.
Specify the account url for Azure Blob Storage. Strictly needed for Active Directory (token) authentication as Service principle credential. Optional for the rest if login (account name) is specified.

Blob Storage Connection String (optional)
Connection string for use with connection string authentication.

Blob Storage Shared Access Key (optional)
Specify the shared access key. Needed only for shared access key authentication.

SAS Token (optional)
SAS Token for use with SAS Token authentication.

Tenant Id (Active Directory Auth) (optional)
Specify the tenant to use. Required only for Active Directory (token) authentication.

Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in Azure connection.
The following parameters are all optional:

* ``tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication.
* ``shared_access_key``: Specify the shared access key. Needed for shared access key authentication.
* ``connection_string``: Connection string for use with connection string authentication.
* ``sas_token``: SAS Token for use with SAS Token authentication.
* ``client_secret_auth_config``: Extra config to pass while authenticating as a service principal using `ClientSecretCredential <https://learn.microsoft.com/en-in/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python>`_

When specifying the connection in environment variable you should specify
Expand Down
53 changes: 51 additions & 2 deletions tests/providers/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@

class TestWasbHook:
def setup_method(self):
db.merge_conn(Connection(conn_id="wasb_test_key", conn_type="wasb", login="login", password="key"))
self.login = "login"
self.wasb_test_key = "wasb_test_key"
self.connection_type = "wasb"
self.connection_string_id = "azure_test_connection_string"
self.shared_key_conn_id = "azure_shared_key_test"
self.shared_key_conn_id_without_host = "azure_shared_key_test_wihout_host"
self.ad_conn_id = "azure_AD_test"
self.sas_conn_id = "sas_token_id"
self.extra__wasb__sas_conn_id = "extra__sas_token_id"
self.http_sas_conn_id = "http_sas_token_id"
self.extra__wasb__http_sas_conn_id = "extra__http_sas_token_id"
self.public_read_conn_id = "pub_read_id"
self.public_read_conn_id_without_host = "pub_read_id_without_host"
self.managed_identity_conn_id = "managed_identity"
self.authority = "https://test_authority.com"

Expand All @@ -60,6 +63,14 @@ def setup_method(self):
"authority": self.authority,
}

db.merge_conn(
Connection(
conn_id=self.wasb_test_key,
conn_type=self.connection_type,
login=self.login,
password="key",
)
)
db.merge_conn(
Connection(
conn_id=self.public_read_conn_id,
Expand All @@ -68,7 +79,14 @@ def setup_method(self):
extra=json.dumps({"proxies": self.proxies}),
)
)

db.merge_conn(
Connection(
conn_id=self.public_read_conn_id_without_host,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.connection_string_id,
Expand All @@ -84,6 +102,14 @@ def setup_method(self):
extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.shared_key_conn_id_without_host,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"shared_access_key": "token", "proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.ad_conn_id,
Expand Down Expand Up @@ -111,13 +137,15 @@ def setup_method(self):
Connection(
conn_id=self.sas_conn_id,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"sas_token": "token", "proxies": self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.extra__wasb__sas_conn_id,
conn_type=self.connection_type,
login=self.login,
extra=json.dumps({"extra__wasb__sas_token": "token", "proxies": self.proxies}),
)
)
Expand Down Expand Up @@ -171,6 +199,23 @@ def test_azure_directory_connection(self):
assert isinstance(hook.get_conn(), BlobServiceClient)
assert isinstance(hook.get_conn().credential, ClientSecretCredential)

@pytest.mark.parametrize(
argnames="conn_id_str",
argvalues=[
"wasb_test_key",
"shared_key_conn_id_without_host",
"public_read_conn_id_without_host",
],
)
def test_account_url_without_host(self, conn_id_str):
conn_id = self.__getattribute__(conn_id_str)
hook = WasbHook(wasb_conn_id=conn_id)
hook_conn = hook.get_connection(hook.conn_id)
conn = hook.get_conn()
assert conn.url.startswith("https://")
assert conn.url.__contains__(hook_conn.login)
assert conn.url.endswith(".blob.core.windows.net/")

@pytest.mark.parametrize(
argnames="conn_id_str, extra_key",
argvalues=[
Expand All @@ -187,6 +232,9 @@ def test_sas_token_connection(self, conn_id_str, extra_key):
hook_conn = hook.get_connection(hook.conn_id)
sas_token = hook_conn.extra_dejson[extra_key]
assert isinstance(conn, BlobServiceClient)
assert conn.url.startswith("https://")
if hook_conn.login:
assert conn.url.__contains__(hook_conn.login)
assert conn.url.endswith(sas_token + "/")

@pytest.mark.parametrize(
Expand Down Expand Up @@ -459,4 +507,5 @@ def test___ensure_prefixes(self):
"extra__wasb__tenant_id",
"extra__wasb__shared_access_key",
"extra__wasb__sas_token",
"extra",
]