Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions pyrit/models/storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

import aiofiles
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
from azure.identity.aio import DefaultAzureCredential
from azure.storage.blob import ContentSettings
from azure.storage.blob.aio import ContainerClient as AsyncContainerClient

from pyrit.auth import AzureStorageAuth

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -183,24 +182,42 @@ def __init__(
self._container_url: str = container_url
self._sas_token = sas_token
self._client_async: AsyncContainerClient = None
self._credential: DefaultAzureCredential | None = None

async def _create_container_client_async(self) -> None:
"""
Create an asynchronous ContainerClient for Azure Storage.

If a SAS token is provided via the
AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used
for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication.
for authentication. Otherwise, DefaultAzureCredential will be used directly, which requires
the caller to have a data-plane role such as Storage Blob Data Contributor.
"""
sas_token = self._sas_token
if not self._sas_token:
logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.")
sas_token = await AzureStorageAuth.get_sas_token(self._container_url)
if self._sas_token:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: how come we don't need to make a similar change to AzureBlobStorageTarget?

self._client_async = AsyncContainerClient.from_container_url(
container_url=self._container_url,
credential=self._sas_token,
)
else:
logger.info("SAS token not provided. Using DefaultAzureCredential for direct Entra ID authentication.")
parsed_url = urlparse(self._container_url)
account_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
container_name = parsed_url.path.lstrip("/")
self._credential = DefaultAzureCredential()
self._client_async = AsyncContainerClient(
account_url=account_url,
container_name=container_name,
credential=self._credential,
)

self._client_async = AsyncContainerClient.from_container_url(
container_url=self._container_url,
credential=sas_token,
)
async def _close_client_async(self) -> None:
"""Close the container client and credential, resetting them to None."""
if self._client_async:
await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore]
self._client_async = None
if self._credential:
await self._credential.close()
self._credential = None

async def _upload_blob_async(self, file_name: str, data: bytes, content_type: str) -> None:
"""
Expand All @@ -225,10 +242,10 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st
except Exception as exc:
if isinstance(exc, ClientAuthenticationError):
logger.exception(
msg="Authentication failed. Please check that the container existence in the "
"Azure Storage Account and ensure the validity of the provided SAS token. If you "
"haven't set the SAS token as an environment variable use `az login` to "
"enable delegation-based SAS authentication to connect to the storage account"
msg="Authentication failed. Please check that the container exists in the "
"Azure Storage Account. If using a SAS token, ensure it is valid. Otherwise, "
"ensure you are logged in via `az login` and have a data-plane role such as "
"Storage Blob Data Contributor on the storage account."
)
raise
logger.exception(msg=f"An unexpected error occurred: {exc}")
Expand Down Expand Up @@ -324,8 +341,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes:
logger.exception(f"Failed to read file at {blob_name}: {exc}")
raise
finally:
await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore]
self._client_async = None
await self._close_client_async()

async def write_file(self, path: Union[Path, str], data: bytes) -> None:
"""
Expand All @@ -348,8 +364,7 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None:
logger.exception(f"Failed to write file at {blob_name}: {exc}")
raise
finally:
await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore]
self._client_async = None
await self._close_client_async()

async def path_exists(self, path: Union[Path, str]) -> bool:
"""
Expand All @@ -372,8 +387,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool:
except ResourceNotFoundError:
return False
finally:
await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore]
self._client_async = None
await self._close_client_async()

async def is_file(self, path: Union[Path, str]) -> bool:
"""
Expand All @@ -396,8 +410,7 @@ async def is_file(self, path: Union[Path, str]) -> bool:
except ResourceNotFoundError:
return False
finally:
await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore]
self._client_async = None
await self._close_client_async()

async def create_directory_if_not_exists(self, directory_path: Union[Path, str]) -> None:
"""
Expand Down
51 changes: 45 additions & 6 deletions tests/unit/models/test_storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,56 @@ async def test_azure_blob_storage_io_create_container_client_uses_explicit_sas_t

mock_container_client = AsyncMock()

with patch(
"pyrit.models.storage_io.AsyncContainerClient.from_container_url", return_value=mock_container_client
) as mock_from_container_url:
await azure_blob_storage_io._create_container_client_async()

mock_from_container_url.assert_called_once_with(container_url=container_url, credential=sas_token)
assert azure_blob_storage_io._client_async is mock_container_client
assert azure_blob_storage_io._credential is None


@pytest.mark.asyncio
async def test_azure_blob_storage_io_create_container_client_uses_default_credential_when_no_sas_token():
container_url = "https://youraccount.blob.core.windows.net/yourcontainer"
azure_blob_storage_io = AzureBlobStorageIO(container_url=container_url)

mock_container_client = AsyncMock()
mock_credential = AsyncMock()

with (
patch("pyrit.models.storage_io.AzureStorageAuth.get_sas_token", new_callable=AsyncMock) as mock_get_sas_token,
patch(
"pyrit.models.storage_io.AsyncContainerClient.from_container_url", return_value=mock_container_client
) as mock_from_container_url,
patch("pyrit.models.storage_io.DefaultAzureCredential", return_value=mock_credential) as mock_credential_cls,
patch("pyrit.models.storage_io.AsyncContainerClient", return_value=mock_container_client) as mock_container_cls,
):
await azure_blob_storage_io._create_container_client_async()

mock_get_sas_token.assert_not_awaited()
mock_from_container_url.assert_called_once_with(container_url=container_url, credential=sas_token)
mock_credential_cls.assert_called_once()
mock_container_cls.assert_called_once_with(
account_url="https://youraccount.blob.core.windows.net",
container_name="yourcontainer",
credential=mock_credential,
)
assert azure_blob_storage_io._client_async is mock_container_client
assert azure_blob_storage_io._credential is mock_credential


@pytest.mark.asyncio
async def test_azure_blob_storage_io_close_client_async_closes_credential_and_client():
container_url = "https://youraccount.blob.core.windows.net/yourcontainer"
azure_blob_storage_io = AzureBlobStorageIO(container_url=container_url)

mock_client = AsyncMock()
mock_credential = AsyncMock()
azure_blob_storage_io._client_async = mock_client
azure_blob_storage_io._credential = mock_credential

await azure_blob_storage_io._close_client_async()

mock_client.close.assert_called_once()
mock_credential.close.assert_called_once()
assert azure_blob_storage_io._client_async is None
assert azure_blob_storage_io._credential is None


@pytest.mark.asyncio
Expand Down
Loading