diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index e69306c07..6e1d25547 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -11,11 +11,10 @@ import aiofiles from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError +from azure.identity.aio import DefaultAzureCredential from azure.storage.blob import ContentSettings from azure.storage.blob.aio import ContainerClient as AsyncContainerClient -from pyrit.auth import AzureStorageAuth - logger = logging.getLogger(__name__) @@ -183,6 +182,7 @@ def __init__( self._container_url: str = container_url self._sas_token = sas_token self._client_async: AsyncContainerClient = None + self._credential: DefaultAzureCredential | None = None async def _create_container_client_async(self) -> None: """ @@ -190,17 +190,34 @@ async def _create_container_client_async(self) -> None: If a SAS token is provided via the AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used - for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. + for authentication. Otherwise, DefaultAzureCredential will be used directly, which requires + the caller to have a data-plane role such as Storage Blob Data Contributor. """ - sas_token = self._sas_token - if not self._sas_token: - logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.") - sas_token = await AzureStorageAuth.get_sas_token(self._container_url) + if self._sas_token: + self._client_async = AsyncContainerClient.from_container_url( + container_url=self._container_url, + credential=self._sas_token, + ) + else: + logger.info("SAS token not provided. Using DefaultAzureCredential for direct Entra ID authentication.") + parsed_url = urlparse(self._container_url) + account_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + container_name = parsed_url.path.lstrip("/") + self._credential = DefaultAzureCredential() + self._client_async = AsyncContainerClient( + account_url=account_url, + container_name=container_name, + credential=self._credential, + ) - self._client_async = AsyncContainerClient.from_container_url( - container_url=self._container_url, - credential=sas_token, - ) + async def _close_client_async(self) -> None: + """Close the container client and credential, resetting them to None.""" + if self._client_async: + await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore] + self._client_async = None + if self._credential: + await self._credential.close() + self._credential = None async def _upload_blob_async(self, file_name: str, data: bytes, content_type: str) -> None: """ @@ -225,10 +242,10 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st except Exception as exc: if isinstance(exc, ClientAuthenticationError): logger.exception( - msg="Authentication failed. Please check that the container existence in the " - "Azure Storage Account and ensure the validity of the provided SAS token. If you " - "haven't set the SAS token as an environment variable use `az login` to " - "enable delegation-based SAS authentication to connect to the storage account" + msg="Authentication failed. Please check that the container exists in the " + "Azure Storage Account. If using a SAS token, ensure it is valid. Otherwise, " + "ensure you are logged in via `az login` and have a data-plane role such as " + "Storage Blob Data Contributor on the storage account." ) raise logger.exception(msg=f"An unexpected error occurred: {exc}") @@ -324,8 +341,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: logger.exception(f"Failed to read file at {blob_name}: {exc}") raise finally: - await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore] - self._client_async = None + await self._close_client_async() async def write_file(self, path: Union[Path, str], data: bytes) -> None: """ @@ -348,8 +364,7 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: logger.exception(f"Failed to write file at {blob_name}: {exc}") raise finally: - await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore] - self._client_async = None + await self._close_client_async() async def path_exists(self, path: Union[Path, str]) -> bool: """ @@ -372,8 +387,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool: except ResourceNotFoundError: return False finally: - await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore] - self._client_async = None + await self._close_client_async() async def is_file(self, path: Union[Path, str]) -> bool: """ @@ -396,8 +410,7 @@ async def is_file(self, path: Union[Path, str]) -> bool: except ResourceNotFoundError: return False finally: - await self._client_async.close() # type: ignore[no-untyped-call, unused-ignore] - self._client_async = None + await self._close_client_async() async def create_directory_if_not_exists(self, directory_path: Union[Path, str]) -> None: """ diff --git a/tests/unit/models/test_storage_io.py b/tests/unit/models/test_storage_io.py index 0159d65b9..2a15db3dc 100644 --- a/tests/unit/models/test_storage_io.py +++ b/tests/unit/models/test_storage_io.py @@ -179,17 +179,56 @@ async def test_azure_blob_storage_io_create_container_client_uses_explicit_sas_t mock_container_client = AsyncMock() + with patch( + "pyrit.models.storage_io.AsyncContainerClient.from_container_url", return_value=mock_container_client + ) as mock_from_container_url: + await azure_blob_storage_io._create_container_client_async() + + mock_from_container_url.assert_called_once_with(container_url=container_url, credential=sas_token) + assert azure_blob_storage_io._client_async is mock_container_client + assert azure_blob_storage_io._credential is None + + +@pytest.mark.asyncio +async def test_azure_blob_storage_io_create_container_client_uses_default_credential_when_no_sas_token(): + container_url = "https://youraccount.blob.core.windows.net/yourcontainer" + azure_blob_storage_io = AzureBlobStorageIO(container_url=container_url) + + mock_container_client = AsyncMock() + mock_credential = AsyncMock() + with ( - patch("pyrit.models.storage_io.AzureStorageAuth.get_sas_token", new_callable=AsyncMock) as mock_get_sas_token, - patch( - "pyrit.models.storage_io.AsyncContainerClient.from_container_url", return_value=mock_container_client - ) as mock_from_container_url, + patch("pyrit.models.storage_io.DefaultAzureCredential", return_value=mock_credential) as mock_credential_cls, + patch("pyrit.models.storage_io.AsyncContainerClient", return_value=mock_container_client) as mock_container_cls, ): await azure_blob_storage_io._create_container_client_async() - mock_get_sas_token.assert_not_awaited() - mock_from_container_url.assert_called_once_with(container_url=container_url, credential=sas_token) + mock_credential_cls.assert_called_once() + mock_container_cls.assert_called_once_with( + account_url="https://youraccount.blob.core.windows.net", + container_name="yourcontainer", + credential=mock_credential, + ) assert azure_blob_storage_io._client_async is mock_container_client + assert azure_blob_storage_io._credential is mock_credential + + +@pytest.mark.asyncio +async def test_azure_blob_storage_io_close_client_async_closes_credential_and_client(): + container_url = "https://youraccount.blob.core.windows.net/yourcontainer" + azure_blob_storage_io = AzureBlobStorageIO(container_url=container_url) + + mock_client = AsyncMock() + mock_credential = AsyncMock() + azure_blob_storage_io._client_async = mock_client + azure_blob_storage_io._credential = mock_credential + + await azure_blob_storage_io._close_client_async() + + mock_client.close.assert_called_once() + mock_credential.close.assert_called_once() + assert azure_blob_storage_io._client_async is None + assert azure_blob_storage_io._credential is None @pytest.mark.asyncio