diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index f6dc3a9c747c..11fef6ea29b8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -39,7 +39,6 @@ from .response_handlers import process_storage_error, PartialBatchErrorException if TYPE_CHECKING: - from azure.core.pipeline import Pipeline from azure.core.pipeline.transport import HttpRequest from azure.core.configuration import Configuration _LOGGER = logging.getLogger(__name__) @@ -67,7 +66,7 @@ async def close(self): await self._client.close() def _create_pipeline(self, credential, **kwargs): - # type: (Any, **Any) -> Tuple[Configuration, Pipeline] + # type: (Any, **Any) -> Tuple[Configuration, AsyncPipeline] self._credential_policy = None if hasattr(credential, 'get_token'): self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index e5f87365399f..840b2c89ffd5 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -8,7 +8,7 @@ import functools import warnings from typing import ( - Any, Dict, List, Optional, Union, + Any, cast, Dict, List, Optional, Union, Tuple, TYPE_CHECKING) from azure.core.async_paging import AsyncItemPaged @@ -37,7 +37,7 @@ from .._models import QueueProperties -class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin): +class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin): # type: ignore[misc] """A client to interact with a specific Queue. :param str account_url: @@ -94,14 +94,13 @@ def __init__( super(QueueClient, self).__init__( account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs ) - self._client = AzureQueueStorage(self.url, base_url=self.url, - pipeline=self._pipeline, loop=loop) + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) @distributed_trace_async - async def create_queue( + async def create_queue( # type: ignore[override] self, *, metadata: Optional[Dict[str, str]] = None, **kwargs: Any @@ -145,7 +144,7 @@ async def create_queue( process_storage_error(error) @distributed_trace_async - async def delete_queue(self, **kwargs: Any) -> None: + async def delete_queue(self, **kwargs: Any) -> None: # type: ignore[override] """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -180,7 +179,7 @@ async def delete_queue(self, **kwargs: Any) -> None: process_storage_error(error) @distributed_trace_async - async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": + async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": # type: ignore[override] """Returns all user-defined metadata for the specified queue. The data returned does not include the queue's list of messages. @@ -201,16 +200,16 @@ async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """ timeout = kwargs.pop('timeout', None) try: - response = await self._client.queue.get_properties( + response = cast("QueueProperties", await (self._client.queue.get_properties( timeout=timeout, cls=deserialize_queue_properties, **kwargs - ) + ))) except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name return response @distributed_trace_async - async def set_queue_metadata( + async def set_queue_metadata( # type: ignore[override] self, metadata: Optional[Dict[str, str]] = None, **kwargs: Any ) -> None: @@ -242,14 +241,14 @@ async def set_queue_metadata( headers = kwargs.pop("headers", {}) headers.update(add_metadata_headers(metadata)) try: - return await self._client.queue.set_metadata( + await self._client.queue.set_metadata( timeout=timeout, headers=headers, cls=return_response_headers, **kwargs ) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: + async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: # type: ignore[override] """Returns details about any stored access policies specified on the queue that may be used with Shared Access Signatures. @@ -264,15 +263,15 @@ async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy """ timeout = kwargs.pop('timeout', None) try: - _, identifiers = await self._client.queue.get_access_policy( + _, identifiers = cast(Tuple[Dict, List], await self._client.queue.get_access_policy( timeout=timeout, cls=return_headers_and_deserialized, **kwargs - ) + )) except HttpResponseError as error: process_storage_error(error) return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace_async - async def set_queue_access_policy( + async def set_queue_access_policy( # type: ignore[override] self, signed_identifiers: Dict[str, AccessPolicy], **kwargs: Any ) -> None: @@ -329,7 +328,7 @@ async def set_queue_access_policy( process_storage_error(error) @distributed_trace_async - async def send_message( + async def send_message( # type: ignore[override] self, content: Any, *, visibility_timeout: Optional[int] = None, @@ -425,7 +424,7 @@ async def send_message( process_storage_error(error) @distributed_trace_async - async def receive_message( + async def receive_message( # type: ignore[override] self, *, visibility_timeout: Optional[int] = None, **kwargs: Any @@ -487,7 +486,7 @@ async def receive_message( process_storage_error(error) @distributed_trace - def receive_messages( + def receive_messages( # type: ignore[override] self, *, messages_per_page: Optional[int] = None, visibility_timeout: Optional[int] = None, @@ -565,7 +564,7 @@ def receive_messages( process_storage_error(error) @distributed_trace_async - async def update_message( + async def update_message( # type: ignore[override] self, message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, content: Optional[Any] = None, @@ -624,14 +623,16 @@ async def update_message( :caption: Update a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id message_text = content or message.content receipt = pop_receipt or message.pop_receipt inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - except AttributeError: + else: message_id = message message_text = content receipt = pop_receipt @@ -666,7 +667,7 @@ async def update_message( else: updated = None try: - response = await self._client.message_id.update( + response = cast(QueueMessage, await self._client.message_id.update( queue_message=updated, visibilitytimeout=visibility_timeout or 0, timeout=timeout, @@ -674,7 +675,7 @@ async def update_message( cls=return_response_headers, queue_message_id=message_id, **kwargs - ) + )) new_message = QueueMessage(content=message_text) new_message.id = message_id new_message.inserted_on = inserted_on @@ -687,7 +688,7 @@ async def update_message( process_storage_error(error) @distributed_trace_async - async def peek_messages( + async def peek_messages( # type: ignore[override] self, max_messages: Optional[int] = None, **kwargs: Any ) -> List[QueueMessage]: @@ -750,7 +751,7 @@ async def peek_messages( process_storage_error(error) @distributed_trace_async - async def clear_messages(self, **kwargs: Any) -> None: + async def clear_messages(self, **kwargs: Any) -> None: # type: ignore[override] """Deletes all messages from the specified queue. :keyword int timeout: @@ -776,7 +777,7 @@ async def clear_messages(self, **kwargs: Any) -> None: process_storage_error(error) @distributed_trace_async - async def delete_message( + async def delete_message( # type: ignore[override] self, message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, **kwargs: Any @@ -816,10 +817,12 @@ async def delete_message( :caption: Delete a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id receipt = pop_receipt or message.pop_receipt - except AttributeError: + else: message_id = message receipt = pop_receipt