Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from .response_handlers import process_storage_error, PartialBatchErrorException

if TYPE_CHECKING:
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import HttpRequest
from azure.core.configuration import Configuration
_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +66,7 @@ async def close(self):
await self._client.close()

def _create_pipeline(self, credential, **kwargs):
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
# type: (Any, **Any) -> Tuple[Configuration, AsyncPipeline]
self._credential_policy = None
if hasattr(credential, 'get_token'):
self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import functools
import warnings
from typing import (
Any, Dict, List, Optional, Union,
Any, cast, Dict, List, Optional, Union, Tuple,
TYPE_CHECKING)

from azure.core.async_paging import AsyncItemPaged
Expand Down Expand Up @@ -37,7 +37,7 @@
from .._models import QueueProperties


class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin):
class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin): # type: ignore[misc]
"""A client to interact with a specific Queue.

:param str account_url:
Expand Down Expand Up @@ -94,14 +94,13 @@ def __init__(
super(QueueClient, self).__init__(
account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs
)
self._client = AzureQueueStorage(self.url, base_url=self.url,
pipeline=self._pipeline, loop=loop)
self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop)
self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access
self._loop = loop
self._configure_encryption(kwargs)

@distributed_trace_async
async def create_queue(
async def create_queue( # type: ignore[override]
self, *,
metadata: Optional[Dict[str, str]] = None,
**kwargs: Any
Expand Down Expand Up @@ -145,7 +144,7 @@ async def create_queue(
process_storage_error(error)

@distributed_trace_async
async def delete_queue(self, **kwargs: Any) -> None:
async def delete_queue(self, **kwargs: Any) -> None: # type: ignore[override]
"""Deletes the specified queue and any messages it contains.

When a queue is successfully deleted, it is immediately marked for deletion
Expand Down Expand Up @@ -180,7 +179,7 @@ async def delete_queue(self, **kwargs: Any) -> None:
process_storage_error(error)

@distributed_trace_async
async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties":
async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": # type: ignore[override]
"""Returns all user-defined metadata for the specified queue.

The data returned does not include the queue's list of messages.
Expand All @@ -201,16 +200,16 @@ async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties":
"""
timeout = kwargs.pop('timeout', None)
try:
response = await self._client.queue.get_properties(
response = cast("QueueProperties", await (self._client.queue.get_properties(
timeout=timeout, cls=deserialize_queue_properties, **kwargs
)
)))
except HttpResponseError as error:
process_storage_error(error)
response.name = self.queue_name
return response

@distributed_trace_async
async def set_queue_metadata(
async def set_queue_metadata( # type: ignore[override]
self, metadata: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> None:
Expand Down Expand Up @@ -242,14 +241,14 @@ async def set_queue_metadata(
headers = kwargs.pop("headers", {})
headers.update(add_metadata_headers(metadata))
try:
return await self._client.queue.set_metadata(
await self._client.queue.set_metadata(
timeout=timeout, headers=headers, cls=return_response_headers, **kwargs
)
except HttpResponseError as error:
process_storage_error(error)

@distributed_trace_async
async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]:
async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: # type: ignore[override]
"""Returns details about any stored access policies specified on the
queue that may be used with Shared Access Signatures.

Expand All @@ -264,15 +263,15 @@ async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy
"""
timeout = kwargs.pop('timeout', None)
try:
_, identifiers = await self._client.queue.get_access_policy(
_, identifiers = cast(Tuple[Dict, List], await self._client.queue.get_access_policy(
timeout=timeout, cls=return_headers_and_deserialized, **kwargs
)
))
except HttpResponseError as error:
process_storage_error(error)
return {s.id: s.access_policy or AccessPolicy() for s in identifiers}

@distributed_trace_async
async def set_queue_access_policy(
async def set_queue_access_policy( # type: ignore[override]
self, signed_identifiers: Dict[str, AccessPolicy],
**kwargs: Any
) -> None:
Expand Down Expand Up @@ -329,7 +328,7 @@ async def set_queue_access_policy(
process_storage_error(error)

@distributed_trace_async
async def send_message(
async def send_message( # type: ignore[override]
self, content: Any,
*,
visibility_timeout: Optional[int] = None,
Expand Down Expand Up @@ -425,7 +424,7 @@ async def send_message(
process_storage_error(error)

@distributed_trace_async
async def receive_message(
async def receive_message( # type: ignore[override]
self, *,
visibility_timeout: Optional[int] = None,
**kwargs: Any
Expand Down Expand Up @@ -487,7 +486,7 @@ async def receive_message(
process_storage_error(error)

@distributed_trace
def receive_messages(
def receive_messages( # type: ignore[override]
self, *,
messages_per_page: Optional[int] = None,
visibility_timeout: Optional[int] = None,
Expand Down Expand Up @@ -565,7 +564,7 @@ def receive_messages(
process_storage_error(error)

@distributed_trace_async
async def update_message(
async def update_message( # type: ignore[override]
self, message: Union[str, QueueMessage],
pop_receipt: Optional[str] = None,
content: Optional[Any] = None,
Expand Down Expand Up @@ -624,14 +623,16 @@ async def update_message(
:caption: Update a message.
"""
timeout = kwargs.pop('timeout', None)
try:

receipt: Optional[str]
if isinstance(message, QueueMessage):
message_id = message.id
message_text = content or message.content
receipt = pop_receipt or message.pop_receipt
inserted_on = message.inserted_on
expires_on = message.expires_on
dequeue_count = message.dequeue_count
except AttributeError:
else:
message_id = message
message_text = content
receipt = pop_receipt
Expand Down Expand Up @@ -666,15 +667,15 @@ async def update_message(
else:
updated = None
try:
response = await self._client.message_id.update(
response = cast(QueueMessage, await self._client.message_id.update(
queue_message=updated,
visibilitytimeout=visibility_timeout or 0,
timeout=timeout,
pop_receipt=receipt,
cls=return_response_headers,
queue_message_id=message_id,
**kwargs
)
))
new_message = QueueMessage(content=message_text)
new_message.id = message_id
new_message.inserted_on = inserted_on
Expand All @@ -687,7 +688,7 @@ async def update_message(
process_storage_error(error)

@distributed_trace_async
async def peek_messages(
async def peek_messages( # type: ignore[override]
self, max_messages: Optional[int] = None,
**kwargs: Any
) -> List[QueueMessage]:
Expand Down Expand Up @@ -750,7 +751,7 @@ async def peek_messages(
process_storage_error(error)

@distributed_trace_async
async def clear_messages(self, **kwargs: Any) -> None:
async def clear_messages(self, **kwargs: Any) -> None: # type: ignore[override]
"""Deletes all messages from the specified queue.

:keyword int timeout:
Expand All @@ -776,7 +777,7 @@ async def clear_messages(self, **kwargs: Any) -> None:
process_storage_error(error)

@distributed_trace_async
async def delete_message(
async def delete_message( # type: ignore[override]
self, message: Union[str, QueueMessage],
pop_receipt: Optional[str] = None,
**kwargs: Any
Expand Down Expand Up @@ -816,10 +817,12 @@ async def delete_message(
:caption: Delete a message.
"""
timeout = kwargs.pop('timeout', None)
try:

receipt: Optional[str]
if isinstance(message, QueueMessage):
message_id = message.id
receipt = pop_receipt or message.pop_receipt
except AttributeError:
else:
message_id = message
receipt = pop_receipt

Expand Down