From b32d53510d64127805031d44f29500b6ab5d2a14 Mon Sep 17 00:00:00 2001 From: vincenttran-msft <101599632+vincenttran-msft@users.noreply.github.com> Date: Fri, 24 Feb 2023 16:34:47 -0800 Subject: [PATCH 01/71] [Storage] Update typing for `queue_service_client` (Queue) (#28949) * qsc done sync&async * Remove #type ignore, fix dict() --- .../storage/queue/_queue_service_client.py | 68 ++++++++---------- .../queue/aio/_queue_service_client_async.py | 70 +++++++++---------- 2 files changed, 65 insertions(+), 73 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 6b5422756a52..22990c6ef1b3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -35,8 +35,8 @@ from ._models import ( CorsRule, Metrics, - QueueProperties, QueueAnalyticsLogging, + QueueProperties ) @@ -111,7 +111,7 @@ def __init__( self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._configure_encryption(kwargs) - def _format_url(self, hostname): + def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location mode hostname. """ @@ -155,8 +155,7 @@ def from_connection_string( return cls(account_url, credential=credential, **kwargs) @distributed_trace - def get_service_stats(self, **kwargs): - # type: (Any) -> Dict[str, Any] + def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: """Retrieves statistics related to replication for the Queue service. It is only available when read-access geo-redundant replication is enabled for @@ -182,15 +181,14 @@ def get_service_stats(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - stats = self._client.service.get_statistics( # type: ignore + stats = self._client.service.get_statistics( timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) return service_stats_deserialize(stats) except HttpResponseError as error: process_storage_error(error) @distributed_trace - def get_service_properties(self, **kwargs): - # type: (Any) -> Dict[str, Any] + def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: """Gets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -211,20 +209,19 @@ def get_service_properties(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - service_props = self._client.service.get_properties(timeout=timeout, **kwargs) # type: ignore + service_props = self._client.service.get_properties(timeout=timeout, **kwargs) return service_properties_deserialize(service_props) except HttpResponseError as error: process_storage_error(error) @distributed_trace - def set_service_properties( # type: ignore - self, analytics_logging=None, # type: Optional[QueueAnalyticsLogging] - hour_metrics=None, # type: Optional[Metrics] - minute_metrics=None, # type: Optional[Metrics] - cors=None, # type: Optional[List[CorsRule]] - **kwargs # type: Any - ): - # type: (...) -> None + def set_service_properties( + self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + hour_metrics: Optional["Metrics"] = None, + minute_metrics: Optional["Metrics"] = None, + cors: Optional[List["CorsRule"]] = None, + **kwargs: Any + ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -268,17 +265,16 @@ def set_service_properties( # type: ignore cors=cors ) try: - return self._client.service.set_properties(props, timeout=timeout, **kwargs) # type: ignore + return self._client.service.set_properties(props, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) @distributed_trace def list_queues( - self, name_starts_with=None, # type: Optional[str] - include_metadata=False, # type: Optional[bool] - **kwargs # type: Any - ): - # type: (...) -> ItemPaged[QueueProperties] + self, name_starts_with: Optional[str] = None, + include_metadata: Optional[bool] = False, + **kwargs: Any + ) -> ItemPaged["QueueProperties"]: """Returns a generator to list the queues under the specified account. The generator will lazily follow the continuation tokens returned by @@ -328,11 +324,10 @@ def list_queues( @distributed_trace def create_queue( - self, name, # type: str - metadata=None, # type: Optional[Dict[str, str]] - **kwargs # type: Any - ): - # type: (...) -> QueueClient + self, name: str, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> QueueClient: """Creates a new queue under the specified account. If a queue with the same name already exists, the operation fails. @@ -342,7 +337,7 @@ def create_queue( :param metadata: A dict with name_value pairs to associate with the queue as metadata. Example: {'Category': 'test'} - :type metadata: dict(str, str) + :type metadata: Dict[str, str] :keyword int timeout: The timeout parameter is expressed in seconds. :rtype: ~azure.storage.queue.QueueClient @@ -366,10 +361,9 @@ def create_queue( @distributed_trace def delete_queue( self, - queue, # type: Union[QueueProperties, str] - **kwargs # type: Any - ): - # type: (...) -> None + queue: Union["QueueProperties", str], + **kwargs: Any + ) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -402,11 +396,11 @@ def delete_queue( kwargs.setdefault('merge_span', True) queue_client.delete_queue(timeout=timeout, **kwargs) - def get_queue_client(self, - queue, # type: Union[QueueProperties, str] - **kwargs # type: Any - ): - # type: (...) -> QueueClient + def get_queue_client( + self, + queue: Union["QueueProperties", str], + **kwargs: Any + ) -> QueueClient: """Get a client to interact with the specified queue. The queue need not already exist. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 191a15f40969..db84771a840f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -36,8 +36,8 @@ from .._models import ( CorsRule, Metrics, - QueueProperties, QueueAnalyticsLogging, + QueueProperties, ) @@ -92,19 +92,18 @@ def __init__( ) -> None: kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) - super(QueueServiceClient, self).__init__( # type: ignore + super(QueueServiceClient, self).__init__( account_url, credential=credential, loop=loop, **kwargs) - self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) # type: ignore + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) @distributed_trace_async - async def get_service_stats(self, **kwargs): - # type: (Optional[Any]) -> Dict[str, Any] + async def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: """Retrieves statistics related to replication for the Queue service. It is only available when read-access geo-redundant replication is enabled for @@ -130,15 +129,14 @@ async def get_service_stats(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - stats = await self._client.service.get_statistics( # type: ignore + stats = await self._client.service.get_statistics( timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) return service_stats_deserialize(stats) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def get_service_properties(self, **kwargs): - # type: (Optional[Any]) -> Dict[str, Any] + async def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: """Gets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -159,20 +157,19 @@ async def get_service_properties(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - service_props = await self._client.service.get_properties(timeout=timeout, **kwargs) # type: ignore + service_props = await self._client.service.get_properties(timeout=timeout, **kwargs) return service_properties_deserialize(service_props) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def set_service_properties( # type: ignore - self, analytics_logging=None, # type: Optional[QueueAnalyticsLogging] - hour_metrics=None, # type: Optional[Metrics] - minute_metrics=None, # type: Optional[Metrics] - cors=None, # type: Optional[List[CorsRule]] - **kwargs - ): - # type: (...) -> None + async def set_service_properties( + self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + hour_metrics: Optional["Metrics"] = None, + minute_metrics: Optional["Metrics"] = None, + cors: Optional[List["CorsRule"]] = None, + **kwargs: Any + ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -216,16 +213,16 @@ async def set_service_properties( # type: ignore cors=cors ) try: - return await self._client.service.set_properties(props, timeout=timeout, **kwargs) # type: ignore + return await self._client.service.set_properties(props, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) @distributed_trace def list_queues( - self, name_starts_with=None, # type: Optional[str] - include_metadata=False, # type: Optional[bool] - **kwargs - ): # type: (...) -> AsyncItemPaged + self, name_starts_with: Optional[str] = None, + include_metadata: Optional[bool] = False, + **kwargs: Any + ) -> AsyncItemPaged: """Returns a generator to list the queues under the specified account. The generator will lazily follow the continuation tokens returned by @@ -274,12 +271,11 @@ def list_queues( ) @distributed_trace_async - async def create_queue( # type: ignore - self, name, # type: str - metadata=None, # type: Optional[Dict[str, str]] - **kwargs - ): - # type: (...) -> QueueClient + async def create_queue( + self, name: str, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> QueueClient: """Creates a new queue under the specified account. If a queue with the same name already exists, the operation fails. @@ -289,7 +285,7 @@ async def create_queue( # type: ignore :param metadata: A dict with name_value pairs to associate with the queue as metadata. Example: {'Category': 'test'} - :type metadata: dict(str, str) + :type metadata: Dict[str, str] :keyword int timeout: The timeout parameter is expressed in seconds. :rtype: ~azure.storage.queue.aio.QueueClient @@ -311,11 +307,10 @@ async def create_queue( # type: ignore return queue @distributed_trace_async - async def delete_queue( # type: ignore - self, queue, # type: Union[QueueProperties, str] - **kwargs - ): - # type: (...) -> None + async def delete_queue( + self, queue: Union["QueueProperties", str], + **kwargs: Any + ) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -348,8 +343,11 @@ async def delete_queue( # type: ignore kwargs.setdefault('merge_span', True) await queue_client.delete_queue(timeout=timeout, **kwargs) - def get_queue_client(self, queue, **kwargs): - # type: (Union[QueueProperties, str], Optional[Any]) -> QueueClient + def get_queue_client( + self, + queue: Union["QueueProperties", str], + **kwargs: Any + ) -> QueueClient: """Get a client to interact with the specified queue. The queue need not already exist. From 7090c88bbb0d47dd6e4906005162c5ed04a18a8c Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Thu, 16 Mar 2023 11:48:49 -0700 Subject: [PATCH 02/71] Fix indents --- .../azure/storage/queue/_queue_service_client.py | 14 +++++++------- .../queue/aio/_queue_service_client_async.py | 12 ++++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 22990c6ef1b3..3235fade4927 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -92,7 +92,7 @@ def __init__( self, account_url: str, credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any - ) -> None: + ) -> None: try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url @@ -122,7 +122,7 @@ def from_connection_string( cls, conn_str: str, credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any - ) -> Self: + ) -> Self: """Create QueueServiceClient from a Connection String. :param str conn_str: @@ -221,7 +221,7 @@ def set_service_properties( minute_metrics: Optional["Metrics"] = None, cors: Optional[List["CorsRule"]] = None, **kwargs: Any - ) -> None: + ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -274,7 +274,7 @@ def list_queues( self, name_starts_with: Optional[str] = None, include_metadata: Optional[bool] = False, **kwargs: Any - ) -> ItemPaged["QueueProperties"]: + ) -> ItemPaged["QueueProperties"]: """Returns a generator to list the queues under the specified account. The generator will lazily follow the continuation tokens returned by @@ -327,7 +327,7 @@ def create_queue( self, name: str, metadata: Optional[Dict[str, str]] = None, **kwargs: Any - ) -> QueueClient: + ) -> QueueClient: """Creates a new queue under the specified account. If a queue with the same name already exists, the operation fails. @@ -363,7 +363,7 @@ def delete_queue( self, queue: Union["QueueProperties", str], **kwargs: Any - ) -> None: + ) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -400,7 +400,7 @@ def get_queue_client( self, queue: Union["QueueProperties", str], **kwargs: Any - ) -> QueueClient: + ) -> QueueClient: """Get a client to interact with the specified queue. The queue need not already exist. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index db84771a840f..37ef459eb27a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -89,7 +89,7 @@ def __init__( self, account_url: str, credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any - ) -> None: + ) -> None: kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) super(QueueServiceClient, self).__init__( @@ -169,7 +169,7 @@ async def set_service_properties( minute_metrics: Optional["Metrics"] = None, cors: Optional[List["CorsRule"]] = None, **kwargs: Any - ) -> None: + ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -222,7 +222,7 @@ def list_queues( self, name_starts_with: Optional[str] = None, include_metadata: Optional[bool] = False, **kwargs: Any - ) -> AsyncItemPaged: + ) -> AsyncItemPaged: """Returns a generator to list the queues under the specified account. The generator will lazily follow the continuation tokens returned by @@ -275,7 +275,7 @@ async def create_queue( self, name: str, metadata: Optional[Dict[str, str]] = None, **kwargs: Any - ) -> QueueClient: + ) -> QueueClient: """Creates a new queue under the specified account. If a queue with the same name already exists, the operation fails. @@ -310,7 +310,7 @@ async def create_queue( async def delete_queue( self, queue: Union["QueueProperties", str], **kwargs: Any - ) -> None: + ) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -347,7 +347,7 @@ def get_queue_client( self, queue: Union["QueueProperties", str], **kwargs: Any - ) -> QueueClient: + ) -> QueueClient: """Get a client to interact with the specified queue. The queue need not already exist. From feb1fcd78e0b2e506f8134f3749288844699231d Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Thu, 16 Mar 2023 16:30:13 -0700 Subject: [PATCH 03/71] Fix indent again :D --- .../storage/queue/_queue_service_client.py | 46 +++++++++---------- .../queue/aio/_queue_service_client_async.py | 38 +++++++-------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 3235fade4927..52e977af86a1 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -89,9 +89,9 @@ class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): """ def __init__( - self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, account_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any ) -> None: try: if not account_url.lower().startswith('http'): @@ -119,9 +119,9 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_connection_string( - cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any + cls, conn_str: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any ) -> Self: """Create QueueServiceClient from a Connection String. @@ -216,11 +216,11 @@ def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: @distributed_trace def set_service_properties( - self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, - hour_metrics: Optional["Metrics"] = None, - minute_metrics: Optional["Metrics"] = None, - cors: Optional[List["CorsRule"]] = None, - **kwargs: Any + self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + hour_metrics: Optional["Metrics"] = None, + minute_metrics: Optional["Metrics"] = None, + cors: Optional[List["CorsRule"]] = None, + **kwargs: Any ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -271,9 +271,9 @@ def set_service_properties( @distributed_trace def list_queues( - self, name_starts_with: Optional[str] = None, - include_metadata: Optional[bool] = False, - **kwargs: Any + self, name_starts_with: Optional[str] = None, + include_metadata: Optional[bool] = False, + **kwargs: Any ) -> ItemPaged["QueueProperties"]: """Returns a generator to list the queues under the specified account. @@ -324,9 +324,9 @@ def list_queues( @distributed_trace def create_queue( - self, name: str, - metadata: Optional[Dict[str, str]] = None, - **kwargs: Any + self, name: str, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any ) -> QueueClient: """Creates a new queue under the specified account. @@ -360,9 +360,9 @@ def create_queue( @distributed_trace def delete_queue( - self, - queue: Union["QueueProperties", str], - **kwargs: Any + self, + queue: Union["QueueProperties", str], + **kwargs: Any ) -> None: """Deletes the specified queue and any messages it contains. @@ -397,9 +397,9 @@ def delete_queue( queue_client.delete_queue(timeout=timeout, **kwargs) def get_queue_client( - self, - queue: Union["QueueProperties", str], - **kwargs: Any + self, + queue: Union["QueueProperties", str], + **kwargs: Any ) -> QueueClient: """Get a client to interact with the specified queue. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 37ef459eb27a..22bac2bdcf92 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -86,9 +86,9 @@ class QueueServiceClient(AsyncStorageAccountHostsMixin, QueueServiceClientBase, """ def __init__( - self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, account_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any ) -> None: kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) @@ -164,11 +164,11 @@ async def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: @distributed_trace_async async def set_service_properties( - self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, - hour_metrics: Optional["Metrics"] = None, - minute_metrics: Optional["Metrics"] = None, - cors: Optional[List["CorsRule"]] = None, - **kwargs: Any + self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + hour_metrics: Optional["Metrics"] = None, + minute_metrics: Optional["Metrics"] = None, + cors: Optional[List["CorsRule"]] = None, + **kwargs: Any ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -219,9 +219,9 @@ async def set_service_properties( @distributed_trace def list_queues( - self, name_starts_with: Optional[str] = None, - include_metadata: Optional[bool] = False, - **kwargs: Any + self, name_starts_with: Optional[str] = None, + include_metadata: Optional[bool] = False, + **kwargs: Any ) -> AsyncItemPaged: """Returns a generator to list the queues under the specified account. @@ -272,9 +272,9 @@ def list_queues( @distributed_trace_async async def create_queue( - self, name: str, - metadata: Optional[Dict[str, str]] = None, - **kwargs: Any + self, name: str, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any ) -> QueueClient: """Creates a new queue under the specified account. @@ -308,8 +308,8 @@ async def create_queue( @distributed_trace_async async def delete_queue( - self, queue: Union["QueueProperties", str], - **kwargs: Any + self, queue: Union["QueueProperties", str], + **kwargs: Any ) -> None: """Deletes the specified queue and any messages it contains. @@ -344,9 +344,9 @@ async def delete_queue( await queue_client.delete_queue(timeout=timeout, **kwargs) def get_queue_client( - self, - queue: Union["QueueProperties", str], - **kwargs: Any + self, + queue: Union["QueueProperties", str], + **kwargs: Any ) -> QueueClient: """Get a client to interact with the specified queue. From 6b563610b030a391beb61eb73e16f15c20e2fdf1 Mon Sep 17 00:00:00 2001 From: vincenttran-msft <101599632+vincenttran-msft@users.noreply.github.com> Date: Mon, 20 Mar 2023 20:24:19 -0700 Subject: [PATCH 04/71] [Storage] Update typing for `queue_client` (Queue) (#28952) * sync 80% done need *, logic * First attempt at *, * Async + unify sync spacing * Fix CI * Revert recordings * PR feedback * Revert recording, fix indents * indent, again :D * Fix content docstring --- .../azure/storage/queue/_queue_client.py | 168 +++++++++--------- .../storage/queue/aio/_queue_client_async.py | 128 +++++++------ 2 files changed, 152 insertions(+), 144 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 6898af943f2c..09e107e126a9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -78,11 +78,11 @@ class QueueClient(StorageAccountHostsMixin, StorageEncryptionMixin): :caption: Create the queue client with url and credential. """ def __init__( - self, account_url: str, - queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> None: + self, account_url: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url @@ -108,7 +108,7 @@ def __init__( self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._configure_encryption(kwargs) - def _format_url(self, hostname): + def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location mode hostname. """ @@ -121,10 +121,10 @@ def _format_url(self, hostname): @classmethod def from_queue_url( - cls, queue_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> Self: + cls, queue_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: """A client to interact with a specific Queue. :param str queue_url: The full URI to the queue, including SAS token if used. @@ -164,11 +164,11 @@ def from_queue_url( @classmethod def from_connection_string( - cls, conn_str: str, - queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> Self: + cls, conn_str: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: """Create QueueClient from a Connection String. :param str conn_str: @@ -200,17 +200,20 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) # type: ignore + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @distributed_trace - def create_queue(self, **kwargs): - # type: (Any) -> None + def create_queue( + self, *, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Creates a new queue in the storage account. If a queue with the same name already exists, the operation fails with a `ResourceExistsError`. - :keyword dict(str,str) metadata: + :keyword Dict[str,str] metadata: A dict containing name-value pairs to associate with the queue as metadata. Note that metadata names preserve the case with which they were created, but are case-insensitive when set or read. @@ -234,11 +237,10 @@ def create_queue(self, **kwargs): :caption: Create a queue. """ headers = kwargs.pop('headers', {}) - metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - return self._client.queue.create( # type: ignore + return self._client.queue.create( metadata=metadata, timeout=timeout, headers=headers, @@ -248,8 +250,7 @@ def create_queue(self, **kwargs): process_storage_error(error) @distributed_trace - def delete_queue(self, **kwargs): - # type: (Any) -> None + def delete_queue(self, **kwargs: Any) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -284,8 +285,7 @@ def delete_queue(self, **kwargs): process_storage_error(error) @distributed_trace - def get_queue_properties(self, **kwargs): - # type: (Any) -> QueueProperties + def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """Returns all user-defined metadata for the specified queue. The data returned does not include the queue's list of messages. @@ -313,14 +313,13 @@ def get_queue_properties(self, **kwargs): except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name - return response # type: ignore + return response @distributed_trace - def set_queue_metadata(self, - metadata=None, # type: Optional[Dict[str, Any]] - **kwargs # type: Any - ): - # type: (...) -> None + def set_queue_metadata( + self, metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Sets user-defined metadata on the specified queue. Metadata is associated with the queue as name-value pairs. @@ -328,7 +327,7 @@ def set_queue_metadata(self, :param metadata: A dict containing name-value pairs to associate with the queue as metadata. - :type metadata: dict(str, str) + :type metadata: Dict[str, str] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. @@ -347,9 +346,9 @@ def set_queue_metadata(self, """ timeout = kwargs.pop('timeout', None) headers = kwargs.pop('headers', {}) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - return self._client.queue.set_metadata( # type: ignore + return self._client.queue.set_metadata( timeout=timeout, headers=headers, cls=return_response_headers, @@ -358,8 +357,7 @@ def set_queue_metadata(self, process_storage_error(error) @distributed_trace - def get_queue_access_policy(self, **kwargs): - # type: (Any) -> Dict[str, AccessPolicy] + def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: """Returns details about any stored access policies specified on the queue that may be used with Shared Access Signatures. @@ -370,7 +368,7 @@ def get_queue_access_policy(self, **kwargs): see `here `_. :return: A dictionary of access policies associated with the queue. - :rtype: dict(str, ~azure.storage.queue.AccessPolicy) + :rtype: Dict[str, ~azure.storage.queue.AccessPolicy] """ timeout = kwargs.pop('timeout', None) try: @@ -383,11 +381,10 @@ def get_queue_access_policy(self, **kwargs): return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace - def set_queue_access_policy(self, - signed_identifiers, # type: Dict[str, AccessPolicy] - **kwargs # type: Any - ): - # type: (...) -> None + def set_queue_access_policy( + self, signed_identifiers: Dict[str, AccessPolicy], + **kwargs: Any + ) -> None: """Sets stored access policies for the queue that may be used with Shared Access Signatures. @@ -406,7 +403,7 @@ def set_queue_access_policy(self, SignedIdentifier access policies to associate with the queue. This may contain up to 5 elements. An empty dict will clear the access policies set on the service. - :type signed_identifiers: dict(str, ~azure.storage.queue.AccessPolicy) + :type signed_identifiers: Dict[str, ~azure.storage.queue.AccessPolicy] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. @@ -434,7 +431,7 @@ def set_queue_access_policy(self, value.start = serialize_iso(value.start) value.expiry = serialize_iso(value.expiry) identifiers.append(SignedIdentifier(id=key, access_policy=value)) - signed_identifiers = identifiers # type: ignore + signed_identifiers = identifiers try: self._client.queue.set_access_policy( queue_acl=signed_identifiers or None, @@ -445,11 +442,12 @@ def set_queue_access_policy(self, @distributed_trace def send_message( - self, - content, # type: Any - **kwargs # type: Any - ): - # type: (...) -> QueueMessage + self, content: Any, + *, + visibility_timeout: Optional[int] = None, + time_to_live: Optional[int] = None, + **kwargs: Any + ) -> "QueueMessage": """Adds a new message to the back of the message queue. The visibility timeout specifies the time that the message will be @@ -463,7 +461,7 @@ def send_message( If the key-encryption-key field is set on the local service object, this method will encrypt the content before uploading. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. The encoded message can be up to 64KB in size. @@ -499,8 +497,6 @@ def send_message( :dedent: 12 :caption: Send messages. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) - time_to_live = kwargs.pop('time_to_live', None) timeout = kwargs.pop('timeout', None) try: self._config.message_encode_policy.configure( @@ -540,7 +536,11 @@ def send_message( process_storage_error(error) @distributed_trace - def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: + def receive_message( + self, *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> Optional[QueueMessage]: """Removes one message from the front of the queue. When the message is retrieved from the queue, the response includes the message @@ -578,7 +578,6 @@ def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: :dedent: 12 :caption: Receive one message from the queue. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) self._config.message_decode_policy.configure( require_encryption=self.require_encryption, @@ -599,8 +598,13 @@ def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: process_storage_error(error) @distributed_trace - def receive_messages(self, **kwargs): - # type: (Any) -> ItemPaged[QueueMessage] + def receive_messages( + self, *, + messages_per_page: Optional[int] = None, + visibility_timeout: Optional[int] = None, + max_messages: Optional[int] = None, + **kwargs: Any + ) -> ItemPaged[QueueMessage]: """Removes one or more messages from the front of the queue. When a message is retrieved from the queue, the response includes the message @@ -637,14 +641,14 @@ def receive_messages(self, **kwargs): larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. visibility_timeout should be set to a value smaller than the time-to-live value. + :keyword int max_messages: + An integer that specifies the maximum number of messages to retrieve from the queue. :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. This value is not tracked or validated on the client. To configure client-side network timesouts see `here `_. - :keyword int max_messages: - An integer that specifies the maximum number of messages to retrieve from the queue. :return: Returns a message iterator of dict-like Message objects. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.queue.QueueMessage] @@ -658,10 +662,7 @@ def receive_messages(self, **kwargs): :dedent: 12 :caption: Receive messages from the queue. """ - messages_per_page = kwargs.pop('messages_per_page', None) - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) - max_messages = kwargs.pop('max_messages', None) self._config.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, @@ -683,13 +684,14 @@ def receive_messages(self, **kwargs): process_storage_error(error) @distributed_trace - def update_message(self, - message, # type: Any - pop_receipt=None, # type: Optional[str] - content=None, # type: Optional[Any] - **kwargs # type: Any - ): - # type: (...) -> QueueMessage + def update_message( + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + content: Optional[Any] = None, + *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> QueueMessage: """Updates the visibility timeout of a message. You can also use this operation to update the contents of a message. @@ -710,7 +712,7 @@ def update_message(self, :param str pop_receipt: A valid pop receipt value returned from an earlier call to the :func:`~receive_messages` or :func:`~update_message` operation. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. :keyword int visibility_timeout: @@ -740,7 +742,6 @@ def update_message(self, :dedent: 12 :caption: Update a message. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) try: message_id = message.id @@ -780,7 +781,7 @@ def update_message(self, encoded_message_text = self._config.message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: - updated = None # type: ignore + updated = None try: response = self._client.message_id.update( queue_message=updated, @@ -802,11 +803,10 @@ def update_message(self, process_storage_error(error) @distributed_trace - def peek_messages(self, - max_messages=None, # type: Optional[int] - **kwargs # type: Any - ): - # type: (...) -> List[QueueMessage] + def peek_messages( + self, max_messages: Optional[int] = None, + **kwargs: Any + ) -> List[QueueMessage]: """Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. @@ -867,8 +867,7 @@ def peek_messages(self, process_storage_error(error) @distributed_trace - def clear_messages(self, **kwargs): - # type: (Any) -> None + def clear_messages(self, **kwargs: Any) -> None: """Deletes all messages from the specified queue. :keyword int timeout: @@ -894,12 +893,11 @@ def clear_messages(self, **kwargs): process_storage_error(error) @distributed_trace - def delete_message(self, - message, # type: Any - pop_receipt=None, # type: Optional[str] - **kwargs # type: Any - ): - # type: (...) -> None + def delete_message( + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + **kwargs: Any + ) -> None: """Deletes the specified message. Normally after a client retrieves a message with the receive messages operation, diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index c165c07bf431..0eaf9010007f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -84,25 +84,28 @@ class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncrypt """ def __init__( - self, account_url: str, - queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> None: + self, account_url: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) super(QueueClient, self).__init__( account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs ) self._client = AzureQueueStorage(self.url, base_url=self.url, - pipeline=self._pipeline, loop=loop) # type: ignore + pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) @distributed_trace_async - async def create_queue(self, **kwargs): - # type: (Optional[Any]) -> None + async def create_queue( + self, *, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Creates a new queue in the storage account. If a queue with the same name already exists, the operation fails with @@ -131,20 +134,18 @@ async def create_queue(self, **kwargs): :dedent: 12 :caption: Create a queue. """ - metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) headers = kwargs.pop("headers", {}) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - return await self._client.queue.create( # type: ignore + return await self._client.queue.create( metadata=metadata, timeout=timeout, headers=headers, cls=deserialize_queue_creation, **kwargs ) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def delete_queue(self, **kwargs): - # type: (Optional[Any]) -> None + async def delete_queue(self, **kwargs: Any) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -179,8 +180,7 @@ async def delete_queue(self, **kwargs): process_storage_error(error) @distributed_trace_async - async def get_queue_properties(self, **kwargs): - # type: (Optional[Any]) -> QueueProperties + async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """Returns all user-defined metadata for the specified queue. The data returned does not include the queue's list of messages. @@ -207,11 +207,13 @@ async def get_queue_properties(self, **kwargs): except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name - return response # type: ignore + return response @distributed_trace_async - async def set_queue_metadata(self, metadata=None, **kwargs): - # type: (Optional[Dict[str, Any]], Optional[Any]) -> None + async def set_queue_metadata( + self, metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Sets user-defined metadata on the specified queue. Metadata is associated with the queue as name-value pairs. @@ -238,17 +240,16 @@ async def set_queue_metadata(self, metadata=None, **kwargs): """ timeout = kwargs.pop('timeout', None) headers = kwargs.pop("headers", {}) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - return await self._client.queue.set_metadata( # type: ignore + return await self._client.queue.set_metadata( timeout=timeout, headers=headers, cls=return_response_headers, **kwargs ) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def get_queue_access_policy(self, **kwargs): - # type: (Optional[Any]) -> Dict[str, Any] + async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: """Returns details about any stored access policies specified on the queue that may be used with Shared Access Signatures. @@ -271,8 +272,10 @@ async def get_queue_access_policy(self, **kwargs): return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace_async - async def set_queue_access_policy(self, signed_identifiers, **kwargs): - # type: (Dict[str, AccessPolicy], Optional[Any]) -> None + async def set_queue_access_policy( + self, signed_identifiers: Dict[str, AccessPolicy], + **kwargs: Any + ) -> None: """Sets stored access policies for the queue that may be used with Shared Access Signatures. @@ -320,19 +323,20 @@ async def set_queue_access_policy(self, signed_identifiers, **kwargs): value.start = serialize_iso(value.start) value.expiry = serialize_iso(value.expiry) identifiers.append(SignedIdentifier(id=key, access_policy=value)) - signed_identifiers = identifiers # type: ignore + signed_identifiers = identifiers try: await self._client.queue.set_access_policy(queue_acl=signed_identifiers or None, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def send_message( # type: ignore - self, - content, # type: Any - **kwargs # type: Optional[Any] - ): - # type: (...) -> QueueMessage + async def send_message( + self, content: Any, + *, + visibility_timeout: Optional[int] = None, + time_to_live: Optional[int] = None, + **kwargs: Any + ) -> "QueueMessage": """Adds a new message to the back of the message queue. The visibility timeout specifies the time that the message will be @@ -346,7 +350,7 @@ async def send_message( # type: ignore If the key-encryption-key field is set on the local service object, this method will encrypt the content before uploading. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. The encoded message can be up to 64KB in size. @@ -382,8 +386,6 @@ async def send_message( # type: ignore :dedent: 16 :caption: Send messages. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) - time_to_live = kwargs.pop('time_to_live', None) timeout = kwargs.pop('timeout', None) try: self._config.message_encode_policy.configure( @@ -424,7 +426,11 @@ async def send_message( # type: ignore process_storage_error(error) @distributed_trace_async - async def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: + async def receive_message( + self, *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> Optional[QueueMessage]: """Removes one message from the front of the queue. When the message is retrieved from the queue, the response includes the message @@ -462,7 +468,6 @@ async def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: :dedent: 12 :caption: Receive one message from the queue. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) self._config.message_decode_policy.configure( require_encryption=self.require_encryption, @@ -483,8 +488,13 @@ async def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: process_storage_error(error) @distributed_trace - def receive_messages(self, **kwargs): - # type: (Optional[Any]) -> AsyncItemPaged[QueueMessage] + def receive_messages( + self, *, + messages_per_page: Optional[int] = None, + visibility_timeout: Optional[int] = None, + max_messages: Optional[int] = None, + **kwargs: Any + ) -> AsyncItemPaged[QueueMessage]: """Removes one or more messages from the front of the queue. When a message is retrieved from the queue, the response includes the message @@ -512,14 +522,14 @@ def receive_messages(self, **kwargs): larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. visibility_timeout should be set to a value smaller than the time-to-live value. + :keyword int max_messages: + An integer that specifies the maximum number of messages to retrieve from the queue. :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. This value is not tracked or validated on the client. To configure client-side network timesouts see `here `_. - :keyword int max_messages: - An integer that specifies the maximum number of messages to retrieve from the queue. :return: Returns a message iterator of dict-like Message objects. :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.storage.queue.QueueMessage] @@ -533,10 +543,7 @@ def receive_messages(self, **kwargs): :dedent: 16 :caption: Receive messages from the queue. """ - messages_per_page = kwargs.pop('messages_per_page', None) - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) - max_messages = kwargs.pop('max_messages', None) self._config.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, @@ -560,13 +567,13 @@ def receive_messages(self, **kwargs): @distributed_trace_async async def update_message( - self, - message, - pop_receipt=None, - content=None, - **kwargs - ): - # type: (Any, int, Optional[str], Optional[Any], Any) -> QueueMessage + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + content: Optional[Any] = None, + *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> QueueMessage: """Updates the visibility timeout of a message. You can also use this operation to update the contents of a message. @@ -587,7 +594,7 @@ async def update_message( :param str pop_receipt: A valid pop receipt value returned from an earlier call to the :func:`~receive_messages` or :func:`~update_message` operation. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. :keyword int visibility_timeout: @@ -617,7 +624,6 @@ async def update_message( :dedent: 16 :caption: Update a message. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) try: message_id = message.id @@ -659,7 +665,7 @@ async def update_message( encoded_message_text = self._config.message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: - updated = None # type: ignore + updated = None try: response = await self._client.message_id.update( queue_message=updated, @@ -682,8 +688,10 @@ async def update_message( process_storage_error(error) @distributed_trace_async - async def peek_messages(self, max_messages=None, **kwargs): - # type: (Optional[int], Optional[Any]) -> List[QueueMessage] + async def peek_messages( + self, max_messages: Optional[int] = None, + **kwargs: Any + ) -> List[QueueMessage]: """Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. @@ -743,8 +751,7 @@ async def peek_messages(self, max_messages=None, **kwargs): process_storage_error(error) @distributed_trace_async - async def clear_messages(self, **kwargs): - # type: (Optional[Any]) -> None + async def clear_messages(self, **kwargs: Any) -> None: """Deletes all messages from the specified queue. :keyword int timeout: @@ -770,8 +777,11 @@ async def clear_messages(self, **kwargs): process_storage_error(error) @distributed_trace_async - async def delete_message(self, message, pop_receipt=None, **kwargs): - # type: (Any, Optional[str], Any) -> None + async def delete_message( + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + **kwargs: Any + ) -> None: """Deletes the specified message. Normally after a client retrieves a message with the receive messages operation, From 012da69af422d645235bd8d02c5bb3d286255524 Mon Sep 17 00:00:00 2001 From: vincenttran-msft <101599632+vincenttran-msft@users.noreply.github.com> Date: Wed, 29 Mar 2023 13:19:15 -0700 Subject: [PATCH 05/71] [Storage] Update typing for miscellaneous files (queue) (#29173) --- .../azure/storage/queue/_deserialize.py | 18 +- .../azure/storage/queue/_encryption.py | 126 +++++--- .../azure/storage/queue/_message_encoding.py | 38 ++- .../azure/storage/queue/_models.py | 293 ++++++++++++------ .../storage/queue/_shared_access_signature.py | 60 ++-- .../azure/storage/queue/aio/_models.py | 66 +++- 6 files changed, 396 insertions(+), 205 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py index 50f79edb6076..3e7b240ef16f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py @@ -5,14 +5,22 @@ # -------------------------------------------------------------------------- # pylint: disable=unused-argument -from azure.core.exceptions import ResourceExistsError +from typing import Any, Dict, TYPE_CHECKING +from azure.core.exceptions import ResourceExistsError from ._shared.models import StorageErrorCode from ._models import QueueProperties from ._shared.response_handlers import deserialize_metadata +if TYPE_CHECKING: + from azure.core.pipeline import PipelineResponse + -def deserialize_queue_properties(response, obj, headers): +def deserialize_queue_properties( + response: "PipelineResponse", + obj: Any, + headers: Dict[str, Any] +) -> QueueProperties: metadata = deserialize_metadata(response, obj, headers) queue_properties = QueueProperties( metadata=metadata, @@ -21,7 +29,11 @@ def deserialize_queue_properties(response, obj, headers): return queue_properties -def deserialize_queue_creation(response, obj, headers): +def deserialize_queue_creation( + response: "PipelineResponse", + obj: Any, + headers: Dict[str, Any] +) -> Dict[str, Any]: response = response.http_response if response.status_code == 204: error_code = StorageErrorCode.queue_already_exists diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index d923a94ffe68..71e03b758891 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for @@ -14,7 +15,7 @@ dumps, loads, ) -from typing import Any, BinaryIO, Dict, Optional, Tuple +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher @@ -28,6 +29,11 @@ from ._version import VERSION from ._shared import encode_base64, decode_base64_to_bytes +if TYPE_CHECKING: + from azure.core.pipeline import PipelineResponse + from cryptography.hazmat.primitives.ciphers import AEADEncryptionContext + from cryptography.hazmat.primitives.padding import PaddingContext + _ENCRYPTION_PROTOCOL_V1 = '1.0' _ENCRYPTION_PROTOCOL_V2 = '2.0' @@ -39,12 +45,12 @@ '{0} does not define a complete interface. Value of {1} is either missing or invalid.' -def _validate_not_none(param_name, param): +def _validate_not_none(param_name: str, param: Any) -> None: if param is None: raise ValueError(f'{param_name} should not be None.') -def _validate_key_encryption_key_wrap(kek): +def _validate_key_encryption_key_wrap(kek: object) -> None: # Note that None is not callable and so will fail the second clause of each check. if not hasattr(kek, 'wrap_key') or not callable(kek.wrap_key): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) @@ -55,7 +61,7 @@ def _validate_key_encryption_key_wrap(kek): class StorageEncryptionMixin(object): - def _configure_encryption(self, kwargs): + def _configure_encryption(self, kwargs: Any) -> None: self.require_encryption = kwargs.get("require_encryption", False) self.encryption_version = kwargs.get("encryption_version", "1.0") self.key_encryption_key = kwargs.get("key_encryption_key") @@ -80,7 +86,7 @@ class _WrappedContentKey: Represents the envelope key details stored on the service. ''' - def __init__(self, algorithm, encrypted_key, key_id): + def __init__(self, algorithm: str, encrypted_key: bytes, key_id: str) -> None: ''' :param str algorithm: The algorithm used for wrapping. @@ -105,11 +111,11 @@ class _EncryptedRegionInfo: This is only used for Encryption V2. ''' - def __init__(self, data_length, nonce_length, tag_length): + def __init__(self, data_length: int, nonce_length: int, tag_length: int) -> None: ''' :param int data_length: The length of the encryption region data (not including nonce + tag). - :param str nonce_length: + :param int nonce_length: The length of nonce used when encrypting. :param int tag_length: The length of the encryption tag. @@ -129,7 +135,7 @@ class _EncryptionAgent: It consists of the encryption protocol version and encryption algorithm used. ''' - def __init__(self, encryption_algorithm, protocol): + def __init__(self, encryption_algorithm: _EncryptionAlgorithm, protocol: str) -> None: ''' :param _EncryptionAlgorithm encryption_algorithm: The algorithm used for encrypting the message contents. @@ -150,12 +156,12 @@ class _EncryptionData: ''' def __init__( - self, - content_encryption_IV, - encrypted_region_info, - encryption_agent, - wrapped_content_key, - key_wrapping_metadata): + self, content_encryption_IV: Optional[bytes], + encrypted_region_info: Optional[_EncryptedRegionInfo], + encryption_agent: _EncryptionAgent, + wrapped_content_key: _WrappedContentKey, + key_wrapping_metadata: Dict[str, Any] + ) -> None: ''' :param Optional[bytes] content_encryption_IV: The content encryption initialization vector. @@ -168,7 +174,7 @@ def __init__( :param _WrappedContentKey wrapped_content_key: An object that stores the wrapping algorithm, the key identifier, and the encrypted key bytes. - :param dict key_wrapping_metadata: + :param Dict[str, Any] key_wrapping_metadata: A dict containing metadata related to the key wrapping. ''' @@ -198,10 +204,9 @@ class GCMBlobEncryptionStream: nonce for each encryption region. """ def __init__( - self, - content_encryption_key: bytes, - data_stream: BinaryIO, - ): + self, content_encryption_key: bytes, + data_stream: BinaryIO, + ) -> None: """ :param bytes content_encryption_key: The encryption key to use. :param BinaryIO data_stream: The data stream to read data from. @@ -291,10 +296,11 @@ def get_adjusted_upload_size(length: int, encryption_version: str) -> int: def get_adjusted_download_range_and_offset( - start: int, - end: int, - length: int, - encryption_data: Optional[_EncryptionData]) -> Tuple[Tuple[int, int], Tuple[int, int]]: + start: int, + end: int, + length: int, + encryption_data: Optional[_EncryptionData] +) -> Tuple[Tuple[int, int], Tuple[int, int]]: """ Gets the new download range and offsets into the decrypted data for the given user-specified range. The new download range will include all @@ -400,7 +406,7 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp return size -def _generate_encryption_data_dict(kek, cek, iv, version): +def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], version: str) -> Dict[str, Any]: ''' Generates and returns the encryption metadata as a dict. @@ -452,12 +458,12 @@ def _generate_encryption_data_dict(kek, cek, iv, version): return encryption_data_dict -def _dict_to_encryption_data(encryption_data_dict): +def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _EncryptionData: ''' Converts the specified dictionary to an EncryptionData object for eventual use in decryption. - :param dict encryption_data_dict: + :param Dict[str, Any] encryption_data_dict: The dictionary containing the encryption data. :return: an _EncryptionData object built from the dictionary. :rtype: _EncryptionData @@ -504,12 +510,12 @@ def _dict_to_encryption_data(encryption_data_dict): return encryption_data -def _generate_AES_CBC_cipher(cek, iv): +def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: ''' Generates and returns an encryption cipher for AES CBC using the given cek and iv. - :param bytes[] cek: The content encryption key for the cipher. - :param bytes[] iv: The initialization vector for the cipher. + :param bytes cek: The content encryption key for the cipher. + :param bytes iv: The initialization vector for the cipher. :return: A cipher for encrypting in AES256 CBC. :rtype: ~cryptography.hazmat.primitives.ciphers.Cipher ''' @@ -520,20 +526,24 @@ def _generate_AES_CBC_cipher(cek, iv): return Cipher(algorithm, mode, backend) -def _validate_and_unwrap_cek(encryption_data, key_encryption_key=None, key_resolver=None): +def _validate_and_unwrap_cek( + encryption_data: _EncryptionData, + key_encryption_key: object = None, + key_resolver: Callable[[str], bytes] = None +) -> bytes: ''' Extracts and returns the content_encryption_key stored in the encryption_data object and performs necessary validation on all parameters. :param _EncryptionData encryption_data: The encryption metadata of the retrieved value. - :param obj key_encryption_key: + :param object key_encryption_key: The key_encryption_key used to unwrap the cek. Please refer to high-level service object instance variables for more details. - :param func key_resolver: + :param Callable[[str], bytes] key_resolver: A function used that, given a key_id, will return a key_encryption_key. Please refer to high-level service object instance variables for more details. :return: the content_encryption_key stored in the encryption_data object. - :rtype: bytes[] + :rtype: bytes ''' _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) @@ -579,7 +589,12 @@ def _validate_and_unwrap_cek(encryption_data, key_encryption_key=None, key_resol return content_encryption_key -def _decrypt_message(message, encryption_data, key_encryption_key=None, resolver=None): +def _decrypt_message( + message: str, + encryption_data: _EncryptionData, + key_encryption_key: object = None, + resolver: callable = None +) -> str: ''' Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. Unwraps the content-encryption-key using the user-provided or resolved key-encryption-key (kek). @@ -639,7 +654,7 @@ def _decrypt_message(message, encryption_data, key_encryption_key=None, resolver return decrypted_data -def encrypt_blob(blob, key_encryption_key, version): +def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple[str, bytes]: ''' Encrypts the given blob using the given encryption protocol version. Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). @@ -698,7 +713,7 @@ def encrypt_blob(blob, key_encryption_key, version): return dumps(encryption_data), encrypted_data -def generate_blob_encryption_data(key_encryption_key, version): +def generate_blob_encryption_data(key_encryption_key: object, version: str) -> Tuple[bytes, Optional[bytes], str]: ''' Generates the encryption_metadata for the blob. @@ -729,13 +744,14 @@ def generate_blob_encryption_data(key_encryption_key, version): def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements - require_encryption, - key_encryption_key, - key_resolver, - content, - start_offset, - end_offset, - response_headers): + require_encryption: bool, + key_encryption_key: object, + key_resolver: Callable[[str], bytes], + content: bytes, + start_offset: int, + end_offset: int, + response_headers: Dict[str, Any] +) -> bytes: """ Decrypts the given blob contents and returns only the requested range. @@ -746,7 +762,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements wrap_key(key)--wraps the specified key using an algorithm of the user's choice. get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. get_kid()--returns a string key id for this key-encryption-key. - :param object key_resolver: + :param Callable[[str], bytes] key_resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :param bytes content: @@ -854,7 +870,11 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements return decrypted_content[start_offset:end_offset] -def get_blob_encryptor_and_padder(cek, iv, should_pad): +def get_blob_encryptor_and_padder( + cek: bytes, + iv: bytes, + should_pad: bool +) -> Tuple["AEADEncryptionContext", "PaddingContext"]: encryptor = None padder = None @@ -866,14 +886,14 @@ def get_blob_encryptor_and_padder(cek, iv, should_pad): return encryptor, padder -def encrypt_queue_message(message, key_encryption_key, version): +def encrypt_queue_message(message: str, key_encryption_key: object, version: str) -> str: ''' Encrypts the given plain text message using the given protocol version. Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encrypted message and the encryption metadata. - :param object message: - The plain text messge to be encrypted. + :param str message: + The plain text message to be encrypted. :param object key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: wrap_key(key)--wraps the specified key using an algorithm of the user's choice. @@ -933,7 +953,13 @@ def encrypt_queue_message(message, key_encryption_key, version): return dumps(queue_message) -def decrypt_queue_message(message, response, require_encryption, key_encryption_key, resolver): +def decrypt_queue_message( + message: str, + response: "PipelineResponse", + require_encryption: bool, + key_encryption_key: object, + resolver: Callable[[str], bytes] +) -> str: ''' Returns the decrypted message contents from an EncryptedQueueMessage. If no encryption metadata is present, will return the unaltered message. @@ -947,7 +973,7 @@ def decrypt_queue_message(message, response, require_encryption, key_encryption_ - returns the unwrapped form of the specified symmetric key usingthe string-specified algorithm. get_kid() - returns a string key id for this key-encryption-key. - :param function resolver(kid): + :param Callable[[str], bytes] resolver(kid): The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The plain text message from the queue message. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 4c12a5fa2ed4..5e35f445d3f5 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -5,30 +5,38 @@ # -------------------------------------------------------------------------- # pylint: disable=unused-argument -import sys from base64 import b64encode, b64decode +from typing import Any, Callable, Dict, TYPE_CHECKING from azure.core.exceptions import DecodeError from ._encryption import decrypt_queue_message, encrypt_queue_message, _ENCRYPTION_PROTOCOL_V1 +if TYPE_CHECKING: + from azure.core.pipeline import PipelineResponse + class MessageEncodePolicy(object): - def __init__(self): + def __init__(self) -> None: self.require_encryption = False self.encryption_version = None self.key_encryption_key = None self.resolver = None - def __call__(self, content): + def __call__(self, content: Any) -> str: if content: content = self.encode(content) if self.key_encryption_key is not None: content = encrypt_queue_message(content, self.key_encryption_key, self.encryption_version) return content - def configure(self, require_encryption, key_encryption_key, resolver, encryption_version=_ENCRYPTION_PROTOCOL_V1): + def configure( + self, require_encryption: bool, + key_encryption_key: object, + resolver: Callable[[str], bytes], + encryption_version: str = _ENCRYPTION_PROTOCOL_V1 + ) -> None: self.require_encryption = require_encryption self.encryption_version = encryption_version self.key_encryption_key = key_encryption_key @@ -36,7 +44,7 @@ def configure(self, require_encryption, key_encryption_key, resolver, encryption if self.require_encryption and not self.key_encryption_key: raise ValueError("Encryption required but no key was provided.") - def encode(self, content): + def encode(self, content: Any) -> str: raise NotImplementedError("Must be implemented by child class.") @@ -47,7 +55,7 @@ def __init__(self): self.key_encryption_key = None self.resolver = None - def __call__(self, response, obj, headers): + def __call__(self, response: "PipelineResponse", obj: object, headers: Dict[str, Any]) -> object: for message in obj: if message.message_text in [None, "", b""]: continue @@ -61,12 +69,12 @@ def __call__(self, response, obj, headers): message.message_text = self.decode(content, response) return obj - def configure(self, require_encryption, key_encryption_key, resolver): + def configure(self, require_encryption: bool, key_encryption_key: object, resolver: Callable[[str], bytes]) -> None: self.require_encryption = require_encryption self.key_encryption_key = key_encryption_key self.resolver = resolver - def decode(self, content, response): + def decode(self, content: Any, response: "PipelineResponse") -> str: raise NotImplementedError("Must be implemented by child class.") @@ -77,7 +85,7 @@ class TextBase64EncodePolicy(MessageEncodePolicy): is not text, a TypeError will be raised. Input text must support UTF-8. """ - def encode(self, content): + def encode(self, content: str) -> str: if not isinstance(content, str): raise TypeError("Message content must be text for base 64 encoding.") return b64encode(content.encode('utf-8')).decode('utf-8') @@ -91,7 +99,7 @@ class TextBase64DecodePolicy(MessageDecodePolicy): support UTF-8. """ - def decode(self, content, response): + def decode(self, content: str, response: "PipelineResponse") -> bytes: try: return b64decode(content.encode('utf-8')).decode('utf-8') except (ValueError, TypeError) as error: @@ -109,7 +117,7 @@ class BinaryBase64EncodePolicy(MessageEncodePolicy): is not bytes, a TypeError will be raised. """ - def encode(self, content): + def encode(self, content: bytes) -> str: if not isinstance(content, bytes): raise TypeError("Message content must be bytes for base 64 encoding.") return b64encode(content).decode('utf-8') @@ -122,7 +130,7 @@ class BinaryBase64DecodePolicy(MessageDecodePolicy): is not valid base 64, a DecodeError will be raised. """ - def decode(self, content, response): + def decode(self, content: str, response: "PipelineResponse") -> bytes: response = response.http_response try: return b64decode(content.encode('utf-8')) @@ -137,8 +145,8 @@ def decode(self, content, response): class NoEncodePolicy(MessageEncodePolicy): """Bypass any message content encoding.""" - def encode(self, content): - if isinstance(content, bytes) and sys.version_info > (3,): + def encode(self, content: str) -> str: + if isinstance(content, bytes): raise TypeError( "Message content must not be bytes. Use the BinaryBase64EncodePolicy to send bytes." ) @@ -148,5 +156,5 @@ def encode(self, content): class NoDecodePolicy(MessageDecodePolicy): """Bypass any message content decoding.""" - def decode(self, content, response): + def decode(self, content: str, response: "PipelineResponse") -> str: return content diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 95f7ea563193..5390072c3ee2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -6,7 +6,8 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -from typing import List # pylint: disable=unused-import +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator from ._shared.response_handlers import return_context_and_deserialized, process_storage_error @@ -17,6 +18,48 @@ from ._generated.models import RetentionPolicy as GeneratedRetentionPolicy from ._generated.models import CorsRule as GeneratedCorsRule +if sys.version_info >= (3, 11): + from typing import Self # pylint: disable=no-name-in-module, ungrouped-imports +else: + from typing_extensions import Self # pylint: disable=ungrouped-imports + +if TYPE_CHECKING: + from datetime import datetime + + +class RetentionPolicy(GeneratedRetentionPolicy): + """The retention policy which determines how long the associated data should + persist. + + All required parameters must be populated in order to send to Azure. + + :param bool enabled: Required. Indicates whether a retention policy is enabled + for the storage service. + :param int days: Indicates the number of days that metrics or logging or + soft-deleted data should be retained. All data older than this value will + be deleted. + """ + + enabled: bool = False + """Indicates whether a retention policy is enabled for the storage service.""" + days: int = None + """Indicates the number of days that metrics or logging or soft-deleted data should be retained.""" + + def __init__(self, enabled: bool = False, days: int = None) -> None: + self.enabled = enabled + self.days = days + if self.enabled and (self.days is None): + raise ValueError("If policy is enabled, 'days' must be specified.") + + @classmethod + def _from_generated(cls, generated: Any) -> Self: + if not generated: + return cls() + return cls( + enabled=generated.enabled, + days=generated.days, + ) + class QueueAnalyticsLogging(GeneratedLogging): """Azure Analytics Logging settings. @@ -27,11 +70,21 @@ class QueueAnalyticsLogging(GeneratedLogging): :keyword bool delete: Required. Indicates whether all delete requests should be logged. :keyword bool read: Required. Indicates whether all read requests should be logged. :keyword bool write: Required. Indicates whether all write requests should be logged. - :keyword ~azure.storage.queue.RetentionPolicy retention_policy: Required. - The retention policy for the metrics. + :keyword ~azure.storage.queue.RetentionPolicy retention_policy: The retention policy for the metrics. """ - def __init__(self, **kwargs): + version: str = '1.0' + """The version of Storage Analytics to configure.""" + delete: bool = False + """Indicates whether all delete requests should be logged.""" + read: bool = False + """Indicates whether all read requests should be logged""" + write: bool = False + """Indicates whether all write requests should be logged.""" + retention_policy: RetentionPolicy = RetentionPolicy() + """The retention policy for the metrics.""" + + def __init__(self, **kwargs: Any) -> None: self.version = kwargs.get('version', '1.0') self.delete = kwargs.get('delete', False) self.read = kwargs.get('read', False) @@ -39,7 +92,7 @@ def __init__(self, **kwargs): self.retention_policy = kwargs.get('retention_policy') or RetentionPolicy() @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: if not generated: return cls() return cls( @@ -58,20 +111,28 @@ class Metrics(GeneratedMetrics): :keyword str version: The version of Storage Analytics to configure. :keyword bool enabled: Required. Indicates whether metrics are enabled for the service. - :keyword bool include_ap_is: Indicates whether metrics should generate summary + :keyword bool include_apis: Indicates whether metrics should generate summary statistics for called API operations. - :keyword ~azure.storage.queue.RetentionPolicy retention_policy: Required. - The retention policy for the metrics. + :keyword ~azure.storage.queue.RetentionPolicy retention_policy: The retention policy for the metrics. """ - def __init__(self, **kwargs): + version: str = '1.0' + """The version of Storage Analytics to configure.""" + enabled: bool = False + """Indicates whether metrics are enabled for the service.""" + include_apis: bool + """Indicates whether metrics should generate summary statistics for called API operations.""" + retention_policy: RetentionPolicy = RetentionPolicy() + """The retention policy for the metrics.""" + + def __init__(self, **kwargs: Any) -> None: self.version = kwargs.get('version', '1.0') self.enabled = kwargs.get('enabled', False) self.include_apis = kwargs.get('include_apis') self.retention_policy = kwargs.get('retention_policy') or RetentionPolicy() @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: if not generated: return cls() return cls( @@ -82,35 +143,6 @@ def _from_generated(cls, generated): ) -class RetentionPolicy(GeneratedRetentionPolicy): - """The retention policy which determines how long the associated data should - persist. - - All required parameters must be populated in order to send to Azure. - - :param bool enabled: Required. Indicates whether a retention policy is enabled - for the storage service. - :param int days: Indicates the number of days that metrics or logging or - soft-deleted data should be retained. All data older than this value will - be deleted. - """ - - def __init__(self, enabled=False, days=None): - self.enabled = enabled - self.days = days - if self.enabled and (self.days is None): - raise ValueError("If policy is enabled, 'days' must be specified.") - - @classmethod - def _from_generated(cls, generated): - if not generated: - return cls() - return cls( - enabled=generated.enabled, - days=generated.days, - ) - - class CorsRule(GeneratedCorsRule): """CORS is an HTTP feature that enables a web application running under one domain to access resources in another domain. Web browsers implement a @@ -120,28 +152,42 @@ class CorsRule(GeneratedCorsRule): All required parameters must be populated in order to send to Azure. - :param list(str) allowed_origins: + :param List[str] allowed_origins: A list of origin domains that will be allowed via CORS, or "*" to allow - all domains. The list of must contain at least one entry. Limited to 64 + all domains. The list must contain at least one entry. Limited to 64 origin domains. Each allowed origin can have up to 256 characters. - :param list(str) allowed_methods: + :param List[str] allowed_methods: A list of HTTP methods that are allowed to be executed by the origin. - The list of must contain at least one entry. For Azure Storage, + The list must contain at least one entry. For Azure Storage, permitted methods are DELETE, GET, HEAD, MERGE, POST, OPTIONS or PUT. :keyword int max_age_in_seconds: The number of seconds that the client/browser should cache a pre-flight response. - :keyword list(str) exposed_headers: + :keyword List[str] exposed_headers: Defaults to an empty list. A list of response headers to expose to CORS clients. Limited to 64 defined headers and two prefixed headers. Each header can be up to 256 characters. - :keyword list(str) allowed_headers: + :keyword List[str] allowed_headers: Defaults to an empty list. A list of headers allowed to be part of the cross-origin request. Limited to 64 defined headers and 2 prefixed headers. Each header can be up to 256 characters. """ - def __init__(self, allowed_origins, allowed_methods, **kwargs): + allowed_origins: str + """The comma-delimited string representation of the list of origin domains that will be allowed via + CORS, or "*" to allow all domains.""" + allowed_methods: str + """The comma-delimited string representation of the list HTTP methods that are allowed to be executed + by the origin.""" + max_age_in_seconds: int + """The number of seconds that the client/browser should cache a pre-flight response.""" + exposed_headers: str + """The comma-delimited string representation of the list of response headers to expose to CORS clients.""" + allowed_headers: str + """The comma-delimited string representation of the list of headers allowed to be part of the cross-origin + request.""" + + def __init__(self, allowed_origins: List[str], allowed_methods: List[str], **kwargs: Any) -> None: self.allowed_origins = ','.join(allowed_origins) self.allowed_methods = ','.join(allowed_methods) self.allowed_headers = ','.join(kwargs.get('allowed_headers', [])) @@ -149,7 +195,7 @@ def __init__(self, allowed_origins, allowed_methods, **kwargs): self.max_age_in_seconds = kwargs.get('max_age_in_seconds', 0) @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: return cls( [generated.allowed_origins], [generated.allowed_methods], @@ -202,41 +248,52 @@ class AccessPolicy(GenAccessPolicy): :type start: ~datetime.datetime or str """ - def __init__(self, permission=None, expiry=None, start=None): + permission: str = None + """The permissions associated with the shared access signature. The user is restricted to + operations allowed by the permissions.""" + expiry: Union["datetime", str] = None + """The time at which the shared access signature becomes invalid.""" + start: Union["datetime", str] = None + """The time at which the shared access signature becomes valid.""" + + def __init__( + self, permission: str = None, + expiry: Union["datetime", str] = None, + start: Union["datetime", str] = None + ) -> None: self.start = start self.expiry = expiry self.permission = permission class QueueMessage(DictMixin): - """Represents a queue message. + """Represents a queue message.""" - :ivar str id: - A GUID value assigned to the message by the Queue service that + id: str + """A GUID value assigned to the message by the Queue service that identifies the message in the queue. This value may be used together with the value of pop_receipt to delete a message from the queue after - it has been retrieved with the receive messages operation. - :ivar date inserted_on: - A UTC date value representing the time the messages was inserted. - :ivar date expires_on: - A UTC date value representing the time the message expires. - :ivar int dequeue_count: - Begins with a value of 1 the first time the message is received. This - value is incremented each time the message is subsequently received. - :param obj content: - The message content. Type is determined by the decode_function set on - the service. Default is str. - :ivar str pop_receipt: - A receipt str which can be used together with the message_id element to + it has been retrieved with the receive messages operation.""" + inserted_on: "datetime" + """A UTC date value representing the time the messages was inserted.""" + expires_on: "datetime" + """A UTC date value representing the time the message expires.""" + dequeue_count: int + """Begins with a value of 1 the first time the message is received. This + value is incremented each time the message is subsequently received.""" + content: Any = None + """The message content. Type is determined by the decode_function set on + the service. Default is str.""" + pop_receipt: str + """A receipt str which can be used together with the message_id element to delete a message from the queue after it has been retrieved with the receive messages operation. Only returned by receive messages operations. Set to - None for peek messages. - :ivar date next_visible_on: - A UTC date value representing the time the message will next be visible. - Only returned by receive messages operations. Set to None for peek messages. - """ + None for peek messages.""" + next_visible_on: "datetime" + """A UTC date value representing the time the message will next be visible. + Only returned by receive messages operations. Set to None for peek messages.""" - def __init__(self, content=None): + def __init__(self, content: Any = None) -> None: self.id = None self.inserted_on = None self.expires_on = None @@ -246,7 +303,7 @@ def __init__(self, content=None): self.next_visible_on = None @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: message = cls(content=generated.message_text) message.id = generated.message_id message.inserted_on = generated.insertion_time @@ -261,13 +318,26 @@ def _from_generated(cls, generated): class MessagesPaged(PageIterator): """An iterable of Queue Messages. - :param callable command: Function to retrieve the next page of items. + :param Callable command: Function to retrieve the next page of items. :param int results_per_page: The maximum number of messages to retrieve per call. :param int max_messages: The maximum number of messages to retrieve from the queue. """ - def __init__(self, command, results_per_page=None, continuation_token=None, max_messages=None): + + command: Callable + """Function to retrieve the next page of items.""" + results_per_page: int = None + """A UTC date value representing the time the message expires.""" + max_messages: int = None + """The maximum number of messages to retrieve from the queue.""" + + def __init__( + self, command: Callable, + results_per_page: int = None, + continuation_token: str = None, + max_messages: int = None + ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") @@ -279,7 +349,7 @@ def __init__(self, command, results_per_page=None, continuation_token=None, max_ self.results_per_page = results_per_page self._max_messages = max_messages - def _get_next_cb(self, continuation_token): + def _get_next_cb(self, continuation_token: str) -> Any: try: if self._max_messages is not None: if self.results_per_page is None: @@ -291,7 +361,7 @@ def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - def _extract_data_cb(self, messages): # pylint: disable=no-self-use + def _extract_data_cb(self, messages: Any) -> Tuple[str, List[QueueMessage]]: # There is no concept of continuation token, so raising on my own condition if not messages: raise StopIteration("End of paging") @@ -303,20 +373,24 @@ def _extract_data_cb(self, messages): # pylint: disable=no-self-use class QueueProperties(DictMixin): """Queue Properties. - :ivar str name: The name of the queue. - :keyword dict(str,str) metadata: + :keyword Dict[str, str] metadata: A dict containing name-value pairs associated with the queue as metadata. This var is set to None unless the include=metadata param was included - for the list queues operation. If this parameter was specified but the + for the list queues operation. """ - def __init__(self, **kwargs): + name: str + """The name of the queue.""" + metadata: Dict[str, str] + """A dict containing name-value pairs associated with the queue as metadata.""" + + def __init__(self, **kwargs: Any) -> None: self.name = None self.metadata = kwargs.get('metadata') self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: props = cls() props.name = generated.name props.metadata = generated.metadata @@ -326,20 +400,36 @@ def _from_generated(cls, generated): class QueuePropertiesPaged(PageIterator): """An iterable of Queue properties. - :ivar str service_endpoint: The service URL. - :ivar str prefix: A queue name prefix being used to filter the list. - :ivar str marker: The continuation token of the current page of results. - :ivar int results_per_page: The maximum number of results retrieved per API call. - :ivar str next_marker: The continuation token to retrieve the next page of results. - :ivar str location_mode: The location mode being used to list results. The available - options include "primary" and "secondary". - :param callable command: Function to retrieve the next page of items. + :param Callable command: Function to retrieve the next page of items. :param str prefix: Filters the results to return only queues whose names begin with the specified prefix. :param int results_per_page: The maximum number of queue names to retrieve per call. :param str continuation_token: An opaque continuation token. """ + + service_endpoint: str + """The service URL.""" + prefix: str + """A queue name prefix being used to filter the list.""" + marker: str + """The continuation token of the current page of results.""" + results_per_page: int = None + """The maximum number of results retrieved per API call.""" + next_marker: str + """The continuation token to retrieve the next page of results.""" + location_mode: str + """The location mode being used to list results. The available options include "primary" and "secondary".""" + command: Callable + """Function to retrieve the next page of items.""" + prefix: str = None + """Filters the results to return only queues whose names begin with the specified prefix.""" + results_per_page: int + """The maximum number of queue names to retrieve per + call.""" + continuation_token: str = None + """An opaque continuation token.""" + def __init__(self, command, prefix=None, results_per_page=None, continuation_token=None): super(QueuePropertiesPaged, self).__init__( self._get_next_cb, @@ -353,7 +443,7 @@ def __init__(self, command, prefix=None, results_per_page=None, continuation_tok self.results_per_page = results_per_page self.location_mode = None - def _get_next_cb(self, continuation_token): + def _get_next_cb(self, continuation_token: str) -> Any: try: return self._command( marker=continuation_token or None, @@ -363,7 +453,7 @@ def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - def _extract_data_cb(self, get_next_return): + def _extract_data_cb(self, get_next_return: str) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return self.service_endpoint = self._response.service_endpoint self.prefix = self._response.prefix @@ -388,7 +478,22 @@ class QueueSasPermissions(object): :param bool process: Get and delete messages from the queue. """ - def __init__(self, read=False, add=False, update=False, process=False): + + read: bool = False + """Read metadata and properties, including message count.""" + add: bool = False + """Add messages to the queue.""" + update: bool = False + """Update messages in the queue.""" + process: bool = False + """Get and delete messages from the queue.""" + + def __init__( + self, read: bool = False, + add: bool = False, + update: bool = False, + process: bool = False + ) -> None: self.read = read self.add = add self.update = update @@ -402,7 +507,7 @@ def __str__(self): return self._str @classmethod - def from_string(cls, permission): + def from_string(cls, permission: str) -> Self: """Create a QueueSasPermissions from a string. To specify read, add, update, or process permissions you need only to @@ -424,7 +529,7 @@ def from_string(cls, permission): return parsed -def service_stats_deserialize(generated): +def service_stats_deserialize(generated: Any) -> Dict[str, Any]: """Deserialize a ServiceStats objects into a dict. """ return { @@ -435,7 +540,7 @@ def service_stats_deserialize(generated): } -def service_properties_deserialize(generated): +def service_properties_deserialize(generated: Any) -> Dict[str, Any]: """Deserialize a ServiceProperties objects into a dict. """ return { diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py index 3959c9be9e61..844d88a09f7c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py @@ -30,7 +30,7 @@ class QueueSharedAccessSignature(SharedAccessSignature): generate_*_shared_access_signature method directly. ''' - def __init__(self, account_name, account_key): + def __init__(self, account_name: str, account_key: str) -> None: ''' :param str account_name: The storage account name used to generate the shared access signatures. @@ -39,9 +39,15 @@ def __init__(self, account_name, account_key): ''' super(QueueSharedAccessSignature, self).__init__(account_name, account_key, x_ms_version=X_MS_VERSION) - def generate_queue(self, queue_name, permission=None, - expiry=None, start=None, policy_id=None, - ip=None, protocol=None): + def generate_queue( + self, queue_name: str, + permission: "QueueSasPermissions" = None, + expiry: Union["datetime", str] = None, + start: Union["datetime", str] = None, + policy_id: str = None, + ip: str = None, + protocol: str = None + ) -> str: ''' Generates a shared access signature for the queue. Use the returned signature with the sas_token parameter of QueueService. @@ -61,14 +67,14 @@ def generate_queue(self, queue_name, permission=None, been specified in an associated stored access policy. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type expiry: datetime or str + :type expiry: ~datetime.datetime or str :param start: The time at which the shared access signature becomes valid. If omitted, start time for this call is assumed to be the time when the storage service receives the request. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type start: datetime or str + :type start: ~datetime.datetime or str :param str policy_id: A unique value up to 64 characters in length that correlates to a stored access policy. @@ -92,7 +98,7 @@ def generate_queue(self, queue_name, permission=None, class _QueueSharedAccessHelper(_SharedAccessHelper): - def add_resource_signature(self, account_name, account_key, path): # pylint: disable=arguments-differ + def add_resource_signature(self, account_name: str, account_key: str, path: str) -> str: # pylint: disable=arguments-differ def get_value_to_append(query): return_value = self.query_dict.get(query) or '' return return_value + '\n' @@ -123,15 +129,15 @@ def get_value_to_append(query): def generate_account_sas( - account_name, # type: str - account_key, # type: str - resource_types, # type: Union[ResourceTypes, str] - permission, # type: Union[AccountSasPermissions, str] - expiry, # type: Optional[Union[datetime, str]] - start=None, # type: Optional[Union[datetime, str]] - ip=None, # type: Optional[str] - **kwargs # type: Any - ): # type: (...) -> str + account_name: str, + account_key: str, + resource_types: Union["ResourceTypes", str], + permission: Union["AccountSasPermissions", str], + expiry: Optional[Union["datetime", str]], + start: Optional[Union["datetime", str]] = None, + ip: Optional[str] = None, + **kwargs: Any +) -> str: """Generates a shared access signature for the queue service. Use the returned signature with the credential parameter of any Queue Service. @@ -180,20 +186,20 @@ def generate_account_sas( start=start, ip=ip, **kwargs - ) # type: ignore + ) def generate_queue_sas( - account_name, # type: str - queue_name, # type: str - account_key, # type: str - permission=None, # type: Optional[Union[QueueSasPermissions, str]] - expiry=None, # type: Optional[Union[datetime, str]] - start=None, # type: Optional[Union[datetime, str]] - policy_id=None, # type: Optional[str] - ip=None, # type: Optional[str] - **kwargs # type: Any - ): # type: (...) -> str + account_name: str, + queue_name: str, + account_key: str, + permission: Optional[Union["QueueSasPermissions", str]] = None, + expiry: Optional[Union["datetime", str]] = None, + start: Optional[Union["datetime", str]] = None, + policy_id: Optional[str] = None, + ip: Optional[str] = None, + **kwargs: Any +) -> str: """Generates a shared access signature for a queue. Use the returned signature with the credential parameter of any Queue Service. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index dcb8d32277f4..4cf9b1877a09 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -6,7 +6,7 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -from typing import List # pylint: disable=unused-import +from typing import Any, Callable, List, Optional, Tuple from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError from .._shared.response_handlers import ( @@ -18,13 +18,26 @@ class MessagesPaged(AsyncPageIterator): """An iterable of Queue Messages. - :param callable command: Function to retrieve the next page of items. + :param Callable command: Function to retrieve the next page of items. :param int results_per_page: The maximum number of messages to retrieve per call. :param int max_messages: The maximum number of messages to retrieve from the queue. """ - def __init__(self, command, results_per_page=None, continuation_token=None, max_messages=None): + + command: Callable + """Function to retrieve the next page of items.""" + results_per_page: int = None + """A UTC date value representing the time the message expires.""" + max_messages: int = None + """The maximum number of messages to retrieve from the queue.""" + + def __init__( + self, command: Callable, + results_per_page: int = None, + continuation_token: str = None, + max_messages: int = None + ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") @@ -36,7 +49,7 @@ def __init__(self, command, results_per_page=None, continuation_token=None, max_ self.results_per_page = results_per_page self._max_messages = max_messages - async def _get_next_cb(self, continuation_token): + async def _get_next_cb(self, continuation_token: str) -> Any: try: if self._max_messages is not None: if self.results_per_page is None: @@ -48,7 +61,7 @@ async def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, messages): + async def _extract_data_cb(self, messages: QueueMessage) -> Tuple[str, List[QueueMessage]]: # There is no concept of continuation token, so raising on my own condition if not messages: raise StopAsyncIteration("End of paging") @@ -60,21 +73,42 @@ async def _extract_data_cb(self, messages): class QueuePropertiesPaged(AsyncPageIterator): """An iterable of Queue properties. - :ivar str service_endpoint: The service URL. - :ivar str prefix: A queue name prefix being used to filter the list. - :ivar str marker: The continuation token of the current page of results. - :ivar int results_per_page: The maximum number of results retrieved per API call. - :ivar str next_marker: The continuation token to retrieve the next page of results. - :ivar str location_mode: The location mode being used to list results. The available - options include "primary" and "secondary". - :param callable command: Function to retrieve the next page of items. + :param Callable command: Function to retrieve the next page of items. :param str prefix: Filters the results to return only queues whose names begin with the specified prefix. :param int results_per_page: The maximum number of queue names to retrieve per call. :param str continuation_token: An opaque continuation token. """ - def __init__(self, command, prefix=None, results_per_page=None, continuation_token=None): + + service_endpoint: str + """The service URL.""" + prefix: str + """A queue name prefix being used to filter the list.""" + marker: str + """The continuation token of the current page of results.""" + results_per_page: int = None + """The maximum number of results retrieved per API call.""" + next_marker: str + """The continuation token to retrieve the next page of results.""" + location_mode: str + """The location mode being used to list results. The available options include "primary" and "secondary".""" + command: Callable + """Function to retrieve the next page of items.""" + prefix: str = None + """Filters the results to return only queues whose names begin with the specified prefix.""" + results_per_page: int + """The maximum number of queue names to retrieve per + call.""" + continuation_token: str = None + """An opaque continuation token.""" + + def __init__( + self, command: Callable, + prefix: str = None, + results_per_page: int = None, + continuation_token: str = None + ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, self._extract_data_cb, @@ -87,7 +121,7 @@ def __init__(self, command, prefix=None, results_per_page=None, continuation_tok self.results_per_page = results_per_page self.location_mode = None - async def _get_next_cb(self, continuation_token): + async def _get_next_cb(self, continuation_token: str) -> Any: try: return await self._command( marker=continuation_token or None, @@ -97,7 +131,7 @@ async def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, get_next_return): + async def _extract_data_cb(self, get_next_return: str) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return self.service_endpoint = self._response.service_endpoint self.prefix = self._response.prefix From f9290eecc01a1954ae68e2525c5a061bcc6c3a3c Mon Sep 17 00:00:00 2001 From: vincenttran-msft <101599632+vincenttran-msft@users.noreply.github.com> Date: Fri, 7 Apr 2023 16:17:44 -0700 Subject: [PATCH 06/71] [Storage] `_queue_client.py` mypy fixes (#29736) * First draft MyPy * Mypy config fix + just Config Mypy left * Fully done w/ custom StorageConfiguration * Pylint * Double import * Fix cast * PR feedback * Unused imports * Pylint * Configuration errors in async BE GONE * Fix tests that assert against _config --- .../azure/storage/queue/_models.py | 6 +- .../azure/storage/queue/_queue_client.py | 70 ++++++++++--------- .../queue/_shared/response_handlers.py | 6 +- .../storage/queue/aio/_queue_client_async.py | 29 ++++---- sdk/storage/azure-storage-queue/mypy.ini | 11 --- ...ageQueueEncodingtest_message_text_xml.json | 22 +++--- .../tests/test_queue_encodings.py | 8 +-- 7 files changed, 72 insertions(+), 80 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 5390072c3ee2..282c7a0def0d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -274,11 +274,11 @@ class QueueMessage(DictMixin): identifies the message in the queue. This value may be used together with the value of pop_receipt to delete a message from the queue after it has been retrieved with the receive messages operation.""" - inserted_on: "datetime" + inserted_on: Optional["datetime"] """A UTC date value representing the time the messages was inserted.""" - expires_on: "datetime" + expires_on: Optional["datetime"] """A UTC date value representing the time the message expires.""" - dequeue_count: int + dequeue_count: Optional[int] """Begins with a value of 1 the first time the message is received. This value is incremented each time the message is subsequently received.""" content: Any = None diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 09e107e126a9..9582f3717892 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -7,8 +7,8 @@ import functools import warnings from typing import ( # pylint: disable=unused-import - Any, Dict, List, Optional, Union, - TYPE_CHECKING) + Any, Dict, cast, List, Optional, Union, + TYPE_CHECKING, Tuple) from urllib.parse import urlparse, quote, unquote from typing_extensions import Self @@ -27,7 +27,7 @@ from ._deserialize import deserialize_queue_properties, deserialize_queue_creation from ._encryption import StorageEncryptionMixin from ._message_encoding import NoEncodePolicy, NoDecodePolicy -from ._models import QueueMessage, AccessPolicy, MessagesPaged +from ._models import AccessPolicy, MessagesPaged, QueueMessage from ._serialize import get_api_version if TYPE_CHECKING: @@ -102,8 +102,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) - self._config.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() - self._config.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() + self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._configure_encryption(kwargs) @@ -112,9 +112,10 @@ def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location mode hostname. """ - queue_name = self.queue_name - if isinstance(queue_name, str): - queue_name = queue_name.encode('UTF-8') + if isinstance(self.queue_name, str): + queue_name = self.queue_name.encode('UTF-8') + else: + queue_name = self.queue_name return ( f"{self.scheme}://{hostname}" f"/{quote(queue_name)}{self._query_str}") @@ -306,10 +307,10 @@ def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """ timeout = kwargs.pop('timeout', None) try: - response = self._client.queue.get_properties( + response = cast("QueueProperties", self._client.queue.get_properties( timeout=timeout, cls=deserialize_queue_properties, - **kwargs) + **kwargs)) except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name @@ -372,10 +373,10 @@ def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: """ timeout = kwargs.pop('timeout', None) try: - _, identifiers = self._client.queue.get_access_policy( + _, identifiers = cast(Tuple[Dict, List], self._client.queue.get_access_policy( timeout=timeout, cls=return_headers_and_deserialized, - **kwargs) + **kwargs)) except HttpResponseError as error: process_storage_error(error) return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @@ -431,10 +432,9 @@ def set_queue_access_policy( value.start = serialize_iso(value.start) value.expiry = serialize_iso(value.expiry) identifiers.append(SignedIdentifier(id=key, access_policy=value)) - signed_identifiers = identifiers try: self._client.queue.set_access_policy( - queue_acl=signed_identifiers or None, + queue_acl=identifiers or None, timeout=timeout, **kwargs) except HttpResponseError as error: @@ -447,7 +447,7 @@ def send_message( visibility_timeout: Optional[int] = None, time_to_live: Optional[int] = None, **kwargs: Any - ) -> "QueueMessage": + ) -> QueueMessage: """Adds a new message to the back of the message queue. The visibility timeout specifies the time that the message will be @@ -499,7 +499,7 @@ def send_message( """ timeout = kwargs.pop('timeout', None) try: - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, @@ -511,11 +511,11 @@ def send_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) - encoded_content = self._config.message_encode_policy(content) + encoded_content = self.message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) try: @@ -579,7 +579,7 @@ def receive_message( :caption: Receive one message from the queue. """ timeout = kwargs.pop('timeout', None) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) @@ -588,7 +588,7 @@ def receive_message( number_of_messages=1, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self.message_decode_policy, **kwargs ) wrapped_message = QueueMessage._from_generated( # pylint: disable=protected-access @@ -663,7 +663,7 @@ def receive_messages( :caption: Receive messages from the queue. """ timeout = kwargs.pop('timeout', None) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) @@ -672,7 +672,7 @@ def receive_messages( self._client.messages.dequeue, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self.message_decode_policy, **kwargs ) if max_messages is not None and messages_per_page is not None: @@ -743,14 +743,16 @@ def update_message( :caption: Update a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id message_text = content or message.content receipt = pop_receipt or message.pop_receipt inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - except AttributeError: + else: message_id = message message_text = content receipt = pop_receipt @@ -762,7 +764,7 @@ def update_message( raise ValueError("pop_receipt must be present") if message_text is not None: try: - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function, @@ -774,23 +776,23 @@ def update_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function) - encoded_message_text = self._config.message_encode_policy(message_text) + encoded_message_text = self.message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: updated = None try: - response = self._client.message_id.update( + response = cast(QueueMessage, self._client.message_id.update( queue_message=updated, visibilitytimeout=visibility_timeout or 0, timeout=timeout, pop_receipt=receipt, cls=return_response_headers, queue_message_id=message_id, - **kwargs) + **kwargs)) new_message = QueueMessage(content=message_text) new_message.id = message_id new_message.inserted_on = inserted_on @@ -849,7 +851,7 @@ def peek_messages( timeout = kwargs.pop('timeout', None) if max_messages and not 1 <= max_messages <= 32: raise ValueError("Number of messages to peek should be between 1 and 32") - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) @@ -857,7 +859,7 @@ def peek_messages( messages = self._client.messages.peek( number_of_messages=max_messages, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self.message_decode_policy, **kwargs) wrapped_messages = [] for peeked in messages: @@ -933,10 +935,12 @@ def delete_message( :caption: Delete a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id receipt = pop_receipt or message.pop_receipt - except AttributeError: + else: message_id = message receipt = pop_receipt diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py index e1633487eba6..dbc82124544e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- from typing import ( # pylint: disable=unused-import - Union, Optional, Any, Iterable, Dict, List, Type, Tuple, - TYPE_CHECKING + Union, Optional, Any, Iterable, Dict, List, NoReturn, Type, + Tuple, TYPE_CHECKING ) import logging from xml.etree.ElementTree import Element @@ -90,7 +90,7 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error): # pylint:disable=too-many-statements +def process_storage_error(storage_error) -> NoReturn: # pylint:disable=too-many-statements raise_error = HttpResponseError serialized = False if not storage_error.response or storage_error.response.status_code in [200, 204]: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 0eaf9010007f..e5f87365399f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -294,7 +294,7 @@ async def set_queue_access_policy( SignedIdentifier access policies to associate with the queue. This may contain up to 5 elements. An empty dict will clear the access policies set on the service. - :type signed_identifiers: dict(str, ~azure.storage.queue.AccessPolicy) + :type signed_identifiers: Dict[str, ~azure.storage.queue.AccessPolicy] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. @@ -323,9 +323,8 @@ async def set_queue_access_policy( value.start = serialize_iso(value.start) value.expiry = serialize_iso(value.expiry) identifiers.append(SignedIdentifier(id=key, access_policy=value)) - signed_identifiers = identifiers try: - await self._client.queue.set_access_policy(queue_acl=signed_identifiers or None, timeout=timeout, **kwargs) + await self._client.queue.set_access_policy(queue_acl=identifiers or None, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) @@ -388,7 +387,7 @@ async def send_message( """ timeout = kwargs.pop('timeout', None) try: - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, @@ -400,11 +399,11 @@ async def send_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) - encoded_content = self._config.message_encode_policy(content) + encoded_content = self.message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) try: @@ -469,7 +468,7 @@ async def receive_message( :caption: Receive one message from the queue. """ timeout = kwargs.pop('timeout', None) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) @@ -478,7 +477,7 @@ async def receive_message( number_of_messages=1, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self.message_decode_policy, **kwargs ) wrapped_message = QueueMessage._from_generated( # pylint: disable=protected-access @@ -544,7 +543,7 @@ def receive_messages( :caption: Receive messages from the queue. """ timeout = kwargs.pop('timeout', None) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -554,7 +553,7 @@ def receive_messages( self._client.messages.dequeue, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self.message_decode_policy, **kwargs ) if max_messages is not None and messages_per_page is not None: @@ -644,7 +643,7 @@ async def update_message( raise ValueError("pop_receipt must be present") if message_text is not None: try: - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function, @@ -657,12 +656,12 @@ async def update_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self.message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function ) - encoded_message_text = self._config.message_encode_policy(message_text) + encoded_message_text = self.message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: updated = None @@ -734,14 +733,14 @@ async def peek_messages( timeout = kwargs.pop('timeout', None) if max_messages and not 1 <= max_messages <= 32: raise ValueError("Number of messages to peek should be between 1 and 32") - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function ) try: messages = await self._client.messages.peek( - number_of_messages=max_messages, timeout=timeout, cls=self._config.message_decode_policy, **kwargs + number_of_messages=max_messages, timeout=timeout, cls=self.message_decode_policy, **kwargs ) wrapped_messages = [] for peeked in messages: diff --git a/sdk/storage/azure-storage-queue/mypy.ini b/sdk/storage/azure-storage-queue/mypy.ini index f8e299bc34f3..0d5fa8139e1d 100644 --- a/sdk/storage/azure-storage-queue/mypy.ini +++ b/sdk/storage/azure-storage-queue/mypy.ini @@ -1,13 +1,2 @@ -[mypy] -python_version = 3.7 -warn_return_any = True -warn_unused_configs = True -ignore_missing_imports = True - -# Per-module options: - [mypy-azure.storage.queue._generated.*] ignore_errors = True - -[mypy-azure.core.*] -ignore_errors = True diff --git a/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json b/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json index 1100637169ec..1f7d22252046 100644 --- a/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json +++ b/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json @@ -8,15 +8,15 @@ "Accept-Encoding": "gzip, deflate", "Connection": "keep-alive", "Content-Length": "0", - "User-Agent": "azsdk-python-storage-queue/12.5.1 Python/3.10.2 (Windows-10-10.0.19044-SP0)", - "x-ms-date": "Wed, 26 Oct 2022 23:31:16 GMT", + "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.1 (Windows-10-10.0.22621-SP0)", + "x-ms-date": "Fri, 07 Apr 2023 22:00:39 GMT", "x-ms-version": "2021-02-12" }, "RequestBody": null, "StatusCode": 201, "ResponseHeaders": { "Content-Length": "0", - "Date": "Wed, 26 Oct 2022 23:31:15 GMT", + "Date": "Fri, 07 Apr 2023 22:00:38 GMT", "Server": [ "Windows-Azure-Queue/1.0", "Microsoft-HTTPAPI/2.0" @@ -34,8 +34,8 @@ "Connection": "keep-alive", "Content-Length": "111", "Content-Type": "application/xml", - "User-Agent": "azsdk-python-storage-queue/12.5.1 Python/3.10.2 (Windows-10-10.0.19044-SP0)", - "x-ms-date": "Wed, 26 Oct 2022 23:31:16 GMT", + "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.1 (Windows-10-10.0.22621-SP0)", + "x-ms-date": "Fri, 07 Apr 2023 22:00:40 GMT", "x-ms-version": "2021-02-12" }, "RequestBody": [ @@ -45,7 +45,7 @@ "StatusCode": 201, "ResponseHeaders": { "Content-Type": "application/xml", - "Date": "Wed, 26 Oct 2022 23:31:15 GMT", + "Date": "Fri, 07 Apr 2023 22:00:38 GMT", "Server": [ "Windows-Azure-Queue/1.0", "Microsoft-HTTPAPI/2.0" @@ -53,7 +53,7 @@ "Transfer-Encoding": "chunked", "x-ms-version": "2021-02-12" }, - "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003E56a10687-3cf2-4234-a8da-16757f842147\u003C/MessageId\u003E\u003CInsertionTime\u003EWed, 26 Oct 2022 23:31:16 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EWed, 02 Nov 2022 23:31:16 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAAhrIaC5Pp2AE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EWed, 26 Oct 2022 23:31:16 GMT\u003C/TimeNextVisible\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" + "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003Ef37bc876-6500-4f1e-874e-6e14128de1f1\u003C/MessageId\u003E\u003CInsertionTime\u003EFri, 07 Apr 2023 22:00:38 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EFri, 14 Apr 2023 22:00:38 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAAPOtiY5xp2QE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EFri, 07 Apr 2023 22:00:38 GMT\u003C/TimeNextVisible\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" }, { "RequestUri": "https://storagename.queue.core.windows.net/mytestqueue9b732da4/messages", @@ -62,8 +62,8 @@ "Accept": "application/xml", "Accept-Encoding": "gzip, deflate", "Connection": "keep-alive", - "User-Agent": "azsdk-python-storage-queue/12.5.1 Python/3.10.2 (Windows-10-10.0.19044-SP0)", - "x-ms-date": "Wed, 26 Oct 2022 23:31:17 GMT", + "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.1 (Windows-10-10.0.22621-SP0)", + "x-ms-date": "Fri, 07 Apr 2023 22:00:40 GMT", "x-ms-version": "2021-02-12" }, "RequestBody": null, @@ -71,7 +71,7 @@ "ResponseHeaders": { "Cache-Control": "no-cache", "Content-Type": "application/xml", - "Date": "Wed, 26 Oct 2022 23:31:15 GMT", + "Date": "Fri, 07 Apr 2023 22:00:38 GMT", "Server": [ "Windows-Azure-Queue/1.0", "Microsoft-HTTPAPI/2.0" @@ -80,7 +80,7 @@ "Vary": "Origin", "x-ms-version": "2021-02-12" }, - "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003E56a10687-3cf2-4234-a8da-16757f842147\u003C/MessageId\u003E\u003CInsertionTime\u003EWed, 26 Oct 2022 23:31:16 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EWed, 02 Nov 2022 23:31:16 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAAVC4OHZPp2AE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EWed, 26 Oct 2022 23:31:46 GMT\u003C/TimeNextVisible\u003E\u003CDequeueCount\u003E1\u003C/DequeueCount\u003E\u003CMessageText\u003E\u0026lt;message1\u0026gt;\u003C/MessageText\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" + "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003Ef37bc876-6500-4f1e-874e-6e14128de1f1\u003C/MessageId\u003E\u003CInsertionTime\u003EFri, 07 Apr 2023 22:00:38 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EFri, 14 Apr 2023 22:00:38 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAAvZZSdZxp2QE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EFri, 07 Apr 2023 22:01:08 GMT\u003C/TimeNextVisible\u003E\u003CDequeueCount\u003E1\u003C/DequeueCount\u003E\u003CMessageText\u003E\u0026lt;message1\u0026gt;\u003C/MessageText\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" } ], "Variables": {} diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py index 86c47c7dce1f..6e7eea541b07 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py @@ -70,8 +70,8 @@ def test_message_text_xml(self, **kwargs): queue = qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) # Asserts - assert isinstance(queue._config.message_encode_policy, NoEncodePolicy) - assert isinstance(queue._config.message_decode_policy, NoDecodePolicy) + assert isinstance(queue.message_encode_policy, NoEncodePolicy) + assert isinstance(queue.message_decode_policy, NoDecodePolicy) self._validate_encoding(queue, message) @QueuePreparer() @@ -225,8 +225,8 @@ def test_message_no_encoding(self): message_decode_policy=None) # Asserts - assert isinstance(queue._config.message_encode_policy, NoEncodePolicy) - assert isinstance(queue._config.message_decode_policy, NoDecodePolicy) + assert isinstance(queue.message_encode_policy, NoEncodePolicy) + assert isinstance(queue.message_decode_policy, NoDecodePolicy) # ------------------------------------------------------------------------------ From 03b65c0031598c1a66b30e2fe48527129b97275e Mon Sep 17 00:00:00 2001 From: vincenttran-msft <101599632+vincenttran-msft@users.noreply.github.com> Date: Wed, 3 May 2023 16:51:21 -0700 Subject: [PATCH 07/71] [Storage] _queue_client_async.py mypy fixes (#29824) --- .../queue/_shared/base_client_async.py | 3 +- .../storage/queue/aio/_queue_client_async.py | 59 ++++++++++--------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index f6dc3a9c747c..11fef6ea29b8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -39,7 +39,6 @@ from .response_handlers import process_storage_error, PartialBatchErrorException if TYPE_CHECKING: - from azure.core.pipeline import Pipeline from azure.core.pipeline.transport import HttpRequest from azure.core.configuration import Configuration _LOGGER = logging.getLogger(__name__) @@ -67,7 +66,7 @@ async def close(self): await self._client.close() def _create_pipeline(self, credential, **kwargs): - # type: (Any, **Any) -> Tuple[Configuration, Pipeline] + # type: (Any, **Any) -> Tuple[Configuration, AsyncPipeline] self._credential_policy = None if hasattr(credential, 'get_token'): self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index e5f87365399f..840b2c89ffd5 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -8,7 +8,7 @@ import functools import warnings from typing import ( - Any, Dict, List, Optional, Union, + Any, cast, Dict, List, Optional, Union, Tuple, TYPE_CHECKING) from azure.core.async_paging import AsyncItemPaged @@ -37,7 +37,7 @@ from .._models import QueueProperties -class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin): +class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin): # type: ignore[misc] """A client to interact with a specific Queue. :param str account_url: @@ -94,14 +94,13 @@ def __init__( super(QueueClient, self).__init__( account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs ) - self._client = AzureQueueStorage(self.url, base_url=self.url, - pipeline=self._pipeline, loop=loop) + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) @distributed_trace_async - async def create_queue( + async def create_queue( # type: ignore[override] self, *, metadata: Optional[Dict[str, str]] = None, **kwargs: Any @@ -145,7 +144,7 @@ async def create_queue( process_storage_error(error) @distributed_trace_async - async def delete_queue(self, **kwargs: Any) -> None: + async def delete_queue(self, **kwargs: Any) -> None: # type: ignore[override] """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -180,7 +179,7 @@ async def delete_queue(self, **kwargs: Any) -> None: process_storage_error(error) @distributed_trace_async - async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": + async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": # type: ignore[override] """Returns all user-defined metadata for the specified queue. The data returned does not include the queue's list of messages. @@ -201,16 +200,16 @@ async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """ timeout = kwargs.pop('timeout', None) try: - response = await self._client.queue.get_properties( + response = cast("QueueProperties", await (self._client.queue.get_properties( timeout=timeout, cls=deserialize_queue_properties, **kwargs - ) + ))) except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name return response @distributed_trace_async - async def set_queue_metadata( + async def set_queue_metadata( # type: ignore[override] self, metadata: Optional[Dict[str, str]] = None, **kwargs: Any ) -> None: @@ -242,14 +241,14 @@ async def set_queue_metadata( headers = kwargs.pop("headers", {}) headers.update(add_metadata_headers(metadata)) try: - return await self._client.queue.set_metadata( + await self._client.queue.set_metadata( timeout=timeout, headers=headers, cls=return_response_headers, **kwargs ) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: + async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: # type: ignore[override] """Returns details about any stored access policies specified on the queue that may be used with Shared Access Signatures. @@ -264,15 +263,15 @@ async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy """ timeout = kwargs.pop('timeout', None) try: - _, identifiers = await self._client.queue.get_access_policy( + _, identifiers = cast(Tuple[Dict, List], await self._client.queue.get_access_policy( timeout=timeout, cls=return_headers_and_deserialized, **kwargs - ) + )) except HttpResponseError as error: process_storage_error(error) return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace_async - async def set_queue_access_policy( + async def set_queue_access_policy( # type: ignore[override] self, signed_identifiers: Dict[str, AccessPolicy], **kwargs: Any ) -> None: @@ -329,7 +328,7 @@ async def set_queue_access_policy( process_storage_error(error) @distributed_trace_async - async def send_message( + async def send_message( # type: ignore[override] self, content: Any, *, visibility_timeout: Optional[int] = None, @@ -425,7 +424,7 @@ async def send_message( process_storage_error(error) @distributed_trace_async - async def receive_message( + async def receive_message( # type: ignore[override] self, *, visibility_timeout: Optional[int] = None, **kwargs: Any @@ -487,7 +486,7 @@ async def receive_message( process_storage_error(error) @distributed_trace - def receive_messages( + def receive_messages( # type: ignore[override] self, *, messages_per_page: Optional[int] = None, visibility_timeout: Optional[int] = None, @@ -565,7 +564,7 @@ def receive_messages( process_storage_error(error) @distributed_trace_async - async def update_message( + async def update_message( # type: ignore[override] self, message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, content: Optional[Any] = None, @@ -624,14 +623,16 @@ async def update_message( :caption: Update a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id message_text = content or message.content receipt = pop_receipt or message.pop_receipt inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - except AttributeError: + else: message_id = message message_text = content receipt = pop_receipt @@ -666,7 +667,7 @@ async def update_message( else: updated = None try: - response = await self._client.message_id.update( + response = cast(QueueMessage, await self._client.message_id.update( queue_message=updated, visibilitytimeout=visibility_timeout or 0, timeout=timeout, @@ -674,7 +675,7 @@ async def update_message( cls=return_response_headers, queue_message_id=message_id, **kwargs - ) + )) new_message = QueueMessage(content=message_text) new_message.id = message_id new_message.inserted_on = inserted_on @@ -687,7 +688,7 @@ async def update_message( process_storage_error(error) @distributed_trace_async - async def peek_messages( + async def peek_messages( # type: ignore[override] self, max_messages: Optional[int] = None, **kwargs: Any ) -> List[QueueMessage]: @@ -750,7 +751,7 @@ async def peek_messages( process_storage_error(error) @distributed_trace_async - async def clear_messages(self, **kwargs: Any) -> None: + async def clear_messages(self, **kwargs: Any) -> None: # type: ignore[override] """Deletes all messages from the specified queue. :keyword int timeout: @@ -776,7 +777,7 @@ async def clear_messages(self, **kwargs: Any) -> None: process_storage_error(error) @distributed_trace_async - async def delete_message( + async def delete_message( # type: ignore[override] self, message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, **kwargs: Any @@ -816,10 +817,12 @@ async def delete_message( :caption: Delete a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id receipt = pop_receipt or message.pop_receipt - except AttributeError: + else: message_id = message receipt = pop_receipt From 2876b3d3982f74e6a426f6792b0c199720ce5bdc Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Tue, 20 Jun 2023 16:32:57 -0700 Subject: [PATCH 08/71] Init done for _queue_client --- .../azure/storage/queue/_queue_client.py | 26 +------ .../storage/queue/_queue_client_helpers.py | 68 +++++++++++++++++++ .../storage/queue/aio/_queue_client_async.py | 15 ++-- 3 files changed, 76 insertions(+), 33 deletions(-) create mode 100644 sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 9582f3717892..3324954346e6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -29,6 +29,7 @@ from ._message_encoding import NoEncodePolicy, NoDecodePolicy from ._models import AccessPolicy, MessagesPaged, QueueMessage from ._serialize import get_api_version +from ._queue_client_helpers import _initialize_client if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -83,30 +84,7 @@ def __init__( credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: - try: - if not account_url.lower().startswith('http'): - account_url = "https://" + account_url - except AttributeError: - raise ValueError("Account URL must be a string.") - parsed_url = urlparse(account_url.rstrip('/')) - if not queue_name: - raise ValueError("Please specify a queue name.") - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {parsed_url}") - - _, sas_token = parse_query(parsed_url.query) - if not sas_token and not credential: - raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") - - self.queue_name = queue_name - self._query_str, credential = self._format_query_string(sas_token, credential) - super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) - - self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() - self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() - self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access - self._configure_encryption(kwargs) + _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, **kwargs) def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py new file mode 100644 index 000000000000..59dfba88324a --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -0,0 +1,68 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from datetime import datetime, timezone +from ._generated import AzureQueueStorage +from urllib.parse import urlparse, quote, unquote +from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query +from ._message_encoding import NoEncodePolicy, NoDecodePolicy +from ._serialize import get_api_version + +def _initialize_client(self, account_url, queue_name, credential, **kwargs): + try: + if not account_url.lower().startswith('http'): + account_url = "https://" + account_url + except AttributeError: + raise ValueError("Account URL must be a string.") + parsed_url = urlparse(account_url.rstrip('/')) + if not queue_name: + raise ValueError("Please specify a queue name.") + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {parsed_url}") + + _, sas_token = parse_query(parsed_url.query) + if not sas_token and not credential: + raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + + self.queue_name = queue_name + self._query_str, credential = self._format_query_string(sas_token, credential) + loop = kwargs.pop('loop', None) + StorageAccountHostsMixin.__init__(self, parsed_url, service='queue', credential=credential, loop=loop, **kwargs) + + self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) + self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._configure_encryption(kwargs) + +def _rfc_1123_to_datetime(rfc_1123: str) -> datetime: + """Converts an RFC 1123 date string to a UTC datetime. + """ + if not rfc_1123: + return None + + return datetime.strptime(rfc_1123, "%a, %d %b %Y %H:%M:%S %Z") + +def _filetime_to_datetime(filetime: str) -> datetime: + """Converts an MS filetime string to a UTC datetime. "0" indicates None. + If parsing MS Filetime fails, tries RFC 1123 as backup. + """ + if not filetime: + return None + + # Try to convert to MS Filetime + try: + filetime = int(filetime) + if filetime == 0: + return None + + return datetime.fromtimestamp((filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc) + except ValueError: + pass + + # Try RFC 1123 as backup + return _rfc_1123_to_datetime(filetime) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 840b2c89ffd5..ce24dad3a27a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -31,9 +31,11 @@ from .._models import QueueMessage, AccessPolicy from .._queue_client import QueueClient as QueueClientBase from ._models import MessagesPaged +from .._queue_client_helpers import _initialize_client if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential + from azure.core.credentials_async import AsyncTokenCredential from .._models import QueueProperties @@ -86,18 +88,13 @@ class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncrypt def __init__( self, account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) - loop = kwargs.pop('loop', None) - super(QueueClient, self).__init__( - account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs - ) - self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + loop = kwargs.get('loop', None) + _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, **kwargs) self._loop = loop - self._configure_encryption(kwargs) @distributed_trace_async async def create_queue( # type: ignore[override] From c4c56be7ba25807f4b0a914afdac712ac18c85aa Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Tue, 20 Jun 2023 20:20:34 -0700 Subject: [PATCH 09/71] Fix async calling into sync --- .../azure/storage/queue/_queue_client.py | 2 +- .../azure/storage/queue/_queue_client_helpers.py | 12 ++++++++++-- .../azure/storage/queue/aio/_queue_client_async.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 3324954346e6..5ad75f752b64 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -84,7 +84,7 @@ def __init__( credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: - _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, **kwargs) + _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, client_type='sync', **kwargs) def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 59dfba88324a..c7be87e509b3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -7,12 +7,13 @@ import sys from datetime import datetime, timezone from ._generated import AzureQueueStorage +from ._generated.aio import AzureQueueStorage as AsyncAzureQueueStorage from urllib.parse import urlparse, quote, unquote from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query from ._message_encoding import NoEncodePolicy, NoDecodePolicy from ._serialize import get_api_version -def _initialize_client(self, account_url, queue_name, credential, **kwargs): +def _initialize_client(self, account_url, queue_name, credential, client_type, **kwargs): try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url @@ -35,7 +36,14 @@ def _initialize_client(self, account_url, queue_name, credential, **kwargs): self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() - self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) + + if client_type == 'sync': + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) + elif client_type == 'async': + self._client = AsyncAzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) + else: + raise ValueError("Unknown client_type provided, valid options are 'sync' and 'async'.") + self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._configure_encryption(kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index ce24dad3a27a..bee100996b3a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -93,7 +93,7 @@ def __init__( ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) loop = kwargs.get('loop', None) - _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, **kwargs) + _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, client_type='async', **kwargs) self._loop = loop @distributed_trace_async From 86d626affa2d1180cfd9b703da63df1236ac5dc0 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 21 Jun 2023 23:02:57 -0700 Subject: [PATCH 10/71] Back to the drawing board --- .../azure/storage/queue/_queue_client.py | 28 +++++++++++++++++-- .../storage/queue/aio/_queue_client_async.py | 17 ++++++----- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 5ad75f752b64..c34ca57f7e59 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -29,7 +29,6 @@ from ._message_encoding import NoEncodePolicy, NoDecodePolicy from ._models import AccessPolicy, MessagesPaged, QueueMessage from ._serialize import get_api_version -from ._queue_client_helpers import _initialize_client if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -84,7 +83,30 @@ def __init__( credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: - _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, client_type='sync', **kwargs) + try: + if not account_url.lower().startswith('http'): + account_url = "https://" + account_url + except AttributeError: + raise ValueError("Account URL must be a string.") + parsed_url = urlparse(account_url.rstrip('/')) + if not queue_name: + raise ValueError("Please specify a queue name.") + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {parsed_url}") + + _, sas_token = parse_query(parsed_url.query) + if not sas_token and not credential: + raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + + self.queue_name = queue_name + self._query_str, credential = self._format_query_string(sas_token, credential) + super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) + + self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) + self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._configure_encryption(kwargs) def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location @@ -932,4 +954,4 @@ def delete_message( **kwargs ) except HttpResponseError as error: - process_storage_error(error) + process_storage_error(error) \ No newline at end of file diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index bee100996b3a..729f46a9e639 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -31,11 +31,9 @@ from .._models import QueueMessage, AccessPolicy from .._queue_client import QueueClient as QueueClientBase from ._models import MessagesPaged -from .._queue_client_helpers import _initialize_client if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential - from azure.core.credentials_async import AsyncTokenCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential from .._models import QueueProperties @@ -88,13 +86,18 @@ class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncrypt def __init__( self, account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) - loop = kwargs.get('loop', None) - _initialize_client(self, account_url=account_url, queue_name=queue_name, credential=credential, client_type='async', **kwargs) + loop = kwargs.pop('loop', None) + super(QueueClient, self).__init__( + account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs + ) + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) + self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._loop = loop + self._configure_encryption(kwargs) @distributed_trace_async async def create_queue( # type: ignore[override] @@ -830,4 +833,4 @@ async def delete_message( # type: ignore[override] pop_receipt=receipt, timeout=timeout, queue_message_id=message_id, **kwargs ) except HttpResponseError as error: - process_storage_error(error) + process_storage_error(error) \ No newline at end of file From 6681a3745c232fc743055dece2efb263f65772fc Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 21 Jun 2023 23:30:08 -0700 Subject: [PATCH 11/71] Finished __init__ --- .../azure/storage/queue/_queue_client.py | 16 +----- .../storage/queue/_queue_client_helpers.py | 50 ++----------------- .../storage/queue/aio/_queue_client_async.py | 31 ++++++++++-- 3 files changed, 32 insertions(+), 65 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index c34ca57f7e59..522699ef5ffe 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -29,6 +29,7 @@ from ._message_encoding import NoEncodePolicy, NoDecodePolicy from ._models import AccessPolicy, MessagesPaged, QueueMessage from ._serialize import get_api_version +from ._queue_client_helpers import _initialize_client if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -83,20 +84,7 @@ def __init__( credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: - try: - if not account_url.lower().startswith('http'): - account_url = "https://" + account_url - except AttributeError: - raise ValueError("Account URL must be a string.") - parsed_url = urlparse(account_url.rstrip('/')) - if not queue_name: - raise ValueError("Please specify a queue name.") - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {parsed_url}") - - _, sas_token = parse_query(parsed_url.query) - if not sas_token and not credential: - raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + parsed_url, sas_token = _initialize_client(account_url=account_url, queue_name=queue_name, credential=credential) self.queue_name = queue_name self._query_str, credential = self._format_query_string(sas_token, credential) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index c7be87e509b3..f1deb92a033d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -13,7 +13,7 @@ from ._message_encoding import NoEncodePolicy, NoDecodePolicy from ._serialize import get_api_version -def _initialize_client(self, account_url, queue_name, credential, client_type, **kwargs): +def _initialize_client(account_url, queue_name, credential): try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url @@ -28,49 +28,5 @@ def _initialize_client(self, account_url, queue_name, credential, client_type, * _, sas_token = parse_query(parsed_url.query) if not sas_token and not credential: raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") - - self.queue_name = queue_name - self._query_str, credential = self._format_query_string(sas_token, credential) - loop = kwargs.pop('loop', None) - StorageAccountHostsMixin.__init__(self, parsed_url, service='queue', credential=credential, loop=loop, **kwargs) - - self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() - self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() - - if client_type == 'sync': - self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) - elif client_type == 'async': - self._client = AsyncAzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) - else: - raise ValueError("Unknown client_type provided, valid options are 'sync' and 'async'.") - - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access - self._configure_encryption(kwargs) - -def _rfc_1123_to_datetime(rfc_1123: str) -> datetime: - """Converts an RFC 1123 date string to a UTC datetime. - """ - if not rfc_1123: - return None - - return datetime.strptime(rfc_1123, "%a, %d %b %Y %H:%M:%S %Z") - -def _filetime_to_datetime(filetime: str) -> datetime: - """Converts an MS filetime string to a UTC datetime. "0" indicates None. - If parsing MS Filetime fails, tries RFC 1123 as backup. - """ - if not filetime: - return None - - # Try to convert to MS Filetime - try: - filetime = int(filetime) - if filetime == 0: - return None - - return datetime.fromtimestamp((filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc) - except ValueError: - pass - - # Try RFC 1123 as backup - return _rfc_1123_to_datetime(filetime) + + return parsed_url, sas_token diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 729f46a9e639..72129ecde7ce 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -10,6 +10,7 @@ from typing import ( Any, cast, Dict, List, Optional, Union, Tuple, TYPE_CHECKING) +from urllib.parse import urlparse, quote, unquote from azure.core.async_paging import AsyncItemPaged from azure.core.exceptions import HttpResponseError @@ -28,16 +29,20 @@ from .._generated.models import SignedIdentifier, QueueMessage as GenQueueMessage from .._deserialize import deserialize_queue_properties, deserialize_queue_creation from .._encryption import StorageEncryptionMixin +from .._message_encoding import NoEncodePolicy, NoDecodePolicy +from .._shared.base_client import StorageAccountHostsMixin +from .._shared.base_client_async import AsyncStorageAccountHostsMixin from .._models import QueueMessage, AccessPolicy from .._queue_client import QueueClient as QueueClientBase from ._models import MessagesPaged +from .._queue_client_helpers import _initialize_client if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential from .._models import QueueProperties -class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin): # type: ignore[misc] +class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore[misc] """A client to interact with a specific Queue. :param str account_url: @@ -91,14 +96,32 @@ def __init__( ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) - super(QueueClient, self).__init__( - account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs - ) + + parsed_url, sas_token = _initialize_client(account_url=account_url, queue_name=queue_name, credential=credential) + + self.queue_name = queue_name + self._query_str, credential = self._format_query_string(sas_token, credential) + super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) + + self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) + def _format_url(self, hostname: str) -> str: + """Format the endpoint URL according to the current location + mode hostname. + """ + if isinstance(self.queue_name, str): + queue_name = self.queue_name.encode('UTF-8') + else: + queue_name = self.queue_name + return ( + f"{self.scheme}://{hostname}" + f"/{quote(queue_name)}{self._query_str}") + @distributed_trace_async async def create_queue( # type: ignore[override] self, *, From 16f2a4e17eec61a5682046383a9dfbc62a5bd6c8 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Thu, 22 Jun 2023 12:27:12 -0700 Subject: [PATCH 12/71] Fix failing tests --- .../storage/queue/aio/_queue_client_async.py | 92 ++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 72129ecde7ce..fa9490cef6c9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -12,6 +12,8 @@ TYPE_CHECKING) from urllib.parse import urlparse, quote, unquote +from typing_extensions import Self + from azure.core.async_paging import AsyncItemPaged from azure.core.exceptions import HttpResponseError from azure.core.tracing.decorator import distributed_trace @@ -30,7 +32,7 @@ from .._deserialize import deserialize_queue_properties, deserialize_queue_creation from .._encryption import StorageEncryptionMixin from .._message_encoding import NoEncodePolicy, NoDecodePolicy -from .._shared.base_client import StorageAccountHostsMixin +from .._shared.base_client import StorageAccountHostsMixin, parse_connection_str from .._shared.base_client_async import AsyncStorageAccountHostsMixin from .._models import QueueMessage, AccessPolicy from .._queue_client import QueueClient as QueueClientBase @@ -38,7 +40,8 @@ from .._queue_client_helpers import _initialize_client if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential + from azure.core.credentials_async import AsyncTokenCredential from .._models import QueueProperties @@ -91,7 +94,7 @@ class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, Stora def __init__( self, account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) @@ -122,6 +125,89 @@ def _format_url(self, hostname: str) -> str: f"{self.scheme}://{hostname}" f"/{quote(queue_name)}{self._query_str}") + @classmethod + def from_queue_url( + cls, queue_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: + """A client to interact with a specific Queue. + + :param str queue_url: The full URI to the queue, including SAS token if used. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential + - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :returns: A queue client. + :rtype: ~azure.storage.queue.QueueClient + """ + try: + if not queue_url.lower().startswith('http'): + queue_url = "https://" + queue_url + except AttributeError: + raise ValueError("Queue URL must be a string.") + parsed_url = urlparse(queue_url.rstrip('/')) + + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {queue_url}") + + queue_path = parsed_url.path.lstrip('/').split('/') + account_path = "" + if len(queue_path) > 1: + account_path = "/" + "/".join(queue_path[:-1]) + account_url = ( + f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" + f"{account_path}?{parsed_url.query}") + queue_name = unquote(queue_path[-1]) + if not queue_name: + raise ValueError("Invalid URL. Please provide a URL with a valid queue name") + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + + @classmethod + def from_connection_string( + cls, conn_str: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: + """Create QueueClient from a Connection String. + + :param str conn_str: + A connection string to an Azure Storage account. + :param queue_name: The queue name. + :type queue_name: str + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token, or the connection string already has shared + access key values. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + Credentials provided here will take precedence over those in the connection string. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :returns: A queue client. + :rtype: ~azure.storage.queue.QueueClient + + .. admonition:: Example: + + .. literalinclude:: ../samples/queue_samples_message.py + :start-after: [START create_queue_client_from_connection_string] + :end-before: [END create_queue_client_from_connection_string] + :language: python + :dedent: 8 + :caption: Create the queue client from connection string. + """ + account_url, secondary, credential = parse_connection_str( + conn_str, credential, 'queue') + if 'secondary_hostname' not in kwargs: + kwargs['secondary_hostname'] = secondary + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + @distributed_trace_async async def create_queue( # type: ignore[override] self, *, From 5921e48d4d08c3522dc9d96d985b0d3d73c55628 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Fri, 23 Jun 2023 14:59:30 -0700 Subject: [PATCH 13/71] queue_client done and decoupled --- .../azure/storage/queue/_queue_client_helpers.py | 1 + .../azure/storage/queue/_shared/base_client.py | 3 ++- .../azure/storage/queue/aio/_queue_client_async.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index f1deb92a033d..7679b8b61b73 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -13,6 +13,7 @@ from ._message_encoding import NoEncodePolicy, NoDecodePolicy from ._serialize import get_api_version +# Rename this to like parse URL def _initialize_client(account_url, queue_name, credential): try: if not account_url.lower().startswith('http'): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 842bce21f801..77dc65bf7733 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential _LOGGER = logging.getLogger(__name__) _SERVICE_PARAMS = { @@ -71,7 +72,7 @@ def __init__( self, parsed_url, # type: Any service, # type: str - credential=None, # type: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long + credential=None, # type: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long **kwargs # type: Any ): # type: (...) -> None diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index fa9490cef6c9..5c619a032103 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- # pylint: disable=invalid-overridden-method +# mypy: disable-error-code="misc" import functools import warnings @@ -45,7 +46,7 @@ from .._models import QueueProperties -class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore[misc] +class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): """A client to interact with a specific Queue. :param str account_url: From 6aabf33862683da44dade0650bd4419e92fe8e31 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Fri, 23 Jun 2023 18:37:49 -0700 Subject: [PATCH 14/71] Queue specific files done --- .../azure/storage/queue/_queue_client.py | 35 ++++--- .../storage/queue/_queue_client_helpers.py | 13 +-- .../storage/queue/_queue_service_client.py | 49 ++++------ .../queue/_queue_service_client_helpers.py | 24 +++++ .../storage/queue/aio/_queue_client_async.py | 33 +++---- .../queue/aio/_queue_service_client_async.py | 98 +++++++++++++------ 6 files changed, 146 insertions(+), 106 deletions(-) create mode 100644 sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 522699ef5ffe..c15f88f282af 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -6,30 +6,31 @@ import functools import warnings -from typing import ( # pylint: disable=unused-import - Any, Dict, cast, List, Optional, Union, - TYPE_CHECKING, Tuple) -from urllib.parse import urlparse, quote, unquote +from typing import ( + Any, cast, Dict, List, Optional, + TYPE_CHECKING, Tuple, Union +) +from urllib.parse import quote, unquote, urlparse from typing_extensions import Self from azure.core.exceptions import HttpResponseError from azure.core.paging import ItemPaged from azure.core.tracing.decorator import distributed_trace -from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query -from ._shared.request_handlers import add_metadata_headers, serialize_iso -from ._shared.response_handlers import ( - process_storage_error, - return_response_headers, - return_headers_and_deserialized) -from ._generated import AzureQueueStorage -from ._generated.models import SignedIdentifier, QueueMessage as GenQueueMessage -from ._deserialize import deserialize_queue_properties, deserialize_queue_creation +from ._deserialize import deserialize_queue_creation, deserialize_queue_properties from ._encryption import StorageEncryptionMixin -from ._message_encoding import NoEncodePolicy, NoDecodePolicy +from ._generated import AzureQueueStorage +from ._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier +from ._message_encoding import NoDecodePolicy, NoEncodePolicy from ._models import AccessPolicy, MessagesPaged, QueueMessage +from ._queue_client_helpers import _parse_url from ._serialize import get_api_version -from ._queue_client_helpers import _initialize_client +from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin +from ._shared.request_handlers import add_metadata_headers, serialize_iso +from ._shared.response_handlers import ( + process_storage_error, + return_headers_and_deserialized, + return_response_headers) if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -84,12 +85,10 @@ def __init__( credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: - parsed_url, sas_token = _initialize_client(account_url=account_url, queue_name=queue_name, credential=credential) - + parsed_url, sas_token = _parse_url(account_url=account_url, queue_name=queue_name, credential=credential) self.queue_name = queue_name self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) - self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 7679b8b61b73..24dcccd61cbb 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -4,17 +4,10 @@ # license information. # -------------------------------------------------------------------------- -import sys -from datetime import datetime, timezone -from ._generated import AzureQueueStorage -from ._generated.aio import AzureQueueStorage as AsyncAzureQueueStorage -from urllib.parse import urlparse, quote, unquote -from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query -from ._message_encoding import NoEncodePolicy, NoDecodePolicy -from ._serialize import get_api_version +from urllib.parse import urlparse +from ._shared.base_client import parse_query -# Rename this to like parse URL -def _initialize_client(account_url, queue_name, credential): +def _parse_url(account_url, queue_name, credential): try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 52e977af86a1..daffa132b0be 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -5,10 +5,10 @@ # -------------------------------------------------------------------------- import functools -from typing import ( # pylint: disable=unused-import - Any, Dict, List, Optional, Union, - TYPE_CHECKING) -from urllib.parse import urlparse +from typing import ( + Any, Dict, List, Optional, + TYPE_CHECKING, Union +) from typing_extensions import Self @@ -16,28 +16,26 @@ from azure.core.paging import ItemPaged from azure.core.pipeline import Pipeline from azure.core.tracing.decorator import distributed_trace -from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query -from ._shared.models import LocationMode -from ._shared.response_handlers import process_storage_error +from ._encryption import StorageEncryptionMixin from ._generated import AzureQueueStorage from ._generated.models import StorageServiceProperties -from ._encryption import StorageEncryptionMixin from ._models import ( + QueueProperties, QueuePropertiesPaged, - service_stats_deserialize, service_properties_deserialize, + service_stats_deserialize, ) -from ._serialize import get_api_version from ._queue_client import QueueClient +from ._queue_service_client_helpers import _parse_url +from ._serialize import get_api_version +from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin, TransportWrapper +from ._shared.models import LocationMode +from ._shared.response_handlers import process_storage_error if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential - from ._models import ( - CorsRule, - Metrics, - QueueAnalyticsLogging, - QueueProperties - ) + from ._generated.models import CorsRule + from ._models import Metrics, QueueAnalyticsLogging class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): @@ -93,18 +91,7 @@ def __init__( credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: - try: - if not account_url.lower().startswith('http'): - account_url = "https://" + account_url - except AttributeError: - raise ValueError("Account URL must be a string.") - parsed_url = urlparse(account_url.rstrip('/')) - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {account_url}") - - _, sas_token = parse_query(parsed_url.query) - if not sas_token and not credential: - raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueServiceClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) @@ -421,14 +408,14 @@ def get_queue_client( :dedent: 8 :caption: Get the queue client. """ - try: + if isinstance(queue, QueueProperties): queue_name = queue.name - except AttributeError: + else: queue_name = queue _pipeline = Pipeline( transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access - policies=self._pipeline._impl_policies # pylint: disable = protected-access + policies=self._pipeline._impl_policies # type: ignore # pylint: disable = protected-access ) return QueueClient( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py new file mode 100644 index 000000000000..b45d02de4e34 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py @@ -0,0 +1,24 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from urllib.parse import urlparse +from ._shared.base_client import parse_query + +def _parse_url(account_url, credential): + try: + if not account_url.lower().startswith('http'): + account_url = "https://" + account_url + except AttributeError: + raise ValueError("Account URL must be a string.") + parsed_url = urlparse(account_url.rstrip('/')) + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {account_url}") + + _, sas_token = parse_query(parsed_url.query) + if not sas_token and not credential: + raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + + return parsed_url, sas_token diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 5c619a032103..a92987f95980 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -9,9 +9,10 @@ import functools import warnings from typing import ( - Any, cast, Dict, List, Optional, Union, Tuple, - TYPE_CHECKING) -from urllib.parse import urlparse, quote, unquote + Any, cast, Dict, List, + Optional, Tuple, TYPE_CHECKING, Union +) +from urllib.parse import quote, unquote, urlparse from typing_extensions import Self @@ -19,26 +20,24 @@ from azure.core.exceptions import HttpResponseError from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async +from ._models import MessagesPaged +from .._deserialize import deserialize_queue_creation, deserialize_queue_properties +from .._encryption import StorageEncryptionMixin +from .._generated.aio import AzureQueueStorage +from .._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier +from .._message_encoding import NoDecodePolicy, NoEncodePolicy +from .._models import AccessPolicy, QueueMessage +from .._queue_client_helpers import _parse_url from .._serialize import get_api_version +from .._shared.base_client import parse_connection_str, StorageAccountHostsMixin from .._shared.base_client_async import AsyncStorageAccountHostsMixin from .._shared.policies_async import ExponentialRetry from .._shared.request_handlers import add_metadata_headers, serialize_iso from .._shared.response_handlers import ( - return_response_headers, process_storage_error, return_headers_and_deserialized, + return_response_headers ) -from .._generated.aio import AzureQueueStorage -from .._generated.models import SignedIdentifier, QueueMessage as GenQueueMessage -from .._deserialize import deserialize_queue_properties, deserialize_queue_creation -from .._encryption import StorageEncryptionMixin -from .._message_encoding import NoEncodePolicy, NoDecodePolicy -from .._shared.base_client import StorageAccountHostsMixin, parse_connection_str -from .._shared.base_client_async import AsyncStorageAccountHostsMixin -from .._models import QueueMessage, AccessPolicy -from .._queue_client import QueueClient as QueueClientBase -from ._models import MessagesPaged -from .._queue_client_helpers import _initialize_client if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential @@ -100,9 +99,7 @@ def __init__( ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) - - parsed_url, sas_token = _initialize_client(account_url=account_url, queue_name=queue_name, credential=credential) - + parsed_url, sas_token = _parse_url(account_url=account_url, queue_name=queue_name, credential=credential) self.queue_name = queue_name self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 22bac2bdcf92..c1f04f804b5f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -4,44 +4,43 @@ # license information. # -------------------------------------------------------------------------- # pylint: disable=invalid-overridden-method +# mypy: disable-error-code="misc" import functools from typing import ( - Any, Dict, List, Optional, Union, - TYPE_CHECKING) + Any, Dict, List, Optional, + TYPE_CHECKING, Union +) + +from typing_extensions import Self from azure.core.async_paging import AsyncItemPaged from azure.core.exceptions import HttpResponseError from azure.core.pipeline import AsyncPipeline from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async +from ._models import QueuePropertiesPaged +from ._queue_client_async import QueueClient +from .._encryption import StorageEncryptionMixin +from .._generated.aio import AzureQueueStorage +from .._generated.models import StorageServiceProperties +from .._models import QueueProperties, service_properties_deserialize, service_stats_deserialize +from .._queue_service_client_helpers import _parse_url from .._serialize import get_api_version +from .._shared.base_client import parse_connection_str, StorageAccountHostsMixin from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper from .._shared.policies_async import ExponentialRetry from .._shared.models import LocationMode from .._shared.response_handlers import process_storage_error -from .._generated.aio import AzureQueueStorage -from .._generated.models import StorageServiceProperties -from .._encryption import StorageEncryptionMixin -from .._models import ( - service_stats_deserialize, - service_properties_deserialize, -) -from .._queue_service_client import QueueServiceClient as QueueServiceClientBase -from ._models import QueuePropertiesPaged -from ._queue_client_async import QueueClient if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential - from .._models import ( - CorsRule, - Metrics, - QueueAnalyticsLogging, - QueueProperties, - ) + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential + from azure.core.credentials_async import AsyncTokenCredential + from .._generated.models import CorsRule + from .._models import Metrics, QueueAnalyticsLogging -class QueueServiceClient(AsyncStorageAccountHostsMixin, QueueServiceClientBase, StorageEncryptionMixin): +class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): """A client to interact with the Queue Service at the account level. This client provides operations to retrieve and configure the account properties @@ -87,21 +86,62 @@ class QueueServiceClient(AsyncStorageAccountHostsMixin, QueueServiceClientBase, def __init__( self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) - super(QueueServiceClient, self).__init__( - account_url, - credential=credential, - loop=loop, - **kwargs) + parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) + self._query_str, credential = self._format_query_string(sas_token, credential) + super(QueueServiceClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) + def _format_url(self, hostname: str) -> str: + """Format the endpoint URL according to the current location + mode hostname. + """ + return f"{self.scheme}://{hostname}/{self._query_str}" + + @classmethod + def from_connection_string( + cls, conn_str: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: + """Create QueueServiceClient from a Connection String. + + :param str conn_str: + A connection string to an Azure Storage account. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token, or the connection string already has shared + access key values. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + Credentials provided here will take precedence over those in the connection string. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :returns: A Queue service client. + :rtype: ~azure.storage.queue.QueueClient + + .. admonition:: Example: + + .. literalinclude:: ../samples/queue_samples_authentication.py + :start-after: [START auth_from_connection_string] + :end-before: [END auth_from_connection_string] + :language: python + :dedent: 8 + :caption: Creating the QueueServiceClient with a connection string. + """ + account_url, secondary, credential = parse_connection_str( + conn_str, credential, 'queue') + if 'secondary_hostname' not in kwargs: + kwargs['secondary_hostname'] = secondary + return cls(account_url, credential=credential, **kwargs) + @distributed_trace_async async def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: """Retrieves statistics related to replication for the Queue service. @@ -368,14 +408,14 @@ def get_queue_client( :dedent: 8 :caption: Get the queue client. """ - try: + if isinstance(queue, QueueProperties): queue_name = queue.name - except AttributeError: + else: queue_name = queue _pipeline = AsyncPipeline( transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access - policies=self._pipeline._impl_policies # pylint: disable = protected-access + policies=self._pipeline._impl_policies # type: ignore # pylint: disable = protected-access ) return QueueClient( From 426a2a1ad13f3b29848e57b56b1e3872adf8d71c Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Fri, 30 Jun 2023 17:55:49 -0700 Subject: [PATCH 15/71] Shared, some base_client leftovers, some model leftovers --- .../azure/storage/queue/_deserialize.py | 6 +- .../azure/storage/queue/_encryption.py | 123 +++++++++++------- .../azure/storage/queue/_message_encoding.py | 40 +++--- .../azure/storage/queue/_models.py | 102 ++++++++------- .../storage/queue/_queue_service_client.py | 3 +- .../storage/queue/_shared/authentication.py | 4 +- .../azure/storage/queue/_shared/parser.py | 11 +- .../storage/queue/_shared_access_signature.py | 23 ++-- .../azure/storage/queue/aio/_models.py | 67 +++++----- .../queue/aio/_queue_service_client_async.py | 3 +- 10 files changed, 219 insertions(+), 163 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py index 3e7b240ef16f..c46b9d3f9731 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py @@ -35,7 +35,7 @@ def deserialize_queue_creation( headers: Dict[str, Any] ) -> Dict[str, Any]: response = response.http_response - if response.status_code == 204: + if response.status_code == 204: # type: ignore error_code = StorageErrorCode.queue_already_exists error = ResourceExistsError( message=( @@ -44,7 +44,7 @@ def deserialize_queue_creation( f"Time:{headers['Date']}\n" f"ErrorCode:{error_code}"), response=response) - error.error_code = error_code - error.additional_info = {} + error.error_code = error_code # type: ignore + error.additional_info = {} # type: ignore raise error return headers diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 71e03b758891..0ae580f0a992 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -15,7 +15,7 @@ dumps, loads, ) -from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher @@ -266,7 +266,7 @@ def _encrypt_region(self, data: bytes) -> bytes: return nonce + cipertext_with_tag -def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: +def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> Optional[Union[_EncryptionData, bool]]: """ Determine whether the given encryption data signifies version 2.0. @@ -344,10 +344,13 @@ def get_adjusted_download_range_and_offset( elif encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: start_offset, end_offset = 0, end - - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length + if encryption_data.encrypted_region_info is not None: + if hasattr(encryption_data.encrypted_region_info, 'nonce_length'): + nonce_length = encryption_data.encrypted_region_info.nonce_length + if hasattr(encryption_data.encrypted_region_info, 'data_length'): + data_length = encryption_data.encrypted_region_info.data_length + if hasattr(encryption_data.encrypted_region_info, 'tag_length'): + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length requested_length = end - start @@ -394,9 +397,14 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp :param Optional[_EncryptionData] encryption_data: The encryption data to determine version and sizes. """ if is_encryption_v2(encryption_data): - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length + if encryption_data is not None: + if encryption_data.encrypted_region_info is not None: + if hasattr(encryption_data.encrypted_region_info, 'nonce_length'): + nonce_length = encryption_data.encrypted_region_info.nonce_length + if hasattr(encryption_data.encrypted_region_info, 'data_length'): + data_length = encryption_data.encrypted_region_info.data_length + if hasattr(encryption_data.encrypted_region_info, 'tag_length'): + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length num_regions = math.ceil(size / region_length) @@ -419,19 +427,23 @@ def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], ''' # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: - wrapped_cek = kek.wrap_key(cek) + if hasattr(kek, 'wrap_key'): + wrapped_cek = kek.wrap_key(cek) # For V2, we include the encryption version in the wrapped key. elif version == _ENCRYPTION_PROTOCOL_V2: # We must pad the version to 8 bytes for AES Keywrap algorithms to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') + cek - wrapped_cek = kek.wrap_key(to_wrap) + if hasattr(kek, 'wrap_key'): + wrapped_cek = kek.wrap_key(to_wrap) # Build the encryption_data dict. # Use OrderedDict to comply with Java's ordering requirement. wrapped_content_key = OrderedDict() - wrapped_content_key['KeyId'] = kek.get_kid() + if hasattr(kek, 'get_kid'): + wrapped_content_key['KeyId'] = kek.get_kid() wrapped_content_key['EncryptedKey'] = encode_base64(wrapped_cek) - wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() + if hasattr(kek, 'get_key_wrap_algorithm'): + wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() encryption_agent = OrderedDict() encryption_agent['Protocol'] = version @@ -453,7 +465,7 @@ def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info - encryption_data_dict['KeyWrappingMetadata'] = {'EncryptionLibrary': 'Python ' + VERSION} + encryption_data_dict['KeyWrappingMetadata'] = OrderedDict([('EncryptionLibrary', 'Python ' + VERSION)]) return encryption_data_dict @@ -529,7 +541,7 @@ def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: def _validate_and_unwrap_cek( encryption_data: _EncryptionData, key_encryption_key: object = None, - key_resolver: Callable[[str], bytes] = None + key_resolver: Optional[Callable[[str], bytes]] = None ) -> bytes: ''' Extracts and returns the content_encryption_key stored in the encryption_data object @@ -539,7 +551,7 @@ def _validate_and_unwrap_cek( :param object key_encryption_key: The key_encryption_key used to unwrap the cek. Please refer to high-level service object instance variables for more details. - :param Callable[[str], bytes] key_resolver: + :param Optional[Callable[[str], bytes]] key_resolver: A function used that, given a key_id, will return a key_encryption_key. Please refer to high-level service object instance variables for more details. :return: the content_encryption_key stored in the encryption_data object. @@ -556,7 +568,7 @@ def _validate_and_unwrap_cek( else: raise ValueError('Specified encryption version is not supported.') - content_encryption_key = None + content_encryption_key: Optional[bytes] = None # If the resolver exists, give priority to the key it finds. if key_resolver is not None: @@ -570,30 +582,35 @@ def _validate_and_unwrap_cek( if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') # Will throw an exception if the specified algorithm is not supported. - content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, + if hasattr(key_encryption_key, 'unwrap_key'): + content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, encryption_data.wrapped_content_key.algorithm) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: version_2_bytes = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') - cek_version_bytes = content_encryption_key[:len(version_2_bytes)] - if cek_version_bytes != version_2_bytes: - raise ValueError('The encryption metadata is not valid and may have been modified.') + if content_encryption_key is not None: + cek_version_bytes = content_encryption_key[:len(version_2_bytes)] + if cek_version_bytes != version_2_bytes: + raise ValueError('The encryption metadata is not valid and may have been modified.') - # Remove version from the start of the cek. - content_encryption_key = content_encryption_key[len(version_2_bytes):] + # Remove version from the start of the cek. + content_encryption_key = content_encryption_key[len(version_2_bytes):] _validate_not_none('content_encryption_key', content_encryption_key) - return content_encryption_key + if isinstance(content_encryption_key, bytes): + validated_cek: bytes = content_encryption_key + + return validated_cek def _decrypt_message( message: str, encryption_data: _EncryptionData, key_encryption_key: object = None, - resolver: callable = None + resolver: Optional[Callable] = None ) -> str: ''' Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. @@ -610,7 +627,7 @@ def _decrypt_message( - returns the unwrapped form of the specified symmetric key using the string-specified algorithm. get_kid() - returns a string key id for this key-encryption-key. - :param function resolver(kid): + :param Optional[Callable] resolver(kid): The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The decrypted plaintext. @@ -626,7 +643,7 @@ def _decrypt_message( cipher = _generate_AES_CBC_cipher(content_encryption_key, encryption_data.content_encryption_IV) # decrypt data - decrypted_data = message + decrypted_data: Union[bytes, str] = message decryptor = cipher.decryptor() decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) @@ -639,19 +656,23 @@ def _decrypt_message( if not block_info or not block_info.nonce_length: raise ValueError("Missing required metadata for decryption.") - nonce_length = encryption_data.encrypted_region_info.nonce_length + if encryption_data.encrypted_region_info is not None: + if hasattr(encryption_data.encryption_agent, 'nonce_length'): + nonce_length = encryption_data.encrypted_region_info.nonce_length # First bytes are the nonce nonce = message[:nonce_length] ciphertext_with_tag = message[nonce_length:] aesgcm = AESGCM(content_encryption_key) - decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) + decrypted_data = str(aesgcm.decrypt(nonce, ciphertext_with_tag, None)) #type: ignore else: raise ValueError('Specified encryption version is not supported.') - return decrypted_data + if isinstance(decrypted_data, str): + decrypted_data_as_str = decrypted_data + return decrypted_data_as_str def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple[str, bytes]: @@ -733,14 +754,18 @@ def generate_blob_encryption_data(key_encryption_key: object, version: str) -> T # Initialization vector only needed for V1 if version == _ENCRYPTION_PROTOCOL_V1: initialization_vector = os.urandom(16) - encryption_data = _generate_encryption_data_dict(key_encryption_key, + encryption_data_dict = _generate_encryption_data_dict(key_encryption_key, content_encryption_key, initialization_vector, version) - encryption_data['EncryptionMode'] = 'FullBlob' - encryption_data = dumps(encryption_data) + encryption_data_dict['EncryptionMode'] = 'FullBlob' + encryption_data = dumps(encryption_data_dict) + if isinstance(encryption_data, str): + serialized_encryption_data = encryption_data + if content_encryption_key is not None: + validated_content_encyption_key = content_encryption_key - return content_encryption_key, initialization_vector, encryption_data + return validated_content_encyption_key, initialization_vector, serialized_encryption_data def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements @@ -830,7 +855,8 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements if blob_type == 'PageBlob': unpad = False - cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) + if isinstance(iv, bytes): + cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) decryptor = cipher.decryptor() content = decryptor.update(content) + decryptor.finalize() @@ -845,9 +871,13 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements total_size = len(content) offset = 0 - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length + if encryption_data.encrypted_region_info is not None: + if hasattr(encryption_data.encrypted_region_info, 'nonce_length'): + nonce_length = encryption_data.encrypted_region_info.nonce_length + if hasattr(encryption_data.encrypted_region_info, 'data_length'): + data_length = encryption_data.encrypted_region_info.data_length + if hasattr(encryption_data.encrypted_region_info, 'tag_length'): + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length decrypted_content = bytearray() @@ -868,6 +898,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements # Read the caller requested data from the decrypted content return decrypted_content[start_offset:end_offset] + raise ValueError('Specified encryption version is not supported.') def get_blob_encryptor_and_padder( @@ -910,7 +941,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str # Queue encoding functions all return unicode strings, and encryption should # operate on binary strings. - message = message.encode('utf-8') + message_as_bytes = message.encode('utf-8') if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks @@ -921,7 +952,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str # PKCS7 with 16 byte blocks ensures compatibility with AES. padder = PKCS7(128).padder() - padded_data = padder.update(message) + padder.finalize() + padded_data = padder.update(message_as_bytes) + padder.finalize() # Encrypt the data. encryptor = cipher.encryptor() @@ -937,7 +968,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str aesgcm = AESGCM(content_encryption_key) # Returns ciphertext + tag - cipertext_with_tag = aesgcm.encrypt(nonce, message, None) + cipertext_with_tag = aesgcm.encrypt(nonce, message_as_bytes, None) encrypted_data = nonce + cipertext_with_tag else: @@ -982,10 +1013,10 @@ def decrypt_queue_message( response = response.http_response try: - message = loads(message) + message_dict = loads(message) - encryption_data = _dict_to_encryption_data(message['EncryptionData']) - decoded_data = decode_base64_to_bytes(message['EncryptedMessageContents']) + encryption_data = _dict_to_encryption_data(message_dict['EncryptionData']) + decoded_data = decode_base64_to_bytes(message_dict['EncryptedMessageContents']) except (KeyError, ValueError): # Message was not json formatted and so was not encrypted # or the user provided a json formatted message @@ -995,9 +1026,9 @@ def decrypt_queue_message( 'Encryption required, but received message does not contain appropriate metatadata. ' + \ 'Message was either not encrypted or metadata was incorrect.') - return message + return message_dict try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore except Exception as error: raise HttpResponseError( message="Decryption failed.", diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 5e35f445d3f5..8a9246bf5da6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -6,7 +6,7 @@ # pylint: disable=unused-argument from base64 import b64encode, b64decode -from typing import Any, Callable, Dict, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union from azure.core.exceptions import DecodeError @@ -18,6 +18,15 @@ class MessageEncodePolicy(object): + require_encryption: Optional[bool] = None + """Indicates whether a retention policy is enabled for the storage service.""" + encryption_version: Optional[str] = None + """Indicates whether a retention policy is enabled for the storage service.""" + key_encryption_key: Optional[object] = None + """Indicates whether a retention policy is enabled for the storage service.""" + resolver: Optional[Callable[[str], bytes]] = None + """Indicates whether a retention policy is enabled for the storage service.""" + def __init__(self) -> None: self.require_encryption = False self.encryption_version = None @@ -27,7 +36,7 @@ def __init__(self) -> None: def __call__(self, content: Any) -> str: if content: content = self.encode(content) - if self.key_encryption_key is not None: + if self.key_encryption_key is not None and isinstance(self.encryption_version, str): content = encrypt_queue_message(content, self.key_encryption_key, self.encryption_version) return content @@ -56,17 +65,18 @@ def __init__(self): self.resolver = None def __call__(self, response: "PipelineResponse", obj: object, headers: Dict[str, Any]) -> object: - for message in obj: - if message.message_text in [None, "", b""]: - continue - content = message.message_text - if (self.key_encryption_key is not None) or (self.resolver is not None): - content = decrypt_queue_message( - content, response, - self.require_encryption, - self.key_encryption_key, - self.resolver) - message.message_text = self.decode(content, response) + if hasattr(obj, '__iter__'): + for message in obj: + if message.message_text in [None, "", b""]: + continue + content = message.message_text + if (self.key_encryption_key is not None) or (self.resolver is not None): + content = decrypt_queue_message( + content, response, + self.require_encryption, + self.key_encryption_key, + self.resolver) + message.message_text = self.decode(content, response) return obj def configure(self, require_encryption: bool, key_encryption_key: object, resolver: Callable[[str], bytes]) -> None: @@ -74,7 +84,7 @@ def configure(self, require_encryption: bool, key_encryption_key: object, resolv self.key_encryption_key = key_encryption_key self.resolver = resolver - def decode(self, content: Any, response: "PipelineResponse") -> str: + def decode(self, content: Any, response: "PipelineResponse") -> Union[bytes, str]: raise NotImplementedError("Must be implemented by child class.") @@ -99,7 +109,7 @@ class TextBase64DecodePolicy(MessageDecodePolicy): support UTF-8. """ - def decode(self, content: str, response: "PipelineResponse") -> bytes: + def decode(self, content: str, response: "PipelineResponse") -> str: try: return b64decode(content.encode('utf-8')).decode('utf-8') except (ValueError, TypeError) as error: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 282c7a0def0d..97ccaf863411 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -42,10 +42,10 @@ class RetentionPolicy(GeneratedRetentionPolicy): enabled: bool = False """Indicates whether a retention policy is enabled for the storage service.""" - days: int = None + days: Optional[int] = None """Indicates the number of days that metrics or logging or soft-deleted data should be retained.""" - def __init__(self, enabled: bool = False, days: int = None) -> None: + def __init__(self, enabled: bool = False, days: Optional[int] = None) -> None: self.enabled = enabled self.days = days if self.enabled and (self.days is None): @@ -120,7 +120,7 @@ class Metrics(GeneratedMetrics): """The version of Storage Analytics to configure.""" enabled: bool = False """Indicates whether metrics are enabled for the service.""" - include_apis: bool + include_apis: Optional[bool] """Indicates whether metrics should generate summary statistics for called API operations.""" retention_policy: RetentionPolicy = RetentionPolicy() """The retention policy for the metrics.""" @@ -248,18 +248,18 @@ class AccessPolicy(GenAccessPolicy): :type start: ~datetime.datetime or str """ - permission: str = None + permission: Optional[str] = None """The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions.""" - expiry: Union["datetime", str] = None + expiry: Optional[Union["datetime", str]] = None # type: ignore """The time at which the shared access signature becomes invalid.""" - start: Union["datetime", str] = None + start: Optional[Union["datetime", str]] = None # type: ignore """The time at which the shared access signature becomes valid.""" def __init__( - self, permission: str = None, - expiry: Union["datetime", str] = None, - start: Union["datetime", str] = None + self, permission: Optional[str] = None, + expiry: Optional[Union["datetime", str]] = None, + start: Optional[Union["datetime", str]] = None ) -> None: self.start = start self.expiry = expiry @@ -269,7 +269,7 @@ def __init__( class QueueMessage(DictMixin): """Represents a queue message.""" - id: str + id: Optional[str] """A GUID value assigned to the message by the Queue service that identifies the message in the queue. This value may be used together with the value of pop_receipt to delete a message from the queue after @@ -284,12 +284,12 @@ class QueueMessage(DictMixin): content: Any = None """The message content. Type is determined by the decode_function set on the service. Default is str.""" - pop_receipt: str + pop_receipt: Optional[str] """A receipt str which can be used together with the message_id element to delete a message from the queue after it has been retrieved with the receive messages operation. Only returned by receive messages operations. Set to None for peek messages.""" - next_visible_on: "datetime" + next_visible_on: Optional["datetime"] """A UTC date value representing the time the message will next be visible. Only returned by receive messages operations. Set to None for peek messages.""" @@ -319,24 +319,24 @@ class MessagesPaged(PageIterator): """An iterable of Queue Messages. :param Callable command: Function to retrieve the next page of items. - :param int results_per_page: The maximum number of messages to retrieve per + :param Optional[int] results_per_page: The maximum number of messages to retrieve per call. - :param int max_messages: The maximum number of messages to retrieve from + :param Optional[int] max_messages: The maximum number of messages to retrieve from the queue. """ command: Callable """Function to retrieve the next page of items.""" - results_per_page: int = None - """A UTC date value representing the time the message expires.""" - max_messages: int = None + results_per_page: Optional[int] = None + """The maximum number of messages to retrieve per call.""" + max_messages: Optional[int] = None """The maximum number of messages to retrieve from the queue.""" def __init__( self, command: Callable, - results_per_page: int = None, - continuation_token: str = None, - max_messages: int = None + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None, + max_messages: Optional[int] = None ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") @@ -349,7 +349,7 @@ def __init__( self.results_per_page = results_per_page self._max_messages = max_messages - def _get_next_cb(self, continuation_token: str) -> Any: + def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: if self._max_messages is not None: if self.results_per_page is None: @@ -373,15 +373,17 @@ def _extract_data_cb(self, messages: Any) -> Tuple[str, List[QueueMessage]]: class QueueProperties(DictMixin): """Queue Properties. - :keyword Dict[str, str] metadata: + :keyword Optional[str] name: + The name of the queue. + :keyword Optional[Dict[str, str]] metadata: A dict containing name-value pairs associated with the queue as metadata. This var is set to None unless the include=metadata param was included for the list queues operation. """ - name: str + name: Optional[str] """The name of the queue.""" - metadata: Dict[str, str] + metadata: Optional[Dict[str, str]] """A dict containing name-value pairs associated with the queue as metadata.""" def __init__(self, **kwargs: Any) -> None: @@ -401,36 +403,34 @@ class QueuePropertiesPaged(PageIterator): """An iterable of Queue properties. :param Callable command: Function to retrieve the next page of items. - :param str prefix: Filters the results to return only queues whose names + :param Optional[str] prefix: Filters the results to return only queues whose names begin with the specified prefix. - :param int results_per_page: The maximum number of queue names to retrieve per + :param Optional[int] results_per_page: The maximum number of queue names to retrieve per call. :param str continuation_token: An opaque continuation token. """ - service_endpoint: str + service_endpoint: Optional[str] """The service URL.""" - prefix: str + prefix: Optional[str] """A queue name prefix being used to filter the list.""" - marker: str + marker: Optional[str] """The continuation token of the current page of results.""" - results_per_page: int = None + results_per_page: Optional[int] = None """The maximum number of results retrieved per API call.""" next_marker: str """The continuation token to retrieve the next page of results.""" - location_mode: str + location_mode: Optional[str] """The location mode being used to list results. The available options include "primary" and "secondary".""" command: Callable """Function to retrieve the next page of items.""" - prefix: str = None - """Filters the results to return only queues whose names begin with the specified prefix.""" - results_per_page: int - """The maximum number of queue names to retrieve per - call.""" - continuation_token: str = None - """An opaque continuation token.""" - - def __init__(self, command, prefix=None, results_per_page=None, continuation_token=None): + + def __init__( + self, command: Callable, + prefix: Optional[str] = None, + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None + ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, self._extract_data_cb, @@ -453,14 +453,22 @@ def _get_next_cb(self, continuation_token: str) -> Any: except HttpResponseError as error: process_storage_error(error) - def _extract_data_cb(self, get_next_return: str) -> Tuple[Optional[str], List[QueueProperties]]: + def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return - self.service_endpoint = self._response.service_endpoint - self.prefix = self._response.prefix - self.marker = self._response.marker - self.results_per_page = self._response.max_results - props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access - return self._response.next_marker or None, props_list + if self._response is not None: + if hasattr(self._response, 'service_endpoint'): + self.service_endpoint = self._response.service_endpoint + if hasattr(self._response, 'prefix'): + self.prefix = self._response.prefix + if hasattr(self._response, 'marker'): + self.marker = self._response.marker + if hasattr(self._response, 'max_results'): + self.results_per_page = self._response.max_results + if hasattr(self._response, 'queue_items'): + props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + if hasattr(self._response, 'next_marker'): + next_marker = self._response.next_marker + return next_marker or None, props_list class QueueSasPermissions(object): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index daffa132b0be..46ba0696028b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -409,7 +409,8 @@ def get_queue_client( :caption: Get the queue client. """ if isinstance(queue, QueueProperties): - queue_name = queue.name + if queue.name is not None: + queue_name = queue.name else: queue_name = queue diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py index 71d103cac92b..89d1f17e725f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py @@ -6,7 +6,7 @@ import logging import re -from typing import List, Tuple +from typing import List, Optional, Tuple from urllib.parse import unquote, urlparse try: @@ -35,7 +35,7 @@ def _wrap_exception(ex, desired_type): return desired_type(msg) # This method attempts to emulate the sorting done by the service -def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]: +def _storage_header_sort(input_headers: List[Tuple[str, Optional[str]]]) -> List[Tuple[str, Optional[str]]]: # Define the custom alphabet for weights custom_weights = "-!#$%&*.^_|~+\"\'(),/`~0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz{}" diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py index a4f9da94cc24..acb1a6a2cd93 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py @@ -6,6 +6,7 @@ import sys from datetime import datetime, timezone +from typing import Optional EPOCH_AS_FILETIME = 116444736000000000 # January 1, 1970 as MS filetime HUNDREDS_OF_NANOSECONDS = 10000000 @@ -23,7 +24,7 @@ def _str(value): def _to_utc_datetime(value): return value.strftime('%Y-%m-%dT%H:%M:%SZ') -def _rfc_1123_to_datetime(rfc_1123: str) -> datetime: +def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: """Converts an RFC 1123 date string to a UTC datetime. """ if not rfc_1123: @@ -31,7 +32,7 @@ def _rfc_1123_to_datetime(rfc_1123: str) -> datetime: return datetime.strptime(rfc_1123, "%a, %d %b %Y %H:%M:%S %Z") -def _filetime_to_datetime(filetime: str) -> datetime: +def _filetime_to_datetime(filetime: str) -> Optional[datetime]: """Converts an MS filetime string to a UTC datetime. "0" indicates None. If parsing MS Filetime fails, tries RFC 1123 as backup. """ @@ -40,11 +41,11 @@ def _filetime_to_datetime(filetime: str) -> datetime: # Try to convert to MS Filetime try: - filetime = int(filetime) - if filetime == 0: + temp_filetime = int(filetime) + if temp_filetime == 0: return None - return datetime.fromtimestamp((filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc) + return datetime.fromtimestamp((temp_filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc) except ValueError: pass diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py index 844d88a09f7c..3ddcdff37fc3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py @@ -41,25 +41,26 @@ def __init__(self, account_name: str, account_key: str) -> None: def generate_queue( self, queue_name: str, - permission: "QueueSasPermissions" = None, - expiry: Union["datetime", str] = None, - start: Union["datetime", str] = None, - policy_id: str = None, - ip: str = None, - protocol: str = None + permission: Optional[Union["QueueSasPermissions", str]] = None, + expiry: Optional[Union["datetime", str]] = None, + start: Optional[Union["datetime", str]] = None, + policy_id: Optional[str] = None, + ip: Optional[str] = None, + protocol: Optional[str] = None ) -> str: ''' Generates a shared access signature for the queue. Use the returned signature with the sas_token parameter of QueueService. :param str queue_name: Name of queue. - :param QueueSasPermissions permission: + :param permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. Permissions must be ordered read, add, update, process. Required unless an id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. + :type permission: ~azure.storage.queue.QueueSasPermissions or str :param expiry: The time at which the shared access signature becomes invalid. Required unless an id is given referencing a stored access policy @@ -98,7 +99,7 @@ def generate_queue( class _QueueSharedAccessHelper(_SharedAccessHelper): - def add_resource_signature(self, account_name: str, account_key: str, path: str) -> str: # pylint: disable=arguments-differ + def add_resource_signature(self, account_name: str, account_key: str, path: str): # pylint: disable=arguments-differ def get_value_to_append(query): return_value = self.query_dict.get(query) or '' return return_value + '\n' @@ -148,9 +149,10 @@ def generate_account_sas( The account key, also called shared key or access key, to generate the shared access signature. :param ~azure.storage.queue.ResourceTypes resource_types: Specifies the resource types that are accessible with the account SAS. - :param ~azure.storage.queue.AccountSasPermissions permission: + :param permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. + :type permission: ~azure.storage.queue.AccountSasPermissions or str :param expiry: The time at which the shared access signature becomes invalid. Required unless an id is given referencing a stored access policy @@ -210,12 +212,13 @@ def generate_queue_sas( The name of the queue. :param str account_key: The account key, also called shared key or access key, to generate the shared access signature. - :param ~azure.storage.queue.QueueSasPermissions permission: + :param permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. Required unless a policy_id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. + :type permission: ~azure.storage.queue.QueueSasPermissions or str :param expiry: The time at which the shared access signature becomes invalid. Required unless a policy_id is given referencing a stored access policy diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index 4cf9b1877a09..66d517ff8091 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -6,7 +6,7 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, AsyncIterator, Callable, List, Optional, Tuple from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError from .._shared.response_handlers import ( @@ -19,24 +19,24 @@ class MessagesPaged(AsyncPageIterator): """An iterable of Queue Messages. :param Callable command: Function to retrieve the next page of items. - :param int results_per_page: The maximum number of messages to retrieve per + :param Optional[int] results_per_page: The maximum number of messages to retrieve per call. - :param int max_messages: The maximum number of messages to retrieve from + :param Optional[int] max_messages: The maximum number of messages to retrieve from the queue. """ command: Callable """Function to retrieve the next page of items.""" - results_per_page: int = None + results_per_page: Optional[int] = None """A UTC date value representing the time the message expires.""" - max_messages: int = None + max_messages: Optional[int] = None """The maximum number of messages to retrieve from the queue.""" def __init__( self, command: Callable, - results_per_page: int = None, - continuation_token: str = None, - max_messages: int = None + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None, + max_messages: Optional[int] = None ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") @@ -49,7 +49,7 @@ def __init__( self.results_per_page = results_per_page self._max_messages = max_messages - async def _get_next_cb(self, continuation_token: str) -> Any: + async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: if self._max_messages is not None: if self.results_per_page is None: @@ -61,7 +61,7 @@ async def _get_next_cb(self, continuation_token: str) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, messages: QueueMessage) -> Tuple[str, List[QueueMessage]]: + async def _extract_data_cb(self, messages: Any) -> Tuple[str, AsyncIterator[List[QueueMessage]]]: # There is no concept of continuation token, so raising on my own condition if not messages: raise StopAsyncIteration("End of paging") @@ -76,38 +76,31 @@ class QueuePropertiesPaged(AsyncPageIterator): :param Callable command: Function to retrieve the next page of items. :param str prefix: Filters the results to return only queues whose names begin with the specified prefix. - :param int results_per_page: The maximum number of queue names to retrieve per + :param Optional[int] results_per_page: The maximum number of queue names to retrieve per call. :param str continuation_token: An opaque continuation token. """ - service_endpoint: str + service_endpoint: Optional[str] """The service URL.""" - prefix: str + prefix: Optional[str] """A queue name prefix being used to filter the list.""" - marker: str + marker: Optional[str] """The continuation token of the current page of results.""" - results_per_page: int = None + results_per_page: Optional[int] = None """The maximum number of results retrieved per API call.""" next_marker: str """The continuation token to retrieve the next page of results.""" - location_mode: str + location_mode: Optional[str] """The location mode being used to list results. The available options include "primary" and "secondary".""" command: Callable """Function to retrieve the next page of items.""" - prefix: str = None - """Filters the results to return only queues whose names begin with the specified prefix.""" - results_per_page: int - """The maximum number of queue names to retrieve per - call.""" - continuation_token: str = None - """An opaque continuation token.""" def __init__( self, command: Callable, - prefix: str = None, - results_per_page: int = None, - continuation_token: str = None + prefix: Optional[str] = None, + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, @@ -131,11 +124,19 @@ async def _get_next_cb(self, continuation_token: str) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, get_next_return: str) -> Tuple[Optional[str], List[QueueProperties]]: + async def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return - self.service_endpoint = self._response.service_endpoint - self.prefix = self._response.prefix - self.marker = self._response.marker - self.results_per_page = self._response.max_results - props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access - return self._response.next_marker or None, props_list + if self._response is not None: + if hasattr(self._response, 'service_endpoint'): + self.service_endpoint = self._response.service_endpoint + if hasattr(self._response, 'prefix'): + self.prefix = self._response.prefix + if hasattr(self._response, 'marker'): + self.marker = self._response.marker + if hasattr(self._response, 'max_results'): + self.results_per_page = self._response.max_results + if hasattr(self._response, 'queue_items'): + props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + if hasattr(self._response, 'next_marker'): + next_marker = self._response.next_marker + return next_marker or None, props_list diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index c1f04f804b5f..35c3819acc0f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -409,7 +409,8 @@ def get_queue_client( :caption: Get the queue client. """ if isinstance(queue, QueueProperties): - queue_name = queue.name + if queue.name is not None: + queue_name = queue.name else: queue_name = queue From e15e1953b3c31e769b067ca50b4f2eb432c78d22 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Fri, 14 Jul 2023 14:44:35 -0700 Subject: [PATCH 16/71] 3 _model 1 response_handlers --- .../azure-storage-queue/azure/__init__.py | 2 +- .../azure/storage/__init__.py | 2 +- .../azure/storage/queue/_models.py | 6 +-- .../storage/queue/_shared/base_client.py | 46 ++++++++++++------- .../queue/_shared/base_client_async.py | 41 +++++++++++------ .../azure/storage/queue/_shared/policies.py | 25 ++++------ .../storage/queue/_shared/policies_async.py | 9 ++-- .../queue/_shared/response_handlers.py | 2 +- .../azure/storage/queue/aio/_models.py | 6 +-- .../samples/network_activity_logging.py | 5 +- .../samples/queue_samples_authentication.py | 39 ++++++++++++---- .../queue_samples_authentication_async.py | 37 +++++++++++---- .../samples/queue_samples_hello_world.py | 6 ++- .../queue_samples_hello_world_async.py | 6 ++- .../samples/queue_samples_message.py | 39 ++++++++++------ .../samples/queue_samples_message_async.py | 34 +++++++++----- .../samples/queue_samples_service.py | 17 ++++--- .../samples/queue_samples_service_async.py | 11 +++-- 18 files changed, 213 insertions(+), 120 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/__init__.py b/sdk/storage/azure-storage-queue/azure/__init__.py index 59cb70146572..0d1f7edf5dc6 100644 --- a/sdk/storage/azure-storage-queue/azure/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: str +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-queue/azure/storage/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/__init__.py index 59cb70146572..0d1f7edf5dc6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/storage/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: str +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 97ccaf863411..236afdcbd399 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -443,7 +443,7 @@ def __init__( self.results_per_page = results_per_page self.location_mode = None - def _get_next_cb(self, continuation_token: str) -> Any: + def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: return self._command( marker=continuation_token or None, @@ -453,7 +453,7 @@ def _get_next_cb(self, continuation_token: str) -> Any: except HttpResponseError as error: process_storage_error(error) - def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[Optional[str], List[QueueProperties]]: + def _extract_data_cb(self, get_next_return: Any) -> Tuple[str, List[QueueProperties]]: self.location_mode, self._response = get_next_return if self._response is not None: if hasattr(self._response, 'service_endpoint'): @@ -468,7 +468,7 @@ def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[Optional[s props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access if hasattr(self._response, 'next_marker'): next_marker = self._response.next_marker - return next_marker or None, props_list + return next_marker, props_list class QueueSasPermissions(object): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 77dc65bf7733..cb303b2916da 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -57,6 +57,7 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential + from azure.core.rest import HttpRequest _LOGGER = logging.getLogger(__name__) _SERVICE_PARAMS = { @@ -217,7 +218,7 @@ def _format_query_string(self, sas_token, credential, snapshot=None, share_snaps def _create_pipeline(self, credential, **kwargs): # type: (Any, **Any) -> Tuple[Configuration, Pipeline] - self._credential_policy = None + self._credential_policy: Any = None if hasattr(credential, "get_token"): self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) elif isinstance(credential, SharedKeyCredentialPolicy): @@ -233,8 +234,8 @@ def _create_pipeline(self, credential, **kwargs): config.transport = kwargs.get("transport") # type: ignore kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) - if not config.transport: - config.transport = RequestsTransport(**kwargs) + if not hasattr(config, 'transport'): + config.transport = RequestsTransport(**kwargs) # type: ignore policies = [ QueueMessagePolicy(), config.proxy_policy, @@ -253,21 +254,22 @@ def _create_pipeline(self, credential, **kwargs): HttpLoggingPolicy(**kwargs) ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore return config, Pipeline(config.transport, policies=policies) def _batch_send( self, - *reqs, # type: HttpRequest + *reqs: HttpRequest, **kwargs - ): + ) -> None: """Given a series of request, do a Storage batch call. """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) batch_id = str(uuid.uuid1()) - request = self._client._client.post( # pylint: disable=protected-access + if hasattr(self, '_client'): + request = self._client._client.post( # pylint: disable=protected-access url=( f'{self.scheme}://{self.primary_hostname}/' f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" @@ -417,26 +419,36 @@ def create_configuration(**kwargs): config.proxy_policy = ProxyPolicy(**kwargs) # Storage settings - config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024) - config.copy_polling_interval = 15 + if hasattr(config, 'max_single_put_size'): + config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024) + if hasattr(config, 'copy_polling_interval'): + config.copy_polling_interval = 15 # Block blob uploads - config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024) - config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) - config.use_byte_buffer = kwargs.get("use_byte_buffer", False) + if hasattr(config, 'max_block_size'): + config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024) + if hasattr(config, 'min_large_block_upload_threshold'): + config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) + if hasattr(config, 'use_byte_buffer'): + config.use_byte_buffer = kwargs.get("use_byte_buffer", False) # Page blob uploads - config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024) + if hasattr(config, 'max_page_size'): + config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024) # Datalake file uploads - config.min_large_chunk_upload_threshold = kwargs.get("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) + if hasattr(config, 'min_large_chunk_upload_threshold'): + config.min_large_chunk_upload_threshold = kwargs.get("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) # Blob downloads - config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024) - config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024) + if hasattr(config, 'max_single_get_size'): + config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024) + if hasattr(config, 'max_chunk_get_size'): + config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024) # File uploads - config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024) + if hasattr(config, 'max_range_size'): + config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024) return config diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 11fef6ea29b8..a995be176cc7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -67,7 +67,7 @@ async def close(self): def _create_pipeline(self, credential, **kwargs): # type: (Any, **Any) -> Tuple[Configuration, AsyncPipeline] - self._credential_policy = None + self._credential_policy: Optional[Union[AsyncBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy]] = None if hasattr(credential, 'get_token'): self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) elif isinstance(credential, SharedKeyCredentialPolicy): @@ -79,15 +79,17 @@ def _create_pipeline(self, credential, **kwargs): config = kwargs.get('_configuration') or create_configuration(**kwargs) if kwargs.get('_pipeline'): return config, kwargs['_pipeline'] - config.transport = kwargs.get('transport') # type: ignore + transport = kwargs.get('transport') kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) - if not config.transport: + if not transport: try: from azure.core.pipeline.transport import AioHttpTransport except ImportError: raise ImportError("Unable to create async transport. Please check aiohttp is installed.") - config.transport = AioHttpTransport(**kwargs) + transport = AioHttpTransport(**kwargs) + if hasattr(self, '_hosts'): + hosts = self._hosts policies = [ QueueMessagePolicy(), config.headers_policy, @@ -98,7 +100,7 @@ def _create_pipeline(self, credential, **kwargs): self._credential_policy, ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), - StorageHosts(hosts=self._hosts, **kwargs), # type: ignore + StorageHosts(hosts=hosts, **kwargs), config.retry_policy, config.logging_policy, AsyncStorageResponseHook(**kwargs), @@ -106,8 +108,10 @@ def _create_pipeline(self, credential, **kwargs): HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") - return config, AsyncPipeline(config.transport, policies=policies) + policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore + if hasattr(config, 'transport'): + config.transport = transport + return config, AsyncPipeline(transport, policies=policies) async def _batch_send( self, @@ -118,20 +122,28 @@ async def _batch_send( """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) - request = self._client._client.post( # pylint: disable=protected-access + if hasattr(self, '_client'): + client = self._client + if hasattr(self, 'scheme'): + scheme = self.scheme + if hasattr(self, 'primary_hostname'): + primary_hostname = self.primary_hostname + if hasattr(self, 'api_version'): + api_version = self.api_version + request = client._client.post( url=( - f'{self.scheme}://{self.primary_hostname}/' + f'{scheme}://{primary_hostname}/' f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), headers={ - 'x-ms-version': self.api_version + 'x-ms-version': api_version } ) policies = [StorageHeadersPolicy()] if self._credential_policy: - policies.append(self._credential_policy) + policies.append(self._credential_policy) # type: ignore request.set_multipart_mixed( *reqs, @@ -139,9 +151,10 @@ async def _batch_send( enforce_https=False ) - pipeline_response = await self._pipeline.run( - request, **kwargs - ) + if hasattr(self, '_pipeline'): + pipeline_response = await self._pipeline.run( + request, **kwargs + ) response = pipeline_response.http_response try: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 0cd4eaeb7664..73efc1695775 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -135,8 +135,7 @@ def on_request(self, request): class StorageHeadersPolicy(HeadersPolicy): request_id_header_name = 'x-ms-client-request-id' - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: PipelineRequest)-> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) request.http_request.headers['x-ms-date'] = current_time @@ -166,8 +165,7 @@ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument self.hosts = hosts super(StorageHosts, self).__init__() - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: PipelineRequest) -> None: request.context.options['hosts'] = self.hosts parsed_url = urlparse(request.http_request.url) @@ -202,8 +200,7 @@ def __init__(self, logging_enable=False, **kwargs): self.logging_body = kwargs.pop("logging_body", False) super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: PipelineRequest) -> None: http_request = request.http_request options = request.context.options self.logging_body = self.logging_body or options.pop("logging_body", False) @@ -243,8 +240,7 @@ def on_request(self, request): except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse, Any) -> None + def on_response(self, request: PipelineRequest, response: PipelineResponse) -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return @@ -287,8 +283,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument self._request_callback = kwargs.get('raw_request_hook') super(StorageRequestHook, self).__init__() - def on_request(self, request): - # type: (PipelineRequest, **Any) -> PipelineResponse + def on_request(self, request: PipelineRequest) -> None: request_callback = request.context.options.pop('raw_request_hook', self._request_callback) if request_callback: request_callback(request) @@ -334,9 +329,10 @@ def send(self, request): elif should_update_counts and upload_stream_current is not None: upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) for pipeline_obj in [request, response]: - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, 'context'): + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current if response_callback: response_callback(response) request.context['response_callback'] = response_callback @@ -379,8 +375,7 @@ def get_content_md5(data): return md5.digest() - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: PipelineRequest) -> None: validate_content = request.context.options.pop('validate_content', False) if validate_content and request.http_request.method != 'GET': computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index b0eae9f1c421..54de66a095c9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -80,12 +80,13 @@ async def send(self, request): elif should_update_counts and upload_stream_current is not None: upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) for pipeline_obj in [request, response]: - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, 'context'): + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): - await response_callback(response) + await response_callback(response) #type: ignore else: response_callback(response) request.context['response_callback'] = response_callback diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py index dbc82124544e..056d6aacd282 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py @@ -90,7 +90,7 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error) -> NoReturn: # pylint:disable=too-many-statements +def process_storage_error(storage_error) -> NoReturn: # pylint:disable=too-many-statements raise_error = HttpResponseError serialized = False if not storage_error.response or storage_error.response.status_code in [200, 204]: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index 66d517ff8091..a616bf7d920f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -61,7 +61,7 @@ async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, messages: Any) -> Tuple[str, AsyncIterator[List[QueueMessage]]]: + async def _extract_data_cb(self, messages: Any) -> Tuple[str, AsyncIterator[QueueMessage]]: # There is no concept of continuation token, so raising on my own condition if not messages: raise StopAsyncIteration("End of paging") @@ -124,7 +124,7 @@ async def _get_next_cb(self, continuation_token: str) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[Optional[str], List[QueueProperties]]: + async def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[str, List[QueueProperties]]: self.location_mode, self._response = get_next_return if self._response is not None: if hasattr(self._response, 'service_endpoint'): @@ -139,4 +139,4 @@ async def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[Opti props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access if hasattr(self._response, 'next_marker'): next_marker = self._response.next_marker - return next_marker or None, props_list + return next_marker, props_list diff --git a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py index 3be395e90c39..f1c7e5111319 100644 --- a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py +++ b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py @@ -59,10 +59,11 @@ queues = service_client.list_queues(logging_enable=True) for queue in queues: print('Queue: {}'.format(queue.name)) - queue_client = service_client.get_queue_client(queue.name) + if isinstance(queue.name, str): + queue_client = service_client.get_queue_client(queue.name) messages = queue_client.peek_messages(max_messages=20, logging_enable=True) for message in messages: try: - print(' Message: {}'.format(base64.b64decode(message.content))) + print(' Message: {!r}'.format(base64.b64decode(message.content))) except binascii.Error: print(' Message: {}'.format(message.content)) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py index 76cc089c1e20..3a8fcd0651c9 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py @@ -49,7 +49,8 @@ def authentication_by_connection_string(self): # Instantiate a QueueServiceClient using a connection string # [START auth_from_connection_string] from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [END auth_from_connection_string] # Get information for the Queue Service @@ -59,7 +60,8 @@ def authentication_by_shared_key(self): # Instantiate a QueueServiceClient using a shared access key # [START create_queue_service_client] from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) + if self.account_url is not None: + queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) # [END create_queue_service_client] # Get information for the Queue Service @@ -69,15 +71,23 @@ def authentication_by_active_directory(self): # [START create_queue_service_client_token] # Get a token credential for authentication from azure.identity import ClientSecretCredential + if self.active_directory_tenant_id is not None: + ad_tenant_id = self.active_directory_tenant_id + if self.active_directory_application_id is not None: + ad_application_id = self.active_directory_application_id + if self.active_directory_application_secret is not None: + ad_application_secret = self.active_directory_application_secret + token_credential = ClientSecretCredential( - self.active_directory_tenant_id, - self.active_directory_application_id, - self.active_directory_application_secret + ad_tenant_id, + ad_application_id, + ad_application_secret ) # Instantiate a QueueServiceClient using a token credential from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) + if self.account_url is not None: + queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) # [END create_queue_service_client_token] # Get information for the Queue Service @@ -86,20 +96,29 @@ def authentication_by_active_directory(self): def authentication_by_shared_access_signature(self): # Instantiate a QueueServiceClient using a connection string from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Create a SAS token to use for authentication of a client from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions + if self.account_name is not None: + account_name = self.account_name + if self.access_key is not None: + access_key = self.access_key + sas_token = generate_account_sas( - self.account_name, - self.access_key, + account_name, + access_key, resource_types=ResourceTypes(service=True), permission=AccountSasPermissions(read=True), expiry=datetime.utcnow() + timedelta(hours=1) ) - token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) + if self.account_url is not None: + account_url = self.account_url + + token_auth_queue_service = QueueServiceClient(account_url=account_url, credential=sas_token) # Get information for the Queue Service properties = token_auth_queue_service.get_service_properties() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py index 5b7f97818cc7..14abc6ce59ce 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py @@ -50,7 +50,8 @@ async def authentication_by_connection_string_async(self): # Instantiate a QueueServiceClient using a connection string # [START async_auth_from_connection_string] from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [END async_auth_from_connection_string] # Get information for the Queue Service @@ -61,7 +62,8 @@ async def authentication_by_shared_key_async(self): # Instantiate a QueueServiceClient using a shared access key # [START async_create_queue_service_client] from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) + if self.account_url is not None: + queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) # [END async_create_queue_service_client] # Get information for the Queue Service @@ -72,15 +74,23 @@ async def authentication_by_active_directory_async(self): # [START async_create_queue_service_client_token] # Get a token credential for authentication from azure.identity.aio import ClientSecretCredential + if self.active_directory_tenant_id is not None: + ad_tenant_id = self.active_directory_tenant_id + if self.active_directory_application_id is not None: + ad_application_id = self.active_directory_application_id + if self.active_directory_application_secret is not None: + ad_application_secret = self.active_directory_application_secret + token_credential = ClientSecretCredential( - self.active_directory_tenant_id, - self.active_directory_application_id, - self.active_directory_application_secret + ad_tenant_id, + ad_application_id, + ad_application_secret ) # Instantiate a QueueServiceClient using a token credential from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) + if self.account_url is not None: + queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) # [END async_create_queue_service_client_token] # Get information for the Queue Service @@ -90,20 +100,27 @@ async def authentication_by_active_directory_async(self): async def authentication_by_shared_access_signature_async(self): # Instantiate a QueueServiceClient using a connection string from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Create a SAS token to use for authentication of a client from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions + if self.account_name is not None: + account_name = self.account_name + if self.access_key is not None: + access_key = self.access_key + sas_token = generate_account_sas( - queue_service.account_name, - queue_service.credential.account_key, + account_name, + access_key, resource_types=ResourceTypes(service=True), permission=AccountSasPermissions(read=True), expiry=datetime.utcnow() + timedelta(hours=1) ) - token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) + if self.account_url is not None: + token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) # Get information for the Queue Service async with token_auth_queue_service: diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py index 5bc6cf41f6b1..5c9efcc02c0d 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py @@ -31,7 +31,8 @@ class QueueHelloWorldSamples(object): def create_client_with_connection_string(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Get queue service properties properties = queue_service.get_service_properties() @@ -39,7 +40,8 @@ def create_client_with_connection_string(self): def queue_and_messages_example(self): # Instantiate the QueueClient from a connection string from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") # Create the queue # [START create_queue] diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py index bc33a36a29a9..694ad61bd55e 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py @@ -32,7 +32,8 @@ class QueueHelloWorldSamplesAsync(object): async def create_client_with_connection_string_async(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Get queue service properties async with queue_service: @@ -41,7 +42,8 @@ async def create_client_with_connection_string_async(self): async def queue_and_messages_example_async(self): # Instantiate the QueueClient from a connection string from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") async with queue: # Create the queue diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py index 04f5139b3fdd..a231085824d3 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py @@ -34,7 +34,8 @@ class QueueMessageSamples(object): def set_access_policy(self): # [START create_queue_client_from_connection_string] from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") # [END create_queue_client_from_connection_string] # Create the queue @@ -50,7 +51,7 @@ def set_access_policy(self): access_policy = AccessPolicy() access_policy.start = datetime.utcnow() - timedelta(hours=1) access_policy.expiry = datetime.utcnow() + timedelta(hours=1) - access_policy.permission = QueueSasPermissions(read=True) + access_policy.permission = QueueSasPermissions(read=True) # type: ignore identifiers = {'my-access-policy-id': access_policy} # Set the access policy @@ -86,7 +87,8 @@ def set_access_policy(self): def queue_metadata(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") # Create the queue queue.create_queue() @@ -108,7 +110,8 @@ def queue_metadata(self): def send_and_receive_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") # Create the queue queue.create_queue() @@ -149,7 +152,8 @@ def send_and_receive_messages(self): def list_message_pages(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") # Create the queue queue.create_queue() @@ -182,7 +186,8 @@ def list_message_pages(self): def receive_one_message_from_queue(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") # Create the queue queue.create_queue() @@ -199,8 +204,10 @@ def receive_one_message_from_queue(self): # We should see message 3 if we peek message3 = queue.peek_messages()[0] - print(message1.content) - print(message2.content) + if message1 is not None and hasattr(message1, 'content'): + print(message1.content) + if message2 is not None and hasattr(message2, 'content'): + print(message2.content) print(message3.content) # [END receive_one_message] @@ -210,7 +217,8 @@ def receive_one_message_from_queue(self): def delete_and_clear_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") # Create the queue queue.create_queue() @@ -242,7 +250,8 @@ def delete_and_clear_messages(self): def peek_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") # Create the queue queue.create_queue() @@ -274,7 +283,8 @@ def peek_messages(self): def update_message(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") # Create the queue queue.create_queue() @@ -289,8 +299,10 @@ def update_message(self): # Update the message list_result = next(messages) + if list_result.id is not None: + id = list_result.id message = queue.update_message( - list_result.id, + id, pop_receipt=list_result.pop_receipt, visibility_timeout=0, content="updated") @@ -303,7 +315,8 @@ def update_message(self): def receive_messages_with_max_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue9") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue9") # Create the queue queue.create_queue() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py index f0d8affca3f1..f26816b57da6 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py @@ -34,7 +34,8 @@ class QueueMessageSamplesAsync(object): async def set_access_policy_async(self): # [START async_create_queue_client_from_connection_string] from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") # [END async_create_queue_client_from_connection_string] # Create the queue @@ -51,7 +52,7 @@ async def set_access_policy_async(self): access_policy = AccessPolicy() access_policy.start = datetime.utcnow() - timedelta(hours=1) access_policy.expiry = datetime.utcnow() + timedelta(hours=1) - access_policy.permission = QueueSasPermissions(read=True) + access_policy.permission = QueueSasPermissions(read=True) # type: ignore identifiers = {'my-access-policy-id': access_policy} # Set the access policy @@ -85,7 +86,8 @@ async def set_access_policy_async(self): async def queue_metadata_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") # Create the queue async with queue: @@ -108,7 +110,8 @@ async def queue_metadata_async(self): async def send_and_receive_messages_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") # Create the queue async with queue: @@ -134,7 +137,7 @@ async def send_and_receive_messages_async(self): # Receive messages by batch messages = queue.receive_messages(messages_per_page=5) async for msg_batch in messages.by_page(): - for msg in msg_batch: + for msg in msg_batch: # type: ignore print(msg.content) await queue.delete_message(msg) # [END async_receive_messages] @@ -152,7 +155,8 @@ async def send_and_receive_messages_async(self): async def receive_one_message_from_queue(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") # Create the queue async with queue: @@ -171,8 +175,10 @@ async def receive_one_message_from_queue(self): # We should see message 3 if we peek message3 = await queue.peek_messages() - print(message1.content) - print(message2.content) + if message1 is not None and hasattr(message1, 'content'): + print(message1.content) + if message2 is not None and hasattr(message2, 'content'): + print(message2.content) print(message3[0].content) # [END receive_one_message] @@ -182,7 +188,8 @@ async def receive_one_message_from_queue(self): async def delete_and_clear_messages_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") # Create the queue async with queue: @@ -218,7 +225,8 @@ async def delete_and_clear_messages_async(self): async def peek_messages_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") # Create the queue async with queue: @@ -253,7 +261,8 @@ async def peek_messages_async(self): async def update_message_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") # Create the queue async with queue: @@ -283,7 +292,8 @@ async def update_message_async(self): async def receive_messages_with_max_messages(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") + if self.connection_string is not None: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") # Create the queue async with queue: diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py index 5b6d24eb5cbd..990498d71025 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py @@ -25,15 +25,18 @@ class QueueServiceSamples(object): - connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + if os.getenv("AZURE_STORAGE_CONNECTION_STRING") is not None: + connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") def queue_service_properties(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START set_queue_service_properties] # Create service properties + from typing import List, Optional from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy # Create logging settings @@ -58,10 +61,10 @@ def queue_service_properties(self): allowed_headers=allowed_headers ) - cors = [cors_rule1, cors_rule2] + cors: Optional[List[CorsRule]] = [cors_rule1, cors_rule2] # Set the service properties - queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) + queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) # type: ignore # [END set_queue_service_properties] # [START get_queue_service_properties] @@ -71,7 +74,8 @@ def queue_service_properties(self): def queues_in_account(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START qsc_create_queue] queue_service.create_queue("myqueue1") @@ -98,7 +102,8 @@ def queues_in_account(self): def get_queue_client(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient, QueueClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START get_queue_client] # Get the queue client to interact with a specific queue diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py index 01d79059ae2f..409379af9a45 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py @@ -32,7 +32,8 @@ class QueueServiceSamplesAsync(object): async def queue_service_properties_async(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) async with queue_service: # [START async_set_queue_service_properties] @@ -64,7 +65,7 @@ async def queue_service_properties_async(self): cors = [cors_rule1, cors_rule2] # Set the service properties - await queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) + await queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) # type: ignore # [END async_set_queue_service_properties] # [START async_get_queue_service_properties] @@ -74,7 +75,8 @@ async def queue_service_properties_async(self): async def queues_in_account_async(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) async with queue_service: # [START async_qsc_create_queue] @@ -102,7 +104,8 @@ async def queues_in_account_async(self): async def get_queue_client_async(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient, QueueClient - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + if self.connection_string is not None: + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START async_get_queue_client] # Get the queue client to interact with a specific queue From d9252588747420e7c1efcf7cbc66107b2508b136 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 9 Aug 2023 19:47:59 -0700 Subject: [PATCH 17/71] Mypy congrats, time to fight pylint & CI :D --- .../azure/storage/queue/_deserialize.py | 2 +- .../azure/storage/queue/_encryption.py | 6 ++--- .../azure/storage/queue/_message_encoding.py | 4 ++-- .../azure/storage/queue/_queue_client.py | 2 +- .../storage/queue/_queue_service_client.py | 2 +- .../storage/queue/_shared/base_client.py | 2 +- .../queue/_shared/base_client_async.py | 2 +- .../azure/storage/queue/_shared/policies.py | 18 +++++++-------- .../queue/_shared/response_handlers.py | 2 +- .../azure/storage/queue/aio/_models.py | 23 ++++++++++++++----- .../storage/queue/aio/_queue_client_async.py | 2 +- .../queue/aio/_queue_service_client_async.py | 2 +- 12 files changed, 39 insertions(+), 28 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py index c46b9d3f9731..ff3f207b95e4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py @@ -43,7 +43,7 @@ def deserialize_queue_creation( f"RequestId:{headers['x-ms-request-id']}\n" f"Time:{headers['Date']}\n" f"ErrorCode:{error_code}"), - response=response) + response=response) # type: ignore error.error_code = error_code # type: ignore error.additional_info = {} # type: ignore raise error diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index e846487347ec..984762a94b1e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -1028,8 +1028,8 @@ def decrypt_queue_message( try: message_dict = loads(message) - encryption_data = _dict_to_encryption_data(message['EncryptionData']) - decoded_data = decode_base64_to_bytes(message['EncryptedMessageContents']) + encryption_data = _dict_to_encryption_data(message_dict['EncryptionData']) + decoded_data = decode_base64_to_bytes(message_dict['EncryptedMessageContents']) except (KeyError, ValueError) as exc: # Message was not json formatted and so was not encrypted # or the user provided a json formatted message @@ -1045,5 +1045,5 @@ def decrypt_queue_message( except Exception as error: raise HttpResponseError( message="Decryption failed.", - response=response, + response=response, #type: ignore error=error) from error diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 0959eece2952..523288afe8a8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -116,7 +116,7 @@ def decode(self, content: str, response: "PipelineResponse") -> str: # ValueError for Python 3, TypeError for Python 2 raise DecodeError( message="Message content is not valid base 64.", - response=response, + response=response, #type: ignore error=error) from error @@ -148,7 +148,7 @@ def decode(self, content: str, response: "PipelineResponse") -> bytes: # ValueError for Python 3, TypeError for Python 2 raise DecodeError( message="Message content is not valid base 64.", - response=response, + response=response, #type: ignore error=error) from error diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 9f40564ff0da..85ecba8f66b4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -92,7 +92,7 @@ def __init__( self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._configure_encryption(kwargs) def _format_url(self, hostname: str) -> str: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 5efa84bb8926..153e12bce1f2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -95,7 +95,7 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueServiceClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._configure_encryption(kwargs) def _format_url(self, hostname: str) -> str: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 03a117df549e..624576b6bece 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -413,7 +413,7 @@ def parse_connection_str(conn_str, credential, service): def create_configuration(**kwargs): # type: (**Any) -> Configuration - config = Configuration(**kwargs) + config: Configuration = Configuration(**kwargs) config.headers_policy = StorageHeadersPolicy(**kwargs) config.user_agent_policy = UserAgentPolicy( sdk_moniker=f"storage-{kwargs.pop('storage_sdk')}/{VERSION}", **kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index b99ca5738478..3de5f26921f7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -111,7 +111,7 @@ def _create_pipeline(self, credential, **kwargs): policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore if hasattr(config, 'transport'): config.transport = transport - return config, AsyncPipeline(transport, policies=policies) + return config, AsyncPipeline(transport, policies=policies) #type: ignore # Given a series of request, do a Storage batch call. async def _batch_send( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index e5bff085bcfd..700515d7c210 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -134,7 +134,7 @@ def on_request(self, request): class StorageHeadersPolicy(HeadersPolicy): request_id_header_name = 'x-ms-client-request-id' - def on_request(self, request: PipelineRequest)-> None: + def on_request(self, request: "PipelineRequest")-> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) request.http_request.headers['x-ms-date'] = current_time @@ -164,7 +164,7 @@ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument self.hosts = hosts super(StorageHosts, self).__init__() - def on_request(self, request: PipelineRequest) -> None: + def on_request(self, request: "PipelineRequest") -> None: request.context.options['hosts'] = self.hosts parsed_url = urlparse(request.http_request.url) @@ -199,7 +199,7 @@ def __init__(self, logging_enable=False, **kwargs): self.logging_body = kwargs.pop("logging_body", False) super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) - def on_request(self, request: PipelineRequest) -> None: + def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request options = request.context.options self.logging_body = self.logging_body or options.pop("logging_body", False) @@ -239,7 +239,7 @@ def on_request(self, request: PipelineRequest) -> None: except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request: PipelineRequest, response: PipelineResponse) -> None: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return @@ -282,7 +282,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument self._request_callback = kwargs.get('raw_request_hook') super(StorageRequestHook, self).__init__() - def on_request(self, request: PipelineRequest) -> None: + def on_request(self, request: "PipelineRequest") -> None: request_callback = request.context.options.pop('raw_request_hook', self._request_callback) if request_callback: request_callback(request) @@ -295,7 +295,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument super(StorageResponseHook, self).__init__() def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + # type: ("PipelineRequest") -> "PipelineResponse" # Values could be 0 data_stream_total = request.context.get('data_stream_total') if data_stream_total is None: @@ -374,7 +374,7 @@ def get_content_md5(data): return md5.digest() - def on_request(self, request: PipelineRequest) -> None: + def on_request(self, request: "PipelineRequest") -> None: validate_content = request.context.options.pop('validate_content', False) if validate_content and request.http_request.method != 'GET': computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) @@ -412,7 +412,7 @@ def _set_next_host_location(self, settings, request): A function which sets the next host location on the request, if applicable. :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the next host location. - :param PipelineRequest request: A pipeline request object. + :param "PipelineRequest" request: A pipeline request object. """ if settings['hosts'] and all(settings['hosts'].values()): url = urlparse(request.url) @@ -653,7 +653,7 @@ def __init__(self, credential, **kwargs): super(StorageBearerTokenCredentialPolicy, self).__init__(credential, STORAGE_OAUTH_SCOPE, **kwargs) def on_challenge(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> bool + # type: ("PipelineRequest", "PipelineResponse") -> bool try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py index 1f67cb8c72f4..0103c7a3d0ed 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py @@ -85,7 +85,7 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error) -> NoReturn: # pylint:disable=too-many-statements +def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements raise_error = HttpResponseError serialized = False if not storage_error.response or storage_error.response.status_code in [200, 204]: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index a616bf7d920f..3bb9c7d3efc3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -6,6 +6,8 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called +import asyncio + from typing import Any, AsyncIterator, Callable, List, Optional, Tuple from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError @@ -15,6 +17,14 @@ from .._models import QueueMessage, QueueProperties +async def async_queue_msg_generator(messages: Any): + for q in messages: + yield QueueMessage._from_generated(q) + +async def async_queue_items_generator(items: Any): + for q in items: + yield QueueProperties._from_generated(q) + class MessagesPaged(AsyncPageIterator): """An iterable of Queue Messages. @@ -48,7 +58,7 @@ def __init__( self._command = command self.results_per_page = results_per_page self._max_messages = max_messages - + async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: if self._max_messages is not None: @@ -67,7 +77,8 @@ async def _extract_data_cb(self, messages: Any) -> Tuple[str, AsyncIterator[Queu raise StopAsyncIteration("End of paging") if self._max_messages is not None: self._max_messages = self._max_messages - len(messages) - return "TOKEN_IGNORED", [QueueMessage._from_generated(q) for q in messages] # pylint: disable=protected-access + queue_msg_iterator = await async_queue_msg_generator(messages) # pylint: disable=protected-access + return "TOKEN_IGNORED", queue_msg_iterator class QueuePropertiesPaged(AsyncPageIterator): @@ -114,7 +125,7 @@ def __init__( self.results_per_page = results_per_page self.location_mode = None - async def _get_next_cb(self, continuation_token: str) -> Any: + async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: return await self._command( marker=continuation_token or None, @@ -124,7 +135,7 @@ async def _get_next_cb(self, continuation_token: str) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[str, List[QueueProperties]]: + async def _extract_data_cb(self, get_next_return: Any) -> Tuple[str, AsyncIterator[Any]]: self.location_mode, self._response = get_next_return if self._response is not None: if hasattr(self._response, 'service_endpoint'): @@ -136,7 +147,7 @@ async def _extract_data_cb(self, get_next_return: Tuple[str, Any]) -> Tuple[str, if hasattr(self._response, 'max_results'): self.results_per_page = self._response.max_results if hasattr(self._response, 'queue_items'): - props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + queue_items_iterator = await async_queue_items_generator(self._response.queue_items) # pylint: disable=protected-access, line-too-long if hasattr(self._response, 'next_marker'): next_marker = self._response.next_marker - return next_marker, props_list + return next_marker, queue_items_iterator diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index d19305917acb..679a1f03377c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -107,7 +107,7 @@ def __init__( self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index acefc2ae466e..6d5351d17b9b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -95,7 +95,7 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueServiceClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) From 9c3d99f808a99af8346e171bcab77ebae0452514 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 9 Aug 2023 19:51:27 -0700 Subject: [PATCH 18/71] unused import asyncio --- .../azure-storage-queue/azure/storage/queue/aio/_models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index 3bb9c7d3efc3..aef2f9da4d00 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -6,8 +6,6 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -import asyncio - from typing import Any, AsyncIterator, Callable, List, Optional, Tuple from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError From e00a8257d243b0ee4208e0f361582fe48a75f9fb Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Thu, 10 Aug 2023 20:26:37 -0700 Subject: [PATCH 19/71] New PyLint rules fixes --- .../azure/storage/queue/_encryption.py | 8 +++----- .../azure/storage/queue/_models.py | 8 ++++++++ .../azure/storage/queue/_queue_client.py | 6 +++++- .../storage/queue/_queue_service_client.py | 5 +++++ .../queue/_queue_service_client_helpers.py | 6 +++--- .../azure/storage/queue/_shared/base_client.py | 5 ++++- .../storage/queue/_shared/base_client_async.py | 4 ++-- .../storage/queue/_shared_access_signature.py | 18 +++++++++--------- .../azure/storage/queue/aio/_models.py | 12 ++++++------ .../storage/queue/aio/_queue_client_async.py | 12 +++++++++--- .../queue/aio/_queue_service_client_async.py | 5 +++++ 11 files changed, 59 insertions(+), 30 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 984762a94b1e..8ca20ccdcbd6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -638,7 +638,7 @@ def _decrypt_message( - returns the unwrapped form of the specified symmetric key using the string-specified algorithm. get_kid() - returns a string key id for this key-encryption-key. - :param Optional[Callable] resolver(kid): + :param Optional[Callable] resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The decrypted plaintext. @@ -910,8 +910,6 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements return decrypted_content[start_offset:end_offset] raise ValueError('Specified encryption version is not supported.') - raise ValueError('Specified encryption version is not supported.') - def get_blob_encryptor_and_padder( cek: bytes, @@ -1017,7 +1015,7 @@ def decrypt_queue_message( - returns the unwrapped form of the specified symmetric key usingthe string-specified algorithm. get_kid() - returns a string key id for this key-encryption-key. - :param Callable[[str], bytes] resolver(kid): + :param Callable[[str], bytes] resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The plain text message from the queue message. @@ -1041,7 +1039,7 @@ def decrypt_queue_message( return message_dict try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore # pylint: disable=line-too-long except Exception as error: raise HttpResponseError( message="Decryption failed.", diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 236afdcbd399..dcf0f037751a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -539,6 +539,10 @@ def from_string(cls, permission: str) -> Self: def service_stats_deserialize(generated: Any) -> Dict[str, Any]: """Deserialize a ServiceStats objects into a dict. + + :param Any generated: The service stats returned from the generated code. + :returns: The deserialized ServiceStats as a Dict. + :rtype: Dict[str, Any] """ return { 'geo_replication': { @@ -550,6 +554,10 @@ def service_stats_deserialize(generated: Any) -> Dict[str, Any]: def service_properties_deserialize(generated: Any) -> Dict[str, Any]: """Deserialize a ServiceProperties objects into a dict. + + :param Any generated: The service properties returned from the generated code. + :returns: The deserialized ServiceProperties as a Dict. + :rtype: Dict[str, Any] """ return { 'analytics_logging': QueueAnalyticsLogging._from_generated(generated.logging), # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 85ecba8f66b4..fd6f236d4a92 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -98,6 +98,10 @@ def __init__( def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str """ if isinstance(self.queue_name, str): queue_name = self.queue_name.encode('UTF-8') @@ -946,4 +950,4 @@ def delete_message( **kwargs ) except HttpResponseError as error: - process_storage_error(error) \ No newline at end of file + process_storage_error(error) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 153e12bce1f2..313a91a75c61 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -101,6 +101,10 @@ def __init__( def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str """ return f"{self.scheme}://{hostname}/{self._query_str}" @@ -234,6 +238,7 @@ def set_service_properties( :type cors: list(~azure.storage.queue.CorsRule) :keyword int timeout: The timeout parameter is expressed in seconds. + :rtype: None .. admonition:: Example: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py index b45d02de4e34..24506fb909e7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py @@ -11,8 +11,8 @@ def _parse_url(account_url, credential): try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url - except AttributeError: - raise ValueError("Account URL must be a string.") + except AttributeError as exc: + raise ValueError("Account URL must be a string.") from exc parsed_url = urlparse(account_url.rstrip('/')) if not parsed_url.netloc: raise ValueError(f"Invalid URL: {account_url}") @@ -20,5 +20,5 @@ def _parse_url(account_url, credential): _, sas_token = parse_query(parsed_url.query) if not sas_token and not credential: raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") - + return parsed_url, sas_token diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 624576b6bece..743ec96564f7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -262,10 +262,13 @@ def _create_pipeline(self, credential, **kwargs): # Given a series of request, do a Storage batch call. def _batch_send( self, - *reqs: HttpRequest, + *reqs: "HttpRequest", **kwargs ) -> None: """Given a series of request, do a Storage batch call. + + :param HttpRequest reqs: A collection of HttpRequest objects. + :rtype: None """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 3de5f26921f7..8b15859324e9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -67,7 +67,7 @@ async def close(self): def _create_pipeline(self, credential, **kwargs): # type: (Any, **Any) -> Tuple[Configuration, AsyncPipeline] - self._credential_policy: Optional[Union[AsyncBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy]] = None + self._credential_policy: Optional[Union[AsyncBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy]] = None # pylint: disable=line-too-long if hasattr(credential, 'get_token'): self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) elif isinstance(credential, SharedKeyCredentialPolicy): @@ -129,7 +129,7 @@ async def _batch_send( primary_hostname = self.primary_hostname if hasattr(self, 'api_version'): api_version = self.api_version - request = client._client.post( + request = client._client.post( # pylint: disable=protected-access url=( f'{scheme}://{primary_hostname}/' f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py index 70279de9a5ee..6cfbca0f04d4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py @@ -4,24 +4,24 @@ # license information. # -------------------------------------------------------------------------- -from typing import ( # pylint: disable=unused-import - Union, Optional, TYPE_CHECKING -) +from typing import Any, Optional, TYPE_CHECKING, Union from azure.storage.queue._shared import sign_string from azure.storage.queue._shared.constants import X_MS_VERSION from azure.storage.queue._shared.models import Services -from azure.storage.queue._shared.shared_access_signature import SharedAccessSignature, _SharedAccessHelper, \ - QueryStringConstants +from azure.storage.queue._shared.shared_access_signature import ( + QueryStringConstants, + SharedAccessSignature, + _SharedAccessHelper +) if TYPE_CHECKING: - from datetime import datetime from azure.storage.queue import ( - ResourceTypes, AccountSasPermissions, - QueueSasPermissions + QueueSasPermissions, + ResourceTypes ) - from typing import Any + from datetime import datetime class QueueSharedAccessSignature(SharedAccessSignature): ''' diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index aef2f9da4d00..323b660b7026 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -6,7 +6,7 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -from typing import Any, AsyncIterator, Callable, List, Optional, Tuple +from typing import Any, AsyncIterator, Callable, Optional, Tuple from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError from .._shared.response_handlers import ( @@ -16,12 +16,12 @@ async def async_queue_msg_generator(messages: Any): - for q in messages: - yield QueueMessage._from_generated(q) + for q in messages: + yield QueueMessage._from_generated(q) # pylint: disable=protected-access async def async_queue_items_generator(items: Any): - for q in items: - yield QueueProperties._from_generated(q) + for q in items: + yield QueueProperties._from_generated(q) # pylint: disable=protected-access class MessagesPaged(AsyncPageIterator): """An iterable of Queue Messages. @@ -56,7 +56,7 @@ def __init__( self._command = command self.results_per_page = results_per_page self._max_messages = max_messages - + async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: if self._max_messages is not None: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 679a1f03377c..8b0342c500f3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -114,6 +114,10 @@ def __init__( def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str """ if isinstance(self.queue_name, str): queue_name = self.queue_name.encode('UTF-8') @@ -141,14 +145,15 @@ def from_queue_url( - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ try: if not queue_url.lower().startswith('http'): queue_url = "https://" + queue_url - except AttributeError: - raise ValueError("Queue URL must be a string.") + except AttributeError as exc: + raise ValueError("Queue URL must be a string.") from exc parsed_url = urlparse(queue_url.rstrip('/')) if not parsed_url.netloc: @@ -188,6 +193,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient @@ -940,4 +946,4 @@ async def delete_message( # type: ignore[override] pop_receipt=receipt, timeout=timeout, queue_message_id=message_id, **kwargs ) except HttpResponseError as error: - process_storage_error(error) \ No newline at end of file + process_storage_error(error) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 6d5351d17b9b..ef2aaf307bc5 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -102,6 +102,10 @@ def __init__( def _format_url(self, hostname: str) -> str: """Format the endpoint URL according to the current location mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str """ return f"{self.scheme}://{hostname}/{self._query_str}" @@ -124,6 +128,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A Queue service client. :rtype: ~azure.storage.queue.QueueClient From 5a8409a7f76b70190dcc329d57212b634f007488 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Fri, 11 Aug 2023 17:58:11 -0700 Subject: [PATCH 20/71] Pylint 2 --- .../azure/storage/queue/_queue_client_helpers.py | 6 +++--- .../azure/storage/queue/_queue_service_client.py | 1 + .../azure/storage/queue/_shared/base_client.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 24dcccd61cbb..7c77dcbda1bb 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -11,8 +11,8 @@ def _parse_url(account_url, queue_name, credential): try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url - except AttributeError: - raise ValueError("Account URL must be a string.") + except AttributeError as exc: + raise ValueError("Account URL must be a string.") from exc parsed_url = urlparse(account_url.rstrip('/')) if not queue_name: raise ValueError("Please specify a queue name.") @@ -22,5 +22,5 @@ def _parse_url(account_url, queue_name, credential): _, sas_token = parse_query(parsed_url.query) if not sas_token and not credential: raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") - + return parsed_url, sas_token diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 313a91a75c61..33a969c4941a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -238,6 +238,7 @@ def set_service_properties( :type cors: list(~azure.storage.queue.CorsRule) :keyword int timeout: The timeout parameter is expressed in seconds. + :returns: None :rtype: None .. admonition:: Example: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 743ec96564f7..4e5dfe770718 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -269,6 +269,7 @@ def _batch_send( :param HttpRequest reqs: A collection of HttpRequest objects. :rtype: None + :returns: None """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) From f51bca933105b6958733276b69f4377afc38b601 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Mon, 14 Aug 2023 14:39:53 -0700 Subject: [PATCH 21/71] Fix failing test cases --- .../azure/storage/queue/aio/_models.py | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index 323b660b7026..f9788a252a25 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -6,7 +6,7 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -from typing import Any, AsyncIterator, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError from .._shared.response_handlers import ( @@ -15,14 +15,6 @@ from .._models import QueueMessage, QueueProperties -async def async_queue_msg_generator(messages: Any): - for q in messages: - yield QueueMessage._from_generated(q) # pylint: disable=protected-access - -async def async_queue_items_generator(items: Any): - for q in items: - yield QueueProperties._from_generated(q) # pylint: disable=protected-access - class MessagesPaged(AsyncPageIterator): """An iterable of Queue Messages. @@ -51,7 +43,7 @@ def __init__( super(MessagesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, + self._extract_data_cb, # type: ignore [arg-type] ) self._command = command self.results_per_page = results_per_page @@ -69,14 +61,13 @@ async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, messages: Any) -> Tuple[str, AsyncIterator[QueueMessage]]: + async def _extract_data_cb(self, messages: Any) -> Tuple[str, List[QueueMessage]]: # There is no concept of continuation token, so raising on my own condition if not messages: raise StopAsyncIteration("End of paging") if self._max_messages is not None: self._max_messages = self._max_messages - len(messages) - queue_msg_iterator = await async_queue_msg_generator(messages) # pylint: disable=protected-access - return "TOKEN_IGNORED", queue_msg_iterator + return "TOKEN_IGNORED", [QueueMessage._from_generated(q) for q in messages] # pylint: disable=protected-access class QueuePropertiesPaged(AsyncPageIterator): @@ -113,7 +104,7 @@ def __init__( ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, + self._extract_data_cb, # type: ignore [arg-type] continuation_token=continuation_token or "" ) self._command = command @@ -133,7 +124,7 @@ async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, get_next_return: Any) -> Tuple[str, AsyncIterator[Any]]: + async def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[Any], List[QueueProperties]]: self.location_mode, self._response = get_next_return if self._response is not None: if hasattr(self._response, 'service_endpoint'): @@ -145,7 +136,7 @@ async def _extract_data_cb(self, get_next_return: Any) -> Tuple[str, AsyncIterat if hasattr(self._response, 'max_results'): self.results_per_page = self._response.max_results if hasattr(self._response, 'queue_items'): - queue_items_iterator = await async_queue_items_generator(self._response.queue_items) # pylint: disable=protected-access, line-too-long + props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access if hasattr(self._response, 'next_marker'): next_marker = self._response.next_marker - return next_marker, queue_items_iterator + return next_marker or None, props_list From 655caa727b043030ebf73108f97cb2fce92ee987 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Mon, 14 Aug 2023 17:29:29 -0700 Subject: [PATCH 22/71] Be gone test failures --- .../azure/storage/queue/_shared/base_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 4e5dfe770718..3e61e2df7916 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -236,7 +236,7 @@ def _create_pipeline(self, credential, **kwargs): config.transport = kwargs.get("transport") # type: ignore kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) - if not hasattr(config, 'transport'): + if not config.transport: config.transport = RequestsTransport(**kwargs) # type: ignore policies = [ QueueMessagePolicy(), From 8c687ead943972e807c2568bb24e226367ab18ca Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Tue, 15 Aug 2023 18:31:23 -0700 Subject: [PATCH 23/71] Finally found root cause of infinite loop --- sdk/storage/azure-storage-queue/azure/storage/queue/_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index dcf0f037751a..ff5d827364ed 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -468,7 +468,7 @@ def _extract_data_cb(self, get_next_return: Any) -> Tuple[str, List[QueuePropert props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access if hasattr(self._response, 'next_marker'): next_marker = self._response.next_marker - return next_marker, props_list + return next_marker or None, props_list class QueueSasPermissions(object): From 1b304daab29024acfc353dc292219e488c28a75e Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 16 Aug 2023 19:30:49 -0700 Subject: [PATCH 24/71] All encryption tests passing locally, 32 pylint errors remaining --- .../azure/storage/queue/_encryption.py | 170 ++++++++---------- .../azure/storage/queue/_models.py | 4 +- .../storage/queue/_shared/base_client.py | 9 +- .../queue/_shared/base_client_async.py | 3 +- 4 files changed, 80 insertions(+), 106 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 8ca20ccdcbd6..cbeef8392cba 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -46,12 +46,12 @@ '{0} does not define a complete interface. Value of {1} is either missing or invalid.' -def _validate_not_none(param_name: str, param: Any) -> None: +def _validate_not_none(param_name: str, param: Any): if param is None: raise ValueError(f'{param_name} should not be None.') -def _validate_key_encryption_key_wrap(kek: object) -> None: +def _validate_key_encryption_key_wrap(kek: object): # Note that None is not callable and so will fail the second clause of each check. if not hasattr(kek, 'wrap_key') or not callable(kek.wrap_key): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) @@ -62,7 +62,7 @@ def _validate_key_encryption_key_wrap(kek: object) -> None: class StorageEncryptionMixin(object): - def _configure_encryption(self, kwargs: Any) -> None: + def _configure_encryption(self, kwargs: Any): self.require_encryption = kwargs.get("require_encryption", False) self.encryption_version = kwargs.get("encryption_version", "1.0") self.key_encryption_key = kwargs.get("key_encryption_key") @@ -75,9 +75,9 @@ def _configure_encryption(self, kwargs: Any) -> None: class _EncryptionAlgorithm(object): - ''' + """ Specifies which client encryption algorithm is used. - ''' + """ AES_CBC_256 = 'AES_CBC_256' AES_GCM_256 = 'AES_GCM_256' @@ -268,7 +268,7 @@ def _encrypt_region(self, data: bytes) -> bytes: return nonce + cipertext_with_tag -def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> Optional[Union[_EncryptionData, bool]]: +def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> Optional[Union[_EncryptionData, bool]]: """ Determine whether the given encryption data signifies version 2.0. @@ -302,11 +302,10 @@ def get_adjusted_upload_size(length: int, encryption_version: str) -> int: def get_adjusted_download_range_and_offset( - start: int, - end: int, - length: int, - encryption_data: Optional[_EncryptionData] -) -> Tuple[Tuple[int, int], Tuple[int, int]]: + start: int, + end: int, + length: int, + encryption_data: Optional[_EncryptionData]) -> Tuple[Tuple[int, int], Tuple[int, int]]: """ Gets the new download range and offsets into the decrypted data for the given user-specified range. The new download range will include all @@ -351,13 +350,10 @@ def get_adjusted_download_range_and_offset( elif encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: start_offset, end_offset = 0, end - if encryption_data.encrypted_region_info is not None: - if hasattr(encryption_data.encrypted_region_info, 'nonce_length'): - nonce_length = encryption_data.encrypted_region_info.nonce_length - if hasattr(encryption_data.encrypted_region_info, 'data_length'): - data_length = encryption_data.encrypted_region_info.data_length - if hasattr(encryption_data.encrypted_region_info, 'tag_length'): - tag_length = encryption_data.encrypted_region_info.tag_length + + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length requested_length = end - start @@ -410,14 +406,9 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp :rtype: int """ if is_encryption_v2(encryption_data): - if encryption_data is not None: - if encryption_data.encrypted_region_info is not None: - if hasattr(encryption_data.encrypted_region_info, 'nonce_length'): - nonce_length = encryption_data.encrypted_region_info.nonce_length - if hasattr(encryption_data.encrypted_region_info, 'data_length'): - data_length = encryption_data.encrypted_region_info.data_length - if hasattr(encryption_data.encrypted_region_info, 'tag_length'): - tag_length = encryption_data.encrypted_region_info.tag_length + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length num_regions = math.ceil(size / region_length) @@ -440,23 +431,19 @@ def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], ''' # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: - if hasattr(kek, 'wrap_key'): - wrapped_cek = kek.wrap_key(cek) + wrapped_cek = kek.wrap_key(cek) # For V2, we include the encryption version in the wrapped key. elif version == _ENCRYPTION_PROTOCOL_V2: # We must pad the version to 8 bytes for AES Keywrap algorithms to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') + cek - if hasattr(kek, 'wrap_key'): - wrapped_cek = kek.wrap_key(to_wrap) + wrapped_cek = kek.wrap_key(to_wrap) # Build the encryption_data dict. # Use OrderedDict to comply with Java's ordering requirement. wrapped_content_key = OrderedDict() - if hasattr(kek, 'get_kid'): - wrapped_content_key['KeyId'] = kek.get_kid() + wrapped_content_key['KeyId'] = kek.get_kid() wrapped_content_key['EncryptedKey'] = encode_base64(wrapped_cek) - if hasattr(kek, 'get_key_wrap_algorithm'): - wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() + wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() encryption_agent = OrderedDict() encryption_agent['Protocol'] = version @@ -478,7 +465,7 @@ def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info - encryption_data_dict['KeyWrappingMetadata'] = OrderedDict([('EncryptionLibrary', 'Python ' + VERSION)]) + encryption_data_dict['KeyWrappingMetadata'] = {'EncryptionLibrary': 'Python ' + VERSION} return encryption_data_dict @@ -488,7 +475,7 @@ def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _Encryptio Converts the specified dictionary to an EncryptionData object for eventual use in decryption. - :param Dict[str, Any] encryption_data_dict: + :param dict encryption_data_dict: The dictionary containing the encryption data. :return: an _EncryptionData object built from the dictionary. :rtype: _EncryptionData @@ -539,11 +526,12 @@ def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: ''' Generates and returns an encryption cipher for AES CBC using the given cek and iv. - :param bytes cek: The content encryption key for the cipher. - :param bytes iv: The initialization vector for the cipher. + :param bytes[] cek: The content encryption key for the cipher. + :param bytes[] iv: The initialization vector for the cipher. :return: A cipher for encrypting in AES256 CBC. :rtype: ~cryptography.hazmat.primitives.ciphers.Cipher ''' + backend = default_backend() algorithm = AES(cek) mode = CBC(iv) @@ -560,15 +548,16 @@ def _validate_and_unwrap_cek( and performs necessary validation on all parameters. :param _EncryptionData encryption_data: The encryption metadata of the retrieved value. - :param object key_encryption_key: + :param obj key_encryption_key: The key_encryption_key used to unwrap the cek. Please refer to high-level service object instance variables for more details. - :param Optional[Callable[[str], bytes]] key_resolver: + :param func key_resolver: A function used that, given a key_id, will return a key_encryption_key. Please refer to high-level service object instance variables for more details. :return: the content_encryption_key stored in the encryption_data object. - :rtype: bytes + :rtype: bytes[] ''' + _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) # Validate we have the right info for the specified version @@ -579,7 +568,7 @@ def _validate_and_unwrap_cek( else: raise ValueError('Specified encryption version is not supported.') - content_encryption_key: Optional[bytes] = None + content_encryption_key = None # If the resolver exists, give priority to the key it finds. if key_resolver is not None: @@ -593,28 +582,23 @@ def _validate_and_unwrap_cek( if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') # Will throw an exception if the specified algorithm is not supported. - if hasattr(key_encryption_key, 'unwrap_key'): - content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, + content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, encryption_data.wrapped_content_key.algorithm) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: version_2_bytes = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') - if content_encryption_key is not None: - cek_version_bytes = content_encryption_key[:len(version_2_bytes)] - if cek_version_bytes != version_2_bytes: - raise ValueError('The encryption metadata is not valid and may have been modified.') + cek_version_bytes = content_encryption_key[:len(version_2_bytes)] + if cek_version_bytes != version_2_bytes: + raise ValueError('The encryption metadata is not valid and may have been modified.') - # Remove version from the start of the cek. - content_encryption_key = content_encryption_key[len(version_2_bytes):] + # Remove version from the start of the cek. + content_encryption_key = content_encryption_key[len(version_2_bytes):] _validate_not_none('content_encryption_key', content_encryption_key) - if isinstance(content_encryption_key, bytes): - validated_cek: bytes = content_encryption_key - - return validated_cek + return content_encryption_key def _decrypt_message( @@ -623,7 +607,7 @@ def _decrypt_message( key_encryption_key: object = None, resolver: Optional[Callable] = None ) -> str: - ''' + """ Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. Unwraps the content-encryption-key using the user-provided or resolved key-encryption-key (kek). Returns the original plaintex. @@ -638,12 +622,12 @@ def _decrypt_message( - returns the unwrapped form of the specified symmetric key using the string-specified algorithm. get_kid() - returns a string key id for this key-encryption-key. - :param Optional[Callable] resolver: + :param Callable resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The decrypted plaintext. :rtype: str - ''' + """ _validate_not_none('message', message) content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, resolver) @@ -654,7 +638,7 @@ def _decrypt_message( cipher = _generate_AES_CBC_cipher(content_encryption_key, encryption_data.content_encryption_IV) # decrypt data - decrypted_data: Union[bytes, str] = message + decrypted_data = message decryptor = cipher.decryptor() decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) @@ -667,23 +651,19 @@ def _decrypt_message( if not block_info or not block_info.nonce_length: raise ValueError("Missing required metadata for decryption.") - if encryption_data.encrypted_region_info is not None: - if hasattr(encryption_data.encryption_agent, 'nonce_length'): - nonce_length = encryption_data.encrypted_region_info.nonce_length + nonce_length = encryption_data.encrypted_region_info.nonce_length # First bytes are the nonce nonce = message[:nonce_length] ciphertext_with_tag = message[nonce_length:] aesgcm = AESGCM(content_encryption_key) - decrypted_data = str(aesgcm.decrypt(nonce, ciphertext_with_tag, None)) #type: ignore + decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) else: raise ValueError('Specified encryption version is not supported.') - if isinstance(decrypted_data, str): - decrypted_data_as_str = decrypted_data - return decrypted_data_as_str + return decrypted_data def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple[str, bytes]: @@ -705,6 +685,7 @@ def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple :return: A tuple of json-formatted string containing the encryption metadata and the encrypted blob data. :rtype: (str, bytes) ''' + _validate_not_none('blob', blob) _validate_not_none('key_encryption_key', key_encryption_key) _validate_key_encryption_key_wrap(key_encryption_key) @@ -764,18 +745,14 @@ def generate_blob_encryption_data(key_encryption_key: object, version: str) -> T # Initialization vector only needed for V1 if version == _ENCRYPTION_PROTOCOL_V1: initialization_vector = os.urandom(16) - encryption_data_dict = _generate_encryption_data_dict(key_encryption_key, + encryption_data = _generate_encryption_data_dict(key_encryption_key, content_encryption_key, initialization_vector, version) - encryption_data_dict['EncryptionMode'] = 'FullBlob' - encryption_data = dumps(encryption_data_dict) - if isinstance(encryption_data, str): - serialized_encryption_data = encryption_data - if content_encryption_key is not None: - validated_content_encyption_key = content_encryption_key + encryption_data['EncryptionMode'] = 'FullBlob' + encryption_data = dumps(encryption_data) - return validated_content_encyption_key, initialization_vector, serialized_encryption_data + return content_encryption_key, initialization_vector, encryption_data def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements @@ -797,7 +774,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements wrap_key(key)--wraps the specified key using an algorithm of the user's choice. get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. get_kid()--returns a string key id for this key-encryption-key. - :param Callable[[str], bytes] key_resolver: + :param object key_resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :param bytes content: @@ -865,8 +842,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements if blob_type == 'PageBlob': unpad = False - if isinstance(iv, bytes): - cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) + cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) decryptor = cipher.decryptor() content = decryptor.update(content) + decryptor.finalize() @@ -881,13 +857,9 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements total_size = len(content) offset = 0 - if encryption_data.encrypted_region_info is not None: - if hasattr(encryption_data.encrypted_region_info, 'nonce_length'): - nonce_length = encryption_data.encrypted_region_info.nonce_length - if hasattr(encryption_data.encrypted_region_info, 'data_length'): - data_length = encryption_data.encrypted_region_info.data_length - if hasattr(encryption_data.encrypted_region_info, 'tag_length'): - tag_length = encryption_data.encrypted_region_info.tag_length + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length decrypted_content = bytearray() @@ -908,6 +880,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements # Read the caller requested data from the decrypted content return decrypted_content[start_offset:end_offset] + raise ValueError('Specified encryption version is not supported.') @@ -933,8 +906,8 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encrypted message and the encryption metadata. - :param str message: - The plain text message to be encrypted. + :param object message: + The plain text messge to be encrypted. :param object key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: wrap_key(key)--wraps the specified key using an algorithm of the user's choice. @@ -944,13 +917,14 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str :return: A json-formatted string containing the encrypted message and the encryption metadata. :rtype: str ''' + _validate_not_none('message', message) _validate_not_none('key_encryption_key', key_encryption_key) _validate_key_encryption_key_wrap(key_encryption_key) # Queue encoding functions all return unicode strings, and encryption should # operate on binary strings. - message_as_bytes = message.encode('utf-8') + message = message.encode('utf-8') if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks @@ -961,7 +935,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str # PKCS7 with 16 byte blocks ensures compatibility with AES. padder = PKCS7(128).padder() - padded_data = padder.update(message_as_bytes) + padder.finalize() + padded_data = padder.update(message) + padder.finalize() # Encrypt the data. encryptor = cipher.encryptor() @@ -977,7 +951,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str aesgcm = AESGCM(content_encryption_key) # Returns ciphertext + tag - cipertext_with_tag = aesgcm.encrypt(nonce, message_as_bytes, None) + cipertext_with_tag = aesgcm.encrypt(nonce, message, None) encrypted_data = nonce + cipertext_with_tag else: @@ -1000,7 +974,7 @@ def decrypt_queue_message( key_encryption_key: object, resolver: Callable[[str], bytes] ) -> str: - ''' + """ Returns the decrypted message contents from an EncryptedQueueMessage. If no encryption metadata is present, will return the unaltered message. :param str message: @@ -1015,19 +989,19 @@ def decrypt_queue_message( - returns the unwrapped form of the specified symmetric key usingthe string-specified algorithm. get_kid() - returns a string key id for this key-encryption-key. - :param Callable[[str], bytes] resolver: + :param Callable resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The plain text message from the queue message. :rtype: str - ''' + """ response = response.http_response try: - message_dict = loads(message) + message = loads(message) - encryption_data = _dict_to_encryption_data(message_dict['EncryptionData']) - decoded_data = decode_base64_to_bytes(message_dict['EncryptedMessageContents']) + encryption_data = _dict_to_encryption_data(message['EncryptionData']) + decoded_data = decode_base64_to_bytes(message['EncryptedMessageContents']) except (KeyError, ValueError) as exc: # Message was not json formatted and so was not encrypted # or the user provided a json formatted message @@ -1037,11 +1011,11 @@ def decrypt_queue_message( 'Encryption required, but received message does not contain appropriate metatadata. ' + \ 'Message was either not encrypted or metadata was incorrect.') from exc - return message_dict + return message try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore # pylint: disable=line-too-long + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') except Exception as error: raise HttpResponseError( message="Decryption failed.", - response=response, #type: ignore - error=error) from error + response=response, + error=error) from error \ No newline at end of file diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index ff5d827364ed..61ad5602c02e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -433,7 +433,7 @@ def __init__( ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, + self._extract_data_cb, #type: ignore continuation_token=continuation_token or "" ) self._command = command @@ -453,7 +453,7 @@ def _get_next_cb(self, continuation_token: Optional[str]) -> Any: except HttpResponseError as error: process_storage_error(error) - def _extract_data_cb(self, get_next_return: Any) -> Tuple[str, List[QueueProperties]]: + def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return if self._response is not None: if hasattr(self._response, 'service_endpoint'): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 3e61e2df7916..290a501f592d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -233,11 +233,11 @@ def _create_pipeline(self, credential, **kwargs): config = kwargs.get("_configuration") or create_configuration(**kwargs) if kwargs.get("_pipeline"): return config, kwargs["_pipeline"] - config.transport = kwargs.get("transport") # type: ignore + transport = kwargs.get("transport") kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) - if not config.transport: - config.transport = RequestsTransport(**kwargs) # type: ignore + if not transport: + transport = RequestsTransport(**kwargs) policies = [ QueueMessagePolicy(), config.proxy_policy, @@ -257,7 +257,8 @@ def _create_pipeline(self, credential, **kwargs): ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore - return config, Pipeline(config.transport, policies=policies) + config.transport = transport #type: ignore + return config, Pipeline(transport, policies=policies) # Given a series of request, do a Storage batch call. def _batch_send( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 8b15859324e9..0c40b18c7349 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -109,8 +109,7 @@ def _create_pipeline(self, credential, **kwargs): ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore - if hasattr(config, 'transport'): - config.transport = transport + config.transport = transport #type: ignore return config, AsyncPipeline(transport, policies=policies) #type: ignore # Given a series of request, do a Storage batch call. From 774dbf235f3a48402c06b65b2382fe470ba0c63f Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 16 Aug 2023 20:35:02 -0700 Subject: [PATCH 25/71] All tests passing locally, mypy green locally --- .../azure/storage/queue/_encryption.py | 110 +++++++++++------- 1 file changed, 66 insertions(+), 44 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index cbeef8392cba..8286c17b2722 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -351,11 +351,13 @@ def get_adjusted_download_range_and_offset( elif encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: start_offset, end_offset = 0, end - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length - region_length = nonce_length + data_length + tag_length - requested_length = end - start + if encryption_data.encrypted_region_info is not None: + if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length + region_length = nonce_length + data_length + tag_length + requested_length = end - start if start is not None: # Find which data region the start is in @@ -406,14 +408,17 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp :rtype: int """ if is_encryption_v2(encryption_data): - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length - region_length = nonce_length + data_length + tag_length - - num_regions = math.ceil(size / region_length) - metadata_size = num_regions * (nonce_length + tag_length) - return size - metadata_size + if encryption_data is not None: + if encryption_data.encrypted_region_info is not None: + if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length + region_length = nonce_length + data_length + tag_length + + num_regions = math.ceil(size / region_length) + metadata_size = num_regions * (nonce_length + tag_length) + return size - metadata_size return size @@ -431,19 +436,23 @@ def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], ''' # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: - wrapped_cek = kek.wrap_key(cek) + if hasattr(kek, 'wrap_key'): + wrapped_cek = kek.wrap_key(cek) # For V2, we include the encryption version in the wrapped key. elif version == _ENCRYPTION_PROTOCOL_V2: # We must pad the version to 8 bytes for AES Keywrap algorithms to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') + cek - wrapped_cek = kek.wrap_key(to_wrap) + if hasattr(kek, 'wrap_key'): + wrapped_cek = kek.wrap_key(to_wrap) # Build the encryption_data dict. # Use OrderedDict to comply with Java's ordering requirement. wrapped_content_key = OrderedDict() - wrapped_content_key['KeyId'] = kek.get_kid() + if hasattr(kek, 'get_kid'): + wrapped_content_key['KeyId'] = kek.get_kid() wrapped_content_key['EncryptedKey'] = encode_base64(wrapped_cek) - wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() + if hasattr(kek, 'get_key_wrap_algorithm'): + wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() encryption_agent = OrderedDict() encryption_agent['Protocol'] = version @@ -465,7 +474,7 @@ def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info - encryption_data_dict['KeyWrappingMetadata'] = {'EncryptionLibrary': 'Python ' + VERSION} + encryption_data_dict['KeyWrappingMetadata'] = OrderedDict([('EncryptionLibrary', 'Python ' + VERSION)]) return encryption_data_dict @@ -548,14 +557,14 @@ def _validate_and_unwrap_cek( and performs necessary validation on all parameters. :param _EncryptionData encryption_data: The encryption metadata of the retrieved value. - :param obj key_encryption_key: + :param object key_encryption_key: The key_encryption_key used to unwrap the cek. Please refer to high-level service object instance variables for more details. - :param func key_resolver: + :param Optional[Callable[[str], bytes]] key_resolver: A function used that, given a key_id, will return a key_encryption_key. Please refer to high-level service object instance variables for more details. :return: the content_encryption_key stored in the encryption_data object. - :rtype: bytes[] + :rtype: bytes ''' _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) @@ -582,23 +591,24 @@ def _validate_and_unwrap_cek( if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') # Will throw an exception if the specified algorithm is not supported. - content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, - encryption_data.wrapped_content_key.algorithm) + if hasattr(key_encryption_key, 'unwrap_key'): + content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, + encryption_data.wrapped_content_key.algorithm) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: version_2_bytes = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') - cek_version_bytes = content_encryption_key[:len(version_2_bytes)] + cek_version_bytes = content_encryption_key[:len(version_2_bytes)] #type: ignore [index] if cek_version_bytes != version_2_bytes: raise ValueError('The encryption metadata is not valid and may have been modified.') # Remove version from the start of the cek. - content_encryption_key = content_encryption_key[len(version_2_bytes):] + content_encryption_key = content_encryption_key[len(version_2_bytes):] #type: ignore [index] _validate_not_none('content_encryption_key', content_encryption_key) - return content_encryption_key + return content_encryption_key #type: ignore[return-value] def _decrypt_message( @@ -651,14 +661,16 @@ def _decrypt_message( if not block_info or not block_info.nonce_length: raise ValueError("Missing required metadata for decryption.") - nonce_length = encryption_data.encrypted_region_info.nonce_length + if encryption_data.encrypted_region_info is not None: + if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): + nonce_length = encryption_data.encrypted_region_info.nonce_length # First bytes are the nonce nonce = message[:nonce_length] ciphertext_with_tag = message[nonce_length:] aesgcm = AESGCM(content_encryption_key) - decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) + decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) #type: ignore else: raise ValueError('Specified encryption version is not supported.') @@ -725,7 +737,7 @@ def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple return dumps(encryption_data), encrypted_data -def generate_blob_encryption_data(key_encryption_key: object, version: str) -> Tuple[bytes, Optional[bytes], str]: +def generate_blob_encryption_data(key_encryption_key: object, version: str) -> Tuple[bytes, Optional[bytes], Dict[str, Any]]: ''' Generates the encryption_metadata for the blob. @@ -750,9 +762,15 @@ def generate_blob_encryption_data(key_encryption_key: object, version: str) -> T initialization_vector, version) encryption_data['EncryptionMode'] = 'FullBlob' - encryption_data = dumps(encryption_data) + encryption_data = dumps(encryption_data) # type: ignore + + if content_encryption_key is not None: + valid_cek = content_encryption_key + + if encryption_data is not None: + valid_encryption_data = encryption_data - return content_encryption_key, initialization_vector, encryption_data + return valid_cek, initialization_vector, valid_encryption_data def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements @@ -842,7 +860,8 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements if blob_type == 'PageBlob': unpad = False - cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) + if isinstance(iv, bytes): + cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) decryptor = cipher.decryptor() content = decryptor.update(content) + decryptor.finalize() @@ -857,10 +876,12 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements total_size = len(content) offset = 0 - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length - region_length = nonce_length + data_length + tag_length + if encryption_data.encrypted_region_info is not None: + if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length + region_length = nonce_length + data_length + tag_length decrypted_content = bytearray() while offset < total_size: @@ -924,7 +945,8 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str # Queue encoding functions all return unicode strings, and encryption should # operate on binary strings. - message = message.encode('utf-8') + message_as_bytes: bytes = message.encode('utf-8') + if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks @@ -935,7 +957,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str # PKCS7 with 16 byte blocks ensures compatibility with AES. padder = PKCS7(128).padder() - padded_data = padder.update(message) + padder.finalize() + padded_data = padder.update(message_as_bytes) + padder.finalize() # Encrypt the data. encryptor = cipher.encryptor() @@ -951,7 +973,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str aesgcm = AESGCM(content_encryption_key) # Returns ciphertext + tag - cipertext_with_tag = aesgcm.encrypt(nonce, message, None) + cipertext_with_tag = aesgcm.encrypt(nonce, message_as_bytes, None) encrypted_data = nonce + cipertext_with_tag else: @@ -998,10 +1020,10 @@ def decrypt_queue_message( response = response.http_response try: - message = loads(message) + deserialized_message: Dict[str, Any] = loads(message) - encryption_data = _dict_to_encryption_data(message['EncryptionData']) - decoded_data = decode_base64_to_bytes(message['EncryptedMessageContents']) + encryption_data = _dict_to_encryption_data(deserialized_message['EncryptionData']) + decoded_data = decode_base64_to_bytes(deserialized_message['EncryptedMessageContents']) except (KeyError, ValueError) as exc: # Message was not json formatted and so was not encrypted # or the user provided a json formatted message @@ -1013,9 +1035,9 @@ def decrypt_queue_message( return message try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore except Exception as error: raise HttpResponseError( message="Decryption failed.", - response=response, + response=response, #type: ignore [arg-type] error=error) from error \ No newline at end of file From 8a8ffada508b22ac374d7b43e8125034a4e41c4e Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Thu, 17 Aug 2023 16:44:52 -0700 Subject: [PATCH 26/71] Pylint --- .../azure/storage/queue/_encryption.py | 8 +++---- ...ageQueueEncodingtest_message_text_xml.json | 22 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 8286c17b2722..70b0e38c265c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -737,7 +737,7 @@ def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple return dumps(encryption_data), encrypted_data -def generate_blob_encryption_data(key_encryption_key: object, version: str) -> Tuple[bytes, Optional[bytes], Dict[str, Any]]: +def generate_blob_encryption_data(key_encryption_key: object, version: str) -> Tuple[bytes, Optional[bytes], Dict[str, Any]]: # pylint: disable=line-too-long ''' Generates the encryption_metadata for the blob. @@ -946,7 +946,7 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str # Queue encoding functions all return unicode strings, and encryption should # operate on binary strings. message_as_bytes: bytes = message.encode('utf-8') - + if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks @@ -1035,9 +1035,9 @@ def decrypt_queue_message( return message try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore # pylint: disable=line-too-long except Exception as error: raise HttpResponseError( message="Decryption failed.", response=response, #type: ignore [arg-type] - error=error) from error \ No newline at end of file + error=error) from error diff --git a/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json b/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json index 1f7d22252046..71a42a3c487c 100644 --- a/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json +++ b/sdk/storage/azure-storage-queue/tests/recordings/test_queue_encodings.pyTestStorageQueueEncodingtest_message_text_xml.json @@ -8,15 +8,15 @@ "Accept-Encoding": "gzip, deflate", "Connection": "keep-alive", "Content-Length": "0", - "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.1 (Windows-10-10.0.22621-SP0)", - "x-ms-date": "Fri, 07 Apr 2023 22:00:39 GMT", + "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.4 (Windows-10-10.0.22621-SP0)", + "x-ms-date": "Thu, 17 Aug 2023 23:43:01 GMT", "x-ms-version": "2021-02-12" }, "RequestBody": null, "StatusCode": 201, "ResponseHeaders": { "Content-Length": "0", - "Date": "Fri, 07 Apr 2023 22:00:38 GMT", + "Date": "Thu, 17 Aug 2023 23:43:01 GMT", "Server": [ "Windows-Azure-Queue/1.0", "Microsoft-HTTPAPI/2.0" @@ -34,8 +34,8 @@ "Connection": "keep-alive", "Content-Length": "111", "Content-Type": "application/xml", - "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.1 (Windows-10-10.0.22621-SP0)", - "x-ms-date": "Fri, 07 Apr 2023 22:00:40 GMT", + "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.4 (Windows-10-10.0.22621-SP0)", + "x-ms-date": "Thu, 17 Aug 2023 23:43:02 GMT", "x-ms-version": "2021-02-12" }, "RequestBody": [ @@ -45,7 +45,7 @@ "StatusCode": 201, "ResponseHeaders": { "Content-Type": "application/xml", - "Date": "Fri, 07 Apr 2023 22:00:38 GMT", + "Date": "Thu, 17 Aug 2023 23:43:01 GMT", "Server": [ "Windows-Azure-Queue/1.0", "Microsoft-HTTPAPI/2.0" @@ -53,7 +53,7 @@ "Transfer-Encoding": "chunked", "x-ms-version": "2021-02-12" }, - "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003Ef37bc876-6500-4f1e-874e-6e14128de1f1\u003C/MessageId\u003E\u003CInsertionTime\u003EFri, 07 Apr 2023 22:00:38 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EFri, 14 Apr 2023 22:00:38 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAAPOtiY5xp2QE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EFri, 07 Apr 2023 22:00:38 GMT\u003C/TimeNextVisible\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" + "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003Ed735bde8-a7c1-4fb0-939d-bf3497fb0270\u003C/MessageId\u003E\u003CInsertionTime\u003EThu, 17 Aug 2023 23:43:02 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EThu, 24 Aug 2023 23:43:02 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAAU8GZj2TR2QE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EThu, 17 Aug 2023 23:43:02 GMT\u003C/TimeNextVisible\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" }, { "RequestUri": "https://storagename.queue.core.windows.net/mytestqueue9b732da4/messages", @@ -62,8 +62,8 @@ "Accept": "application/xml", "Accept-Encoding": "gzip, deflate", "Connection": "keep-alive", - "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.1 (Windows-10-10.0.22621-SP0)", - "x-ms-date": "Fri, 07 Apr 2023 22:00:40 GMT", + "User-Agent": "azsdk-python-storage-queue/12.7.0b1 Python/3.11.4 (Windows-10-10.0.22621-SP0)", + "x-ms-date": "Thu, 17 Aug 2023 23:43:02 GMT", "x-ms-version": "2021-02-12" }, "RequestBody": null, @@ -71,7 +71,7 @@ "ResponseHeaders": { "Cache-Control": "no-cache", "Content-Type": "application/xml", - "Date": "Fri, 07 Apr 2023 22:00:38 GMT", + "Date": "Thu, 17 Aug 2023 23:43:01 GMT", "Server": [ "Windows-Azure-Queue/1.0", "Microsoft-HTTPAPI/2.0" @@ -80,7 +80,7 @@ "Vary": "Origin", "x-ms-version": "2021-02-12" }, - "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003Ef37bc876-6500-4f1e-874e-6e14128de1f1\u003C/MessageId\u003E\u003CInsertionTime\u003EFri, 07 Apr 2023 22:00:38 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EFri, 14 Apr 2023 22:00:38 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAAvZZSdZxp2QE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EFri, 07 Apr 2023 22:01:08 GMT\u003C/TimeNextVisible\u003E\u003CDequeueCount\u003E1\u003C/DequeueCount\u003E\u003CMessageText\u003E\u0026lt;message1\u0026gt;\u003C/MessageText\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" + "ResponseBody": "\uFEFF\u003C?xml version=\u00221.0\u0022 encoding=\u0022utf-8\u0022?\u003E\u003CQueueMessagesList\u003E\u003CQueueMessage\u003E\u003CMessageId\u003Ed735bde8-a7c1-4fb0-939d-bf3497fb0270\u003C/MessageId\u003E\u003CInsertionTime\u003EThu, 17 Aug 2023 23:43:02 GMT\u003C/InsertionTime\u003E\u003CExpirationTime\u003EThu, 24 Aug 2023 23:43:02 GMT\u003C/ExpirationTime\u003E\u003CPopReceipt\u003EAgAAAAMAAAAAAAAA/PuYoWTR2QE=\u003C/PopReceipt\u003E\u003CTimeNextVisible\u003EThu, 17 Aug 2023 23:43:32 GMT\u003C/TimeNextVisible\u003E\u003CDequeueCount\u003E1\u003C/DequeueCount\u003E\u003CMessageText\u003E\u0026lt;message1\u0026gt;\u003C/MessageText\u003E\u003C/QueueMessage\u003E\u003C/QueueMessagesList\u003E" } ], "Variables": {} From 20d348d677f90d65ffb7a58e536c72ea073e19a8 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Mon, 28 Aug 2023 17:32:57 -0700 Subject: [PATCH 27/71] Encryption feedback done --- .../azure/storage/queue/_encryption.py | 324 ++++++++++-------- .../azure/storage/queue/_queue_client.py | 2 +- .../storage/queue/aio/_queue_client_async.py | 2 +- 3 files changed, 178 insertions(+), 150 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 3224e381fdc6..bba92353f827 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -15,7 +15,8 @@ dumps, loads, ) -from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher @@ -46,23 +47,40 @@ '{0} does not define a complete interface. Value of {1} is either missing or invalid.' +class KeyEncryptionKey(Protocol): + """Protocol that defines what calling functions should be defined for a user-provided key-encryption-key (kek).""" + + def wrap_key(self, key): + ... + + def unwrap_key(self, key, algorithm): + ... + + def get_kid(self): + ... + + def get_key_wrap_algorithm(self): + ... + + def _validate_not_none(param_name: str, param: Any): if param is None: raise ValueError(f'{param_name} should not be None.') -def _validate_key_encryption_key_wrap(kek: object): +def _validate_key_encryption_key_wrap(kek: KeyEncryptionKey): # Note that None is not callable and so will fail the second clause of each check. if not hasattr(kek, 'wrap_key') or not callable(kek.wrap_key): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) if not hasattr(kek, 'get_kid') or not callable(kek.get_kid): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) - if not hasattr(kek, 'get_key_wrap_algorithm') or not callable(kek.get_key_wrap_algorithm): + if not (hasattr(kek, 'get_key_wrap_algorithm') or + not callable(kek.get_key_wrap_algorithm)): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_key_wrap_algorithm')) class StorageEncryptionMixin(object): - def _configure_encryption(self, kwargs: Any): + def _configure_encryption(self, kwargs: Dict[str, Any]): self.require_encryption = kwargs.get("require_encryption", False) self.encryption_version = kwargs.get("encryption_version", "1.0") self.key_encryption_key = kwargs.get("key_encryption_key") @@ -83,19 +101,19 @@ class _EncryptionAlgorithm(object): class _WrappedContentKey: - ''' + """ Represents the envelope key details stored on the service. - ''' + """ def __init__(self, algorithm: str, encrypted_key: bytes, key_id: str) -> None: - ''' + """ :param str algorithm: The algorithm used for wrapping. :param bytes encrypted_key: The encrypted content-encryption-key. :param str key_id: The key-encryption-key identifier string. - ''' + """ _validate_not_none('algorithm', algorithm) _validate_not_none('encrypted_key', encrypted_key) _validate_not_none('key_id', key_id) @@ -106,20 +124,20 @@ def __init__(self, algorithm: str, encrypted_key: bytes, key_id: str) -> None: class _EncryptedRegionInfo: - ''' + """ Represents the length of encryption elements. This is only used for Encryption V2. - ''' + """ def __init__(self, data_length: int, nonce_length: int, tag_length: int) -> None: - ''' + """ :param int data_length: The length of the encryption region data (not including nonce + tag). :param int nonce_length: The length of nonce used when encrypting. :param int tag_length: The length of the encryption tag. - ''' + """ _validate_not_none('data_length', data_length) _validate_not_none('nonce_length', nonce_length) _validate_not_none('tag_length', tag_length) @@ -130,18 +148,18 @@ def __init__(self, data_length: int, nonce_length: int, tag_length: int) -> None class _EncryptionAgent: - ''' + """ Represents the encryption agent stored on the service. It consists of the encryption protocol version and encryption algorithm used. - ''' + """ def __init__(self, encryption_algorithm: _EncryptionAlgorithm, protocol: str) -> None: - ''' + """ :param _EncryptionAlgorithm encryption_algorithm: The algorithm used for encrypting the message contents. :param str protocol: The protocol version used for encryption. - ''' + """ _validate_not_none('encryption_algorithm', encryption_algorithm) _validate_not_none('protocol', protocol) @@ -150,9 +168,9 @@ def __init__(self, encryption_algorithm: _EncryptionAlgorithm, protocol: str) -> class _EncryptionData: - ''' + """ Represents the encryption data that is stored on the service. - ''' + """ def __init__( self, content_encryption_IV: Optional[bytes], @@ -161,7 +179,7 @@ def __init__( wrapped_content_key: _WrappedContentKey, key_wrapping_metadata: Dict[str, Any] ) -> None: - ''' + """ :param Optional[bytes] content_encryption_IV: The content encryption initialization vector. Required for AES-CBC (V1). @@ -175,7 +193,7 @@ def __init__( and the encrypted key bytes. :param Dict[str, Any] key_wrapping_metadata: A dict containing metadata related to the key wrapping. - ''' + """ _validate_not_none('encryption_agent', encryption_agent) _validate_not_none('wrapped_content_key', wrapped_content_key) @@ -195,20 +213,20 @@ def __init__( class GCMBlobEncryptionStream: - ''' + """ A stream that performs AES-GCM encryption on the given data as it's streamed. Data is read and encrypted in regions. The stream will use the same encryption key and will generate a guaranteed unique nonce for each encryption region. - ''' + """ def __init__( self, content_encryption_key: bytes, data_stream: BinaryIO, ) -> None: - ''' + """ :param bytes content_encryption_key: The encryption key to use. :param BinaryIO data_stream: The data stream to read data from. - ''' + """ self.content_encryption_key = content_encryption_key self.data_stream = data_stream @@ -268,7 +286,7 @@ def _encrypt_region(self, data: bytes) -> bytes: return nonce + ciphertext_with_tag -def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> Optional[Union[_EncryptionData, bool]]: +def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: """ Determine whether the given encryption data signifies version 2.0. @@ -277,7 +295,10 @@ def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> Optional[Un :rtype: bool """ # If encryption_data is None, assume no encryption - return encryption_data and encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2 + if encryption_data and (encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2): + return True + + return False def modify_user_agent_for_encryption( @@ -376,12 +397,11 @@ def get_adjusted_download_range_and_offset( start_offset, end_offset = 0, end if encryption_data.encrypted_region_info is not None: - if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length - region_length = nonce_length + data_length + tag_length - requested_length = end - start + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length + region_length = nonce_length + data_length + tag_length + requested_length = end - start if start is not None: # Find which data region the start is in @@ -421,62 +441,60 @@ def parse_encryption_data(metadata: Dict[str, Any]) -> Optional[_EncryptionData] return None -def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_EncryptionData]) -> int: +def adjust_blob_size_for_encryption(size: int, encryption_data: _EncryptionData) -> int: """ Adjusts the given blob size for encryption by subtracting the size of the encryption data (nonce + tag). This only has an affect for encryption V2. :param int size: The original blob size. - :param Optional[_EncryptionData] encryption_data: The encryption data to determine version and sizes. + :param _EncryptionData encryption_data: The encryption data to determine version and sizes. :return: The new blob size. :rtype: int """ - if is_encryption_v2(encryption_data): - if encryption_data is not None: - if encryption_data.encrypted_region_info is not None: - if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length - region_length = nonce_length + data_length + tag_length - - num_regions = math.ceil(size / region_length) - metadata_size = num_regions * (nonce_length + tag_length) - return size - metadata_size + if is_encryption_v2(encryption_data) and encryption_data.encrypted_region_info is not None: + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length + region_length = nonce_length + data_length + tag_length + + num_regions = math.ceil(size / region_length) + metadata_size = num_regions * (nonce_length + tag_length) + return size - metadata_size return size -def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], version: str) -> Dict[str, Any]: - ''' +def _generate_encryption_data_dict( + kek: KeyEncryptionKey, + cek: bytes, + iv: Optional[bytes], + version: str + ) -> Dict[str, Any]: + """ Generates and returns the encryption metadata as a dict. - :param object kek: The key encryption key. See calling functions for more information. + :param KeyEncryptionKey kek: The key encryption key. See calling functions for more information. :param bytes cek: The content encryption key. :param Optional[bytes] iv: The initialization vector. Only required for AES-CBC. :param str version: The client encryption version used. :return: A dict containing all the encryption metadata. :rtype: dict - ''' + """ # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: - if hasattr(kek, 'wrap_key'): - wrapped_cek = kek.wrap_key(cek) + wrapped_cek = kek.wrap_key(cek) # For V2, we include the encryption version in the wrapped key. elif version == _ENCRYPTION_PROTOCOL_V2: # We must pad the version to 8 bytes for AES Keywrap algorithms to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') + cek - if hasattr(kek, 'wrap_key'): - wrapped_cek = kek.wrap_key(to_wrap) + wrapped_cek = kek.wrap_key(to_wrap) # Build the encryption_data dict. # Use OrderedDict to comply with Java's ordering requirement. wrapped_content_key = OrderedDict() - if hasattr(kek, 'get_kid'): - wrapped_content_key['KeyId'] = kek.get_kid() + wrapped_content_key['KeyId'] = kek.get_kid() wrapped_content_key['EncryptedKey'] = encode_base64(wrapped_cek) - if hasattr(kek, 'get_key_wrap_algorithm'): - wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() + wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() encryption_agent = OrderedDict() encryption_agent['Protocol'] = version @@ -498,13 +516,13 @@ def _generate_encryption_data_dict(kek: object, cek: bytes, iv: Optional[bytes], encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info - encryption_data_dict['KeyWrappingMetadata'] = OrderedDict([('EncryptionLibrary', 'Python ' + VERSION)]) + encryption_data_dict['KeyWrappingMetadata'] = OrderedDict({'EncryptionLibrary': 'Python ' + VERSION}) return encryption_data_dict def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _EncryptionData: - ''' + """ Converts the specified dictionary to an EncryptionData object for eventual use in decryption. @@ -512,7 +530,7 @@ def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _Encryptio The dictionary containing the encryption data. :return: an _EncryptionData object built from the dictionary. :rtype: _EncryptionData - ''' + """ try: protocol = encryption_data_dict['EncryptionAgent']['Protocol'] if protocol not in [_ENCRYPTION_PROTOCOL_V1, _ENCRYPTION_PROTOCOL_V2]: @@ -556,14 +574,14 @@ def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _Encryptio def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: - ''' + """ Generates and returns an encryption cipher for AES CBC using the given cek and iv. :param bytes[] cek: The content encryption key for the cipher. :param bytes[] iv: The initialization vector for the cipher. :return: A cipher for encrypting in AES256 CBC. :rtype: ~cryptography.hazmat.primitives.ciphers.Cipher - ''' + """ backend = default_backend() algorithm = AES(cek) @@ -573,23 +591,28 @@ def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: def _validate_and_unwrap_cek( encryption_data: _EncryptionData, - key_encryption_key: object = None, + key_encryption_key: KeyEncryptionKey, key_resolver: Optional[Callable[[str], bytes]] = None ) -> bytes: - ''' + """ Extracts and returns the content_encryption_key stored in the encryption_data object and performs necessary validation on all parameters. :param _EncryptionData encryption_data: The encryption metadata of the retrieved value. - :param object key_encryption_key: - The key_encryption_key used to unwrap the cek. Please refer to high-level service object - instance variables for more details. + :param KeyEncryptionKey key_encryption_key: + The user-provided key-encryption-key. Must implement the following methods: + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. :param Optional[Callable[[str], bytes]] key_resolver: A function used that, given a key_id, will return a key_encryption_key. Please refer to high-level service object instance variables for more details. - :return: the content_encryption_key stored in the encryption_data object. + :return: The content_encryption_key stored in the encryption_data object. :rtype: bytes - ''' + """ _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) @@ -601,11 +624,11 @@ def _validate_and_unwrap_cek( else: raise ValueError('Specified encryption version is not supported.') - content_encryption_key = None + content_encryption_key = b'' # If the resolver exists, give priority to the key it finds. if key_resolver is not None: - key_encryption_key = key_resolver(encryption_data.wrapped_content_key.key_id) + key_encryption_key = key_resolver(encryption_data.wrapped_content_key.key_id) #type: ignore [assignment] _validate_not_none('key_encryption_key', key_encryption_key) if not hasattr(key_encryption_key, 'get_kid') or not callable(key_encryption_key.get_kid): @@ -615,47 +638,48 @@ def _validate_and_unwrap_cek( if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') # Will throw an exception if the specified algorithm is not supported. - if hasattr(key_encryption_key, 'unwrap_key'): - content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, + content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, encryption_data.wrapped_content_key.algorithm) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: version_2_bytes = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') - cek_version_bytes = content_encryption_key[:len(version_2_bytes)] #type: ignore [index] + cek_version_bytes = content_encryption_key[:len(version_2_bytes)] if cek_version_bytes != version_2_bytes: raise ValueError('The encryption metadata is not valid and may have been modified.') # Remove version from the start of the cek. - content_encryption_key = content_encryption_key[len(version_2_bytes):] #type: ignore [index] + content_encryption_key = content_encryption_key[len(version_2_bytes):] _validate_not_none('content_encryption_key', content_encryption_key) - return content_encryption_key #type: ignore[return-value] + return content_encryption_key def _decrypt_message( message: str, encryption_data: _EncryptionData, - key_encryption_key: object = None, + key_encryption_key: KeyEncryptionKey, resolver: Optional[Callable] = None ) -> str: """ Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. Unwraps the content-encryption-key using the user-provided or resolved key-encryption-key (kek). - Returns the original plaintex. + Returns the original plaintext. :param str message: The ciphertext to be decrypted. :param _EncryptionData encryption_data: The metadata associated with this ciphertext. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - unwrap_key(key, algorithm) - - returns the unwrapped form of the specified symmetric key using the string-specified algorithm. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. get_kid() - - returns a string key id for this key-encryption-key. + - Returns a string key id for this key-encryption-key. :param Callable resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. @@ -686,15 +710,15 @@ def _decrypt_message( raise ValueError("Missing required metadata for decryption.") if encryption_data.encrypted_region_info is not None: - if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): - nonce_length = encryption_data.encrypted_region_info.nonce_length + nonce_length = encryption_data.encrypted_region_info.nonce_length # First bytes are the nonce - nonce = message[:nonce_length] - ciphertext_with_tag = message[nonce_length:] + message_as_bytes = message.encode('utf-8') + nonce = message_as_bytes[:nonce_length] + ciphertext_with_tag = message_as_bytes[nonce_length:] aesgcm = AESGCM(content_encryption_key) - decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) #type: ignore + decrypted_data = (aesgcm.decrypt(nonce, ciphertext_with_tag, None)).decode() else: raise ValueError('Specified encryption version is not supported.') @@ -702,8 +726,8 @@ def _decrypt_message( return decrypted_data -def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple[str, bytes]: - ''' +def encrypt_blob(blob: bytes, key_encryption_key: KeyEncryptionKey, version: str) -> Tuple[str, bytes]: + """ Encrypts the given blob using the given encryption protocol version. Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encryption metadata. This method should @@ -712,15 +736,18 @@ def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple :param bytes blob: The blob to be encrypted. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - wrap_key(key)--wraps the specified key using an algorithm of the user's choice. - get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. - get_kid()--returns a string key id for this key-encryption-key. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. :param str version: The client encryption version to use. :return: A tuple of json-formatted string containing the encryption metadata and the encrypted blob data. :rtype: (str, bytes) - ''' + """ _validate_not_none('blob', blob) _validate_not_none('key_encryption_key', key_encryption_key) @@ -761,45 +788,41 @@ def encrypt_blob(blob: bytes, key_encryption_key: object, version: str) -> Tuple return dumps(encryption_data), encrypted_data -def generate_blob_encryption_data(key_encryption_key: object, version: str) -> Tuple[bytes, Optional[bytes], Dict[str, Any]]: # pylint: disable=line-too-long - ''' +def generate_blob_encryption_data( + key_encryption_key: KeyEncryptionKey, + version: str + ) -> Tuple[bytes, Optional[bytes], str]: + """ Generates the encryption_metadata for the blob. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The key-encryption-key used to wrap the cek associate with this blob. :param str version: The client encryption version to use. :return: A tuple containing the cek and iv for this blob as well as the serialized encryption metadata for the blob. :rtype: (bytes, Optional[bytes], str) - ''' - encryption_data = None - content_encryption_key = None - initialization_vector = None - if key_encryption_key: - _validate_key_encryption_key_wrap(key_encryption_key) - content_encryption_key = os.urandom(32) - # Initialization vector only needed for V1 - if version == _ENCRYPTION_PROTOCOL_V1: - initialization_vector = os.urandom(16) - encryption_data = _generate_encryption_data_dict(key_encryption_key, - content_encryption_key, - initialization_vector, - version) - encryption_data['EncryptionMode'] = 'FullBlob' - encryption_data = dumps(encryption_data) # type: ignore + """ - if content_encryption_key is not None: - valid_cek = content_encryption_key + initialization_vector = None - if encryption_data is not None: - valid_encryption_data = encryption_data + _validate_key_encryption_key_wrap(key_encryption_key) + content_encryption_key = os.urandom(32) + # Initialization vector only needed for V1 + if version == _ENCRYPTION_PROTOCOL_V1: + initialization_vector = os.urandom(16) + encryption_data = _generate_encryption_data_dict(key_encryption_key, + content_encryption_key, + initialization_vector, + version) + encryption_data['EncryptionMode'] = 'FullBlob' + encryption_data_dump = dumps(encryption_data) - return valid_cek, initialization_vector, valid_encryption_data + return content_encryption_key, initialization_vector, encryption_data_dump def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements require_encryption: bool, - key_encryption_key: object, + key_encryption_key: KeyEncryptionKey, key_resolver: Callable[[str], bytes], content: bytes, start_offset: int, @@ -811,11 +834,14 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements :param bool require_encryption: Whether the calling blob service requires objects to be decrypted. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - wrap_key(key)--wraps the specified key using an algorithm of the user's choice. - get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. - get_kid()--returns a string key id for this key-encryption-key. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. :param object key_resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. @@ -854,7 +880,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements if version == _ENCRYPTION_PROTOCOL_V1: blob_type = response_headers['x-ms-blob-type'] - iv = None + iv = b'' unpad = False if 'content-range' in response_headers: content_range = response_headers['content-range'] @@ -873,19 +899,18 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements content = content[16:] start_offset -= 16 else: - iv = encryption_data.content_encryption_IV + iv = encryption_data.content_encryption_IV #type: ignore [assignment] if end_range == blob_size - 1: unpad = True else: unpad = True - iv = encryption_data.content_encryption_IV + iv = encryption_data.content_encryption_IV #type: ignore [assignment] if blob_type == 'PageBlob': unpad = False - if isinstance(iv, bytes): - cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) + cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) decryptor = cipher.decryptor() content = decryptor.update(content) + decryptor.finalize() @@ -901,11 +926,10 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements offset = 0 if encryption_data.encrypted_region_info is not None: - if isinstance(encryption_data.encrypted_region_info, _EncryptedRegionInfo): - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length - region_length = nonce_length + data_length + tag_length + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length + region_length = nonce_length + data_length + tag_length decrypted_content = bytearray() while offset < total_size: @@ -945,23 +969,26 @@ def get_blob_encryptor_and_padder( return encryptor, padder -def encrypt_queue_message(message: str, key_encryption_key: object, version: str) -> str: - ''' +def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, version: str) -> str: + """ Encrypts the given plain text message using the given protocol version. Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encrypted message and the encryption metadata. :param object message: The plain text message to be encrypted. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - wrap_key(key)--wraps the specified key using an algorithm of the user's choice. - get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. - get_kid()--returns a string key id for this key-encryption-key. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. :param str version: The client encryption version to use. :return: A json-formatted string containing the encrypted message and the encryption metadata. :rtype: str - ''' + """ _validate_not_none('message', message) _validate_not_none('key_encryption_key', key_encryption_key) @@ -971,7 +998,6 @@ def encrypt_queue_message(message: str, key_encryption_key: object, version: str # operate on binary strings. message_as_bytes: bytes = message.encode('utf-8') - if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks content_encryption_key = os.urandom(32) @@ -1017,7 +1043,7 @@ def decrypt_queue_message( message: str, response: "PipelineResponse", require_encryption: bool, - key_encryption_key: object, + key_encryption_key: KeyEncryptionKey, resolver: Callable[[str], bytes] ) -> str: """ @@ -1029,12 +1055,14 @@ def decrypt_queue_message( The pipeline response used to generate an error with. :param bool require_encryption: If set, will enforce that the retrieved messages are encrypted and decrypt them. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - unwrap_key(key, algorithm) - - returns the unwrapped form of the specified symmetric key usingthe string-specified algorithm. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. get_kid() - - returns a string key id for this key-encryption-key. + - Returns a string key id for this key-encryption-key. :param Callable resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. @@ -1059,7 +1087,7 @@ def decrypt_queue_message( return message try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') #type: ignore # pylint: disable=line-too-long + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver) except Exception as error: raise HttpResponseError( message="Decryption failed.", diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 2ddae459e20c..1ef19c08a9c0 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -775,7 +775,7 @@ def update_message( inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - else: + except AttributeError: message_id = message message_text = content receipt = pop_receipt diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 2e1b632a8194..6c4c6d7b6973 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -772,7 +772,7 @@ async def update_message( # type: ignore[override] inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - else: + except AttributeError: message_id = message message_text = content receipt = pop_receipt From f87e1be5e57ffcbcce8dddcf85e481d23d5a8eef Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 30 Aug 2023 10:56:36 -0700 Subject: [PATCH 28/71] Shared, sample, and lint output left --- .../azure/storage/queue/_encryption.py | 6 ++-- .../azure/storage/queue/_message_encoding.py | 35 +++++++++---------- .../azure/storage/queue/_models.py | 21 +++++------ .../azure/storage/queue/_queue_client.py | 2 +- .../storage/queue/_queue_client_helpers.py | 31 ++++++++++++++-- .../storage/queue/_queue_service_client.py | 5 ++- .../queue/_queue_service_client_helpers.py | 27 +++++++++++++- .../storage/queue/_shared/authentication.py | 6 ++-- .../storage/queue/_shared/base_client.py | 2 +- .../azure/storage/queue/aio/_models.py | 23 +++++------- .../storage/queue/aio/_queue_client_async.py | 5 ++- .../queue/aio/_queue_service_client_async.py | 3 +- 12 files changed, 102 insertions(+), 64 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index bba92353f827..39218b7b8928 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -15,7 +15,7 @@ dumps, loads, ) -from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend @@ -52,7 +52,7 @@ class KeyEncryptionKey(Protocol): def wrap_key(self, key): ... - + def unwrap_key(self, key, algorithm): ... @@ -297,7 +297,7 @@ def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: # If encryption_data is None, assume no encryption if encryption_data and (encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2): return True - + return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 523288afe8a8..daf73f8ac731 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -6,11 +6,11 @@ # pylint: disable=unused-argument from base64 import b64encode, b64decode -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union from azure.core.exceptions import DecodeError -from ._encryption import decrypt_queue_message, encrypt_queue_message, _ENCRYPTION_PROTOCOL_V1 +from ._encryption import decrypt_queue_message, encrypt_queue_message, _ENCRYPTION_PROTOCOL_V1, KeyEncryptionKey if TYPE_CHECKING: from azure.core.pipeline import PipelineResponse @@ -22,7 +22,7 @@ class MessageEncodePolicy(object): """Indicates whether a retention policy is enabled for the storage service.""" encryption_version: Optional[str] = None """Indicates whether a retention policy is enabled for the storage service.""" - key_encryption_key: Optional[object] = None + key_encryption_key: Optional[KeyEncryptionKey] = None """Indicates whether a retention policy is enabled for the storage service.""" resolver: Optional[Callable[[str], bytes]] = None """Indicates whether a retention policy is enabled for the storage service.""" @@ -42,7 +42,7 @@ def __call__(self, content: Any) -> str: def configure( self, require_encryption: bool, - key_encryption_key: object, + key_encryption_key: KeyEncryptionKey, resolver: Callable[[str], bytes], encryption_version: str = _ENCRYPTION_PROTOCOL_V1 ) -> None: @@ -64,22 +64,21 @@ def __init__(self): self.key_encryption_key = None self.resolver = None - def __call__(self, response: "PipelineResponse", obj: object, headers: Dict[str, Any]) -> object: - if hasattr(obj, '__iter__'): - for message in obj: - if message.message_text in [None, "", b""]: - continue - content = message.message_text - if (self.key_encryption_key is not None) or (self.resolver is not None): - content = decrypt_queue_message( - content, response, - self.require_encryption, - self.key_encryption_key, - self.resolver) - message.message_text = self.decode(content, response) + def __call__(self, response: "PipelineResponse", obj: Iterable, headers: Dict[str, Any]) -> object: + for message in obj: + if message.message_text in [None, "", b""]: + continue + content = message.message_text + if (self.key_encryption_key is not None) or (self.resolver is not None): + content = decrypt_queue_message( + content, response, + self.require_encryption, + self.key_encryption_key, + self.resolver) + message.message_text = self.decode(content, response) return obj - def configure(self, require_encryption: bool, key_encryption_key: object, resolver: Callable[[str], bytes]) -> None: + def configure(self, require_encryption: bool, key_encryption_key: KeyEncryptionKey, resolver: Callable[[str], bytes]) -> None: self.require_encryption = require_encryption self.key_encryption_key = key_encryption_key self.resolver = resolver diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 61ad5602c02e..715e75bb2055 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -424,6 +424,8 @@ class QueuePropertiesPaged(PageIterator): """The location mode being used to list results. The available options include "primary" and "secondary".""" command: Callable """Function to retrieve the next page of items.""" + _response: Any + """Function to retrieve the next page of items.""" def __init__( self, command: Callable, @@ -455,19 +457,12 @@ def _get_next_cb(self, continuation_token: Optional[str]) -> Any: def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return - if self._response is not None: - if hasattr(self._response, 'service_endpoint'): - self.service_endpoint = self._response.service_endpoint - if hasattr(self._response, 'prefix'): - self.prefix = self._response.prefix - if hasattr(self._response, 'marker'): - self.marker = self._response.marker - if hasattr(self._response, 'max_results'): - self.results_per_page = self._response.max_results - if hasattr(self._response, 'queue_items'): - props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access - if hasattr(self._response, 'next_marker'): - next_marker = self._response.next_marker + self.service_endpoint = self._response.service_endpoint + self.prefix = self._response.prefix + self.marker = self._response.marker + self.results_per_page = self._response.max_results + props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + next_marker = self._response.next_marker return next_marker or None, props_list diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 1ef19c08a9c0..b151871cc637 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -92,7 +92,7 @@ def __init__( self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) - self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access + self._client._config.version = get_api_version(kwargs) self._configure_encryption(kwargs) def _format_url(self, hostname: str) -> str: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 7c77dcbda1bb..0377d439f3b2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -4,10 +4,37 @@ # license information. # -------------------------------------------------------------------------- +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union from urllib.parse import urlparse from ._shared.base_client import parse_query -def _parse_url(account_url, queue_name, credential): +if TYPE_CHECKING: + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from urllib.parse import ParseResult + + +def _parse_url( + account_url: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] # pylint: disable=line-too-long +) -> Tuple["ParseResult", Any]: + """Performs initial input validation and returns the parsed URL and SAS token. + + :param str account_url: The URL to the storage account. + :param str queue_name: The name of the queue. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential + - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long + :returns: The parsed URL and SAS token. + :rtype: Tuple[ParseResult, Any] + """ try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url @@ -23,4 +50,4 @@ def _parse_url(account_url, queue_name, credential): if not sas_token and not credential: raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") - return parsed_url, sas_token + return parsed_url, sas_token \ No newline at end of file diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 33a969c4941a..9104d56d7a2f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -235,11 +235,10 @@ def set_service_properties( You can include up to five CorsRule elements in the list. If an empty list is specified, all CORS rules will be deleted, and CORS will be disabled for the service. - :type cors: list(~azure.storage.queue.CorsRule) + :type cors: List[~azure.storage.queue.CorsRule] :keyword int timeout: The timeout parameter is expressed in seconds. :returns: None - :rtype: None .. admonition:: Example: @@ -258,7 +257,7 @@ def set_service_properties( cors=cors ) try: - return self._client.service.set_properties(props, timeout=timeout, **kwargs) + self._client.service.set_properties(props, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py index 24506fb909e7..a9a44407a94e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py @@ -4,10 +4,35 @@ # license information. # -------------------------------------------------------------------------- +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union from urllib.parse import urlparse from ._shared.base_client import parse_query -def _parse_url(account_url, credential): +if TYPE_CHECKING: + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from urllib.parse import ParseResult + + +def _parse_url( + account_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] # pylint: disable=line-too-long +) -> Tuple["ParseResult", Any]: + """Performs initial input validation and returns the parsed URL and SAS token. + + :param str account_url: The URL to the storage account. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential + - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long + :returns: The parsed URL and SAS token. + :rtype: Tuple[ParseResult, Any] + """ try: if not account_url.lower().startswith('http'): account_url = "https://" + account_url diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py index 89cbe4315d43..b2a1b1d775dd 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py @@ -6,7 +6,7 @@ import logging import re -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from urllib.parse import unquote, urlparse try: @@ -35,7 +35,7 @@ def _wrap_exception(ex, desired_type): return desired_type(msg) # This method attempts to emulate the sorting done by the service -def _storage_header_sort(input_headers: List[Tuple[str, Optional[str]]]) -> List[Tuple[str, Optional[str]]]: +def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]: # Define the custom alphabet for weights custom_weights = "-!#$%&*.^_|~+\"\'(),/`~0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz{}" @@ -56,7 +56,7 @@ def _storage_header_sort(input_headers: List[Tuple[str, Optional[str]]]) -> List sorted_headers = [] for key in header_keys: sorted_headers.append((key, header_dict.get(key))) - return sorted_headers + return sorted_headers # type: ignore [return-value] class AzureSigningError(ClientAuthenticationError): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 67aa85367ee9..92a814b2c566 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -258,7 +258,7 @@ def _create_pipeline(self, credential, **kwargs): ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore - config.transport = transport #type: ignore + config.transport = transport # type: ignore return config, Pipeline(transport, policies=policies) # Given a series of request, do a Storage batch call. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index f9788a252a25..027ab28d372f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -95,6 +95,8 @@ class QueuePropertiesPaged(AsyncPageIterator): """The location mode being used to list results. The available options include "primary" and "secondary".""" command: Callable """Function to retrieve the next page of items.""" + _response: Any + """Function to retrieve the next page of items.""" def __init__( self, command: Callable, @@ -124,19 +126,12 @@ async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[Any], List[QueueProperties]]: + async def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return - if self._response is not None: - if hasattr(self._response, 'service_endpoint'): - self.service_endpoint = self._response.service_endpoint - if hasattr(self._response, 'prefix'): - self.prefix = self._response.prefix - if hasattr(self._response, 'marker'): - self.marker = self._response.marker - if hasattr(self._response, 'max_results'): - self.results_per_page = self._response.max_results - if hasattr(self._response, 'queue_items'): - props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access - if hasattr(self._response, 'next_marker'): - next_marker = self._response.next_marker + self.service_endpoint = self._response.service_endpoint + self.prefix = self._response.prefix + self.marker = self._response.marker + self.results_per_page = self._response.max_results + props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + next_marker = self._response.next_marker return next_marker or None, props_list diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 6c4c6d7b6973..34c8a88f0662 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- # pylint: disable=invalid-overridden-method -# mypy: disable-error-code="misc" import functools import warnings @@ -45,7 +44,7 @@ from .._models import QueueProperties -class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): +class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] """A client to interact with a specific Queue. :param str account_url: @@ -213,7 +212,7 @@ def from_connection_string( return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @distributed_trace_async - async def create_queue( # type: ignore[override] + async def create_queue( self, *, metadata: Optional[Dict[str, str]] = None, **kwargs: Any diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index ef2aaf307bc5..c4497332203f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- # pylint: disable=invalid-overridden-method -# mypy: disable-error-code="misc" import functools from typing import ( @@ -40,7 +39,7 @@ from .._models import Metrics, QueueAnalyticsLogging -class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): +class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] """A client to interact with the Queue Service at the account level. This client provides operations to retrieve and configure the account properties From e317e544d1ade331110ab0dd7d95adde77beedfb Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Fri, 1 Sep 2023 17:00:43 -0700 Subject: [PATCH 29/71] More comments addressed, need bc_async, policies, and samples + output --- .../azure/storage/queue/_encryption.py | 26 ++-- .../azure/storage/queue/_queue_client.py | 2 +- .../storage/queue/_shared/authentication.py | 4 +- .../storage/queue/_shared/base_client.py | 126 +++++++++--------- .../queue/_shared/base_client_async.py | 22 ++- .../azure/storage/queue/_shared/models.py | 50 +++++++ 6 files changed, 139 insertions(+), 91 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 39218b7b8928..77f383ce5699 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -396,10 +396,12 @@ def get_adjusted_download_range_and_offset( elif encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: start_offset, end_offset = 0, end - if encryption_data.encrypted_region_info is not None: - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length + if encryption_data.encrypted_region_info is None: + raise ValueError("Field required for V2 encryption is missing (encrypted_region_info).") + + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length requested_length = end - start @@ -709,8 +711,10 @@ def _decrypt_message( if not block_info or not block_info.nonce_length: raise ValueError("Missing required metadata for decryption.") - if encryption_data.encrypted_region_info is not None: - nonce_length = encryption_data.encrypted_region_info.nonce_length + if encryption_data.encrypted_region_info is None: + raise ValueError("Field required for V2 encryption is missing (encrypted_region_info).") + + nonce_length = encryption_data.encrypted_region_info.nonce_length # First bytes are the nonce message_as_bytes = message.encode('utf-8') @@ -925,10 +929,12 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements total_size = len(content) offset = 0 - if encryption_data.encrypted_region_info is not None: - nonce_length = encryption_data.encrypted_region_info.nonce_length - data_length = encryption_data.encrypted_region_info.data_length - tag_length = encryption_data.encrypted_region_info.tag_length + if encryption_data.encrypted_region_info is None: + raise ValueError("Field required for V2 encryption is missing (encrypted_region_info).") + + nonce_length = encryption_data.encrypted_region_info.nonce_length + data_length = encryption_data.encrypted_region_info.data_length + tag_length = encryption_data.encrypted_region_info.tag_length region_length = nonce_length + data_length + tag_length decrypted_content = bytearray() diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index b151871cc637..1ef19c08a9c0 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -92,7 +92,7 @@ def __init__( self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) - self._client._config.version = get_api_version(kwargs) + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._configure_encryption(kwargs) def _format_url(self, hostname: str) -> str: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py index b2a1b1d775dd..7491965f984b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py @@ -55,8 +55,8 @@ def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str # Build list of sorted tuples sorted_headers = [] for key in header_keys: - sorted_headers.append((key, header_dict.get(key))) - return sorted_headers # type: ignore [return-value] + sorted_headers.append((key, header_dict.pop(key))) + return sorted_headers class AzureSigningError(ClientAuthenticationError): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 92a814b2c566..6fd0e57fa21b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -20,7 +20,6 @@ from urlparse import parse_qs # type: ignore from urllib2 import quote # type: ignore -from azure.core.configuration import Configuration from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline @@ -37,7 +36,7 @@ ) from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE -from .models import LocationMode +from .models import LocationMode, StorageConfiguration from .authentication import SharedKeyCredentialPolicy from .shared_access_signature import QueryStringConstants from .request_handlers import serialize_batch_body, _get_batch_request_delimiter @@ -71,12 +70,11 @@ class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-attributes def __init__( self, - parsed_url, # type: Any - service, # type: str - credential=None, # type: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long - **kwargs # type: Any - ): - # type: (...) -> None + parsed_url: Any, + service: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts") self.scheme = parsed_url.scheme @@ -96,7 +94,7 @@ def __init__( raise ValueError("Token credential is only supported with HTTPS.") secondary_hostname = None - if hasattr(self.credential, "account_name"): + if hasattr(self.credential, "account_name") and self.credential is not None: self.account_name = self.credential.account_name secondary_hostname = f"{self.credential.account_name}-secondary.{service_name}.{SERVICE_HOST_BASE}" @@ -139,7 +137,7 @@ def url(self): def primary_endpoint(self): """The full primary endpoint URL. - :type: str + :rtype: str """ return self._format_url(self._hosts[LocationMode.PRIMARY]) @@ -147,7 +145,7 @@ def primary_endpoint(self): def primary_hostname(self): """The hostname of the primary endpoint. - :type: str + :rtype: str """ return self._hosts[LocationMode.PRIMARY] @@ -158,7 +156,7 @@ def secondary_endpoint(self): If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. - :type: str + :rtype: str :raise ValueError: """ if not self._hosts[LocationMode.SECONDARY]: @@ -172,7 +170,7 @@ def secondary_hostname(self): If not available this will be None. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. - :type: str or None + :rtype: Optional[str] """ return self._hosts[LocationMode.SECONDARY] @@ -182,7 +180,7 @@ def location_mode(self): By default this will be "primary". Options include "primary" and "secondary". - :type: str + :rtype: str """ return self._location_mode @@ -199,11 +197,16 @@ def location_mode(self, value): def api_version(self): """The version of the Storage API used for requests. - :type: str + :rtype: str """ return self._client._config.version # pylint: disable=protected-access - def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None): + def _format_query_string( + self, sas_token: Optional[str], + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]], # pylint: disable=line-too-long + snapshot: Optional[str] = None, + share_snapshot: Optional[str] = None + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]]]: # pylint: disable=line-too-long query_str = "?" if snapshot: query_str += f"snapshot={self.snapshot}&" @@ -213,17 +216,19 @@ def _format_query_string(self, sas_token, credential, snapshot=None, share_snaps raise ValueError( "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") if is_credential_sastoken(credential): - query_str += credential.lstrip("?") + query_str += credential.lstrip("?") # type: ignore [union-attr] credential = None elif sas_token: query_str += sas_token return query_str.rstrip("?&"), credential - def _create_pipeline(self, credential, **kwargs): - # type: (Any, **Any) -> Tuple[Configuration, Pipeline] + def _create_pipeline( + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None if hasattr(credential, "get_token"): - self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) + self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) # type: ignore [arg-type] elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -265,29 +270,26 @@ def _create_pipeline(self, credential, **kwargs): def _batch_send( self, *reqs: "HttpRequest", - **kwargs + **kwargs: Any ) -> None: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :rtype: None - :returns: None """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) batch_id = str(uuid.uuid1()) - if hasattr(self, '_client'): - request = self._client._client.post( # pylint: disable=protected-access - url=( - f'{self.scheme}://{self.primary_hostname}/' - f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" - f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" - ), - headers={ - 'x-ms-version': self.api_version, - "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) - } + request = self._client._client.post( # pylint: disable=protected-access + url=( + f'{self.scheme}://{self.primary_hostname}/' + f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" + f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" + ), + headers={ + 'x-ms-version': self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) + } ) policies = [StorageHeadersPolicy()] @@ -354,7 +356,10 @@ def __exit__(self, *args): # pylint: disable=arguments-differ pass -def _format_shared_key_credential(account_name, credential): +def _format_shared_key_credential( + account_name: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long +) -> Any: if isinstance(credential, str): if not account_name: raise ValueError("Unable to determine account name for shared key credential.") @@ -370,12 +375,16 @@ def _format_shared_key_credential(account_name, credential): return credential -def parse_connection_str(conn_str, credential, service): +def parse_connection_str( + conn_str: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long + service: str +) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]]]: # pylint: disable=line-too-long conn_str = conn_str.rstrip(";") - conn_settings = [s.split("=", 1) for s in conn_str.split(";")] - if any(len(tup) != 2 for tup in conn_settings): + conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings_list): raise ValueError("Connection string is either blank or malformed.") - conn_settings = dict((key.upper(), val) for key, val in conn_settings) + conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) endpoints = _SERVICE_PARAMS[service] primary = None secondary = None @@ -418,9 +427,8 @@ def parse_connection_str(conn_str, credential, service): return primary, secondary, credential -def create_configuration(**kwargs): - # type: (**Any) -> Configuration - config: Configuration = Configuration(**kwargs) +def create_configuration(**kwargs: Any) -> StorageConfiguration: + config: StorageConfiguration = StorageConfiguration(**kwargs) config.headers_policy = StorageHeadersPolicy(**kwargs) config.user_agent_policy = UserAgentPolicy(sdk_moniker=kwargs.pop('sdk_moniker'), **kwargs) config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) @@ -428,40 +436,30 @@ def create_configuration(**kwargs): config.proxy_policy = ProxyPolicy(**kwargs) # Storage settings - if hasattr(config, 'max_single_put_size'): - config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024) - if hasattr(config, 'copy_polling_interval'): - config.copy_polling_interval = 15 + config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024) + config.copy_polling_interval = 15 # Block blob uploads - if hasattr(config, 'max_block_size'): - config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024) - if hasattr(config, 'min_large_block_upload_threshold'): - config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) - if hasattr(config, 'use_byte_buffer'): - config.use_byte_buffer = kwargs.get("use_byte_buffer", False) + config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024) + config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) + config.use_byte_buffer = kwargs.get("use_byte_buffer", False) # Page blob uploads - if hasattr(config, 'max_page_size'): - config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024) + config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024) # Datalake file uploads - if hasattr(config, 'min_large_chunk_upload_threshold'): - config.min_large_chunk_upload_threshold = kwargs.get("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) + config.min_large_chunk_upload_threshold = kwargs.get("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) # Blob downloads - if hasattr(config, 'max_single_get_size'): - config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024) - if hasattr(config, 'max_chunk_get_size'): - config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024) + config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024) + config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024) # File uploads - if hasattr(config, 'max_range_size'): - config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024) + config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024) return config -def parse_query(query_str): +def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]: sas_values = QueryStringConstants.to_list() parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()} sas_params = [f"{k}={quote(v, safe='')}" for k, v in parsed_query.items() if k in sas_values] @@ -473,7 +471,7 @@ def parse_query(query_str): return snapshot, sas_token -def is_credential_sastoken(credential): +def is_credential_sastoken(credential: Any) -> bool: if not credential or not isinstance(credential, str): return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 0c40b18c7349..b0b3a31b59e7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -88,8 +88,7 @@ def _create_pipeline(self, credential, **kwargs): except ImportError as exc: raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc transport = AioHttpTransport(**kwargs) - if hasattr(self, '_hosts'): - hosts = self._hosts + hosts = self._hosts policies = [ QueueMessagePolicy(), config.headers_policy, @@ -120,14 +119,10 @@ async def _batch_send( ): # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) - if hasattr(self, '_client'): - client = self._client - if hasattr(self, 'scheme'): - scheme = self.scheme - if hasattr(self, 'primary_hostname'): - primary_hostname = self.primary_hostname - if hasattr(self, 'api_version'): - api_version = self.api_version + client = self._client + scheme = self.scheme + primary_hostname = self.primary_hostname + api_version = self.api_version request = client._client.post( # pylint: disable=protected-access url=( f'{scheme}://{primary_hostname}/' @@ -149,10 +144,9 @@ async def _batch_send( enforce_https=False ) - if hasattr(self, '_pipeline'): - pipeline_response = await self._pipeline.run( - request, **kwargs - ) + pipeline_response = await self._pipeline.run( + request, **kwargs + ) response = pipeline_response.http_response try: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index b6519ebe9f81..06b76b0a2b17 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -7,6 +7,7 @@ from enum import Enum from azure.core import CaseInsensitiveEnumMeta +from azure.core.configuration import Configuration def get_enum_value(value): @@ -487,3 +488,52 @@ def __init__(self): self.signed_service = None self.signed_version = None self.value = None + + +class StorageConfiguration(Configuration): + """ + Specifies the configurable values used in Azure Storage. + + :param int max_single_put_size: If the blob size is less than or equal max_single_put_size, then the blob will be + uploaded with only one http PUT request. If the blob size is larger than max_single_put_size, + the blob will be uploaded in chunks. Defaults to 64*1024*1024, or 64MB. + :param int copy_polling_interval: The interval in seconds for polling copy operations. + :param int max_block_size: The maximum chunk size for uploading a block blob in chunks. + Defaults to 4*1024*1024, or 4MB. + :param int min_large_block_upload_threshold: The minimum chunk size required to use the memory efficient + algorithm when uploading a block blob. + :param bool use_byte_buffer: Use a byte buffer for block blob uploads. Defaults to False. + :param int max_page_size: The maximum chunk size for uploading a page blob. Defaults to 4*1024*1024, or 4MB. + :param int min_large_chunk_upload_threshold: The max size for a single put operation. + :param int max_single_get_size: The maximum size for a blob to be downloaded in a single call, + the exceeded part will be downloaded in chunks (could be parallel). Defaults to 32*1024*1024, or 32MB. + :param int max_chunk_get_size: The maximum chunk size used for downloading a blob. Defaults to 4*1024*1024, + or 4MB. + :param int max_range_size: The max range size for file upload. + + """ + def __init__( + self, + max_single_put_size = 64 * 1024 * 1024, + copy_polling_interval = 15, + max_block_size = 4 * 1024 * 1024, + min_large_block_upload_threshold = 4 * 1024 * 1024 + 1, + use_byte_buffer = False, + max_page_size = 4 * 1024 * 1024, + min_large_chunk_upload_threshold = 100 * 1024 * 1024 + 1, + max_single_get_size = 32 * 1024 * 1024, + max_chunk_get_size = 4 * 1024 * 1024, + max_range_size = 4 * 1024 * 1024, + **kwargs, + ): + super(StorageConfiguration, self).__init__(**kwargs) + self.max_single_put_size = max_single_put_size + self.copy_polling_interval = copy_polling_interval + self.max_block_size = max_block_size + self.min_large_block_upload_threshold = min_large_block_upload_threshold + self.use_byte_buffer = use_byte_buffer + self.max_page_size = max_page_size + self.min_large_chunk_upload_threshold = min_large_chunk_upload_threshold + self.max_single_get_size = max_single_get_size + self.max_chunk_get_size = max_chunk_get_size + self.max_range_size = max_range_size \ No newline at end of file From 512263ad1b72f04ce06c8beb34c51dcbf55d5fd9 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Tue, 5 Sep 2023 14:54:08 -0700 Subject: [PATCH 30/71] Just samples, added more helpers --- .../azure/storage/queue/_message_encoding.py | 8 +-- .../azure/storage/queue/_queue_client.py | 32 +-------- .../storage/queue/_queue_client_helpers.py | 72 +++++++++++++++++-- .../queue/_shared/base_client_async.py | 11 ++- .../azure/storage/queue/_shared/policies.py | 27 +++---- .../storage/queue/aio/_queue_client_async.py | 32 +-------- 6 files changed, 100 insertions(+), 82 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index daf73f8ac731..60b2a262860a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -19,13 +19,13 @@ class MessageEncodePolicy(object): require_encryption: Optional[bool] = None - """Indicates whether a retention policy is enabled for the storage service.""" + """Indicates whether encryption is required or not.""" encryption_version: Optional[str] = None - """Indicates whether a retention policy is enabled for the storage service.""" + """Indicates the version of encryption being used.""" key_encryption_key: Optional[KeyEncryptionKey] = None - """Indicates whether a retention policy is enabled for the storage service.""" + """The user-provided key-encryption-key.""" resolver: Optional[Callable[[str], bytes]] = None - """Indicates whether a retention policy is enabled for the storage service.""" + """The user-provided key resolver.""" def __init__(self) -> None: self.require_encryption = False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 5ee94846128c..2f6fde464f9a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -23,7 +23,7 @@ from ._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier from ._message_encoding import NoDecodePolicy, NoEncodePolicy from ._models import AccessPolicy, MessagesPaged, QueueMessage -from ._queue_client_helpers import _parse_url +from ._queue_client_helpers import _parse_url, _format_url_helper, _from_queue_url_helper from ._serialize import get_api_version from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin from ._shared.request_handlers import add_metadata_headers, serialize_iso @@ -103,13 +103,7 @@ def _format_url(self, hostname: str) -> str: :returns: The formatted endpoint URL according to the specified location mode hostname. :rtype: str """ - if isinstance(self.queue_name, str): - queue_name = self.queue_name.encode('UTF-8') - else: - queue_name = self.queue_name - return ( - f"{self.scheme}://{hostname}" - f"/{quote(queue_name)}{self._query_str}") + return _format_url_helper(queue_name=self.queue_name, hostname=hostname, scheme=self.scheme, query_str=self._query_str) # pylint: disable=line-too-long @classmethod def from_queue_url( @@ -133,27 +127,7 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - try: - if not queue_url.lower().startswith('http'): - queue_url = "https://" + queue_url - except AttributeError as exc: - raise ValueError("Queue URL must be a string.") from exc - parsed_url = urlparse(queue_url.rstrip('/')) - - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {queue_url}") - - queue_path = parsed_url.path.lstrip('/').split('/') - account_path = "" - if len(queue_path) > 1: - account_path = "/" + "/".join(queue_path[:-1]) - account_url = ( - f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" - f"{account_path}?{parsed_url.query}") - queue_name = unquote(queue_path[-1]) - if not queue_name: - raise ValueError("Invalid URL. Please provide a URL with a valid queue name") - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + return _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential, **kwargs) @classmethod def from_connection_string( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 0377d439f3b2..2bf6846c38fd 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -5,18 +5,19 @@ # -------------------------------------------------------------------------- from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union -from urllib.parse import urlparse +from urllib.parse import quote, unquote, urlparse from ._shared.base_client import parse_query if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials import AsyncTokenCredential, AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.storage.queue import QueueClient from urllib.parse import ParseResult def _parse_url( account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long ) -> Tuple["ParseResult", Any]: """Performs initial input validation and returns the parsed URL and SAS token. @@ -31,7 +32,7 @@ def _parse_url( - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long :returns: The parsed URL and SAS token. :rtype: Tuple[ParseResult, Any] """ @@ -50,4 +51,65 @@ def _parse_url( if not sas_token and not credential: raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") - return parsed_url, sas_token \ No newline at end of file + return parsed_url, sas_token + +def _format_url_helper(queue_name: Union[bytes, str], hostname: str, scheme: str, query_str: str) -> str: + """Format the endpoint URL according to the current location mode hostname. + + :param Union[bytes, str] queue_name: The name of the queue. + :param str hostname: The current location mode hostname. + :param str scheme: The scheme for the current location mode hostname. + :param str query_str: The query string of the endpoint URL being formatted. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str + """ + if isinstance(queue_name, str): + queue_name = queue_name.encode('UTF-8') + else: + pass + return ( + f"{scheme}://{hostname}" + f"/{quote(queue_name)}{query_str}") + +def _from_queue_url_helper( + cls, queue_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any +) -> "QueueClient": + """A client to interact with a specific Queue. + + :param str queue_url: The full URI to the queue, including SAS token if used. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential + - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long + :returns: A queue client. + :rtype: QueueClient + """ + try: + if not queue_url.lower().startswith('http'): + queue_url = "https://" + queue_url + except AttributeError as exc: + raise ValueError("Queue URL must be a string.") from exc + parsed_url = urlparse(queue_url.rstrip('/')) + + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {queue_url}") + + queue_path = parsed_url.path.lstrip('/').split('/') + account_path = "" + if len(queue_path) > 1: + account_path = "/" + "/".join(queue_path[:-1]) + account_url = ( + f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" + f"{account_path}?{parsed_url.query}") + queue_name = unquote(queue_path[-1]) + if not queue_name: + raise ValueError("Invalid URL. Please provide a URL with a valid queue name") + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index b0b3a31b59e7..bcbdfd88e9a0 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -10,7 +10,7 @@ ) import logging -from azure.core.credentials import AzureSasCredential +from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from azure.core.pipeline import AsyncPipeline from azure.core.async_paging import AsyncList from azure.core.exceptions import HttpResponseError @@ -35,11 +35,14 @@ StorageRequestHook, ) from .policies_async import AsyncStorageResponseHook +from .models import StorageConfiguration from .response_handlers import process_storage_error, PartialBatchErrorException if TYPE_CHECKING: from azure.core.pipeline.transport import HttpRequest + from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential from azure.core.configuration import Configuration _LOGGER = logging.getLogger(__name__) @@ -65,8 +68,10 @@ async def close(self): """ await self._client.close() - def _create_pipeline(self, credential, **kwargs): - # type: (Any, **Any) -> Tuple[Configuration, AsyncPipeline] + def _create_pipeline( + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[Union[AsyncBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy]] = None # pylint: disable=line-too-long if hasattr(credential, 'get_token'): self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 700515d7c210..d42eade5e1d3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -12,7 +12,7 @@ from io import SEEK_SET, UnsupportedOperation import logging import uuid -from typing import Any, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING from wsgiref.handlers import format_date_time try: from urllib.parse import ( @@ -294,8 +294,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument self._response_callback = kwargs.get('raw_response_hook') super(StorageResponseHook, self).__init__() - def send(self, request): - # type: ("PipelineRequest") -> "PipelineResponse" + def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 data_stream_total = request.context.get('data_stream_total') if data_stream_total is None: @@ -382,7 +381,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.context['validate_content_md5'] = computed_md5 request.context['validate_content'] = validate_content - def on_response(self, request, response): + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = request.context.get('validate_content_md5') or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) @@ -424,7 +423,7 @@ def _set_next_host_location(self, settings, request): updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) request.url = updated.geturl() - def configure_retries(self, request): + def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: body_position = None if hasattr(request.http_request.body, 'read'): try: @@ -605,13 +604,19 @@ def get_backoff_time(self, settings): class LinearRetry(StorageRetryPolicy): """Linear retry.""" - def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_jitter_range=3, **kwargs): + def __init__( + self, backoff: int = 15, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: """ Constructs a Linear retry object. :param int backoff: The backoff interval, in seconds, between retries. - :param int max_attempts: + :param int retry_total: The maximum number of retry attempts. :param bool retry_to_secondary: Whether the request should be retried to secondary, if able. This should @@ -626,7 +631,7 @@ def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_j super(LinearRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Optional[Dict[str, Any]]) -> Optional[int]: """ Calculates how long to sleep before retrying. @@ -648,12 +653,10 @@ def get_backoff_time(self, settings): class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """ Custom Bearer token credential policy for following Storage Bearer challenges """ - def __init__(self, credential, **kwargs): - # type: (TokenCredential, **Any) -> None + def __init__(self, credential: "TokenCredential", **kwargs) -> None: super(StorageBearerTokenCredentialPolicy, self).__init__(credential, STORAGE_OAUTH_SCOPE, **kwargs) - def on_challenge(self, request, response): - # type: ("PipelineRequest", "PipelineResponse") -> bool + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index d1ed37c2d22e..53e808e4b483 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -26,7 +26,7 @@ from .._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier from .._message_encoding import NoDecodePolicy, NoEncodePolicy from .._models import AccessPolicy, QueueMessage -from .._queue_client_helpers import _parse_url +from .._queue_client_helpers import _parse_url, _format_url_helper, _from_queue_url_helper from .._serialize import get_api_version from .._shared.base_client import parse_connection_str, StorageAccountHostsMixin from .._shared.base_client_async import AsyncStorageAccountHostsMixin @@ -118,13 +118,7 @@ def _format_url(self, hostname: str) -> str: :returns: The formatted endpoint URL according to the specified location mode hostname. :rtype: str """ - if isinstance(self.queue_name, str): - queue_name = self.queue_name.encode('UTF-8') - else: - queue_name = self.queue_name - return ( - f"{self.scheme}://{hostname}" - f"/{quote(queue_name)}{self._query_str}") + return _format_url_helper(queue_name=self.queue_name, hostname=hostname, scheme=self.scheme, query_str=self._query_str) # pylint: disable=line-too-long @classmethod def from_queue_url( @@ -148,27 +142,7 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - try: - if not queue_url.lower().startswith('http'): - queue_url = "https://" + queue_url - except AttributeError as exc: - raise ValueError("Queue URL must be a string.") from exc - parsed_url = urlparse(queue_url.rstrip('/')) - - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {queue_url}") - - queue_path = parsed_url.path.lstrip('/').split('/') - account_path = "" - if len(queue_path) > 1: - account_path = "/" + "/".join(queue_path[:-1]) - account_url = ( - f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" - f"{account_path}?{parsed_url.query}") - queue_name = unquote(queue_path[-1]) - if not queue_name: - raise ValueError("Invalid URL. Please provide a URL with a valid queue name") - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + return _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential, **kwargs) @classmethod def from_connection_string( From 2eb11de6de1b4ec523b81a8a93549716d2eea158 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 6 Sep 2023 17:18:01 -0700 Subject: [PATCH 31/71] More feedback --- .../azure/storage/queue/_encryption.py | 8 +-- .../azure/storage/queue/_message_encoding.py | 10 ++-- .../azure/storage/queue/_queue_client.py | 16 +++--- .../storage/queue/_queue_client_helpers.py | 13 +++-- .../storage/queue/_queue_service_client.py | 9 ++-- .../queue/_queue_service_client_helpers.py | 5 +- .../storage/queue/_shared/base_client.py | 11 ++-- .../queue/_shared/base_client_async.py | 3 +- .../azure/storage/queue/_shared/models.py | 51 ++++++++++--------- .../azure/storage/queue/_shared/policies.py | 2 +- .../storage/queue/aio/_queue_client_async.py | 19 +++---- .../queue/aio/_queue_service_client_async.py | 10 ++-- 12 files changed, 85 insertions(+), 72 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 77f383ce5699..5d8928f8e92f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -15,7 +15,7 @@ dumps, loads, ) -from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend @@ -700,11 +700,11 @@ def _decrypt_message( # decrypt data decrypted_data = message decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) + decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) #type: ignore # unpad data unpadder = PKCS7(128).unpadder() - decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) + decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) #type: ignore elif encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: block_info = encryption_data.encrypted_region_info @@ -972,7 +972,7 @@ def get_blob_encryptor_and_padder( encryptor = cipher.encryptor() padder = PKCS7(128).padder() if should_pad else None - return encryptor, padder + return encryptor, padder #type: ignore [return-value] def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, version: str) -> str: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 60b2a262860a..cfda11a90172 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -42,8 +42,8 @@ def __call__(self, content: Any) -> str: def configure( self, require_encryption: bool, - key_encryption_key: KeyEncryptionKey, - resolver: Callable[[str], bytes], + key_encryption_key: Optional[KeyEncryptionKey], + resolver: Optional[Callable[[str], bytes]], encryption_version: str = _ENCRYPTION_PROTOCOL_V1 ) -> None: self.require_encryption = require_encryption @@ -78,7 +78,11 @@ def __call__(self, response: "PipelineResponse", obj: Iterable, headers: Dict[st message.message_text = self.decode(content, response) return obj - def configure(self, require_encryption: bool, key_encryption_key: KeyEncryptionKey, resolver: Callable[[str], bytes]) -> None: + def configure( + self, require_encryption: bool, + key_encryption_key: Optional[KeyEncryptionKey], + resolver: Optional[Callable[[str], bytes]] + ) -> None: self.require_encryption = require_encryption self.key_encryption_key = key_encryption_key self.resolver = resolver diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 2f6fde464f9a..3cdfe5a535e2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials_async import AsyncTokenCredential from ._models import QueueProperties @@ -82,7 +83,7 @@ class QueueClient(StorageAccountHostsMixin, StorageEncryptionMixin): def __init__( self, account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: parsed_url, sas_token = _parse_url(account_url=account_url, queue_name=queue_name, credential=credential) @@ -127,13 +128,14 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - return _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential, **kwargs) + account_url, queue_name = _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential) + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @classmethod def from_connection_string( cls, conn_str: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueClient from a Connection String. @@ -151,7 +153,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient @@ -168,7 +170,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) #type: ignore [arg-type] @distributed_trace def create_queue( @@ -741,14 +743,14 @@ def update_message( self.encryption_version, kwargs) - try: + if isinstance(message, QueueMessage): message_id = message.id message_text = content or message.content receipt = pop_receipt or message.pop_receipt inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - except AttributeError: + else: message_id = message message_text = content receipt = pop_receipt diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 2bf6846c38fd..cab051987b8d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -9,8 +9,8 @@ from ._shared.base_client import parse_query if TYPE_CHECKING: - from azure.core.credentials import AsyncTokenCredential, AzureNamedKeyCredential, AzureSasCredential, TokenCredential - from azure.storage.queue import QueueClient + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials_async import AsyncTokenCredential from urllib.parse import ParseResult @@ -74,8 +74,7 @@ def _format_url_helper(queue_name: Union[bytes, str], hostname: str, scheme: str def _from_queue_url_helper( cls, queue_url: str, credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any -) -> "QueueClient": +) -> Tuple[str, str]: """A client to interact with a specific Queue. :param str queue_url: The full URI to the queue, including SAS token if used. @@ -89,8 +88,8 @@ def _from_queue_url_helper( If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. :paramtype Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long - :returns: A queue client. - :rtype: QueueClient + :returns: The parsed out account_url and queue name. + :rtype: Tuple[str, str] """ try: if not queue_url.lower().startswith('http'): @@ -112,4 +111,4 @@ def _from_queue_url_helper( queue_name = unquote(queue_path[-1]) if not queue_name: raise ValueError("Invalid URL. Please provide a URL with a valid queue name") - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + return(account_url, queue_name) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 9104d56d7a2f..85f25c7221b4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials_async import AsyncTokenCredential from ._generated.models import CorsRule from ._models import Metrics, QueueAnalyticsLogging @@ -88,7 +89,7 @@ class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): def __init__( self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) @@ -111,7 +112,7 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_connection_string( cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueServiceClient from a Connection String. @@ -127,7 +128,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A Queue service client. :rtype: ~azure.storage.queue.QueueClient @@ -144,7 +145,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, credential=credential, **kwargs) + return cls(account_url, credential=credential, **kwargs) #type: ignore [arg-type] @distributed_trace def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py index a9a44407a94e..84e3c0526484 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py @@ -10,12 +10,13 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials_async import AsyncTokenCredential from urllib.parse import ParseResult def _parse_url( account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long ) -> Tuple["ParseResult", Any]: """Performs initial input validation and returns the parsed URL and SAS token. @@ -29,7 +30,7 @@ def _parse_url( - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long :returns: The parsed URL and SAS token. :rtype: Tuple[ParseResult, Any] """ diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 6fd0e57fa21b..27cb842b1ec9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -68,11 +68,12 @@ class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-attributes + _client: Any def __init__( self, parsed_url: Any, service: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) @@ -203,15 +204,15 @@ def api_version(self): def _format_query_string( self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long snapshot: Optional[str] = None, share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]]]: # pylint: disable=line-too-long + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]]]: # pylint: disable=line-too-long query_str = "?" if snapshot: - query_str += f"snapshot={self.snapshot}&" + query_str += f"snapshot={snapshot}&" if share_snapshot: - query_str += f"sharesnapshot={self.snapshot}&" + query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index bcbdfd88e9a0..39077e991c21 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +# mypy: disable-error-code="attr-defined" from typing import ( # pylint: disable=unused-import Union, Optional, Any, Iterable, Dict, List, Type, Tuple, @@ -74,7 +75,7 @@ def _create_pipeline( ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[Union[AsyncBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy]] = None # pylint: disable=line-too-long if hasattr(credential, 'get_token'): - self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) + self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) #type: ignore [arg-type] elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 06b76b0a2b17..b3b1f70b133d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -8,6 +8,8 @@ from azure.core import CaseInsensitiveEnumMeta from azure.core.configuration import Configuration +from azure.core.pipeline.policies import UserAgentPolicy +from .._message_encoding import MessageDecodePolicy def get_enum_value(value): @@ -512,28 +514,29 @@ class StorageConfiguration(Configuration): :param int max_range_size: The max range size for file upload. """ - def __init__( - self, - max_single_put_size = 64 * 1024 * 1024, - copy_polling_interval = 15, - max_block_size = 4 * 1024 * 1024, - min_large_block_upload_threshold = 4 * 1024 * 1024 + 1, - use_byte_buffer = False, - max_page_size = 4 * 1024 * 1024, - min_large_chunk_upload_threshold = 100 * 1024 * 1024 + 1, - max_single_get_size = 32 * 1024 * 1024, - max_chunk_get_size = 4 * 1024 * 1024, - max_range_size = 4 * 1024 * 1024, - **kwargs, - ): + + max_single_put_size: int + copy_polling_interval: int + max_block_size: int + min_large_block_upload_threshold: int + use_byte_buffer: bool + max_page_size: int + min_large_chunk_upload_threshold: int + max_single_get_size: int + max_chunk_get_size: int + max_range_size: int + user_agent_policy: UserAgentPolicy + message_decode_policy: MessageDecodePolicy + + def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) - self.max_single_put_size = max_single_put_size - self.copy_polling_interval = copy_polling_interval - self.max_block_size = max_block_size - self.min_large_block_upload_threshold = min_large_block_upload_threshold - self.use_byte_buffer = use_byte_buffer - self.max_page_size = max_page_size - self.min_large_chunk_upload_threshold = min_large_chunk_upload_threshold - self.max_single_get_size = max_single_get_size - self.max_chunk_get_size = max_chunk_get_size - self.max_range_size = max_range_size \ No newline at end of file + self.max_single_put_size = 64 * 1024 * 1024 + self.copy_polling_interval = 15 + self.max_block_size = 4 * 1024 * 1024, + self.min_large_block_upload_threshold = 4 * 1024 * 1024 + 1, + self.use_byte_buffer = False + self.max_page_size = 4 * 1024 * 1024 + self.min_large_chunk_upload_threshold = 100 * 1024 * 1024 + 1 + self.max_single_get_size = 32 * 1024 * 1024 + self.max_chunk_get_size = 4 * 1024 * 1024 + self.max_range_size = 4 * 1024 * 1024 diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index d42eade5e1d3..fb362ab097f2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -647,7 +647,7 @@ def get_backoff_time(self, settings: Optional[Dict[str, Any]]) -> Optional[int]: random_range_start = self.backoff - self.random_jitter_range \ if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range - return random_generator.uniform(random_range_start, random_range_end) + return int(random_generator.uniform(random_range_start, random_range_end)) class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 53e808e4b483..a5f94068e69f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -39,7 +39,7 @@ ) if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential from azure.core.credentials_async import AsyncTokenCredential from .._models import QueueProperties @@ -93,7 +93,7 @@ class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, Stora def __init__( self, account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) @@ -142,13 +142,14 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - return _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential, **kwargs) + account_url, queue_name = _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential) + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @classmethod def from_connection_string( cls, conn_str: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueClient from a Connection String. @@ -166,7 +167,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient @@ -183,7 +184,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) #type: ignore [arg-type] @distributed_trace_async async def create_queue( @@ -673,7 +674,7 @@ def receive_messages( # type: ignore[override] process_storage_error(error) @distributed_trace_async - async def update_message( # type: ignore[override] + async def update_message( self, message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, content: Optional[Any] = None, @@ -739,14 +740,14 @@ async def update_message( # type: ignore[override] self.encryption_version, kwargs) - try: + if isinstance(message, QueueMessage): message_id = message.id message_text = content or message.content receipt = pop_receipt or message.pop_receipt inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - except AttributeError: + else: message_id = message message_text = content receipt = pop_receipt diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index c4497332203f..14a46b7b2427 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -33,7 +33,7 @@ from .._shared.response_handlers import process_storage_error if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential from azure.core.credentials_async import AsyncTokenCredential from .._generated.models import CorsRule from .._models import Metrics, QueueAnalyticsLogging @@ -85,7 +85,7 @@ class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin def __init__( self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) @@ -111,7 +111,7 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_connection_string( cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueServiceClient from a Connection String. @@ -127,7 +127,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A Queue service client. :rtype: ~azure.storage.queue.QueueClient @@ -144,7 +144,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, credential=credential, **kwargs) + return cls(account_url, credential=credential, **kwargs) #type: ignore [arg-type] @distributed_trace_async async def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: From 721fc5431f822dc9492ac60a5016bed629cb9a8b Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 6 Sep 2023 20:24:16 -0700 Subject: [PATCH 32/71] More fixes --- .../azure/storage/queue/_models.py | 2 +- .../azure/storage/queue/_shared/models.py | 72 ++++++++++++------ .../azure/storage/queue/_shared/policies.py | 74 ++++++++++++------- .../storage/queue/_shared/policies_async.py | 4 +- 4 files changed, 102 insertions(+), 50 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 715e75bb2055..d4f11a18a68b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -78,7 +78,7 @@ class QueueAnalyticsLogging(GeneratedLogging): delete: bool = False """Indicates whether all delete requests should be logged.""" read: bool = False - """Indicates whether all read requests should be logged""" + """Indicates whether all read requests should be logged.""" write: bool = False """Indicates whether all write requests should be logged.""" retention_policy: RetentionPolicy = RetentionPolicy() diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index b3b1f70b133d..93c88875dbc7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -5,6 +5,7 @@ # -------------------------------------------------------------------------- # pylint: disable=too-many-instance-attributes from enum import Enum +from typing import Optional from azure.core import CaseInsensitiveEnumMeta from azure.core.configuration import Configuration @@ -272,6 +273,11 @@ class ResourceTypes(object): files(e.g. Put Blob, Query Entity, Get Messages, Create File, etc.) """ + service: bool = False + container: bool = False + object: bool = False + _str: str + def __init__(self, service=False, container=False, object=False): # pylint: disable=redefined-builtin self.service = service self.container = container @@ -347,9 +353,28 @@ class AccountSasPermissions(object): To enable permanent delete on the blob is permitted. Valid for Object resource type of Blob only. """ - def __init__(self, read=False, write=False, delete=False, - list=False, # pylint: disable=redefined-builtin - add=False, create=False, update=False, process=False, delete_previous_version=False, **kwargs): + + read: bool = False + write: bool = False + delete: bool = False + delete_previous_version: bool = False + list: bool = False + add: bool = False + create: bool = False + update: bool = False + process: bool = False + tag: bool = False + filter_by_tags: bool = False + set_immutability_policy: bool = False + permanent_delete: bool = False + + def __init__( + self, read: bool = False, write: bool = False, + delete: bool = False, list: bool = False, # pylint: disable=redefined-builtin + add:bool = False, create:bool = False, update:bool =False, + process:bool = False, delete_previous_version:bool = False, + **kwargs + ) -> None: self.read = read self.write = write self.delete = delete @@ -426,7 +451,11 @@ class Services(object): Access for the `~azure.storage.fileshare.ShareServiceClient` """ - def __init__(self, blob=False, queue=False, fileshare=False): + blob: bool = False + queue: bool = False + fileshare: bool = False + + def __init__(self, blob: bool = False, queue: bool = False, fileshare: bool = False): self.blob = blob self.queue = queue self.fileshare = fileshare @@ -466,22 +495,23 @@ class UserDelegationKey(object): The fields are saved as simple strings since the user does not have to interact with this object; to generate an identify SAS, the user can simply pass it to the right API. - - :ivar str signed_oid: - Object ID of this token. - :ivar str signed_tid: - Tenant ID of the tenant that issued this token. - :ivar str signed_start: - The datetime this token becomes valid. - :ivar str signed_expiry: - The datetime this token expires. - :ivar str signed_service: - What service this key is valid for. - :ivar str signed_version: - The version identifier of the REST service that created this token. - :ivar str value: - The user delegation key. """ + + signed_oid: Optional[str] = None + """Object ID of this token.""" + signed_tid: Optional[str] = None + """Tenant ID of the tenant that issued this token.""" + signed_start: Optional[str] = None + """The datetime this token becomes valid.""" + signed_expiry: Optional[str] = None + """The datetime this token expires.""" + signed_service: Optional[str] = None + """What service this key is valid for.""" + signed_version: Optional[str] = None + """The version identifier of the REST service that created this token.""" + value: Optional[str] = None + """The user delegation key.""" + def __init__(self): self.signed_oid = None self.signed_tid = None @@ -532,8 +562,8 @@ def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) self.max_single_put_size = 64 * 1024 * 1024 self.copy_polling_interval = 15 - self.max_block_size = 4 * 1024 * 1024, - self.min_large_block_upload_threshold = 4 * 1024 * 1024 + 1, + self.max_block_size = 4 * 1024 * 1024 + self.min_large_block_upload_threshold = 4 * 1024 * 1024 + 1 self.use_byte_buffer = False self.max_page_size = 4 * 1024 * 1024 self.min_large_chunk_upload_threshold = 100 * 1024 * 1024 + 1 diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index fb362ab097f2..77525de48539 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -12,7 +12,7 @@ from io import SEEK_SET, UnsupportedOperation import logging import uuid -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, Union, TYPE_CHECKING from wsgiref.handlers import format_date_time try: from urllib.parse import ( @@ -195,7 +195,11 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): This accepts both global configuration, and per-request level with "enable_http_logger" """ - def __init__(self, logging_enable=False, **kwargs): + + logging_enable: bool = False + """Whether logging should be enabled or not.""" + + def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) @@ -398,7 +402,18 @@ class StorageRetryPolicy(HTTPPolicy): The base class for Exponential and Linear retries containing shared code. """ - def __init__(self, **kwargs): + total_retries: int = 10 + """The max number of retries.""" + connect_retries: int = 3 + """The max number of connect retries.""" + retry_read: int = 3 + """The max number of read retries.""" + retry_status:int = 3 + """The max number of status retries.""" + retry_to_secondary: bool = False + """Whether the secondary endpoint should be retried.""" + + def __init__(self, **kwargs: Any) -> None: self.total_retries = kwargs.pop('retry_total', 10) self.connect_retries = kwargs.pop('retry_connect', 3) self.read_retries = kwargs.pop('retry_read', 3) @@ -406,12 +421,12 @@ def __init__(self, **kwargs): self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings, request): + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the next host location. - :param "PipelineRequest" request: A pipeline request object. + Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param PipelineRequest request: A pipeline request object. """ if settings['hosts'] and all(settings['hosts'].values()): url = urlparse(request.url) @@ -446,11 +461,11 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: 'history': [] } - def get_backoff_time(self, settings): # pylint: disable=unused-argument + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """ Formula for computing the current backoff. Should be calculated by child class. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: The backoff time. :rtype: float """ @@ -462,15 +477,19 @@ def sleep(self, settings, transport): return transport.sleep(backoff) - def increment(self, settings, request, response=None, error=None): + def increment( + self, settings: Dict[str, Any], + request: "PipelineRequest", response: Optional["PipelineResponse"] = None, + error: Optional[Union[ServiceRequestError, ServiceResponseError]] = None + ) -> bool: """Increment the retry counters. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the increment operation. - :param "PipelineRequest" request: A pipeline request object. - :param "PipelineResponse": A pipeline response object. + Dict[str, Any]] settings: The configurable values pertaining to the increment operation. + :param PipelineRequest request: A pipeline request object. + :param Optional[PipelineResponse] response: A pipeline response object. :param error: An error encountered during the request, or None if the response was received successfully. - :paramtype error: Union[ServiceRequestError, ServiceResponseError] + :paramtype error: Optional[Union[ServiceRequestError, ServiceResponseError]] :returns: Whether the retry attempts are exhausted. :rtype: bool """ @@ -491,7 +510,7 @@ def increment(self, settings, request, response=None, error=None): else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist - if response: + if response is not None: settings['status'] -= 1 settings['history'].append(RequestHistory(request, http_response=response)) @@ -556,9 +575,12 @@ def send(self, request): class ExponentialRetry(StorageRetryPolicy): """Exponential retry.""" - def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, - retry_to_secondary=False, random_jitter_range=3, **kwargs): - ''' + def __init__( + self, initial_backoff: int = 15, + increment_base: int = 3, retry_total: int = 3, + retry_to_secondary: bool = False, random_jitter_range: int = 3, **kwargs + ) -> None: + """ Constructs an Exponential retry object. The initial_backoff is used for the first retry. Subsequent retries are retried after initial_backoff + increment_power^retry_count seconds. @@ -568,7 +590,7 @@ def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, :param int increment_base: The base, in seconds, to increment the initial_backoff by after the first retry. - :param int max_attempts: + :param int retry_total: The maximum number of retry attempts. :param bool retry_to_secondary: Whether the request should be retried to secondary, if able. This should @@ -577,22 +599,22 @@ def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, :param int random_jitter_range: A number in seconds which indicates a range to jitter/randomize for the back-off interval. For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. - ''' + """ self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range super(ExponentialRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to get backoff time. + Dict[str, Any]] settings: The configurable values pertaining to get backoff time. :returns: - An integer indicating how long to wait before retrying the request, + A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. - :rtype: int or None + :rtype: float """ random_generator = random.Random() backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) @@ -631,11 +653,11 @@ def __init__( super(LinearRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings: Optional[Dict[str, Any]]) -> Optional[int]: + def get_backoff_time(self, settings: Dict[str, Any]) -> Optional[int]: """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. @@ -647,7 +669,7 @@ def get_backoff_time(self, settings: Optional[Dict[str, Any]]) -> Optional[int]: random_range_start = self.backoff - self.random_jitter_range \ if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range - return int(random_generator.uniform(random_range_start, random_range_end)) + return random_generator.uniform(random_range_start, random_range_end) class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 897c42dd24d9..2c41ef44d4f3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -179,7 +179,7 @@ def get_backoff_time(self, settings): """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. @@ -220,7 +220,7 @@ def get_backoff_time(self, settings): """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. From 3c6fff175bdb08e8b85f21c3dbd79f4c4e32f7b4 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Thu, 7 Sep 2023 13:38:09 -0700 Subject: [PATCH 33/71] Lint --- .../azure/storage/queue/_encryption.py | 2 +- .../azure/storage/queue/_queue_client.py | 3 +-- .../azure/storage/queue/_queue_client_helpers.py | 15 +-------------- .../azure/storage/queue/_shared/base_client.py | 13 ++++++++----- .../storage/queue/_shared/base_client_async.py | 4 ++-- .../azure/storage/queue/_shared/models.py | 2 +- .../azure/storage/queue/_shared/policies.py | 8 ++++---- .../azure/storage/queue/_shared/policies_async.py | 4 ++-- .../storage/queue/aio/_queue_client_async.py | 5 ++--- .../queue/aio/_queue_service_client_async.py | 2 +- 10 files changed, 23 insertions(+), 35 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 5d8928f8e92f..868e4093ca84 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -15,7 +15,7 @@ dumps, loads, ) -from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 3cdfe5a535e2..4f8d64e9103d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -10,7 +10,6 @@ Any, cast, Dict, List, Optional, TYPE_CHECKING, Tuple, Union ) -from urllib.parse import quote, unquote, urlparse from typing_extensions import Self @@ -128,7 +127,7 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - account_url, queue_name = _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential) + account_url, queue_name = _from_queue_url_helper(queue_url=queue_url) return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @classmethod diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index cab051987b8d..fc60cd95bb35 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -71,23 +71,10 @@ def _format_url_helper(queue_name: Union[bytes, str], hostname: str, scheme: str f"{scheme}://{hostname}" f"/{quote(queue_name)}{query_str}") -def _from_queue_url_helper( - cls, queue_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long -) -> Tuple[str, str]: +def _from_queue_url_helper(queue_url: str) -> Tuple[str, str]: """A client to interact with a specific Queue. :param str queue_url: The full URI to the queue, including SAS token if used. - :param credential: - The credentials with which to authenticate. This is optional if the - account URL already has a SAS token. The value can be a SAS token string, - an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, - an account shared access key, or an instance of a TokenCredentials class from azure.identity. - If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential - - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. - If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" - should be the storage account key. - :paramtype Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long :returns: The parsed out account_url and queue name. :rtype: Tuple[str, str] """ diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 27cb842b1ec9..d6998134533d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -8,6 +8,7 @@ from typing import ( # pylint: disable=unused-import Any, Dict, + Iterator, Optional, Tuple, TYPE_CHECKING, @@ -56,7 +57,7 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential - from azure.core.rest import HttpRequest + from azure.core.pipeline.transport import HttpRequest, HttpResponse _LOGGER = logging.getLogger(__name__) _SERVICE_PARAMS = { @@ -224,12 +225,12 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None if hasattr(credential, "get_token"): - self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) # type: ignore [arg-type] + self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) # type: ignore [arg-type] # pylint: disable=line-too-long elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -272,10 +273,12 @@ def _batch_send( self, *reqs: "HttpRequest", **kwargs: Any - ) -> None: + ) -> Iterator["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. + :returns: An iterator of HttpResponse objects. + :rtype: Iterator[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) @@ -378,7 +381,7 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long service: str ) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]]]: # pylint: disable=line-too-long conn_str = conn_str.rstrip(";") diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 39077e991c21..62fc6e43f77a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -70,12 +70,12 @@ async def close(self): await self._client.close() def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[Union[AsyncBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy]] = None # pylint: disable=line-too-long if hasattr(credential, 'get_token'): - self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) #type: ignore [arg-type] + self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) #type: ignore [arg-type] # pylint: disable=line-too-long elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 93c88875dbc7..1de2367dabfe 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -511,7 +511,7 @@ class UserDelegationKey(object): """The version identifier of the REST service that created this token.""" value: Optional[str] = None """The user delegation key.""" - + def __init__(self): self.signed_oid = None self.signed_tid = None diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 77525de48539..d84b1c982f9e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -425,7 +425,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe """ A function which sets the next host location on the request, if applicable. - Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. :param PipelineRequest request: A pipeline request object. """ if settings['hosts'] and all(settings['hosts'].values()): @@ -465,7 +465,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl """ Formula for computing the current backoff. Should be calculated by child class. - Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: The backoff time. :rtype: float """ @@ -610,7 +610,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - Dict[str, Any]] settings: The configurable values pertaining to get backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. :returns: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. @@ -657,7 +657,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> Optional[int]: """ Calculates how long to sleep before retrying. - Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 2c41ef44d4f3..2112d395c665 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -179,7 +179,7 @@ def get_backoff_time(self, settings): """ Calculates how long to sleep before retrying. - Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. @@ -220,7 +220,7 @@ def get_backoff_time(self, settings): """ Calculates how long to sleep before retrying. - Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index a5f94068e69f..0061ecdb3f45 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -11,7 +11,6 @@ Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union ) -from urllib.parse import quote, unquote, urlparse from typing_extensions import Self @@ -44,7 +43,7 @@ from .._models import QueueProperties -class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] +class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long """A client to interact with a specific Queue. :param str account_url: @@ -142,7 +141,7 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - account_url, queue_name = _from_queue_url_helper(cls=cls, queue_url=queue_url, credential=credential) + account_url, queue_name = _from_queue_url_helper(queue_url=queue_url) return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @classmethod diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 14a46b7b2427..44177b31ca65 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -39,7 +39,7 @@ from .._models import Metrics, QueueAnalyticsLogging -class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] +class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long """A client to interact with the Queue Service at the account level. This client provides operations to retrieve and configure the account properties From d82d226d28927925d7684ab07058eaed3f7f423a Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Fri, 8 Sep 2023 20:14:37 -0700 Subject: [PATCH 34/71] Feedback 2 --- .../azure/storage/queue/_encryption.py | 20 +- .../azure/storage/queue/_models.py | 167 +++--- .../storage/queue/_queue_service_client.py | 5 +- .../storage/queue/_shared/authentication.py | 2 +- .../azure/storage/queue/_shared/models.py | 3 +- .../azure/storage/queue/_shared/policies.py | 8 +- .../storage/queue/_shared/policies_async.py | 2 +- .../queue/aio/_queue_service_client_async.py | 2 +- .../samples/network_activity_logging.py | 15 +- .../samples/queue_samples_authentication.py | 105 ++-- .../queue_samples_authentication_async.py | 109 ++-- .../samples/queue_samples_hello_world.py | 48 +- .../queue_samples_hello_world_async.py | 62 ++- .../samples/queue_samples_message.py | 490 +++++++++-------- .../samples/queue_samples_message_async.py | 494 +++++++++--------- .../samples/queue_samples_service.py | 133 ++--- .../samples/queue_samples_service_async.py | 135 ++--- 17 files changed, 953 insertions(+), 847 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 868e4093ca84..c757253109bf 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -51,15 +51,33 @@ class KeyEncryptionKey(Protocol): """Protocol that defines what calling functions should be defined for a user-provided key-encryption-key (kek).""" def wrap_key(self, key): + """ + Wraps the specified key using an algorithm of the user's choice. + :param str key: + The user-provided key to be encrypted. + """ ... - def unwrap_key(self, key, algorithm): + def unwrap_key(self, key, algorithm): + """ + Unwraps the specified key using an algorithm of the user's choice. + :param str key: + The user-provided key to be unencrypted. + :param str algorithm: + The algorithm used to encrypt the key. This specifies what algorithm to use for the unwrap operation. + """ ... def get_kid(self): + """ + Returns the key ID as specified by the user. + """ ... def get_key_wrap_algorithm(self): + """ + Returns the key wrap algorithm as specified by the user. + """ ... diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index d4f11a18a68b..4181d78907e4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -205,6 +205,72 @@ def _from_generated(cls, generated: Any) -> Self: ) +class QueueSasPermissions(object): + """QueueSasPermissions class to be used with the + :func:`~azure.storage.queue.generate_queue_sas` function and for the AccessPolicies used with + :func:`~azure.storage.queue.QueueClient.set_queue_access_policy`. + + :param bool read: + Read metadata and properties, including message count. Peek at messages. + :param bool add: + Add messages to the queue. + :param bool update: + Update messages in the queue. Note: Use the Process permission with + Update so you can first get the message you want to update. + :param bool process: + Get and delete messages from the queue. + """ + + read: bool = False + """Read metadata and properties, including message count.""" + add: bool = False + """Add messages to the queue.""" + update: bool = False + """Update messages in the queue.""" + process: bool = False + """Get and delete messages from the queue.""" + + def __init__( + self, read: bool = False, + add: bool = False, + update: bool = False, + process: bool = False + ) -> None: + self.read = read + self.add = add + self.update = update + self.process = process + self._str = (('r' if self.read else '') + + ('a' if self.add else '') + + ('u' if self.update else '') + + ('p' if self.process else '')) + + def __str__(self): + return self._str + + @classmethod + def from_string(cls, permission: str) -> Self: + """Create a QueueSasPermissions from a string. + + To specify read, add, update, or process permissions you need only to + include the first letter of the word in the string. E.g. For read and + update permissions, you would provide a string "ru". + + :param str permission: The string which dictates the + read, add, update, or process permissions. + :return: A QueueSasPermissions object + :rtype: ~azure.storage.queue.QueueSasPermissions + """ + p_read = 'r' in permission + p_add = 'a' in permission + p_update = 'u' in permission + p_process = 'p' in permission + + parsed = cls(p_read, p_add, p_update, p_process) + + return parsed + + class AccessPolicy(GenAccessPolicy): """Access Policy class used by the set and get access policy methods. @@ -225,39 +291,37 @@ class AccessPolicy(GenAccessPolicy): both in the Shared Access Signature URL and in the stored access policy, the request will fail with status code 400 (Bad Request). - :param str permission: + :param Optional[QueueSasPermissions] permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. Required unless an id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. - :param expiry: + :param Optional[Union["datetime", str]] expiry: The time at which the shared access signature becomes invalid. Required unless an id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type expiry: ~datetime.datetime or str - :param start: + :param Optional[Union["datetime", str]] start: The time at which the shared access signature becomes valid. If omitted, start time for this call is assumed to be the time when the storage service receives the request. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type start: ~datetime.datetime or str """ - permission: Optional[str] = None + permission: Optional[QueueSasPermissions] = None """The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions.""" - expiry: Optional[Union["datetime", str]] = None # type: ignore + expiry: Optional[Union["datetime", str]] = None """The time at which the shared access signature becomes invalid.""" - start: Optional[Union["datetime", str]] = None # type: ignore + start: Optional[Union["datetime", str]] = None """The time at which the shared access signature becomes valid.""" def __init__( - self, permission: Optional[str] = None, + self, permission: Optional[QueueSasPermissions] = None, expiry: Optional[Union["datetime", str]] = None, start: Optional[Union["datetime", str]] = None ) -> None: @@ -269,19 +333,19 @@ def __init__( class QueueMessage(DictMixin): """Represents a queue message.""" - id: Optional[str] + id: str """A GUID value assigned to the message by the Queue service that identifies the message in the queue. This value may be used together with the value of pop_receipt to delete a message from the queue after it has been retrieved with the receive messages operation.""" - inserted_on: Optional["datetime"] + inserted_on: "datetime" """A UTC date value representing the time the messages was inserted.""" - expires_on: Optional["datetime"] + expires_on: "datetime" """A UTC date value representing the time the message expires.""" - dequeue_count: Optional[int] + dequeue_count: int """Begins with a value of 1 the first time the message is received. This value is incremented each time the message is subsequently received.""" - content: Any = None + content: Any """The message content. Type is determined by the decode_function set on the service. Default is str.""" pop_receipt: Optional[str] @@ -294,10 +358,6 @@ class QueueMessage(DictMixin): Only returned by receive messages operations. Set to None for peek messages.""" def __init__(self, content: Any = None) -> None: - self.id = None - self.inserted_on = None - self.expires_on = None - self.dequeue_count = None self.content = content self.pop_receipt = None self.next_visible_on = None @@ -373,7 +433,7 @@ def _extract_data_cb(self, messages: Any) -> Tuple[str, List[QueueMessage]]: class QueueProperties(DictMixin): """Queue Properties. - :keyword Optional[str] name: + :keyword str name: The name of the queue. :keyword Optional[Dict[str, str]] metadata: A dict containing name-value pairs associated with the queue as metadata. @@ -381,13 +441,12 @@ class QueueProperties(DictMixin): for the list queues operation. """ - name: Optional[str] + name: str """The name of the queue.""" metadata: Optional[Dict[str, str]] """A dict containing name-value pairs associated with the queue as metadata.""" def __init__(self, **kwargs: Any) -> None: - self.name = None self.metadata = kwargs.get('metadata') self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') @@ -466,72 +525,6 @@ def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[Qu return next_marker or None, props_list -class QueueSasPermissions(object): - """QueueSasPermissions class to be used with the - :func:`~azure.storage.queue.generate_queue_sas` function and for the AccessPolicies used with - :func:`~azure.storage.queue.QueueClient.set_queue_access_policy`. - - :param bool read: - Read metadata and properties, including message count. Peek at messages. - :param bool add: - Add messages to the queue. - :param bool update: - Update messages in the queue. Note: Use the Process permission with - Update so you can first get the message you want to update. - :param bool process: - Get and delete messages from the queue. - """ - - read: bool = False - """Read metadata and properties, including message count.""" - add: bool = False - """Add messages to the queue.""" - update: bool = False - """Update messages in the queue.""" - process: bool = False - """Get and delete messages from the queue.""" - - def __init__( - self, read: bool = False, - add: bool = False, - update: bool = False, - process: bool = False - ) -> None: - self.read = read - self.add = add - self.update = update - self.process = process - self._str = (('r' if self.read else '') + - ('a' if self.add else '') + - ('u' if self.update else '') + - ('p' if self.process else '')) - - def __str__(self): - return self._str - - @classmethod - def from_string(cls, permission: str) -> Self: - """Create a QueueSasPermissions from a string. - - To specify read, add, update, or process permissions you need only to - include the first letter of the word in the string. E.g. For read and - update permissions, you would provide a string "ru". - - :param str permission: The string which dictates the - read, add, update, or process permissions. - :return: A QueueSasPermissions object - :rtype: ~azure.storage.queue.QueueSasPermissions - """ - p_read = 'r' in permission - p_add = 'a' in permission - p_update = 'u' in permission - p_process = 'p' in permission - - parsed = cls(p_read, p_add, p_update, p_process) - - return parsed - - def service_stats_deserialize(generated: Any) -> Dict[str, Any]: """Deserialize a ServiceStats objects into a dict. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 85f25c7221b4..7c168d8df78a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -35,8 +35,7 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential from azure.core.credentials_async import AsyncTokenCredential - from ._generated.models import CorsRule - from ._models import Metrics, QueueAnalyticsLogging + from ._models import CorsRule, Metrics, QueueAnalyticsLogging class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): @@ -236,7 +235,7 @@ def set_service_properties( You can include up to five CorsRule elements in the list. If an empty list is specified, all CORS rules will be deleted, and CORS will be disabled for the service. - :type cors: List[~azure.storage.queue.CorsRule] + :type cors: Optional[List[~azure.storage.queue.CorsRule]] :keyword int timeout: The timeout parameter is expressed in seconds. :returns: None diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py index 7491965f984b..abbbfe88127f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py @@ -6,7 +6,7 @@ import logging import re -from typing import Dict, List, Optional, Tuple +from typing import List, Tuple from urllib.parse import unquote, urlparse try: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 1de2367dabfe..3301b8935d20 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -10,7 +10,7 @@ from azure.core import CaseInsensitiveEnumMeta from azure.core.configuration import Configuration from azure.core.pipeline.policies import UserAgentPolicy -from .._message_encoding import MessageDecodePolicy +from .._message_encoding import MessageDecodePolicy, NoDecodePolicy def get_enum_value(value): @@ -570,3 +570,4 @@ def __init__(self, **kwargs): self.max_single_get_size = 32 * 1024 * 1024 self.max_chunk_get_size = 4 * 1024 * 1024 self.max_range_size = 4 * 1024 * 1024 + self.message_decode_policy = NoDecodePolicy() diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index d84b1c982f9e..8eb25017ed75 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -50,7 +50,7 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential - from azure.core.pipeline import PipelineRequest, PipelineResponse + from azure.core.pipeline.transport import PipelineRequest, PipelineResponse _LOGGER = logging.getLogger(__name__) @@ -653,15 +653,15 @@ def __init__( super(LinearRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings: Dict[str, Any]) -> Optional[int]: + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: - An integer indicating how long to wait before retrying the request, + A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. - :rtype: int or None + :rtype: float """ random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 2112d395c665..0b6bfe1fa817 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential - from azure.core.pipeline import PipelineRequest, PipelineResponse + from azure.core.pipeline.transport import PipelineRequest, PipelineResponse _LOGGER = logging.getLogger(__name__) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 44177b31ca65..5804bbc38dd9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -235,7 +235,7 @@ async def set_service_properties( You can include up to five CorsRule elements in the list. If an empty list is specified, all CORS rules will be deleted, and CORS will be disabled for the service. - :type cors: list(~azure.storage.queue.CorsRule) + :type cors: Optional[list(~azure.storage.queue.CorsRule)] :keyword int timeout: The timeout parameter is expressed in seconds. diff --git a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py index f1c7e5111319..a9f44c941f2e 100644 --- a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py +++ b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py @@ -59,11 +59,10 @@ queues = service_client.list_queues(logging_enable=True) for queue in queues: print('Queue: {}'.format(queue.name)) - if isinstance(queue.name, str): - queue_client = service_client.get_queue_client(queue.name) - messages = queue_client.peek_messages(max_messages=20, logging_enable=True) - for message in messages: - try: - print(' Message: {!r}'.format(base64.b64decode(message.content))) - except binascii.Error: - print(' Message: {}'.format(message.content)) + queue_client = service_client.get_queue_client(queue.name) +messages = queue_client.peek_messages(max_messages=20, logging_enable=True) +for message in messages: + try: + print(' Message: {!r}'.format(base64.b64decode(message.content))) + except binascii.Error: + print(' Message: {}'.format(message.content)) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py index 3a8fcd0651c9..d8ebea813d9c 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py @@ -30,7 +30,7 @@ from datetime import datetime, timedelta -import os +import os, sys class QueueAuthSamples(object): @@ -51,77 +51,84 @@ def authentication_by_connection_string(self): from azure.storage.queue import QueueServiceClient if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [END auth_from_connection_string] + # [END auth_from_connection_string] - # Get information for the Queue Service - properties = queue_service.get_service_properties() + # Get information for the Queue Service + properties = queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def authentication_by_shared_key(self): # Instantiate a QueueServiceClient using a shared access key # [START create_queue_service_client] from azure.storage.queue import QueueServiceClient - if self.account_url is not None: + if self.account_url is not None and self.access_key is not None: queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) - # [END create_queue_service_client] + # [END create_queue_service_client] - # Get information for the Queue Service - properties = queue_service.get_service_properties() + # Get information for the Queue Service + properties = queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def authentication_by_active_directory(self): # [START create_queue_service_client_token] # Get a token credential for authentication from azure.identity import ClientSecretCredential - if self.active_directory_tenant_id is not None: - ad_tenant_id = self.active_directory_tenant_id - if self.active_directory_application_id is not None: - ad_application_id = self.active_directory_application_id - if self.active_directory_application_secret is not None: - ad_application_secret = self.active_directory_application_secret - - token_credential = ClientSecretCredential( - ad_tenant_id, - ad_application_id, - ad_application_secret - ) - - # Instantiate a QueueServiceClient using a token credential - from azure.storage.queue import QueueServiceClient - if self.account_url is not None: + if ( + self.active_directory_tenant_id is not None and + self.active_directory_application_id is not None and + self.active_directory_application_secret is not None and + self.account_url is not None + ): + token_credential = ClientSecretCredential( + self.active_directory_tenant_id, + self.active_directory_application_id, + self.active_directory_application_secret + ) + + # Instantiate a QueueServiceClient using a token credential + from azure.storage.queue import QueueServiceClient queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) - # [END create_queue_service_client_token] + # [END create_queue_service_client_token] - # Get information for the Queue Service - properties = queue_service.get_service_properties() + # Get information for the Queue Service + properties = queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def authentication_by_shared_access_signature(self): # Instantiate a QueueServiceClient using a connection string from azure.storage.queue import QueueServiceClient - if self.connection_string is not None: + if ( + self.connection_string is not None and + self.account_name is not None and + self.access_key is not None and + self.account_url + ): queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # Create a SAS token to use for authentication of a client - from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions - - if self.account_name is not None: - account_name = self.account_name - if self.access_key is not None: - access_key = self.access_key - - sas_token = generate_account_sas( - account_name, - access_key, - resource_types=ResourceTypes(service=True), - permission=AccountSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1) - ) + # Create a SAS token to use for authentication of a client + from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions - if self.account_url is not None: - account_url = self.account_url + sas_token = generate_account_sas( + self.account_name, + self.access_key, + resource_types=ResourceTypes(service=True), + permission=AccountSasPermissions(read=True), + expiry=datetime.utcnow() + timedelta(hours=1) + ) - token_auth_queue_service = QueueServiceClient(account_url=account_url, credential=sas_token) + token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) - # Get information for the Queue Service - properties = token_auth_queue_service.get_service_properties() + # Get information for the Queue Service + properties = token_auth_queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) if __name__ == '__main__': @@ -129,4 +136,4 @@ def authentication_by_shared_access_signature(self): sample.authentication_by_connection_string() sample.authentication_by_shared_key() sample.authentication_by_active_directory() - sample.authentication_by_shared_access_signature() + sample.authentication_by_shared_access_signature() \ No newline at end of file diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py index 14abc6ce59ce..c2b8ea8d6303 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py @@ -31,7 +31,7 @@ from datetime import datetime, timedelta import asyncio -import os +import os, sys class QueueAuthSamplesAsync(object): @@ -52,79 +52,88 @@ async def authentication_by_connection_string_async(self): from azure.storage.queue.aio import QueueServiceClient if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [END async_auth_from_connection_string] + # [END async_auth_from_connection_string] - # Get information for the Queue Service - async with queue_service: - properties = await queue_service.get_service_properties() + # Get information for the Queue Service + async with queue_service: + properties = await queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def authentication_by_shared_key_async(self): # Instantiate a QueueServiceClient using a shared access key # [START async_create_queue_service_client] from azure.storage.queue.aio import QueueServiceClient - if self.account_url is not None: + if self.account_url is not None and self.access_key is not None: queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) - # [END async_create_queue_service_client] + # [END async_create_queue_service_client] - # Get information for the Queue Service - async with queue_service: - properties = await queue_service.get_service_properties() + # Get information for the Queue Service + async with queue_service: + properties = await queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def authentication_by_active_directory_async(self): # [START async_create_queue_service_client_token] # Get a token credential for authentication from azure.identity.aio import ClientSecretCredential - if self.active_directory_tenant_id is not None: - ad_tenant_id = self.active_directory_tenant_id - if self.active_directory_application_id is not None: - ad_application_id = self.active_directory_application_id - if self.active_directory_application_secret is not None: - ad_application_secret = self.active_directory_application_secret - - token_credential = ClientSecretCredential( - ad_tenant_id, - ad_application_id, - ad_application_secret - ) - - # Instantiate a QueueServiceClient using a token credential - from azure.storage.queue.aio import QueueServiceClient - if self.account_url is not None: + if ( + self.active_directory_tenant_id is not None and + self.active_directory_application_id is not None and + self.active_directory_application_secret is not None and + self.account_url is not None + ): + token_credential = ClientSecretCredential( + self.active_directory_tenant_id, + self.active_directory_application_id, + self.active_directory_application_secret + ) + + # Instantiate a QueueServiceClient using a token credential + from azure.storage.queue.aio import QueueServiceClient queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) - # [END async_create_queue_service_client_token] + # [END async_create_queue_service_client_token] - # Get information for the Queue Service - async with queue_service: - properties = await queue_service.get_service_properties() + # Get information for the Queue Service + async with queue_service: + properties = await queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def authentication_by_shared_access_signature_async(self): # Instantiate a QueueServiceClient using a connection string from azure.storage.queue.aio import QueueServiceClient - if self.connection_string is not None: + if ( + self.connection_string is not None and + self.account_name is not None and + self.access_key is not None and + self.account_url + ): queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # Create a SAS token to use for authentication of a client - from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions - - if self.account_name is not None: - account_name = self.account_name - if self.access_key is not None: - access_key = self.access_key + # Create a SAS token to use for authentication of a client + from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions - sas_token = generate_account_sas( - account_name, - access_key, - resource_types=ResourceTypes(service=True), - permission=AccountSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1) - ) + sas_token = generate_account_sas( + queue_service.account_name, + queue_service.credential.account_key, + resource_types=ResourceTypes(service=True), + permission=AccountSasPermissions(read=True), + expiry=datetime.utcnow() + timedelta(hours=1) + ) - if self.account_url is not None: token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) - # Get information for the Queue Service - async with token_auth_queue_service: - properties = await token_auth_queue_service.get_service_properties() + # Get information for the Queue Service + async with token_auth_queue_service: + properties = await token_auth_queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def main(): @@ -135,4 +144,4 @@ async def main(): await sample.authentication_by_shared_access_signature_async() if __name__ == '__main__': - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py index 5c9efcc02c0d..250a854e9128 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py @@ -21,7 +21,7 @@ """ -import os +import os, sys class QueueHelloWorldSamples(object): @@ -34,8 +34,11 @@ def create_client_with_connection_string(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # Get queue service properties - properties = queue_service.get_service_properties() + # Get queue service properties + properties = queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def queue_and_messages_example(self): # Instantiate the QueueClient from a connection string @@ -43,30 +46,33 @@ def queue_and_messages_example(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") - # Create the queue - # [START create_queue] - queue.create_queue() - # [END create_queue] + # Create the queue + # [START create_queue] + queue.create_queue() + # [END create_queue] - try: - # Send messages - queue.send_message("I'm using queues!") - queue.send_message("This is my second message") + try: + # Send messages + queue.send_message("I'm using queues!") + queue.send_message("This is my second message") - # Receive the messages - response = queue.receive_messages(messages_per_page=2) + # Receive the messages + response = queue.receive_messages(messages_per_page=2) - # Print the content of the messages - for message in response: - print(message.content) + # Print the content of the messages + for message in response: + print(message.content) - finally: - # [START delete_queue] - queue.delete_queue() - # [END delete_queue] + finally: + # [START delete_queue] + queue.delete_queue() + # [END delete_queue] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) if __name__ == '__main__': sample = QueueHelloWorldSamples() sample.create_client_with_connection_string() - sample.queue_and_messages_example() + sample.queue_and_messages_example() \ No newline at end of file diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py index 694ad61bd55e..bb864fe47f5b 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py @@ -22,7 +22,7 @@ import asyncio -import os +import os, sys class QueueHelloWorldSamplesAsync(object): @@ -35,9 +35,12 @@ async def create_client_with_connection_string_async(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # Get queue service properties - async with queue_service: - properties = await queue_service.get_service_properties() + # Get queue service properties + async with queue_service: + properties = await queue_service.get_service_properties() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def queue_and_messages_example_async(self): # Instantiate the QueueClient from a connection string @@ -45,30 +48,33 @@ async def queue_and_messages_example_async(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") - async with queue: - # Create the queue - # [START async_create_queue] - await queue.create_queue() - # [END async_create_queue] - - try: - # Send messages - await asyncio.gather( - queue.send_message("I'm using queues!"), - queue.send_message("This is my second message") - ) - - # Receive the messages - response = queue.receive_messages(messages_per_page=2) - - # Print the content of the messages - async for message in response: - print(message.content) - - finally: - # [START async_delete_queue] - await queue.delete_queue() - # [END async_delete_queue] + async with queue: + # Create the queue + # [START async_create_queue] + await queue.create_queue() + # [END async_create_queue] + + try: + # Send messages + await asyncio.gather( + queue.send_message("I'm using queues!"), + queue.send_message("This is my second message") + ) + + # Receive the messages + response = queue.receive_messages(messages_per_page=2) + + # Print the content of the messages + async for message in response: + print(message.content) + + finally: + # [START async_delete_queue] + await queue.delete_queue() + # [END async_delete_queue] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def main(): diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py index a231085824d3..a64b4777aba7 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py @@ -24,7 +24,7 @@ from datetime import datetime, timedelta -import os +import os, sys class QueueMessageSamples(object): @@ -36,53 +36,56 @@ def set_access_policy(self): from azure.storage.queue import QueueClient if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") - # [END create_queue_client_from_connection_string] - - # Create the queue - queue.create_queue() - - # Send a message - queue.send_message("hello world") - - try: - # [START set_access_policy] - # Create an access policy - from azure.storage.queue import AccessPolicy, QueueSasPermissions - access_policy = AccessPolicy() - access_policy.start = datetime.utcnow() - timedelta(hours=1) - access_policy.expiry = datetime.utcnow() + timedelta(hours=1) - access_policy.permission = QueueSasPermissions(read=True) # type: ignore - identifiers = {'my-access-policy-id': access_policy} - - # Set the access policy - queue.set_queue_access_policy(identifiers) - # [END set_access_policy] - - # Use the access policy to generate a SAS token - # [START queue_client_sas_token] - from azure.storage.queue import generate_queue_sas - sas_token = generate_queue_sas( - queue.account_name, - queue.queue_name, - queue.credential.account_key, - policy_id='my-access-policy-id' - ) - # [END queue_client_sas_token] - - # Authenticate with the sas token - # [START create_queue_client] - token_auth_queue = QueueClient.from_queue_url( - queue_url=queue.url, - credential=sas_token - ) - # [END create_queue_client] - - # Use the newly authenticated client to receive messages - my_message = token_auth_queue.receive_messages() - - finally: - # Delete the queue - queue.delete_queue() + # [END create_queue_client_from_connection_string] + + # Create the queue + queue.create_queue() + + # Send a message + queue.send_message("hello world") + + try: + # [START set_access_policy] + # Create an access policy + from azure.storage.queue import AccessPolicy, QueueSasPermissions + access_policy = AccessPolicy() + access_policy.start = datetime.utcnow() - timedelta(hours=1) + access_policy.expiry = datetime.utcnow() + timedelta(hours=1) + access_policy.permission = QueueSasPermissions(read=True) + identifiers = {'my-access-policy-id': access_policy} + + # Set the access policy + queue.set_queue_access_policy(identifiers) + # [END set_access_policy] + + # Use the access policy to generate a SAS token + # [START queue_client_sas_token] + from azure.storage.queue import generate_queue_sas + sas_token = generate_queue_sas( + queue.account_name, + queue.queue_name, + queue.credential.account_key, + policy_id='my-access-policy-id' + ) + # [END queue_client_sas_token] + + # Authenticate with the sas token + # [START create_queue_client] + token_auth_queue = QueueClient.from_queue_url( + queue_url=queue.url, + credential=sas_token + ) + # [END create_queue_client] + + # Use the newly authenticated client to receive messages + my_message = token_auth_queue.receive_messages() + + finally: + # Delete the queue + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def queue_metadata(self): # Instantiate a queue client @@ -90,22 +93,25 @@ def queue_metadata(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") - # Create the queue - queue.create_queue() + # Create the queue + queue.create_queue() - try: - # [START set_queue_metadata] - metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} - queue.set_queue_metadata(metadata=metadata) - # [END set_queue_metadata] + try: + # [START set_queue_metadata] + metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} + queue.set_queue_metadata(metadata=metadata) + # [END set_queue_metadata] - # [START get_queue_properties] - properties = queue.get_queue_properties().metadata - # [END get_queue_properties] + # [START get_queue_properties] + properties = queue.get_queue_properties().metadata + # [END get_queue_properties] - finally: - # Delete the queue - queue.delete_queue() + finally: + # Delete the queue + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def send_and_receive_messages(self): # Instantiate a queue client @@ -113,41 +119,44 @@ def send_and_receive_messages(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") - # Create the queue - queue.create_queue() - - try: - # [START send_messages] - queue.send_message("message1") - queue.send_message("message2", visibility_timeout=30) # wait 30s before becoming visible - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") - # [END send_messages] - - # [START receive_messages] - # Receive messages one-by-one - messages = queue.receive_messages() - for msg in messages: - print(msg.content) - - # Receive messages by batch - messages = queue.receive_messages(messages_per_page=5) - for msg_batch in messages.by_page(): - for msg in msg_batch: + # Create the queue + queue.create_queue() + + try: + # [START send_messages] + queue.send_message("message1") + queue.send_message("message2", visibility_timeout=30) # wait 30s before becoming visible + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") + # [END send_messages] + + # [START receive_messages] + # Receive messages one-by-one + messages = queue.receive_messages() + for msg in messages: print(msg.content) - queue.delete_message(msg) - # [END receive_messages] - - # Only prints 4 messages because message 2 is not visible yet - # >>message1 - # >>message3 - # >>message4 - # >>message5 - finally: - # Delete the queue - queue.delete_queue() + # Receive messages by batch + messages = queue.receive_messages(messages_per_page=5) + for msg_batch in messages.by_page(): + for msg in msg_batch: + print(msg.content) + queue.delete_message(msg) + # [END receive_messages] + + # Only prints 4 messages because message 2 is not visible yet + # >>message1 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def list_message_pages(self): # Instantiate a queue client @@ -155,33 +164,36 @@ def list_message_pages(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") - # Create the queue - queue.create_queue() - - try: - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") - queue.send_message("message6") - - # [START receive_messages_listing] - # Store two messages in each page - message_batches = queue.receive_messages(messages_per_page=2).by_page() - - # Iterate through the page lists - print(list(next(message_batches))) - print(list(next(message_batches))) - - # There are two iterations in the last page as well. - last_page = next(message_batches) - for message in last_page: - print(message) - # [END receive_messages_listing] - - finally: - queue.delete_queue() + # Create the queue + queue.create_queue() + + try: + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") + queue.send_message("message6") + + # [START receive_messages_listing] + # Store two messages in each page + message_batches = queue.receive_messages(messages_per_page=2).by_page() + + # Iterate through the page lists + print(list(next(message_batches))) + print(list(next(message_batches))) + + # There are two iterations in the last page as well. + last_page = next(message_batches) + for message in last_page: + print(message) + # [END receive_messages_listing] + + finally: + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def receive_one_message_from_queue(self): # Instantiate a queue client @@ -189,30 +201,33 @@ def receive_one_message_from_queue(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") - # Create the queue - queue.create_queue() + # Create the queue + queue.create_queue() - try: - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") + try: + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") - # [START receive_one_message] - # Pop two messages from the front of the queue - message1 = queue.receive_message() - message2 = queue.receive_message() - # We should see message 3 if we peek - message3 = queue.peek_messages()[0] + # [START receive_one_message] + # Pop two messages from the front of the queue + message1 = queue.receive_message() + message2 = queue.receive_message() + # We should see message 3 if we peek + message3 = queue.peek_messages()[0] - if message1 is not None and hasattr(message1, 'content'): + if not message1 or not message2 or not message3: + raise ValueError("One of the messages are None.") print(message1.content) - if message2 is not None and hasattr(message2, 'content'): print(message2.content) - print(message3.content) - # [END receive_one_message] + print(message3.content) + # [END receive_one_message] - finally: - queue.delete_queue() + finally: + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def delete_and_clear_messages(self): # Instantiate a queue client @@ -220,32 +235,35 @@ def delete_and_clear_messages(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") - # Create the queue - queue.create_queue() + # Create the queue + queue.create_queue() - try: - # Send messages - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") + try: + # Send messages + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") - # [START delete_message] - # Get the message at the front of the queue - msg = next(queue.receive_messages()) + # [START delete_message] + # Get the message at the front of the queue + msg = next(queue.receive_messages()) - # Delete the specified message - queue.delete_message(msg) - # [END delete_message] + # Delete the specified message + queue.delete_message(msg) + # [END delete_message] - # [START clear_messages] - queue.clear_messages() - # [END clear_messages] + # [START clear_messages] + queue.clear_messages() + # [END clear_messages] - finally: - # Delete the queue - queue.delete_queue() + finally: + # Delete the queue + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def peek_messages(self): # Instantiate a queue client @@ -253,32 +271,35 @@ def peek_messages(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") - # Create the queue - queue.create_queue() + # Create the queue + queue.create_queue() - try: - # Send messages - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") + try: + # Send messages + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") - # [START peek_message] - # Peek at one message at the front of the queue - msg = queue.peek_messages() + # [START peek_message] + # Peek at one message at the front of the queue + msg = queue.peek_messages() - # Peek at the last 5 messages - messages = queue.peek_messages(max_messages=5) + # Peek at the last 5 messages + messages = queue.peek_messages(max_messages=5) - # Print the last 5 messages - for message in messages: - print(message.content) - # [END peek_message] + # Print the last 5 messages + for message in messages: + print(message.content) + # [END peek_message] - finally: - # Delete the queue - queue.delete_queue() + finally: + # Delete the queue + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def update_message(self): # Instantiate a queue client @@ -286,31 +307,33 @@ def update_message(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") - # Create the queue - queue.create_queue() + # Create the queue + queue.create_queue() - try: - # [START update_message] - # Send a message - queue.send_message("update me") + try: + # [START update_message] + # Send a message + queue.send_message("update me") - # Receive the message - messages = queue.receive_messages() + # Receive the message + messages = queue.receive_messages() - # Update the message - list_result = next(messages) - if list_result.id is not None: + # Update the message + list_result = next(messages) id = list_result.id - message = queue.update_message( - id, - pop_receipt=list_result.pop_receipt, - visibility_timeout=0, - content="updated") - # [END update_message] - - finally: - # Delete the queue - queue.delete_queue() + message = queue.update_message( + id, + pop_receipt=list_result.pop_receipt, + visibility_timeout=0, + content="updated") + # [END update_message] + + finally: + # Delete the queue + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def receive_messages_with_max_messages(self): # Instantiate a queue client @@ -318,37 +341,40 @@ def receive_messages_with_max_messages(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue9") - # Create the queue - queue.create_queue() - - try: - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") - queue.send_message("message6") - queue.send_message("message7") - queue.send_message("message8") - queue.send_message("message9") - queue.send_message("message10") - - # Receive messages one-by-one - messages = queue.receive_messages(max_messages=5) - for msg in messages: - print(msg.content) - queue.delete_message(msg) - - # Only prints 5 messages because 'max_messages'=5 - # >>message1 - # >>message2 - # >>message3 - # >>message4 - # >>message5 + # Create the queue + queue.create_queue() + + try: + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") + queue.send_message("message6") + queue.send_message("message7") + queue.send_message("message8") + queue.send_message("message9") + queue.send_message("message10") + + # Receive messages one-by-one + messages = queue.receive_messages(max_messages=5) + for msg in messages: + print(msg.content) + queue.delete_message(msg) - finally: - # Delete the queue - queue.delete_queue() + # Only prints 5 messages because 'max_messages'=5 + # >>message1 + # >>message2 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) if __name__ == '__main__': diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py index f26816b57da6..edb622157112 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py @@ -24,7 +24,7 @@ from datetime import datetime, timedelta import asyncio -import os +import os, sys class QueueMessageSamplesAsync(object): @@ -36,52 +36,55 @@ async def set_access_policy_async(self): from azure.storage.queue.aio import QueueClient if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") - # [END async_create_queue_client_from_connection_string] - - # Create the queue - async with queue: - await queue.create_queue() - - # Send a message - await queue.send_message("hello world") - - try: - # [START async_set_access_policy] - # Create an access policy - from azure.storage.queue import AccessPolicy, QueueSasPermissions - access_policy = AccessPolicy() - access_policy.start = datetime.utcnow() - timedelta(hours=1) - access_policy.expiry = datetime.utcnow() + timedelta(hours=1) - access_policy.permission = QueueSasPermissions(read=True) # type: ignore - identifiers = {'my-access-policy-id': access_policy} - - # Set the access policy - await queue.set_queue_access_policy(identifiers) - # [END async_set_access_policy] - - # Use the access policy to generate a SAS token - from azure.storage.queue import generate_queue_sas - sas_token = generate_queue_sas( - queue.account_name, - queue.queue_name, - queue.credential.account_key, - policy_id='my-access-policy-id' - ) - - # Authenticate with the sas token - # [START async_create_queue_client] - token_auth_queue = QueueClient.from_queue_url( - queue_url=queue.url, - credential=sas_token - ) - # [END async_create_queue_client] - - # Use the newly authenticated client to receive messages - my_messages = token_auth_queue.receive_messages() - - finally: - # Delete the queue - await queue.delete_queue() + # [END async_create_queue_client_from_connection_string] + + # Create the queue + async with queue: + await queue.create_queue() + + # Send a message + await queue.send_message("hello world") + + try: + # [START async_set_access_policy] + # Create an access policy + from azure.storage.queue import AccessPolicy, QueueSasPermissions + access_policy = AccessPolicy() + access_policy.start = datetime.utcnow() - timedelta(hours=1) + access_policy.expiry = datetime.utcnow() + timedelta(hours=1) + access_policy.permission = QueueSasPermissions(read=True) + identifiers = {'my-access-policy-id': access_policy} + + # Set the access policy + await queue.set_queue_access_policy(identifiers) + # [END async_set_access_policy] + + # Use the access policy to generate a SAS token + from azure.storage.queue import generate_queue_sas + sas_token = generate_queue_sas( + queue.account_name, + queue.queue_name, + queue.credential.account_key, + policy_id='my-access-policy-id' + ) + + # Authenticate with the sas token + # [START async_create_queue_client] + token_auth_queue = QueueClient.from_queue_url( + queue_url=queue.url, + credential=sas_token + ) + # [END async_create_queue_client] + + # Use the newly authenticated client to receive messages + my_messages = token_auth_queue.receive_messages() + + finally: + # Delete the queue + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def queue_metadata_async(self): # Instantiate a queue client @@ -89,23 +92,26 @@ async def queue_metadata_async(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") - # Create the queue - async with queue: - await queue.create_queue() + # Create the queue + async with queue: + await queue.create_queue() - try: - # [START async_set_queue_metadata] - metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} - await queue.set_queue_metadata(metadata=metadata) - # [END async_set_queue_metadata] + try: + # [START async_set_queue_metadata] + metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} + await queue.set_queue_metadata(metadata=metadata) + # [END async_set_queue_metadata] - # [START async_get_queue_properties] - properties = await queue.get_queue_properties() - # [END async_get_queue_properties] + # [START async_get_queue_properties] + properties = await queue.get_queue_properties() + # [END async_get_queue_properties] - finally: - # Delete the queue - await queue.delete_queue() + finally: + # Delete the queue + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def send_and_receive_messages_async(self): # Instantiate a queue client @@ -113,44 +119,47 @@ async def send_and_receive_messages_async(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") - # Create the queue - async with queue: - await queue.create_queue() - - try: - # [START async_send_messages] - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2", visibility_timeout=30), # wait 30s before becoming visible - queue.send_message("message3"), - queue.send_message("message4"), - queue.send_message("message5") - ) - # [END async_send_messages] - - # [START async_receive_messages] - # Receive messages one-by-one - messages = queue.receive_messages() - async for msg in messages: - print(msg.content) - - # Receive messages by batch - messages = queue.receive_messages(messages_per_page=5) - async for msg_batch in messages.by_page(): - for msg in msg_batch: # type: ignore + # Create the queue + async with queue: + await queue.create_queue() + + try: + # [START async_send_messages] + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2", visibility_timeout=30), # wait 30s before becoming visible + queue.send_message("message3"), + queue.send_message("message4"), + queue.send_message("message5") + ) + # [END async_send_messages] + + # [START async_receive_messages] + # Receive messages one-by-one + messages = queue.receive_messages() + async for msg in messages: print(msg.content) - await queue.delete_message(msg) - # [END async_receive_messages] - - # Only prints 4 messages because message 2 is not visible yet - # >>message1 - # >>message3 - # >>message4 - # >>message5 - finally: - # Delete the queue - await queue.delete_queue() + # Receive messages by batch + messages = queue.receive_messages(messages_per_page=5) + async for msg_batch in messages.by_page(): + async for msg in msg_batch: + print(msg.content) + await queue.delete_message(msg) + # [END async_receive_messages] + + # Only prints 4 messages because message 2 is not visible yet + # >>message1 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def receive_one_message_from_queue(self): # Instantiate a queue client @@ -158,32 +167,35 @@ async def receive_one_message_from_queue(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") - # Create the queue - async with queue: - await queue.create_queue() - - try: - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2"), - queue.send_message("message3")) - - # [START receive_one_message] - # Pop two messages from the front of the queue - message1 = await queue.receive_message() - message2 = await queue.receive_message() - # We should see message 3 if we peek - message3 = await queue.peek_messages() - - if message1 is not None and hasattr(message1, 'content'): + # Create the queue + async with queue: + await queue.create_queue() + + try: + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2"), + queue.send_message("message3")) + + # [START receive_one_message] + # Pop two messages from the front of the queue + message1 = await queue.receive_message() + message2 = await queue.receive_message() + # We should see message 3 if we peek + message3 = await queue.peek_messages() + + if not message1 or not message2 or not message3: + raise ValueError("One of the messages are None.") print(message1.content) - if message2 is not None and hasattr(message2, 'content'): print(message2.content) - print(message3[0].content) - # [END receive_one_message] + print(message3[0].content) + # [END receive_one_message] - finally: - await queue.delete_queue() + finally: + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def delete_and_clear_messages_async(self): # Instantiate a queue client @@ -191,36 +203,39 @@ async def delete_and_clear_messages_async(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") - # Create the queue - async with queue: - await queue.create_queue() - - try: - # Send messages - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2"), - queue.send_message("message3"), - queue.send_message("message4"), - queue.send_message("message5") - ) - - # [START async_delete_message] - # Get the message at the front of the queue - messages = queue.receive_messages() - async for msg in messages: - # Delete the specified message - await queue.delete_message(msg) - # [END async_delete_message] - break - - # [START async_clear_messages] - await queue.clear_messages() - # [END async_clear_messages] - - finally: - # Delete the queue - await queue.delete_queue() + # Create the queue + async with queue: + await queue.create_queue() + + try: + # Send messages + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2"), + queue.send_message("message3"), + queue.send_message("message4"), + queue.send_message("message5") + ) + + # [START async_delete_message] + # Get the message at the front of the queue + messages = queue.receive_messages() + async for msg in messages: + # Delete the specified message + await queue.delete_message(msg) + # [END async_delete_message] + break + + # [START async_clear_messages] + await queue.clear_messages() + # [END async_clear_messages] + + finally: + # Delete the queue + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def peek_messages_async(self): # Instantiate a queue client @@ -228,35 +243,38 @@ async def peek_messages_async(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") - # Create the queue - async with queue: - await queue.create_queue() - - try: - # Send messages - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2"), - queue.send_message("message3"), - queue.send_message("message4"), - queue.send_message("message5") - ) - - # [START async_peek_message] - # Peek at one message at the front of the queue - msg = await queue.peek_messages() - - # Peek at the last 5 messages - messages = await queue.peek_messages(max_messages=5) - - # Print the last 5 messages - for message in messages: - print(message.content) - # [END async_peek_message] - - finally: - # Delete the queue - await queue.delete_queue() + # Create the queue + async with queue: + await queue.create_queue() + + try: + # Send messages + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2"), + queue.send_message("message3"), + queue.send_message("message4"), + queue.send_message("message5") + ) + + # [START async_peek_message] + # Peek at one message at the front of the queue + msg = await queue.peek_messages() + + # Peek at the last 5 messages + messages = await queue.peek_messages(max_messages=5) + + # Print the last 5 messages + for message in messages: + print(message.content) + # [END async_peek_message] + + finally: + # Delete the queue + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def update_message_async(self): # Instantiate a queue client @@ -264,30 +282,33 @@ async def update_message_async(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") - # Create the queue - async with queue: - await queue.create_queue() - - try: - # [START async_update_message] - # Send a message - await queue.send_message("update me") - - # Receive the message - messages = queue.receive_messages() - - # Update the message - async for message in messages: - message = await queue.update_message( - message, - visibility_timeout=0, - content="updated") - # [END async_update_message] - break - - finally: - # Delete the queue - await queue.delete_queue() + # Create the queue + async with queue: + await queue.create_queue() + + try: + # [START async_update_message] + # Send a message + await queue.send_message("update me") + + # Receive the message + messages = queue.receive_messages() + + # Update the message + async for message in messages: + message = await queue.update_message( + message, + visibility_timeout=0, + content="updated") + # [END async_update_message] + break + + finally: + # Delete the queue + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def receive_messages_with_max_messages(self): # Instantiate a queue client @@ -295,38 +316,41 @@ async def receive_messages_with_max_messages(self): if self.connection_string is not None: queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") - # Create the queue - async with queue: - await queue.create_queue() - - try: - await queue.send_message("message1") - await queue.send_message("message2") - await queue.send_message("message3") - await queue.send_message("message4") - await queue.send_message("message5") - await queue.send_message("message6") - await queue.send_message("message7") - await queue.send_message("message8") - await queue.send_message("message9") - await queue.send_message("message10") - - # Receive messages one-by-one - messages = queue.receive_messages(max_messages=5) - async for msg in messages: - print(msg.content) - await queue.delete_message(msg) - - # Only prints 5 messages because 'max_messages'=5 - # >>message1 - # >>message2 - # >>message3 - # >>message4 - # >>message5 - - finally: - # Delete the queue - await queue.delete_queue() + # Create the queue + async with queue: + await queue.create_queue() + + try: + await queue.send_message("message1") + await queue.send_message("message2") + await queue.send_message("message3") + await queue.send_message("message4") + await queue.send_message("message5") + await queue.send_message("message6") + await queue.send_message("message7") + await queue.send_message("message8") + await queue.send_message("message9") + await queue.send_message("message10") + + # Receive messages one-by-one + messages = queue.receive_messages(max_messages=5) + async for msg in messages: + print(msg.content) + await queue.delete_message(msg) + + # Only prints 5 messages because 'max_messages'=5 + # >>message1 + # >>message2 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + await queue.delete_queue() + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def main(): diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py index 990498d71025..b2c46dbd8d5a 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py @@ -20,7 +20,7 @@ 1) AZURE_STORAGE_CONNECTION_STRING - the connection string to your storage account """ -import os +import os, sys class QueueServiceSamples(object): @@ -34,42 +34,45 @@ def queue_service_properties(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [START set_queue_service_properties] - # Create service properties - from typing import List, Optional - from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy - - # Create logging settings - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create metrics for requests statistics - hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create CORS rules - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] - max_age_in_seconds = 500 - exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] - allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] - cors_rule2 = CorsRule( - allowed_origins, - allowed_methods, - max_age_in_seconds=max_age_in_seconds, - exposed_headers=exposed_headers, - allowed_headers=allowed_headers - ) - - cors: Optional[List[CorsRule]] = [cors_rule1, cors_rule2] - - # Set the service properties - queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) # type: ignore - # [END set_queue_service_properties] - - # [START get_queue_service_properties] - properties = queue_service.get_service_properties() - # [END get_queue_service_properties] + # [START set_queue_service_properties] + # Create service properties + from typing import List, Optional + from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy + + # Create logging settings + logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create metrics for requests statistics + hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create CORS rules + cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) + allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] + allowed_methods = ['GET', 'PUT'] + max_age_in_seconds = 500 + exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] + allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] + cors_rule2 = CorsRule( + allowed_origins, + allowed_methods, + max_age_in_seconds=max_age_in_seconds, + exposed_headers=exposed_headers, + allowed_headers=allowed_headers + ) + + cors = [cors_rule1, cors_rule2] + + # Set the service properties + queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) + # [END set_queue_service_properties] + + # [START get_queue_service_properties] + properties = queue_service.get_service_properties() + # [END get_queue_service_properties] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def queues_in_account(self): # Instantiate the QueueServiceClient from a connection string @@ -77,27 +80,30 @@ def queues_in_account(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [START qsc_create_queue] - queue_service.create_queue("myqueue1") - # [END qsc_create_queue] - - try: - # [START qsc_list_queues] - # List all the queues in the service - list_queues = queue_service.list_queues() - for queue in list_queues: - print(queue) - - # List the queues in the service that start with the name "my" - list_my_queues = queue_service.list_queues(name_starts_with="my") - for queue in list_my_queues: - print(queue) - # [END qsc_list_queues] - - finally: - # [START qsc_delete_queue] - queue_service.delete_queue("myqueue1") - # [END qsc_delete_queue] + # [START qsc_create_queue] + queue_service.create_queue("myqueue1") + # [END qsc_create_queue] + + try: + # [START qsc_list_queues] + # List all the queues in the service + list_queues = queue_service.list_queues() + for queue in list_queues: + print(queue) + + # List the queues in the service that start with the name "my" + list_my_queues = queue_service.list_queues(name_starts_with="my") + for queue in list_my_queues: + print(queue) + # [END qsc_list_queues] + + finally: + # [START qsc_delete_queue] + queue_service.delete_queue("myqueue1") + # [END qsc_delete_queue] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) def get_queue_client(self): # Instantiate the QueueServiceClient from a connection string @@ -105,10 +111,13 @@ def get_queue_client(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [START get_queue_client] - # Get the queue client to interact with a specific queue - queue = queue_service.get_queue_client(queue="myqueue2") - # [END get_queue_client] + # [START get_queue_client] + # Get the queue client to interact with a specific queue + queue = queue_service.get_queue_client(queue="myqueue2") + # [END get_queue_client] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) if __name__ == '__main__': diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py index 409379af9a45..9f823ea55510 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py @@ -22,7 +22,7 @@ import asyncio -import os +import os, sys class QueueServiceSamplesAsync(object): @@ -35,42 +35,45 @@ async def queue_service_properties_async(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - async with queue_service: - # [START async_set_queue_service_properties] - # Create service properties - from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy - - # Create logging settings - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create metrics for requests statistics - hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create CORS rules - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] - max_age_in_seconds = 500 - exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] - allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] - cors_rule2 = CorsRule( - allowed_origins, - allowed_methods, - max_age_in_seconds=max_age_in_seconds, - exposed_headers=exposed_headers, - allowed_headers=allowed_headers - ) - - cors = [cors_rule1, cors_rule2] - - # Set the service properties - await queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) # type: ignore - # [END async_set_queue_service_properties] - - # [START async_get_queue_service_properties] - properties = await queue_service.get_service_properties() - # [END async_get_queue_service_properties] + async with queue_service: + # [START async_set_queue_service_properties] + # Create service properties + from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy + + # Create logging settings + logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create metrics for requests statistics + hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create CORS rules + cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) + allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] + allowed_methods = ['GET', 'PUT'] + max_age_in_seconds = 500 + exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] + allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] + cors_rule2 = CorsRule( + allowed_origins, + allowed_methods, + max_age_in_seconds=max_age_in_seconds, + exposed_headers=exposed_headers, + allowed_headers=allowed_headers + ) + + cors = [cors_rule1, cors_rule2] + + # Set the service properties + await queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) + # [END async_set_queue_service_properties] + + # [START async_get_queue_service_properties] + properties = await queue_service.get_service_properties() + # [END async_get_queue_service_properties] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def queues_in_account_async(self): # Instantiate the QueueServiceClient from a connection string @@ -78,28 +81,31 @@ async def queues_in_account_async(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - async with queue_service: - # [START async_qsc_create_queue] - await queue_service.create_queue("myqueue1") - # [END async_qsc_create_queue] - - try: - # [START async_qsc_list_queues] - # List all the queues in the service - list_queues = queue_service.list_queues() - async for queue in list_queues: - print(queue) - - # List the queues in the service that start with the name "my_" - list_my_queues = queue_service.list_queues(name_starts_with="my_") - async for queue in list_my_queues: - print(queue) - # [END async_qsc_list_queues] - - finally: - # [START async_qsc_delete_queue] - await queue_service.delete_queue("myqueue1") - # [END async_qsc_delete_queue] + async with queue_service: + # [START async_qsc_create_queue] + await queue_service.create_queue("myqueue1") + # [END async_qsc_create_queue] + + try: + # [START async_qsc_list_queues] + # List all the queues in the service + list_queues = queue_service.list_queues() + async for queue in list_queues: + print(queue) + + # List the queues in the service that start with the name "my_" + list_my_queues = queue_service.list_queues(name_starts_with="my_") + async for queue in list_my_queues: + print(queue) + # [END async_qsc_list_queues] + + finally: + # [START async_qsc_delete_queue] + await queue_service.delete_queue("myqueue1") + # [END async_qsc_delete_queue] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def get_queue_client_async(self): # Instantiate the QueueServiceClient from a connection string @@ -107,10 +113,13 @@ async def get_queue_client_async(self): if self.connection_string is not None: queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [START async_get_queue_client] - # Get the queue client to interact with a specific queue - queue = queue_service.get_queue_client(queue="myqueue2") - # [END async_get_queue_client] + # [START async_get_queue_client] + # Get the queue client to interact with a specific queue + queue = queue_service.get_queue_client(queue="myqueue2") + # [END async_get_queue_client] + else: + print("Missing required enviornment variable(s). Please see specific test for more details.") + sys.exit(1) async def main(): From d8d96031d5990be23f8b9715c32040d8b124275f Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Mon, 11 Sep 2023 15:52:00 -0700 Subject: [PATCH 35/71] CI --- .../azure-storage-queue/azure/storage/queue/_encryption.py | 2 +- .../azure-storage-queue/azure/storage/queue/_shared/policies.py | 2 +- .../azure/storage/queue/_shared/policies_async.py | 2 +- .../azure/storage/queue/aio/_queue_client_async.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index c757253109bf..b3189c887a05 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -58,7 +58,7 @@ def wrap_key(self, key): """ ... - def unwrap_key(self, key, algorithm): + def unwrap_key(self, key, algorithm): """ Unwraps the specified key using an algorithm of the user's choice. :param str key: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 8eb25017ed75..9aa71d758a07 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -50,7 +50,7 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential - from azure.core.pipeline.transport import PipelineRequest, PipelineResponse + from azure.core.pipeline.transport import PipelineRequest, PipelineResponse # pylint: disable=C4750 _LOGGER = logging.getLogger(__name__) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 0b6bfe1fa817..c8da6ec11874 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential - from azure.core.pipeline.transport import PipelineRequest, PipelineResponse + from azure.core.pipeline.transport import PipelineRequest, PipelineResponse # pylint: disable=C4750 _LOGGER = logging.getLogger(__name__) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 0061ecdb3f45..bb19714183f4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -166,7 +166,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient From 3ce0c84a0d3c98be7d809291e402af44ed1d36c3 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Tue, 12 Sep 2023 13:42:47 -0700 Subject: [PATCH 36/71] Fix weird invisible unicode char --- .../azure/storage/queue/aio/_queue_client_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index bb19714183f4..dda48d8f4096 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -166,7 +166,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient From ab30dc33793cfa56a1e66f0768f48687d612fd8e Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Tue, 12 Sep 2023 16:43:12 -0700 Subject: [PATCH 37/71] No mypy errors hopefully no CI --- .../azure/storage/queue/_models.py | 13 +++++++------ .../azure/storage/queue/_queue_client.py | 16 +++++++++------- .../azure/storage/queue/_queue_service_client.py | 3 +-- .../azure/storage/queue/_shared/models.py | 2 -- .../storage/queue/aio/_queue_client_async.py | 10 +++++----- .../queue/aio/_queue_service_client_async.py | 9 ++++----- 6 files changed, 26 insertions(+), 27 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 4181d78907e4..de93064c917a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -312,12 +312,12 @@ class AccessPolicy(GenAccessPolicy): be UTC. """ - permission: Optional[QueueSasPermissions] = None + permission: Optional[QueueSasPermissions] = None # type: ignore [assignment] """The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions.""" - expiry: Optional[Union["datetime", str]] = None + expiry: Optional[Union["datetime", str]] = None # type: ignore [assignment] """The time at which the shared access signature becomes invalid.""" - start: Optional[Union["datetime", str]] = None + start: Optional[Union["datetime", str]] = None # type: ignore [assignment] """The time at which the shared access signature becomes valid.""" def __init__( @@ -338,11 +338,11 @@ class QueueMessage(DictMixin): identifies the message in the queue. This value may be used together with the value of pop_receipt to delete a message from the queue after it has been retrieved with the receive messages operation.""" - inserted_on: "datetime" + inserted_on: Optional["datetime"] """A UTC date value representing the time the messages was inserted.""" - expires_on: "datetime" + expires_on: Optional["datetime"] """A UTC date value representing the time the message expires.""" - dequeue_count: int + dequeue_count: Optional[int] """Begins with a value of 1 the first time the message is received. This value is incremented each time the message is subsequently received.""" content: Any @@ -358,6 +358,7 @@ class QueueMessage(DictMixin): Only returned by receive messages operations. Set to None for peek messages.""" def __init__(self, content: Any = None) -> None: + self.id = None # type: ignore [assignment] self.content = content self.pop_receipt = None self.next_visible_on = None diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 4f8d64e9103d..29e69fa4107e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -563,10 +563,11 @@ def receive_message( self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: message = self._client.messages.dequeue( number_of_messages=1, @@ -647,7 +648,6 @@ def receive_messages( :caption: Receive messages from the queue. """ timeout = kwargs.pop('timeout', None) - max_messages = kwargs.pop('max_messages', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( self._config.user_agent_policy.user_agent, @@ -655,10 +655,11 @@ def receive_messages( self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: command = functools.partial( self._client.messages.dequeue, @@ -856,10 +857,11 @@ def peek_messages( self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: messages = self._client.messages.peek( number_of_messages=max_messages, diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 7c168d8df78a..050d373643b7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -238,7 +238,6 @@ def set_service_properties( :type cors: Optional[List[~azure.storage.queue.CorsRule]] :keyword int timeout: The timeout parameter is expressed in seconds. - :returns: None .. admonition:: Example: @@ -254,7 +253,7 @@ def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=cors + cors=cors # type: ignore [arg-type] ) try: self._client.service.set_properties(props, timeout=timeout, **kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 3301b8935d20..43ea08a477e5 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -556,7 +556,6 @@ class StorageConfiguration(Configuration): max_chunk_get_size: int max_range_size: int user_agent_policy: UserAgentPolicy - message_decode_policy: MessageDecodePolicy def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) @@ -570,4 +569,3 @@ def __init__(self, **kwargs): self.max_single_get_size = 32 * 1024 * 1024 self.max_chunk_get_size = 4 * 1024 * 1024 self.max_range_size = 4 * 1024 * 1024 - self.message_decode_policy = NoDecodePolicy() diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index dda48d8f4096..fba5035f4363 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -568,10 +568,11 @@ async def receive_message( # type: ignore[override] self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: message = await self._client.messages.dequeue( number_of_messages=1, @@ -643,7 +644,6 @@ def receive_messages( # type: ignore[override] :caption: Receive messages from the queue. """ timeout = kwargs.pop('timeout', None) - max_messages = kwargs.pop('max_messages', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( self._config.user_agent_policy.user_agent, @@ -651,7 +651,7 @@ def receive_messages( # type: ignore[override] self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -856,7 +856,7 @@ async def peek_messages( # type: ignore[override] self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self.message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 5804bbc38dd9..48b164e48285 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -35,8 +35,7 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential from azure.core.credentials_async import AsyncTokenCredential - from .._generated.models import CorsRule - from .._models import Metrics, QueueAnalyticsLogging + from .._models import CorsRule, Metrics, QueueAnalyticsLogging class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long @@ -144,7 +143,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, credential=credential, **kwargs) #type: ignore [arg-type] + return cls(account_url, credential=credential, **kwargs) # type: ignore [arg-type] @distributed_trace_async async def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: @@ -235,7 +234,7 @@ async def set_service_properties( You can include up to five CorsRule elements in the list. If an empty list is specified, all CORS rules will be deleted, and CORS will be disabled for the service. - :type cors: Optional[list(~azure.storage.queue.CorsRule)] + :type cors: Optional[List(~azure.storage.queue.CorsRule)] :keyword int timeout: The timeout parameter is expressed in seconds. @@ -253,7 +252,7 @@ async def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=cors + cors=cors # type: ignore [arg-type] ) try: await self._client.service.set_properties(props, timeout=timeout, **kwargs) From 151935652bb9bbfc829571b94029236ff5b088a8 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Wed, 13 Sep 2023 17:22:28 -0700 Subject: [PATCH 38/71] Pylint and CI, but now hinting issues --- .../azure/storage/queue/_encryption.py | 13 ++++++------- .../azure/storage/queue/_shared/models.py | 1 - 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index b3189c887a05..daa9f1c28630 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -718,11 +718,11 @@ def _decrypt_message( # decrypt data decrypted_data = message decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) #type: ignore + decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) # unpad data unpadder = PKCS7(128).unpadder() - decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) #type: ignore + decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) elif encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: block_info = encryption_data.encrypted_region_info @@ -735,12 +735,11 @@ def _decrypt_message( nonce_length = encryption_data.encrypted_region_info.nonce_length # First bytes are the nonce - message_as_bytes = message.encode('utf-8') - nonce = message_as_bytes[:nonce_length] - ciphertext_with_tag = message_as_bytes[nonce_length:] + nonce = message[:nonce_length] + ciphertext_with_tag = message[nonce_length:] aesgcm = AESGCM(content_encryption_key) - decrypted_data = (aesgcm.decrypt(nonce, ciphertext_with_tag, None)).decode() + decrypted_data = (aesgcm.decrypt(nonce, ciphertext_with_tag, None)) else: raise ValueError('Specified encryption version is not supported.') @@ -1111,7 +1110,7 @@ def decrypt_queue_message( return message try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver) + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') except Exception as error: raise HttpResponseError( message="Decryption failed.", diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 43ea08a477e5..368971b2eb86 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -10,7 +10,6 @@ from azure.core import CaseInsensitiveEnumMeta from azure.core.configuration import Configuration from azure.core.pipeline.policies import UserAgentPolicy -from .._message_encoding import MessageDecodePolicy, NoDecodePolicy def get_enum_value(value): From 96c739eb3d166f1677c2f9998de793e138653e33 Mon Sep 17 00:00:00 2001 From: Vincent Tran Date: Tue, 19 Sep 2023 14:39:42 -0700 Subject: [PATCH 39/71] Fix Lint which should fix CI --- .../azure/storage/queue/_shared/base_client.py | 4 ++-- .../azure/storage/queue/_shared/base_client_async.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index a67d7192c322..0e9073534ace 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -433,11 +433,11 @@ def parse_connection_str( def create_configuration(**kwargs): - # type: (**Any) -> Configuration + # type: (**Any) -> StorageConfiguration # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" - config = Configuration(**kwargs) + config = StorageConfiguration(**kwargs) config.headers_policy = StorageHeadersPolicy(**kwargs) config.user_agent_policy = UserAgentPolicy(**kwargs) config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 62fc6e43f77a..fc1c9b5e1923 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -44,7 +44,6 @@ from azure.core.pipeline.transport import HttpRequest from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential - from azure.core.configuration import Configuration _LOGGER = logging.getLogger(__name__) From e236f2eaf421056c0521a8db3fe0efcf727e8682 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 25 Sep 2023 16:28:29 -0700 Subject: [PATCH 40/71] Attempted overload --- .../azure/storage/queue/_message_encoding.py | 1 - .../azure/storage/queue/_models.py | 39 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index cfda11a90172..78e67f40949e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -# pylint: disable=unused-argument from base64 import b64encode, b64decode from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index de93064c917a..f379fe02873a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -7,7 +7,7 @@ # pylint: disable=super-init-not-called import sys -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TYPE_CHECKING, Union from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator from ._shared.response_handlers import return_context_and_deserialized, process_storage_error @@ -357,11 +357,40 @@ class QueueMessage(DictMixin): """A UTC date value representing the time the message will next be visible. Only returned by receive messages operations. Set to None for peek messages.""" - def __init__(self, content: Any = None) -> None: - self.id = None # type: ignore [assignment] + @overload + def __init__(self, *, id: str, content: Any = None) -> None: + ... + + @overload + def __init__( + self, *, + id: Optional[str] = None, + inserted_on: Optional["datetime"] = None, + expires_on: Optional["datetime"] = None, + dequeue_count: Optional[int] = None, + content: Any = None, + pop_receipt: Optional[str] = None, + next_visible_on: Optional["datetime"] = None + ) -> None: + ... + + def __init__( + self, *, + id: Optional[str] = None, + inserted_on: Optional["datetime"] = None, + expires_on: Optional["datetime"] = None, + dequeue_count: Optional[int] = None, + content: Any = None, + pop_receipt: Optional[str] = None, + next_visible_on: Optional["datetime"] = None + ) -> None: + self.id = id # type: ignore [assignment] + self.inserted_on = inserted_on + self.expires_on = expires_on + self.dequeue_count = dequeue_count self.content = content - self.pop_receipt = None - self.next_visible_on = None + self.pop_receipt = pop_receipt + self.next_visible_on = next_visible_on @classmethod def _from_generated(cls, generated: Any) -> Self: From fea8e482a3e1d5fbd5e12411f71427eb0f2de052 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 25 Sep 2023 17:45:33 -0700 Subject: [PATCH 41/71] Just hanging encryption and message_encoding Qs, fix import ordering --- .../azure/storage/queue/_deserialize.py | 2 +- .../azure/storage/queue/_encryption.py | 4 +- .../azure/storage/queue/_message_encoding.py | 17 ++++++--- .../azure/storage/queue/_models.py | 4 +- .../azure/storage/queue/_queue_client.py | 5 +-- .../storage/queue/_queue_service_client.py | 1 - .../storage/queue/_shared/base_client.py | 20 +++++----- .../queue/_shared/base_client_async.py | 32 +++++++++------- .../azure/storage/queue/_shared/parser.py | 2 +- .../azure/storage/queue/_shared/policies.py | 24 ++++++------ .../storage/queue/_shared/policies_async.py | 38 +++++++++++-------- .../queue/_shared/response_handlers.py | 20 ++++------ .../azure/storage/queue/aio/_models.py | 5 +-- .../storage/queue/aio/_queue_client_async.py | 3 +- .../queue/aio/_queue_service_client_async.py | 3 +- 15 files changed, 92 insertions(+), 88 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py index ff3f207b95e4..f2016049827e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py @@ -8,8 +8,8 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import ResourceExistsError -from ._shared.models import StorageErrorCode from ._models import QueueProperties +from ._shared.models import StorageErrorCode from ._shared.response_handlers import deserialize_metadata if TYPE_CHECKING: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index daa9f1c28630..d18f08c7b2ba 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -5,8 +5,8 @@ # license information. # -------------------------------------------------------------------------- -import os import math +import os import sys import warnings from collections import OrderedDict @@ -29,7 +29,7 @@ from azure.core.utils import CaseInsensitiveDict from ._version import VERSION -from ._shared import encode_base64, decode_base64_to_bytes +from ._shared import decode_base64_to_bytes, encode_base64 if TYPE_CHECKING: from azure.core.pipeline import PipelineResponse diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 78e67f40949e..ec7ca9df9ef8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -4,12 +4,12 @@ # license information. # -------------------------------------------------------------------------- -from base64 import b64encode, b64decode +from base64 import b64decode, b64encode from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union from azure.core.exceptions import DecodeError -from ._encryption import decrypt_queue_message, encrypt_queue_message, _ENCRYPTION_PROTOCOL_V1, KeyEncryptionKey +from ._encryption import decrypt_queue_message, encrypt_queue_message, KeyEncryptionKey, _ENCRYPTION_PROTOCOL_V1 if TYPE_CHECKING: from azure.core.pipeline import PipelineResponse @@ -58,7 +58,14 @@ def encode(self, content: Any) -> str: class MessageDecodePolicy(object): - def __init__(self): + require_encryption: Optional[bool] = None + """Indicates whether encryption is required or not.""" + key_encryption_key: Optional[KeyEncryptionKey] = None + """The user-provided key-encryption-key.""" + resolver: Optional[Callable[[str], bytes]] = None + """The user-provided key resolver.""" + + def __init__(self) -> None: self.require_encryption = False self.key_encryption_key = None self.resolver = None @@ -159,9 +166,7 @@ class NoEncodePolicy(MessageEncodePolicy): def encode(self, content: str) -> str: if isinstance(content, bytes): - raise TypeError( - "Message content must not be bytes. Use the BinaryBase64EncodePolicy to send bytes." - ) + raise TypeError("Message content must not be bytes. Use the BinaryBase64EncodePolicy to send bytes.") return content diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index f379fe02873a..788d8737f822 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -10,13 +10,13 @@ from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TYPE_CHECKING, Union from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator -from ._shared.response_handlers import return_context_and_deserialized, process_storage_error +from ._shared.response_handlers import process_storage_error, return_context_and_deserialized from ._shared.models import DictMixin from ._generated.models import AccessPolicy as GenAccessPolicy +from ._generated.models import CorsRule as GeneratedCorsRule from ._generated.models import Logging as GeneratedLogging from ._generated.models import Metrics as GeneratedMetrics from ._generated.models import RetentionPolicy as GeneratedRetentionPolicy -from ._generated.models import CorsRule as GeneratedCorsRule if sys.version_info >= (3, 11): from typing import Self # pylint: disable=no-name-in-module, ungrouped-imports diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 29e69fa4107e..23ba7eb50fd2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -8,9 +8,8 @@ import warnings from typing import ( Any, cast, Dict, List, Optional, - TYPE_CHECKING, Tuple, Union + Tuple, TYPE_CHECKING, Union ) - from typing_extensions import Self from azure.core.exceptions import HttpResponseError @@ -22,7 +21,7 @@ from ._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier from ._message_encoding import NoDecodePolicy, NoEncodePolicy from ._models import AccessPolicy, MessagesPaged, QueueMessage -from ._queue_client_helpers import _parse_url, _format_url_helper, _from_queue_url_helper +from ._queue_client_helpers import _format_url_helper, _from_queue_url_helper, _parse_url from ._serialize import get_api_version from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin from ._shared.request_handlers import add_metadata_headers, serialize_iso diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 050d373643b7..c811e826ceef 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -9,7 +9,6 @@ Any, Dict, List, Optional, TYPE_CHECKING, Union ) - from typing_extensions import Self from azure.core.exceptions import HttpResponseError diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 0e9073534ace..36242700faa2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -5,7 +5,7 @@ # -------------------------------------------------------------------------- import logging import uuid -from typing import ( # pylint: disable=unused-import +from typing import ( Any, Dict, Iterator, @@ -18,13 +18,13 @@ try: from urllib.parse import parse_qs, quote except ImportError: - from urlparse import parse_qs # type: ignore - from urllib2 import quote # type: ignore + from urlparse import parse_qs # type: ignore [no-redef] + from urllib2 import quote # type: ignore [no-redef] from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline -from azure.core.pipeline.transport import RequestsTransport, HttpTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module from azure.core.pipeline.policies import ( AzureSasCredentialPolicy, BearerTokenCredentialPolicy, @@ -36,11 +36,9 @@ UserAgentPolicy, ) +from .authentication import SharedKeyCredentialPolicy from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE from .models import LocationMode, StorageConfiguration -from .authentication import SharedKeyCredentialPolicy -from .shared_access_signature import QueryStringConstants -from .request_handlers import serialize_batch_body, _get_batch_request_delimiter from .policies import ( ExponentialRetry, QueueMessagePolicy, @@ -51,8 +49,10 @@ StorageRequestHook, StorageResponseHook, ) +from .request_handlers import serialize_batch_body, _get_batch_request_delimiter +from .response_handlers import PartialBatchErrorException, process_storage_error +from .shared_access_signature import QueryStringConstants from .._version import VERSION -from .response_handlers import process_storage_error, PartialBatchErrorException if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -268,7 +268,6 @@ def _create_pipeline( config.transport = transport # type: ignore return config, Pipeline(transport, policies=policies) - # Given a series of request, do a Storage batch call. def _batch_send( self, *reqs: "HttpRequest", @@ -432,8 +431,7 @@ def parse_connection_str( -def create_configuration(**kwargs): - # type: (**Any) -> StorageConfiguration +def create_configuration(**kwargs: Any) -> StorageConfiguration: # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index fc1c9b5e1923..d2dad1d30a94 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -5,16 +5,16 @@ # -------------------------------------------------------------------------- # mypy: disable-error-code="attr-defined" -from typing import ( # pylint: disable=unused-import - Union, Optional, Any, Iterable, Dict, List, Type, Tuple, - TYPE_CHECKING +from typing import ( + Any, Dict, Iterator, Optional, + Tuple, TYPE_CHECKING, Union ) import logging -from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential -from azure.core.pipeline import AsyncPipeline from azure.core.async_paging import AsyncList +from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential from azure.core.exceptions import HttpResponseError +from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import ( AsyncBearerTokenCredentialPolicy, AsyncRedirectPolicy, @@ -25,9 +25,10 @@ ) from azure.core.pipeline.transport import AsyncHttpTransport -from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT, STORAGE_OAUTH_SCOPE from .authentication import SharedKeyCredentialPolicy from .base_client import create_configuration +from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT, STORAGE_OAUTH_SCOPE +from .models import StorageConfiguration from .policies import ( QueueMessagePolicy, StorageContentValidation, @@ -36,14 +37,12 @@ StorageRequestHook, ) from .policies_async import AsyncStorageResponseHook -from .models import StorageConfiguration - -from .response_handlers import process_storage_error, PartialBatchErrorException +from .response_handlers import PartialBatchErrorException, process_storage_error if TYPE_CHECKING: - from azure.core.pipeline.transport import HttpRequest from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline.transport import HttpRequest, HttpResponse _LOGGER = logging.getLogger(__name__) @@ -116,12 +115,17 @@ def _create_pipeline( config.transport = transport #type: ignore return config, AsyncPipeline(transport, policies=policies) #type: ignore - # Given a series of request, do a Storage batch call. async def _batch_send( self, - *reqs, # type: HttpRequest - **kwargs - ): + *reqs: "HttpRequest", + **kwargs: Any + ) -> AsyncList["HttpResponse"]: + """Given a series of request, do a Storage batch call. + + :param HttpRequest reqs: A collection of HttpRequest objects. + :returns: An AsyncList of HttpResponse objects. + :rtype: AsyncList[HttpResponse] + """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) client = self._client diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py index da4ef58cc394..cd59cfe104ca 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py @@ -21,7 +21,7 @@ def _str(value): _str = str -def _to_utc_datetime(value): +def _to_utc_datetime(value: datetime) -> str: return value.strftime('%Y-%m-%dT%H:%M:%SZ') def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 9aa71d758a07..68daa530e09a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -6,29 +6,30 @@ import base64 import hashlib -import re -import random -from time import time -from io import SEEK_SET, UnsupportedOperation import logging +import random +import re import uuid -from typing import Any, Dict, Optional, Union, TYPE_CHECKING +from io import SEEK_SET, UnsupportedOperation +from time import time +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from wsgiref.handlers import format_date_time try: from urllib.parse import ( - urlparse, parse_qsl, - urlunparse, urlencode, + urlparse, + urlunparse, ) except ImportError: from urllib import urlencode # type: ignore from urlparse import ( # type: ignore - urlparse, parse_qsl, - urlunparse, + urlparse, + urlunparse ) +from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import ( BearerTokenCredentialPolicy, HeadersPolicy, @@ -37,7 +38,6 @@ RequestHistory, SansIOHTTPPolicy, ) -from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError from .authentication import StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE, STORAGE_OAUTH_SCOPE @@ -134,7 +134,7 @@ def on_request(self, request): class StorageHeadersPolicy(HeadersPolicy): request_id_header_name = 'x-ms-client-request-id' - def on_request(self, request: "PipelineRequest")-> None: + def on_request(self, request: "PipelineRequest") -> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) request.http_request.headers['x-ms-date'] = current_time @@ -349,7 +349,7 @@ class StorageContentValidation(SansIOHTTPPolicy): """ header_name = 'Content-MD5' - def __init__(self, **kwargs): # pylint: disable=unused-argument + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @staticmethod diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index c8da6ec11874..207cd1320c02 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -6,12 +6,12 @@ # pylint: disable=invalid-overridden-method import asyncio -import random import logging -from typing import Any, TYPE_CHECKING +import random +from typing import Any, Dict, TYPE_CHECKING -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy from azure.core.exceptions import AzureError +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy from .authentication import StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE, STORAGE_OAUTH_SCOPE @@ -45,8 +45,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument self._response_callback = kwargs.get('raw_response_hook') super(AsyncStorageResponseHook, self).__init__() - async def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + async def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 data_stream_total = request.context.get('data_stream_total') if data_stream_total is None: @@ -145,9 +144,12 @@ async def send(self, request): class ExponentialRetry(AsyncStorageRetryPolicy): """Exponential retry.""" - def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, - retry_to_secondary=False, random_jitter_range=3, **kwargs): - ''' + def __init__( + self, initial_backoff: int = 15, + increment_base: int = 3, retry_total: int = 3, + retry_to_secondary: bool = False, random_jitter_range: int = 3, **kwargs + ) -> None: + """ Constructs an Exponential retry object. The initial_backoff is used for the first retry. Subsequent retries are retried after initial_backoff + increment_power^retry_count seconds. For example, by default the first retry @@ -168,14 +170,14 @@ def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, :param int random_jitter_range: A number in seconds which indicates a range to jitter/randomize for the back-off interval. For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. - ''' + """ self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range super(ExponentialRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. @@ -195,7 +197,13 @@ def get_backoff_time(self, settings): class LinearRetry(AsyncStorageRetryPolicy): """Linear retry.""" - def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_jitter_range=3, **kwargs): + def __init__( + self, backoff: int = 15, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: """ Constructs a Linear retry object. @@ -216,7 +224,7 @@ def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_j super(LinearRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. @@ -238,12 +246,10 @@ def get_backoff_time(self, settings): class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """ Custom Bearer token credential policy for following Storage Bearer challenges """ - def __init__(self, credential, **kwargs): - # type: (AsyncTokenCredential, **Any) -> None + def __init__(self, credential: AsyncTokenCredential, **kwargs: Any) -> None: super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, STORAGE_OAUTH_SCOPE, **kwargs) - async def on_challenge(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> bool + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py index 0103c7a3d0ed..313009997a8e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py @@ -3,26 +3,22 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import NoReturn, TYPE_CHECKING import logging +from typing import NoReturn from xml.etree.ElementTree import Element -from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.exceptions import ( + ClientAuthenticationError, + DecodeError, HttpResponseError, - ResourceNotFoundError, - ResourceModifiedError, ResourceExistsError, - ClientAuthenticationError, - DecodeError) + ResourceModifiedError, + ResourceNotFoundError, +) +from azure.core.pipeline.policies import ContentDecodePolicy -from .parser import _to_utc_datetime from .models import StorageErrorCode, UserDelegationKey, get_enum_value - - -if TYPE_CHECKING: - from datetime import datetime - from azure.core.exceptions import AzureError +from .parser import _to_utc_datetime _LOGGER = logging.getLogger(__name__) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index 027ab28d372f..564558eec50b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -7,12 +7,11 @@ # pylint: disable=super-init-not-called from typing import Any, Callable, List, Optional, Tuple + from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError -from .._shared.response_handlers import ( - process_storage_error, - return_context_and_deserialized) from .._models import QueueMessage, QueueProperties +from .._shared.response_handlers import process_storage_error, return_context_and_deserialized class MessagesPaged(AsyncPageIterator): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index fba5035f4363..1990640e9ef4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -11,7 +11,6 @@ Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union ) - from typing_extensions import Self from azure.core.async_paging import AsyncItemPaged @@ -25,7 +24,7 @@ from .._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier from .._message_encoding import NoDecodePolicy, NoEncodePolicy from .._models import AccessPolicy, QueueMessage -from .._queue_client_helpers import _parse_url, _format_url_helper, _from_queue_url_helper +from .._queue_client_helpers import _format_url_helper, _from_queue_url_helper, _parse_url from .._serialize import get_api_version from .._shared.base_client import parse_connection_str, StorageAccountHostsMixin from .._shared.base_client_async import AsyncStorageAccountHostsMixin diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 48b164e48285..6a66f4e982ab 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -10,7 +10,6 @@ Any, Dict, List, Optional, TYPE_CHECKING, Union ) - from typing_extensions import Self from azure.core.async_paging import AsyncItemPaged @@ -28,8 +27,8 @@ from .._serialize import get_api_version from .._shared.base_client import parse_connection_str, StorageAccountHostsMixin from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper -from .._shared.policies_async import ExponentialRetry from .._shared.models import LocationMode +from .._shared.policies_async import ExponentialRetry from .._shared.response_handlers import process_storage_error if TYPE_CHECKING: From e9173f2554ee36e151ec43d1293c23ef047aeafe Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Thu, 28 Sep 2023 16:28:21 -0700 Subject: [PATCH 42/71] Addressed encryption feedback --- .../azure/storage/queue/_encryption.py | 144 ++++++++---------- .../azure/storage/queue/_message_encoding.py | 10 +- .../storage/queue/_shared/policies_async.py | 2 +- 3 files changed, 72 insertions(+), 84 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index d18f08c7b2ba..98693f5f9921 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -48,36 +48,17 @@ class KeyEncryptionKey(Protocol): - """Protocol that defines what calling functions should be defined for a user-provided key-encryption-key (kek).""" - def wrap_key(self, key): - """ - Wraps the specified key using an algorithm of the user's choice. - :param str key: - The user-provided key to be encrypted. - """ + def wrap_key(self, key: bytes) -> bytes: ... - def unwrap_key(self, key, algorithm): - """ - Unwraps the specified key using an algorithm of the user's choice. - :param str key: - The user-provided key to be unencrypted. - :param str algorithm: - The algorithm used to encrypt the key. This specifies what algorithm to use for the unwrap operation. - """ + def unwrap_key(self, key: bytes, algorithm: str) -> bytes: ... - def get_kid(self): - """ - Returns the key ID as specified by the user. - """ + def get_kid(self) -> str: ... - def get_key_wrap_algorithm(self): - """ - Returns the key wrap algorithm as specified by the user. - """ + def get_key_wrap_algorithm(self) -> str: ... @@ -415,7 +396,7 @@ def get_adjusted_download_range_and_offset( start_offset, end_offset = 0, end if encryption_data.encrypted_region_info is None: - raise ValueError("Field required for V2 encryption is missing (encrypted_region_info).") + raise ValueError("Missing required metadata for Encryption V2") nonce_length = encryption_data.encrypted_region_info.nonce_length data_length = encryption_data.encrypted_region_info.data_length @@ -526,8 +507,8 @@ def _generate_encryption_data_dict( encryption_agent['EncryptionAlgorithm'] = _EncryptionAlgorithm.AES_GCM_256 encrypted_region_info = OrderedDict() - encrypted_region_info['DataLength'] = _GCM_REGION_DATA_LENGTH - encrypted_region_info['NonceLength'] = _GCM_NONCE_LENGTH + encrypted_region_info['DataLength'] = str(_GCM_REGION_DATA_LENGTH) + encrypted_region_info['NonceLength'] = str(_GCM_NONCE_LENGTH) encryption_data_dict = OrderedDict() encryption_data_dict['WrappedContentKey'] = wrapped_content_key @@ -611,15 +592,15 @@ def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: def _validate_and_unwrap_cek( encryption_data: _EncryptionData, - key_encryption_key: KeyEncryptionKey, - key_resolver: Optional[Callable[[str], bytes]] = None + key_encryption_key: Optional[KeyEncryptionKey] = None, + key_resolver: Optional[Callable[[str], KeyEncryptionKey]] = None ) -> bytes: """ Extracts and returns the content_encryption_key stored in the encryption_data object and performs necessary validation on all parameters. :param _EncryptionData encryption_data: The encryption metadata of the retrieved value. - :param KeyEncryptionKey key_encryption_key: + :param Optional[KeyEncryptionKey] key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: wrap_key(key) - Wraps the specified key using an algorithm of the user's choice. @@ -627,7 +608,7 @@ def _validate_and_unwrap_cek( - Returns the algorithm used to wrap the specified symmetric key. get_kid() - Returns a string key id for this key-encryption-key. - :param Optional[Callable[[str], bytes]] key_resolver: + :param Optional[Callable[[str], KeyEncryptionKey]] key_resolver: A function used that, given a key_id, will return a key_encryption_key. Please refer to high-level service object instance variables for more details. :return: The content_encryption_key stored in the encryption_data object. @@ -644,13 +625,14 @@ def _validate_and_unwrap_cek( else: raise ValueError('Specified encryption version is not supported.') - content_encryption_key = b'' + content_encryption_key: Optional[bytes] = None # If the resolver exists, give priority to the key it finds. if key_resolver is not None: - key_encryption_key = key_resolver(encryption_data.wrapped_content_key.key_id) #type: ignore [assignment] + key_encryption_key = key_resolver(encryption_data.wrapped_content_key.key_id) - _validate_not_none('key_encryption_key', key_encryption_key) + if key_encryption_key is None: + raise ValueError("Unable to decrypt. key_resolver and key_encryption_key cannot both be None.") if not hasattr(key_encryption_key, 'get_kid') or not callable(key_encryption_key.get_kid): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) if not hasattr(key_encryption_key, 'unwrap_key') or not callable(key_encryption_key.unwrap_key): @@ -678,21 +660,21 @@ def _validate_and_unwrap_cek( def _decrypt_message( - message: str, + message: bytes, encryption_data: _EncryptionData, - key_encryption_key: KeyEncryptionKey, - resolver: Optional[Callable] = None -) -> str: + key_encryption_key: Optional[KeyEncryptionKey] = None, + resolver: Optional[Callable[[str], KeyEncryptionKey]] = None +) -> bytes: """ Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. Unwraps the content-encryption-key using the user-provided or resolved key-encryption-key (kek). Returns the original plaintext. - :param str message: + :param bytes message: The ciphertext to be decrypted. :param _EncryptionData encryption_data: The metadata associated with this ciphertext. - :param KeyEncryptionKey key_encryption_key: + :param Optional[KeyEncryptionKey] key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: wrap_key(key) - Wraps the specified key using an algorithm of the user's choice. @@ -700,11 +682,11 @@ def _decrypt_message( - Returns the algorithm used to wrap the specified symmetric key. get_kid() - Returns a string key id for this key-encryption-key. - :param Callable resolver: + :param Optional[Callable[[str], KeyEncryptionKey]] resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The decrypted plaintext. - :rtype: str + :rtype: bytes """ _validate_not_none('message', message) content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, resolver) @@ -716,9 +698,8 @@ def _decrypt_message( cipher = _generate_AES_CBC_cipher(content_encryption_key, encryption_data.content_encryption_IV) # decrypt data - decrypted_data = message decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) + decrypted_data = (decryptor.update(message) + decryptor.finalize()) # unpad data unpadder = PKCS7(128).unpadder() @@ -730,7 +711,7 @@ def _decrypt_message( raise ValueError("Missing required metadata for decryption.") if encryption_data.encrypted_region_info is None: - raise ValueError("Field required for V2 encryption is missing (encrypted_region_info).") + raise ValueError("Missing required metadata for Encryption V2") nonce_length = encryption_data.encrypted_region_info.nonce_length @@ -739,7 +720,7 @@ def _decrypt_message( ciphertext_with_tag = message[nonce_length:] aesgcm = AESGCM(content_encryption_key) - decrypted_data = (aesgcm.decrypt(nonce, ciphertext_with_tag, None)) + decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) else: raise ValueError('Specified encryption version is not supported.') @@ -810,41 +791,43 @@ def encrypt_blob(blob: bytes, key_encryption_key: KeyEncryptionKey, version: str def generate_blob_encryption_data( - key_encryption_key: KeyEncryptionKey, - version: str - ) -> Tuple[bytes, Optional[bytes], str]: + key_encryption_key: Optional[KeyEncryptionKey], + version: str +) -> Tuple[Optional[bytes], Optional[bytes], Optional[str]]: """ Generates the encryption_metadata for the blob. - :param KeyEncryptionKey key_encryption_key: + :param Optional[KeyEncryptionKey] key_encryption_key: The key-encryption-key used to wrap the cek associate with this blob. :param str version: The client encryption version to use. :return: A tuple containing the cek and iv for this blob as well as the serialized encryption metadata for the blob. - :rtype: (bytes, Optional[bytes], str) + :rtype: (Optional[bytes], Optional[bytes], Optional[str]) """ + encryption_data = None + content_encryption_key = None initialization_vector = None + if key_encryption_key: + _validate_key_encryption_key_wrap(key_encryption_key) + content_encryption_key = os.urandom(32) + # Initialization vector only needed for V1 + if version == _ENCRYPTION_PROTOCOL_V1: + initialization_vector = os.urandom(16) + encryption_data_dict = _generate_encryption_data_dict(key_encryption_key, + content_encryption_key, + initialization_vector, + version) + encryption_data_dict['EncryptionMode'] = 'FullBlob' + encryption_data = dumps(encryption_data_dict) - _validate_key_encryption_key_wrap(key_encryption_key) - content_encryption_key = os.urandom(32) - # Initialization vector only needed for V1 - if version == _ENCRYPTION_PROTOCOL_V1: - initialization_vector = os.urandom(16) - encryption_data = _generate_encryption_data_dict(key_encryption_key, - content_encryption_key, - initialization_vector, - version) - encryption_data['EncryptionMode'] = 'FullBlob' - encryption_data_dump = dumps(encryption_data) - - return content_encryption_key, initialization_vector, encryption_data_dump + return content_encryption_key, initialization_vector, encryption_data def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements require_encryption: bool, key_encryption_key: KeyEncryptionKey, - key_resolver: Callable[[str], bytes], + key_resolver: Optional[Callable[[str], KeyEncryptionKey]], content: bytes, start_offset: int, end_offset: int, @@ -863,9 +846,10 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements - Returns the algorithm used to wrap the specified symmetric key. get_kid() - Returns a string key id for this key-encryption-key. - :param object key_resolver: + :param key_resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. + :paramtype key_resolver: Optional[Callable[[str], KeyEncryptionKey]] :param bytes content: The encrypted blob content. :param int start_offset: @@ -901,7 +885,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements if version == _ENCRYPTION_PROTOCOL_V1: blob_type = response_headers['x-ms-blob-type'] - iv = b'' + iv: Optional[bytes] = None unpad = False if 'content-range' in response_headers: content_range = response_headers['content-range'] @@ -920,17 +904,20 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements content = content[16:] start_offset -= 16 else: - iv = encryption_data.content_encryption_IV #type: ignore [assignment] + iv = encryption_data.content_encryption_IV if end_range == blob_size - 1: unpad = True else: unpad = True - iv = encryption_data.content_encryption_IV #type: ignore [assignment] + iv = encryption_data.content_encryption_IV if blob_type == 'PageBlob': unpad = False + if iv is None: + raise ValueError("Missing required metadata for Encryption V1") + cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) decryptor = cipher.decryptor() @@ -947,7 +934,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements offset = 0 if encryption_data.encrypted_region_info is None: - raise ValueError("Field required for V2 encryption is missing (encrypted_region_info).") + raise ValueError("Missing required metadata for Encryption V2") nonce_length = encryption_data.encrypted_region_info.nonce_length data_length = encryption_data.encrypted_region_info.data_length @@ -977,10 +964,10 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements def get_blob_encryptor_and_padder( - cek: bytes, - iv: bytes, + cek: Optional[bytes], + iv: Optional[bytes], should_pad: bool -) -> Tuple["AEADEncryptionContext", "PaddingContext"]: +) -> Tuple[Optional["AEADEncryptionContext"], Optional["PaddingContext"]]: encryptor = None padder = None @@ -988,8 +975,9 @@ def get_blob_encryptor_and_padder( cipher = _generate_AES_CBC_cipher(cek, iv) encryptor = cipher.encryptor() padder = PKCS7(128).padder() if should_pad else None + return encryptor, padder - return encryptor, padder #type: ignore [return-value] + return encryptor, padder def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, version: str) -> str: @@ -998,7 +986,7 @@ def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, ve Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encrypted message and the encryption metadata. - :param object message: + :param str message: The plain text message to be encrypted. :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: @@ -1066,8 +1054,8 @@ def decrypt_queue_message( message: str, response: "PipelineResponse", require_encryption: bool, - key_encryption_key: KeyEncryptionKey, - resolver: Callable[[str], bytes] + key_encryption_key: Optional[KeyEncryptionKey], + resolver: Optional[Callable[[str], KeyEncryptionKey]] ) -> str: """ Returns the decrypted message contents from an EncryptedQueueMessage. @@ -1078,7 +1066,7 @@ def decrypt_queue_message( The pipeline response used to generate an error with. :param bool require_encryption: If set, will enforce that the retrieved messages are encrypted and decrypt them. - :param KeyEncryptionKey key_encryption_key: + :param Optional[KeyEncryptionKey] key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: wrap_key(key) - Wraps the specified key using an algorithm of the user's choice. @@ -1086,7 +1074,7 @@ def decrypt_queue_message( - Returns the algorithm used to wrap the specified symmetric key. get_kid() - Returns a string key id for this key-encryption-key. - :param Callable resolver: + :param Optional[Callable[[str], KeyEncryptionKey]] resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The plain text message from the queue message. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index ec7ca9df9ef8..8c5e2288df50 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -23,7 +23,7 @@ class MessageEncodePolicy(object): """Indicates the version of encryption being used.""" key_encryption_key: Optional[KeyEncryptionKey] = None """The user-provided key-encryption-key.""" - resolver: Optional[Callable[[str], bytes]] = None + resolver: Optional[Callable[[str], KeyEncryptionKey]] = None """The user-provided key resolver.""" def __init__(self) -> None: @@ -42,7 +42,7 @@ def __call__(self, content: Any) -> str: def configure( self, require_encryption: bool, key_encryption_key: Optional[KeyEncryptionKey], - resolver: Optional[Callable[[str], bytes]], + resolver: Optional[Callable[[str], KeyEncryptionKey]], encryption_version: str = _ENCRYPTION_PROTOCOL_V1 ) -> None: self.require_encryption = require_encryption @@ -58,11 +58,11 @@ def encode(self, content: Any) -> str: class MessageDecodePolicy(object): - require_encryption: Optional[bool] = None + require_encryption: bool = False """Indicates whether encryption is required or not.""" key_encryption_key: Optional[KeyEncryptionKey] = None """The user-provided key-encryption-key.""" - resolver: Optional[Callable[[str], bytes]] = None + resolver: Optional[Callable[[str], KeyEncryptionKey]] = None """The user-provided key resolver.""" def __init__(self) -> None: @@ -87,7 +87,7 @@ def __call__(self, response: "PipelineResponse", obj: Iterable, headers: Dict[st def configure( self, require_encryption: bool, key_encryption_key: Optional[KeyEncryptionKey], - resolver: Optional[Callable[[str], bytes]] + resolver: Optional[Callable[[str], KeyEncryptionKey]] ) -> None: self.require_encryption = require_encryption self.key_encryption_key = key_encryption_key diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 207cd1320c02..3498338118e7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -246,7 +246,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """ Custom Bearer token credential policy for following Storage Bearer challenges """ - def __init__(self, credential: AsyncTokenCredential, **kwargs: Any) -> None: + def __init__(self, credential: "AsyncTokenCredential", **kwargs: Any) -> None: super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, STORAGE_OAUTH_SCOPE, **kwargs) async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: From 0bb66d2f0ff4e1546681b26b42fc6d3a205e64ac Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 2 Oct 2023 13:09:07 -0700 Subject: [PATCH 43/71] Feedback, need base client decouple, need CI passing --- .../azure/storage/queue/_encryption.py | 4 +- .../azure/storage/queue/_message_encoding.py | 12 ++--- .../azure/storage/queue/_models.py | 14 +++--- .../azure/storage/queue/_queue_client.py | 13 +++-- .../storage/queue/_queue_client_helpers.py | 2 +- .../storage/queue/_queue_service_client.py | 8 ++-- .../queue/_queue_service_client_helpers.py | 2 +- .../storage/queue/_shared/base_client.py | 25 ++++------ .../queue/_shared/base_client_async.py | 47 ++++++++++++++++--- .../azure/storage/queue/_shared/models.py | 21 +++++++-- .../azure/storage/queue/_shared/policies.py | 34 ++++++-------- .../storage/queue/_shared/policies_async.py | 11 +++-- .../storage/queue/aio/_queue_client_async.py | 19 +++++--- .../queue/aio/_queue_service_client_async.py | 9 ++-- 14 files changed, 134 insertions(+), 87 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 98693f5f9921..a875ee5cae0e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -294,10 +294,8 @@ def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: :rtype: bool """ # If encryption_data is None, assume no encryption - if encryption_data and (encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2): - return True - return False + return bool(encryption_data and (encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2)) def modify_user_agent_for_encryption( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 8c5e2288df50..ce490e354b68 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -17,25 +17,25 @@ class MessageEncodePolicy(object): - require_encryption: Optional[bool] = None + require_encryption: bool """Indicates whether encryption is required or not.""" - encryption_version: Optional[str] = None + encryption_version: str """Indicates the version of encryption being used.""" - key_encryption_key: Optional[KeyEncryptionKey] = None + key_encryption_key: Optional[KeyEncryptionKey] """The user-provided key-encryption-key.""" - resolver: Optional[Callable[[str], KeyEncryptionKey]] = None + resolver: Optional[Callable[[str], KeyEncryptionKey]] """The user-provided key resolver.""" def __init__(self) -> None: self.require_encryption = False - self.encryption_version = None + self.encryption_version = _ENCRYPTION_PROTOCOL_V1 self.key_encryption_key = None self.resolver = None def __call__(self, content: Any) -> str: if content: content = self.encode(content) - if self.key_encryption_key is not None and isinstance(self.encryption_version, str): + if self.key_encryption_key is not None: content = encrypt_queue_message(content, self.key_encryption_key, self.encryption_version) return content diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 788d8737f822..1de88ea78129 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -312,12 +312,12 @@ class AccessPolicy(GenAccessPolicy): be UTC. """ - permission: Optional[QueueSasPermissions] = None # type: ignore [assignment] + permission: Optional[QueueSasPermissions] """The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions.""" - expiry: Optional[Union["datetime", str]] = None # type: ignore [assignment] + expiry: Optional[Union["datetime", str]] """The time at which the shared access signature becomes invalid.""" - start: Optional[Union["datetime", str]] = None # type: ignore [assignment] + start: Optional[Union["datetime", str]] """The time at which the shared access signature becomes valid.""" def __init__( @@ -345,7 +345,7 @@ class QueueMessage(DictMixin): dequeue_count: Optional[int] """Begins with a value of 1 the first time the message is received. This value is incremented each time the message is subsequently received.""" - content: Any + content: Optional[Any] """The message content. Type is determined by the decode_function set on the service. Default is str.""" pop_receipt: Optional[str] @@ -368,7 +368,7 @@ def __init__( inserted_on: Optional["datetime"] = None, expires_on: Optional["datetime"] = None, dequeue_count: Optional[int] = None, - content: Any = None, + content: Optional[Any] = None, pop_receipt: Optional[str] = None, next_visible_on: Optional["datetime"] = None ) -> None: @@ -380,7 +380,7 @@ def __init__( inserted_on: Optional["datetime"] = None, expires_on: Optional["datetime"] = None, dequeue_count: Optional[int] = None, - content: Any = None, + content: Optional[Any] = None, pop_receipt: Optional[str] = None, next_visible_on: Optional["datetime"] = None ) -> None: @@ -475,6 +475,8 @@ class QueueProperties(DictMixin): """The name of the queue.""" metadata: Optional[Dict[str, str]] """A dict containing name-value pairs associated with the queue as metadata.""" + approximate_message_count: int + """The approximate number of messages contained in the queue.""" def __init__(self, **kwargs: Any) -> None: self.metadata = kwargs.get('metadata') diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 23ba7eb50fd2..9061d6178156 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -57,6 +57,7 @@ class QueueClient(StorageAccountHostsMixin, StorageEncryptionMixin): - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -81,7 +82,7 @@ class QueueClient(StorageAccountHostsMixin, StorageEncryptionMixin): def __init__( self, account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: parsed_url, sas_token = _parse_url(account_url=account_url, queue_name=queue_name, credential=credential) @@ -102,7 +103,11 @@ def _format_url(self, hostname: str) -> str: :returns: The formatted endpoint URL according to the specified location mode hostname. :rtype: str """ - return _format_url_helper(queue_name=self.queue_name, hostname=hostname, scheme=self.scheme, query_str=self._query_str) # pylint: disable=line-too-long + return _format_url_helper( + queue_name=self.queue_name, + hostname=hostname, + scheme=self.scheme, + query_str=self._query_str) @classmethod def from_queue_url( @@ -133,7 +138,7 @@ def from_queue_url( def from_connection_string( cls, conn_str: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueClient from a Connection String. @@ -151,7 +156,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index fc60cd95bb35..54ef0bd347dc 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -32,7 +32,7 @@ def _parse_url( - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential, TokenCredential]] # pylint: disable=line-too-long :returns: The parsed URL and SAS token. :rtype: Tuple[ParseResult, Any] """ diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index c811e826ceef..af9dad1e4b7d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -62,6 +62,7 @@ class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -87,7 +88,7 @@ class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): def __init__( self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) @@ -110,7 +111,7 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_connection_string( cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueServiceClient from a Connection String. @@ -126,7 +127,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :returns: A Queue service client. :rtype: ~azure.storage.queue.QueueClient @@ -413,7 +414,6 @@ def get_queue_client( :caption: Get the queue client. """ if isinstance(queue, QueueProperties): - if queue.name is not None: queue_name = queue.name else: queue_name = queue diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py index 84e3c0526484..881d3d6929cc 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py @@ -30,7 +30,7 @@ def _parse_url( - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential, TokenCredential]] # pylint: disable=line-too-long :returns: The parsed URL and SAS token. :rtype: Tuple[ParseResult, Any] """ diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 36242700faa2..f52238d3a1b2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -14,12 +14,7 @@ TYPE_CHECKING, Union, ) - -try: - from urllib.parse import parse_qs, quote -except ImportError: - from urlparse import parse_qs # type: ignore [no-redef] - from urllib2 import quote # type: ignore [no-redef] +from urllib.parse import parse_qs, quote from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from azure.core.exceptions import HttpResponseError @@ -74,7 +69,7 @@ def __init__( self, parsed_url: Any, service: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) @@ -96,7 +91,7 @@ def __init__( raise ValueError("Token credential is only supported with HTTPS.") secondary_hostname = None - if hasattr(self.credential, "account_name") and self.credential is not None: + if hasattr(self.credential, "account_name"): self.account_name = self.credential.account_name secondary_hostname = f"{self.credential.account_name}-secondary.{service_name}.{SERVICE_HOST_BASE}" @@ -205,10 +200,10 @@ def api_version(self): def _format_query_string( self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]], # pylint: disable=line-too-long snapshot: Optional[str] = None, share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]]]: # pylint: disable=line-too-long + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]]]: # pylint: disable=line-too-long query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -225,12 +220,12 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None if hasattr(credential, "get_token"): - self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) # type: ignore [arg-type] # pylint: disable=line-too-long + self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) # type: ignore elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -361,7 +356,7 @@ def __exit__(self, *args): # pylint: disable=arguments-differ def _format_shared_key_credential( account_name: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long ) -> Any: if isinstance(credential, str): if not account_name: @@ -380,9 +375,9 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]], # pylint: disable=line-too-long service: str -) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]]]: # pylint: disable=line-too-long +) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]]]: # pylint: disable=line-too-long conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index d2dad1d30a94..ce3c96aaadc6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -5,11 +5,9 @@ # -------------------------------------------------------------------------- # mypy: disable-error-code="attr-defined" -from typing import ( - Any, Dict, Iterator, Optional, - Tuple, TYPE_CHECKING, Union -) import logging +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from urllib.parse import parse_qs from azure.core.async_paging import AsyncList from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential @@ -38,6 +36,7 @@ ) from .policies_async import AsyncStorageResponseHook from .response_handlers import PartialBatchErrorException, process_storage_error +from .shared_access_signature import QueryStringConstants if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -67,13 +66,37 @@ async def close(self): """ await self._client.close() + def _format_query_string( + self, sas_token: Optional[str], + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long + snapshot: Optional[str] = None, + share_snapshot: Optional[str] = None + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]]]: # pylint: disable=line-too-long + query_str = "?" + if snapshot: + query_str += f"snapshot={snapshot}&" + if share_snapshot: + query_str += f"sharesnapshot={share_snapshot}&" + if sas_token and isinstance(credential, AzureSasCredential): + raise ValueError( + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + if is_credential_sastoken(credential): + query_str += credential.lstrip("?") # type: ignore [union-attr] + credential = None + elif sas_token: + query_str += sas_token + return query_str.rstrip("?&"), credential + def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, AsyncPipeline]: - self._credential_policy: Optional[Union[AsyncBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy]] = None # pylint: disable=line-too-long + self._credential_policy: Optional[ + Union[AsyncBearerTokenCredentialPolicy, + SharedKeyCredentialPolicy, + AzureSasCredentialPolicy]] = None if hasattr(credential, 'get_token'): - self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) #type: ignore [arg-type] # pylint: disable=line-too-long + self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE) #type: ignore elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -177,6 +200,16 @@ async def _batch_send( except HttpResponseError as error: process_storage_error(error) +def is_credential_sastoken(credential: Any) -> bool: + if not credential or not isinstance(credential, str): + return False + + sas_values = QueryStringConstants.to_list() + parsed_query = parse_qs(credential.lstrip("?")) + if parsed_query and all(k in sas_values for k in parsed_query.keys()): + return True + return False + class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 368971b2eb86..e9be33980b8f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -277,7 +277,12 @@ class ResourceTypes(object): object: bool = False _str: str - def __init__(self, service=False, container=False, object=False): # pylint: disable=redefined-builtin + def __init__( + self, + service: bool = False, + container: bool = False, + object: bool = False # pylint: disable=redefined-builtin + ) -> None: self.service = service self.container = container self.object = object @@ -368,10 +373,16 @@ class AccountSasPermissions(object): permanent_delete: bool = False def __init__( - self, read: bool = False, write: bool = False, - delete: bool = False, list: bool = False, # pylint: disable=redefined-builtin - add:bool = False, create:bool = False, update:bool =False, - process:bool = False, delete_previous_version:bool = False, + self, + read: bool = False, + write: bool = False, + delete: bool = False, + list: bool = False, # pylint: disable=redefined-builtin + add: bool = False, + create: bool = False, + update: bool = False, + process: bool = False, + delete_previous_version: bool = False, **kwargs ) -> None: self.read = read diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 68daa530e09a..6695daca3973 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -13,21 +13,13 @@ from io import SEEK_SET, UnsupportedOperation from time import time from typing import Any, Dict, Optional, TYPE_CHECKING, Union -from wsgiref.handlers import format_date_time -try: - from urllib.parse import ( +from urllib.parse import ( parse_qsl, urlencode, urlparse, urlunparse, - ) -except ImportError: - from urllib import urlencode # type: ignore - from urlparse import ( # type: ignore - parse_qsl, - urlparse, - urlunparse - ) +) +from wsgiref.handlers import format_date_time from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import ( @@ -196,9 +188,6 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): This accepts both global configuration, and per-request level with "enable_http_logger" """ - logging_enable: bool = False - """Whether logging should be enabled or not.""" - def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) @@ -478,8 +467,10 @@ def sleep(self, settings, transport): transport.sleep(backoff) def increment( - self, settings: Dict[str, Any], - request: "PipelineRequest", response: Optional["PipelineResponse"] = None, + self, + settings: Dict[str, Any], + request: "PipelineRequest", + response: Optional["PipelineResponse"] = None, error: Optional[Union[ServiceRequestError, ServiceResponseError]] = None ) -> bool: """Increment the retry counters. @@ -510,7 +501,7 @@ def increment( else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist - if response is not None: + if response: settings['status'] -= 1 settings['history'].append(RequestHistory(request, http_response=response)) @@ -576,9 +567,12 @@ class ExponentialRetry(StorageRetryPolicy): """Exponential retry.""" def __init__( - self, initial_backoff: int = 15, - increment_base: int = 3, retry_total: int = 3, - retry_to_secondary: bool = False, random_jitter_range: int = 3, **kwargs + self, + initial_backoff: int = 15, + increment_base: int = 3, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, **kwargs ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 3498338118e7..f7d21396e663 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -85,7 +85,7 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": pipeline_obj.context['upload_stream_current'] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): - await response_callback(response) #type: ignore + await response_callback(response) # type: ignore else: response_callback(response) request.context['response_callback'] = response_callback @@ -145,9 +145,12 @@ class ExponentialRetry(AsyncStorageRetryPolicy): """Exponential retry.""" def __init__( - self, initial_backoff: int = 15, - increment_base: int = 3, retry_total: int = 3, - retry_to_secondary: bool = False, random_jitter_range: int = 3, **kwargs + self, + initial_backoff: int = 15, + increment_base: int = 3, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, **kwargs ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 1990640e9ef4..b2842358cc5f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -37,7 +37,7 @@ ) if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential from azure.core.credentials_async import AsyncTokenCredential from .._models import QueueProperties @@ -59,6 +59,7 @@ class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, Stora - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -91,7 +92,7 @@ class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, Stora def __init__( self, account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) @@ -116,7 +117,11 @@ def _format_url(self, hostname: str) -> str: :returns: The formatted endpoint URL according to the specified location mode hostname. :rtype: str """ - return _format_url_helper(queue_name=self.queue_name, hostname=hostname, scheme=self.scheme, query_str=self._query_str) # pylint: disable=line-too-long + return _format_url_helper( + queue_name=self.queue_name, + hostname=hostname, + scheme=self.scheme, + query_str=self._query_str) @classmethod def from_queue_url( @@ -136,7 +141,7 @@ def from_queue_url( - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ @@ -147,7 +152,7 @@ def from_queue_url( def from_connection_string( cls, conn_str: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueClient from a Connection String. @@ -165,7 +170,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient @@ -229,7 +234,7 @@ async def create_queue( process_storage_error(error) @distributed_trace_async - async def delete_queue(self, **kwargs: Any) -> None: # type: ignore[override] + async def delete_queue(self, **kwargs: Any) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 6a66f4e982ab..5624b048c1d8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -37,7 +37,7 @@ from .._models import CorsRule, Metrics, QueueAnalyticsLogging -class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long +class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # pylint: disable=line-too-long """A client to interact with the Queue Service at the account level. This client provides operations to retrieve and configure the account properties @@ -58,6 +58,7 @@ class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -83,7 +84,7 @@ class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin def __init__( self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) @@ -109,7 +110,7 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_connection_string( cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Self: """Create QueueServiceClient from a Connection String. @@ -125,7 +126,7 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :paramtype credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :returns: A Queue service client. :rtype: ~azure.storage.queue.QueueClient From d100264fa8bf9463e8761adcffad6ca08cc69619 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Tue, 3 Oct 2023 17:37:36 -0700 Subject: [PATCH 44/71] Fixing CI --- .../azure/storage/queue/_encryption.py | 8 ++-- .../azure/storage/queue/_models.py | 48 +++++-------------- .../storage/queue/_queue_service_client.py | 2 +- .../storage/queue/_shared/base_client.py | 2 +- .../queue/_shared/base_client_async.py | 2 +- 5 files changed, 18 insertions(+), 44 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index a875ee5cae0e..26a200dc374d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -505,8 +505,8 @@ def _generate_encryption_data_dict( encryption_agent['EncryptionAlgorithm'] = _EncryptionAlgorithm.AES_GCM_256 encrypted_region_info = OrderedDict() - encrypted_region_info['DataLength'] = str(_GCM_REGION_DATA_LENGTH) - encrypted_region_info['NonceLength'] = str(_GCM_NONCE_LENGTH) + encrypted_region_info['DataLength'] = _GCM_REGION_DATA_LENGTH + encrypted_region_info['NonceLength'] = _GCM_NONCE_LENGTH encryption_data_dict = OrderedDict() encryption_data_dict['WrappedContentKey'] = wrapped_content_key @@ -514,7 +514,7 @@ def _generate_encryption_data_dict( if version == _ENCRYPTION_PROTOCOL_V1: encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: - encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info + encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info # type: ignore[assignment] encryption_data_dict['KeyWrappingMetadata'] = OrderedDict({'EncryptionLibrary': 'Python ' + VERSION}) return encryption_data_dict @@ -711,7 +711,7 @@ def _decrypt_message( if encryption_data.encrypted_region_info is None: raise ValueError("Missing required metadata for Encryption V2") - nonce_length = encryption_data.encrypted_region_info.nonce_length + nonce_length = int(encryption_data.encrypted_region_info.nonce_length) # First bytes are the nonce nonce = message[:nonce_length] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 1de88ea78129..f3cb43817daa 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -312,12 +312,12 @@ class AccessPolicy(GenAccessPolicy): be UTC. """ - permission: Optional[QueueSasPermissions] + permission: Optional[QueueSasPermissions] #type: ignore [assignment] """The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions.""" - expiry: Optional[Union["datetime", str]] + expiry: Optional[Union["datetime", str]] #type: ignore [assignment] """The time at which the shared access signature becomes invalid.""" - start: Optional[Union["datetime", str]] + start: Optional[Union["datetime", str]] #type: ignore [assignment] """The time at which the shared access signature becomes valid.""" def __init__( @@ -357,40 +357,14 @@ class QueueMessage(DictMixin): """A UTC date value representing the time the message will next be visible. Only returned by receive messages operations. Set to None for peek messages.""" - @overload - def __init__(self, *, id: str, content: Any = None) -> None: - ... - - @overload - def __init__( - self, *, - id: Optional[str] = None, - inserted_on: Optional["datetime"] = None, - expires_on: Optional["datetime"] = None, - dequeue_count: Optional[int] = None, - content: Optional[Any] = None, - pop_receipt: Optional[str] = None, - next_visible_on: Optional["datetime"] = None - ) -> None: - ... - - def __init__( - self, *, - id: Optional[str] = None, - inserted_on: Optional["datetime"] = None, - expires_on: Optional["datetime"] = None, - dequeue_count: Optional[int] = None, - content: Optional[Any] = None, - pop_receipt: Optional[str] = None, - next_visible_on: Optional["datetime"] = None - ) -> None: - self.id = id # type: ignore [assignment] - self.inserted_on = inserted_on - self.expires_on = expires_on - self.dequeue_count = dequeue_count + def __init__(self, content=None): + self.id = None + self.inserted_on = None + self.expires_on = None + self.dequeue_count = None self.content = content - self.pop_receipt = pop_receipt - self.next_visible_on = next_visible_on + self.pop_receipt = None + self.next_visible_on = None @classmethod def _from_generated(cls, generated: Any) -> Self: @@ -480,7 +454,7 @@ class QueueProperties(DictMixin): def __init__(self, **kwargs: Any) -> None: self.metadata = kwargs.get('metadata') - self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') + self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') #type: ignore [assignment] @classmethod def _from_generated(cls, generated: Any) -> Self: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index af9dad1e4b7d..22e4ae07d54a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -414,7 +414,7 @@ def get_queue_client( :caption: Get the queue client. """ if isinstance(queue, QueueProperties): - queue_name = queue.name + queue_name = queue.name else: queue_name = queue diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index f52238d3a1b2..77199367434b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -479,6 +479,6 @@ def is_credential_sastoken(credential: Any) -> bool: sas_values = QueryStringConstants.to_list() parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query.keys()): + if parsed_query and all(k in sas_values for k in parsed_query): return True return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index ce3c96aaadc6..ec6f3362216f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -206,7 +206,7 @@ def is_credential_sastoken(credential: Any) -> bool: sas_values = QueryStringConstants.to_list() parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query.keys()): + if parsed_query and all(k in sas_values for k in parsed_query): return True return False From 0d5d67e00b54b06fe9a544ee4ebb6e4cb51c43f8 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Wed, 4 Oct 2023 12:53:03 -0700 Subject: [PATCH 45/71] CI --- sdk/storage/azure-storage-queue/azure/storage/queue/_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index f3cb43817daa..919507b6e474 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -7,7 +7,7 @@ # pylint: disable=super-init-not-called import sys -from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator from ._shared.response_handlers import process_storage_error, return_context_and_deserialized From 0efb8bcf0fd41195f21060dcfee4fa3de0d79f54 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Thu, 5 Oct 2023 18:19:51 -0700 Subject: [PATCH 46/71] CI --- .../azure/storage/queue/_encryption.py | 6 +-- .../azure/storage/queue/_models.py | 42 ++++++++++++++++++- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 26a200dc374d..012589168253 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -15,7 +15,7 @@ dumps, loads, ) -from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend @@ -508,13 +508,13 @@ def _generate_encryption_data_dict( encrypted_region_info['DataLength'] = _GCM_REGION_DATA_LENGTH encrypted_region_info['NonceLength'] = _GCM_NONCE_LENGTH - encryption_data_dict = OrderedDict() + encryption_data_dict: OrderedDict[str, Any] = OrderedDict() encryption_data_dict['WrappedContentKey'] = wrapped_content_key encryption_data_dict['EncryptionAgent'] = encryption_agent if version == _ENCRYPTION_PROTOCOL_V1: encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: - encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info # type: ignore[assignment] + encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info encryption_data_dict['KeyWrappingMetadata'] = OrderedDict({'EncryptionLibrary': 'Python ' + VERSION}) return encryption_data_dict diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 919507b6e474..17c7cf1cf6c8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -7,7 +7,7 @@ # pylint: disable=super-init-not-called import sys -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TYPE_CHECKING, Union from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator from ._shared.response_handlers import process_storage_error, return_context_and_deserialized @@ -357,6 +357,46 @@ class QueueMessage(DictMixin): """A UTC date value representing the time the message will next be visible. Only returned by receive messages operations. Set to None for peek messages.""" + # @overload + # def __init__( + # self, + # *, + # id: str, + # inserted_on: Optional["datetime"] = None, + # expires_on: Optional["datetime"] = None, + # dequeue_count: Optional[int] = None, + # content: Optional[Any] = None, + # pop_receipt: Optional[str] = None, + # next_visible_on: Optional["datetime"] = None + # ) -> None: + # ... + + # @overload + # def __init__( + # self, + # content: Optional[Any] = None, + # *, + # id: Optional[str], + # inserted_on: Optional["datetime"] = None, + # expires_on: Optional["datetime"] = None, + # dequeue_count: Optional[int] = None, + # pop_receipt: Optional[str] = None, + # next_visible_on: Optional["datetime"] = None + # ) -> None: + # ... + + # def __init__(self, *args: Any, **kwargs: Any) -> None: + # self.id = kwargs.pop('id', None) + # self.inserted_on = kwargs.pop('inserted_on', None) + # self.expires_on = kwargs.pop('expires_on', None) + # self.dequeue_count = kwargs.pop('dequeue_count', None) + # if args: + # self.content = args[0] + # else: + # self.content = kwargs.pop('content', None) + # self.pop_receipt = kwargs.pop('pop_receipt', None) + # self.next_visible_on = kwargs.pop('next_visible_on', None) + def __init__(self, content=None): self.id = None self.inserted_on = None From f95247705e42a8679aea6650416b8979ef841b6e Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 6 Oct 2023 16:51:14 -0700 Subject: [PATCH 47/71] Ignores :( --- .../azure/storage/queue/_queue_client.py | 2 +- .../azure/storage/queue/_queue_service_client.py | 2 +- .../azure/storage/queue/_shared/base_client.py | 2 +- .../azure/storage/queue/aio/_queue_client_async.py | 2 +- .../azure/storage/queue/aio/_queue_service_client_async.py | 6 +++--- .../azure-storage-queue/samples/network_activity_logging.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 9061d6178156..703fecf1042a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -169,7 +169,7 @@ def from_connection_string( :dedent: 8 :caption: Create the queue client from connection string. """ - account_url, secondary, credential = parse_connection_str( + account_url, secondary, credential = parse_connection_str( #type: ignore conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 22e4ae07d54a..79351bd59f0d 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -140,7 +140,7 @@ def from_connection_string( :dedent: 8 :caption: Creating the QueueServiceClient with a connection string. """ - account_url, secondary, credential = parse_connection_str( + account_url, secondary, credential = parse_connection_str( #type: ignore conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 77199367434b..989f62ee3bc0 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -220,7 +220,7 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index b2842358cc5f..ea264bc82401 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -183,7 +183,7 @@ def from_connection_string( :dedent: 8 :caption: Create the queue client from connection string. """ - account_url, secondary, credential = parse_connection_str( + account_url, secondary, credential = parse_connection_str( #type: ignore conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 5624b048c1d8..0baa18a1d3b0 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -32,12 +32,12 @@ from .._shared.response_handlers import process_storage_error if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential from azure.core.credentials_async import AsyncTokenCredential from .._models import CorsRule, Metrics, QueueAnalyticsLogging -class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # pylint: disable=line-too-long +class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long """A client to interact with the Queue Service at the account level. This client provides operations to retrieve and configure the account properties @@ -139,7 +139,7 @@ def from_connection_string( :dedent: 8 :caption: Creating the QueueServiceClient with a connection string. """ - account_url, secondary, credential = parse_connection_str( + account_url, secondary, credential = parse_connection_str( #type: ignore conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary diff --git a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py index a9f44c941f2e..2370cea048e5 100644 --- a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py +++ b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py @@ -63,6 +63,6 @@ messages = queue_client.peek_messages(max_messages=20, logging_enable=True) for message in messages: try: - print(' Message: {!r}'.format(base64.b64decode(message.content))) + print(' Message: {!r}'.format(base64.b64decode(message.content))) #type: ignore [arg-type] except binascii.Error: print(' Message: {}'.format(message.content)) From 22267f7d1166ed46e9bdc17fb30a46421d76eb3c Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 6 Oct 2023 17:19:22 -0700 Subject: [PATCH 48/71] More fixes --- .../azure/storage/queue/_queue_client.py | 2 +- .../storage/queue/_queue_service_client.py | 2 +- .../storage/queue/_shared/base_client.py | 6 +- .../queue/_shared/base_client_async.py | 60 ++++++++++++++++++- .../storage/queue/aio/_queue_client_async.py | 6 +- .../queue/aio/_queue_service_client_async.py | 6 +- 6 files changed, 70 insertions(+), 12 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 703fecf1042a..9061d6178156 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -169,7 +169,7 @@ def from_connection_string( :dedent: 8 :caption: Create the queue client from connection string. """ - account_url, secondary, credential = parse_connection_str( #type: ignore + account_url, secondary, credential = parse_connection_str( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 79351bd59f0d..22e4ae07d54a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -140,7 +140,7 @@ def from_connection_string( :dedent: 8 :caption: Creating the QueueServiceClient with a connection string. """ - account_url, secondary, credential = parse_connection_str( #type: ignore + account_url, secondary, credential = parse_connection_str( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 989f62ee3bc0..641e44a39319 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -220,7 +220,7 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None @@ -375,9 +375,9 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]], # pylint: disable=line-too-long service: str -) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]]]: # pylint: disable=line-too-long +) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]]]: # pylint: disable=line-too-long conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index ec6f3362216f..170ca5b1b407 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -25,7 +25,7 @@ from .authentication import SharedKeyCredentialPolicy from .base_client import create_configuration -from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT, STORAGE_OAUTH_SCOPE +from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE from .models import StorageConfiguration from .policies import ( QueueMessagePolicy, @@ -44,6 +44,13 @@ from azure.core.pipeline.transport import HttpRequest, HttpResponse _LOGGER = logging.getLogger(__name__) +_SERVICE_PARAMS = { + "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"}, + "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"}, + "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"}, + "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"}, +} + class AsyncStorageAccountHostsMixin(object): @@ -200,6 +207,57 @@ async def _batch_send( except HttpResponseError as error: process_storage_error(error) +def parse_connection_str( + conn_str: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]], # pylint: disable=line-too-long + service: str +) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]]]: # pylint: disable=line-too-long + conn_str = conn_str.rstrip(";") + conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings_list): + raise ValueError("Connection string is either blank or malformed.") + conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) + endpoints = _SERVICE_PARAMS[service] + primary = None + secondary = None + if not credential: + try: + credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]} + except KeyError: + credential = conn_settings.get("SHAREDACCESSSIGNATURE") + if endpoints["primary"] in conn_settings: + primary = conn_settings[endpoints["primary"]] + if endpoints["secondary"] in conn_settings: + secondary = conn_settings[endpoints["secondary"]] + else: + if endpoints["secondary"] in conn_settings: + raise ValueError("Connection string specifies only secondary endpoint.") + try: + primary =( + f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" + f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + secondary = ( + f"{conn_settings['ACCOUNTNAME']}-secondary." + f"{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + except KeyError: + pass + + if not primary: + try: + primary = ( + f"https://{conn_settings['ACCOUNTNAME']}." + f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}" + ) + except KeyError as exc: + raise ValueError("Connection string missing required connection details.") from exc + if service == "dfs": + primary = primary.replace(".blob.", ".dfs.") + if secondary: + secondary = secondary.replace(".blob.", ".dfs.") + return primary, secondary, credential + def is_credential_sastoken(credential: Any) -> bool: if not credential or not isinstance(credential, str): return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index ea264bc82401..1c36018957a8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -26,8 +26,8 @@ from .._models import AccessPolicy, QueueMessage from .._queue_client_helpers import _format_url_helper, _from_queue_url_helper, _parse_url from .._serialize import get_api_version -from .._shared.base_client import parse_connection_str, StorageAccountHostsMixin -from .._shared.base_client_async import AsyncStorageAccountHostsMixin +from .._shared.base_client import StorageAccountHostsMixin +from .._shared.base_client_async import AsyncStorageAccountHostsMixin, parse_connection_str from .._shared.policies_async import ExponentialRetry from .._shared.request_handlers import add_metadata_headers, serialize_iso from .._shared.response_handlers import ( @@ -183,7 +183,7 @@ def from_connection_string( :dedent: 8 :caption: Create the queue client from connection string. """ - account_url, secondary, credential = parse_connection_str( #type: ignore + account_url, secondary, credential = parse_connection_str( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 0baa18a1d3b0..961db2800c95 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -25,8 +25,8 @@ from .._models import QueueProperties, service_properties_deserialize, service_stats_deserialize from .._queue_service_client_helpers import _parse_url from .._serialize import get_api_version -from .._shared.base_client import parse_connection_str, StorageAccountHostsMixin -from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper +from .._shared.base_client import StorageAccountHostsMixin +from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper, parse_connection_str from .._shared.models import LocationMode from .._shared.policies_async import ExponentialRetry from .._shared.response_handlers import process_storage_error @@ -139,7 +139,7 @@ def from_connection_string( :dedent: 8 :caption: Creating the QueueServiceClient with a connection string. """ - account_url, secondary, credential = parse_connection_str( #type: ignore + account_url, secondary, credential = parse_connection_str( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary From 20713590ac6576809972fec3b02b740688f165c0 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 9 Oct 2023 14:01:27 -0700 Subject: [PATCH 49/71] CI --- .../azure-storage-queue/azure/storage/queue/_encryption.py | 2 +- sdk/storage/azure-storage-queue/azure/storage/queue/_models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 012589168253..ecb3114fb4fe 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -15,7 +15,7 @@ dumps, loads, ) -from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 17c7cf1cf6c8..bae02e9701a3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -7,7 +7,7 @@ # pylint: disable=super-init-not-called import sys -from typing import Any, Callable, Dict, List, Optional, overload, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator from ._shared.response_handlers import process_storage_error, return_context_and_deserialized From 1844459afa3967c9e7edaee92e9393a18979f68b Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 9 Oct 2023 16:33:12 -0700 Subject: [PATCH 50/71] Pylint again --- .../azure-storage-queue/azure/storage/queue/_encryption.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index ecb3114fb4fe..a997ec430d60 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -468,7 +468,7 @@ def _generate_encryption_data_dict( cek: bytes, iv: Optional[bytes], version: str - ) -> Dict[str, Any]: + ) -> "OrderedDict[str, Any]": """ Generates and returns the encryption metadata as a dict. @@ -477,7 +477,7 @@ def _generate_encryption_data_dict( :param Optional[bytes] iv: The initialization vector. Only required for AES-CBC. :param str version: The client encryption version used. :return: A dict containing all the encryption metadata. - :rtype: dict + :rtype: OrderedDict[str, Any] """ # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: From 8768e4df1f4cf7be2f576d896c3324bf29d5481f Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Tue, 10 Oct 2023 11:09:11 -0700 Subject: [PATCH 51/71] Adjust typehint --- .../azure-storage-queue/azure/storage/queue/_encryption.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index a997ec430d60..e03f39bd64df 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -468,7 +468,7 @@ def _generate_encryption_data_dict( cek: bytes, iv: Optional[bytes], version: str - ) -> "OrderedDict[str, Any]": + ) -> OrderedDict: """ Generates and returns the encryption metadata as a dict. @@ -477,7 +477,7 @@ def _generate_encryption_data_dict( :param Optional[bytes] iv: The initialization vector. Only required for AES-CBC. :param str version: The client encryption version used. :return: A dict containing all the encryption metadata. - :rtype: OrderedDict[str, Any] + :rtype: OrderedDict """ # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: From 080bea5b85b3604ca94f29987fc639c8ae2d3e36 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Tue, 10 Oct 2023 12:58:30 -0700 Subject: [PATCH 52/71] Try SO workaround 41207128 --- .../azure-storage-queue/azure/storage/queue/_encryption.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index e03f39bd64df..905ddcc057f4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -16,6 +16,7 @@ loads, ) from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING +from typing import OrderedDict as TypedOrderedDict from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend @@ -468,7 +469,7 @@ def _generate_encryption_data_dict( cek: bytes, iv: Optional[bytes], version: str - ) -> OrderedDict: + ) -> TypedOrderedDict[str, Any]: """ Generates and returns the encryption metadata as a dict. @@ -477,7 +478,7 @@ def _generate_encryption_data_dict( :param Optional[bytes] iv: The initialization vector. Only required for AES-CBC. :param str version: The client encryption version used. :return: A dict containing all the encryption metadata. - :rtype: OrderedDict + :rtype: Dict[str, Any] """ # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: @@ -508,7 +509,7 @@ def _generate_encryption_data_dict( encrypted_region_info['DataLength'] = _GCM_REGION_DATA_LENGTH encrypted_region_info['NonceLength'] = _GCM_NONCE_LENGTH - encryption_data_dict: OrderedDict[str, Any] = OrderedDict() + encryption_data_dict: TypedOrderedDict[str, Any] = OrderedDict() encryption_data_dict['WrappedContentKey'] = wrapped_content_key encryption_data_dict['EncryptionAgent'] = encryption_agent if version == _ENCRYPTION_PROTOCOL_V1: From 8fa308be1537210a3957e225974f1c922c185de4 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Wed, 11 Oct 2023 16:50:15 -0700 Subject: [PATCH 53/71] Reimport Pylint --- .../azure/storage/queue/_shared/base_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index b4e5154ae8b3..2e55fd160dc9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -31,7 +31,6 @@ UserAgentPolicy, ) -from .models import LocationMode from .authentication import SharedKeyCredentialPolicy from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE from .models import LocationMode, StorageConfiguration From 509e538210d090115e274d44cafe5ce89c5d8229 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Thu, 12 Oct 2023 18:28:50 -0700 Subject: [PATCH 54/71] First round of feedback --- .../azure/storage/queue/_encryption.py | 11 ++--- .../azure/storage/queue/_models.py | 20 ++++++-- .../azure/storage/queue/_queue_client.py | 3 +- .../storage/queue/_queue_service_client.py | 16 +++---- .../storage/queue/_shared/base_client.py | 46 +++++-------------- .../queue/_shared/base_client_async.py | 18 ++++---- .../azure/storage/queue/_shared/policies.py | 34 ++++++++++---- .../storage/queue/_shared/policies_async.py | 18 +++++++- .../storage/queue/aio/_queue_client_async.py | 22 ++++----- .../queue/aio/_queue_service_client_async.py | 13 +++--- .../samples/network_activity_logging.py | 2 +- 11 files changed, 108 insertions(+), 95 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 766a93202fab..5f64d2549132 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -74,8 +74,7 @@ def _validate_key_encryption_key_wrap(kek: KeyEncryptionKey): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) if not hasattr(kek, 'get_kid') or not callable(kek.get_kid): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) - if not (hasattr(kek, 'get_key_wrap_algorithm') or - not callable(kek.get_key_wrap_algorithm)): + if not hasattr(kek, 'get_key_wrap_algorithm') or not callable(kek.get_key_wrap_algorithm): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_key_wrap_algorithm')) @@ -286,7 +285,7 @@ def _encrypt_region(self, data: bytes) -> bytes: return nonce + ciphertext_with_tag -def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: +def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: """ Determine whether the given encryption data signifies version 2.0. @@ -649,8 +648,9 @@ def _validate_and_unwrap_cek( if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') # Will throw an exception if the specified algorithm is not supported. - content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, - encryption_data.wrapped_content_key.algorithm) + content_encryption_key = key_encryption_key.unwrap_key( + encryption_data.wrapped_content_key.encrypted_key, + encryption_data.wrapped_content_key.algorithm) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. @@ -984,7 +984,6 @@ def get_blob_encryptor_and_padder( cipher = _generate_AES_CBC_cipher(cek, iv) encryptor = cipher.encryptor() padder = PKCS7(128).padder() if should_pad else None - return encryptor, padder return encryptor, padder diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index bae02e9701a3..a8b871fee7a2 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -194,6 +194,19 @@ def __init__(self, allowed_origins: List[str], allowed_methods: List[str], **kwa self.exposed_headers = ','.join(kwargs.get('exposed_headers', [])) self.max_age_in_seconds = kwargs.get('max_age_in_seconds', 0) + @staticmethod + def to_generated(rules: Optional[List["CorsRule"]]) -> List[GeneratedCorsRule]: + if rules is not None: + cors_rules = rules[0] + generated_cors = [GeneratedCorsRule( + allowed_origins=cors_rules.allowed_origins, + allowed_methods=cors_rules.allowed_methods, + allowed_headers=cors_rules.allowed_headers, + exposed_headers=cors_rules.exposed_headers, + max_age_in_seconds=cors_rules.max_age_in_seconds + )] + return generated_cors + @classmethod def _from_generated(cls, generated: Any) -> Self: return cls( @@ -489,12 +502,12 @@ class QueueProperties(DictMixin): """The name of the queue.""" metadata: Optional[Dict[str, str]] """A dict containing name-value pairs associated with the queue as metadata.""" - approximate_message_count: int + approximate_message_count: Optional[int] """The approximate number of messages contained in the queue.""" def __init__(self, **kwargs: Any) -> None: self.metadata = kwargs.get('metadata') - self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') #type: ignore [assignment] + self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') @classmethod def _from_generated(cls, generated: Any) -> Self: @@ -567,8 +580,7 @@ def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[Qu self.marker = self._response.marker self.results_per_page = self._response.max_results props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access - next_marker = self._response.next_marker - return next_marker or None, props_list + return self._response.next_marker or None, props_list def service_stats_deserialize(generated: Any) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 3ca5896feda7..d8422edc30c9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -33,7 +33,6 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential - from azure.core.credentials_async import AsyncTokenCredential from ._models import QueueProperties @@ -183,7 +182,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) #type: ignore [arg-type] + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @distributed_trace def create_queue( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 3c5223efc8fd..a80cc5caa23a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -19,6 +19,7 @@ from ._generated import AzureQueueStorage from ._generated.models import StorageServiceProperties from ._models import ( + CorsRule, QueueProperties, QueuePropertiesPaged, service_properties_deserialize, @@ -33,8 +34,7 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential - from azure.core.credentials_async import AsyncTokenCredential - from ._models import CorsRule, Metrics, QueueAnalyticsLogging + from ._models import Metrics, QueueAnalyticsLogging class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): @@ -150,7 +150,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, credential=credential, **kwargs) #type: ignore [arg-type] + return cls(account_url, credential=credential, **kwargs) @distributed_trace def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: @@ -217,7 +217,7 @@ def set_service_properties( self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, hour_metrics: Optional["Metrics"] = None, minute_metrics: Optional["Metrics"] = None, - cors: Optional[List["CorsRule"]] = None, + cors: Optional[List[CorsRule]] = None, **kwargs: Any ) -> None: """Sets the properties of a storage account's Queue service, including @@ -259,7 +259,7 @@ def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=cors # type: ignore [arg-type] + cors=CorsRule.to_generated(cors) ) try: self._client.service.set_properties(props, timeout=timeout, **kwargs) @@ -358,8 +358,7 @@ def create_queue( @distributed_trace def delete_queue( - self, - queue: Union["QueueProperties", str], + self, queue: Union["QueueProperties", str], **kwargs: Any ) -> None: """Deletes the specified queue and any messages it contains. @@ -395,8 +394,7 @@ def delete_queue( queue_client.delete_queue(timeout=timeout, **kwargs) def get_queue_client( - self, - queue: Union["QueueProperties", str], + self, queue: Union["QueueProperties", str], **kwargs: Any ) -> QueueClient: """Get a client to interact with the specified queue. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 2e55fd160dc9..aa0acb0a9fda 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -7,6 +7,7 @@ import uuid from typing import ( Any, + cast, Dict, Iterator, Optional, @@ -16,7 +17,7 @@ ) from urllib.parse import parse_qs, quote -from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module @@ -50,7 +51,6 @@ from .._version import VERSION if TYPE_CHECKING: - from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline.transport import HttpRequest, HttpResponse @@ -69,7 +69,7 @@ def __init__( self, parsed_url: Any, service: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) @@ -200,10 +200,10 @@ def api_version(self): def _format_query_string( self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long snapshot: Optional[str] = None, share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]]]: # pylint: disable=line-too-long + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -213,14 +213,15 @@ def _format_query_string( raise ValueError( "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") if is_credential_sastoken(credential): - query_str += credential.lstrip("?") # type: ignore [union-attr] + credential = cast(str, credential) + query_str += credential.lstrip("?") credential = None elif sas_token: query_str += sas_token return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None @@ -229,7 +230,7 @@ def _create_pipeline( audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE - self._credential_policy = BearerTokenCredentialPolicy(credential, audience) + self._credential_policy = BearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -360,7 +361,7 @@ def __exit__(self, *args): # pylint: disable=arguments-differ def _format_shared_key_credential( account_name: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long ) -> Any: if isinstance(credential, str): if not account_name: @@ -379,16 +380,15 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], # pylint: disable=line-too-long service: str -) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]]]: # pylint: disable=line-too-long +) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): raise ValueError("Connection string is either blank or malformed.") conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) endpoints = _SERVICE_PARAMS[service] - primary = None secondary = None if not credential: try: @@ -440,28 +440,6 @@ def create_configuration(**kwargs: Any) -> StorageConfiguration: config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) config.logging_policy = StorageLoggingPolicy(**kwargs) config.proxy_policy = ProxyPolicy(**kwargs) - - # Storage settings - config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024) - config.copy_polling_interval = 15 - - # Block blob uploads - config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024) - config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) - config.use_byte_buffer = kwargs.get("use_byte_buffer", False) - - # Page blob uploads - config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024) - - # Datalake file uploads - config.min_large_chunk_upload_threshold = kwargs.get("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) - - # Blob downloads - config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024) - config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024) - - # File uploads - config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024) return config diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 81db2be32166..737d9691b928 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -6,11 +6,12 @@ # mypy: disable-error-code="attr-defined" import logging -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union from urllib.parse import parse_qs from azure.core.async_paging import AsyncList from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential +from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import ( @@ -39,8 +40,6 @@ from .shared_access_signature import QueryStringConstants if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline.transport import HttpRequest, HttpResponse _LOGGER = logging.getLogger(__name__) @@ -75,10 +74,10 @@ async def close(self): def _format_query_string( self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long snapshot: Optional[str] = None, share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]]]: # pylint: disable=line-too-long + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -95,7 +94,7 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long **kwargs: Any ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[ @@ -107,7 +106,7 @@ def _create_pipeline( audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE - self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, audience) + self._credential_policy = AsyncBearerTokenCredentialPolicy(cast(AsyncTokenCredential, credential), audience) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -213,16 +212,15 @@ async def _batch_send( def parse_connection_str( conn_str: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]], # pylint: disable=line-too-long + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], # pylint: disable=line-too-long service: str -) -> Tuple[Optional[str], Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential"]]]: # pylint: disable=line-too-long +) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): raise ValueError("Connection string is either blank or malformed.") conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) endpoints = _SERVICE_PARAMS[service] - primary = None secondary = None if not credential: try: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 6a50c5c4e62a..68e7b7b5d557 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -42,7 +42,10 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential - from azure.core.pipeline.transport import PipelineRequest, PipelineResponse # pylint: disable=C4750 + from azure.core.pipeline.transport import ( + PipelineRequest, + PipelineResponse + ) _LOGGER = logging.getLogger(__name__) @@ -391,15 +394,15 @@ class StorageRetryPolicy(HTTPPolicy): The base class for Exponential and Linear retries containing shared code. """ - total_retries: int = 10 + total_retries: int """The max number of retries.""" - connect_retries: int = 3 + connect_retries: int """The max number of connect retries.""" - retry_read: int = 3 + retry_read: int """The max number of read retries.""" - retry_status:int = 3 + retry_status: int """The max number of status retries.""" - retry_to_secondary: bool = False + retry_to_secondary: bool """Whether the secondary endpoint should be retried.""" def __init__(self, **kwargs: Any) -> None: @@ -467,8 +470,7 @@ def sleep(self, settings, transport): transport.sleep(backoff) def increment( - self, - settings: Dict[str, Any], + self, settings: Dict[str, Any], request: "PipelineRequest", response: Optional["PipelineResponse"] = None, error: Optional[Union[ServiceRequestError, ServiceResponseError]] = None @@ -566,9 +568,16 @@ def send(self, request): class ExponentialRetry(StorageRetryPolicy): """Exponential retry.""" + initial_backoff: int + """The initial backoff interval, in seconds, for the first retry.""" + increment_base: int + """The base, in seconds, to increment the initial_backoff by after the + first retry.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + def __init__( - self, - initial_backoff: int = 15, + self, initial_backoff: int = 15, increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, @@ -620,6 +629,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class LinearRetry(StorageRetryPolicy): """Linear retry.""" + initial_backoff: int + """The backoff interval, in seconds, between retries.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + def __init__( self, backoff: int = 15, retry_total: int = 3, diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index c2f0a8da10a1..d3bf6106bd96 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -19,7 +19,10 @@ if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential - from azure.core.pipeline.transport import PipelineRequest, PipelineResponse # pylint: disable=C4750 + from azure.core.pipeline.transport import ( + PipelineRequest, + PipelineResponse + ) _LOGGER = logging.getLogger(__name__) @@ -144,6 +147,14 @@ async def send(self, request): class ExponentialRetry(AsyncStorageRetryPolicy): """Exponential retry.""" + initial_backoff: int + """The initial backoff interval, in seconds, for the first retry.""" + increment_base: int + """The base, in seconds, to increment the initial_backoff by after the + first retry.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + def __init__( self, initial_backoff: int = 15, @@ -200,6 +211,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class LinearRetry(AsyncStorageRetryPolicy): """Linear retry.""" + initial_backoff: int + """The backoff interval, in seconds, between retries.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + def __init__( self, backoff: int = 15, retry_total: int = 3, diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 03a893a1a024..8f22f1386824 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -190,7 +190,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) #type: ignore [arg-type] + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @distributed_trace_async async def create_queue( @@ -272,7 +272,7 @@ async def delete_queue(self, **kwargs: Any) -> None: process_storage_error(error) @distributed_trace_async - async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": # type: ignore[override] + async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """Returns all user-defined metadata for the specified queue. The data returned does not include the queue's list of messages. @@ -342,7 +342,7 @@ async def set_queue_metadata( process_storage_error(error) @distributed_trace_async - async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: # type: ignore[override] + async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: """Returns details about any stored access policies specified on the queue that may be used with Shared Access Signatures. @@ -365,7 +365,7 @@ async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace_async - async def set_queue_access_policy( # type: ignore[override] + async def set_queue_access_policy( self, signed_identifiers: Dict[str, AccessPolicy], **kwargs: Any ) -> None: @@ -422,7 +422,7 @@ async def set_queue_access_policy( # type: ignore[override] process_storage_error(error) @distributed_trace_async - async def send_message( # type: ignore[override] + async def send_message( self, content: Any, *, visibility_timeout: Optional[int] = None, @@ -525,7 +525,7 @@ async def send_message( # type: ignore[override] process_storage_error(error) @distributed_trace_async - async def receive_message( # type: ignore[override] + async def receive_message( self, *, visibility_timeout: Optional[int] = None, **kwargs: Any @@ -595,7 +595,7 @@ async def receive_message( # type: ignore[override] process_storage_error(error) @distributed_trace - def receive_messages( # type: ignore[override] + def receive_messages( self, *, messages_per_page: Optional[int] = None, visibility_timeout: Optional[int] = None, @@ -796,7 +796,7 @@ async def update_message( cls=return_response_headers, queue_message_id=message_id, **kwargs - )) + )) new_message = QueueMessage(content=message_text) new_message.id = message_id new_message.inserted_on = inserted_on @@ -809,7 +809,7 @@ async def update_message( process_storage_error(error) @distributed_trace_async - async def peek_messages( # type: ignore[override] + async def peek_messages( self, max_messages: Optional[int] = None, **kwargs: Any ) -> List[QueueMessage]: @@ -880,7 +880,7 @@ async def peek_messages( # type: ignore[override] process_storage_error(error) @distributed_trace_async - async def clear_messages(self, **kwargs: Any) -> None: # type: ignore[override] + async def clear_messages(self, **kwargs: Any) -> None: """Deletes all messages from the specified queue. :keyword int timeout: @@ -906,7 +906,7 @@ async def clear_messages(self, **kwargs: Any) -> None: # type: ignore[override] process_storage_error(error) @distributed_trace_async - async def delete_message( # type: ignore[override] + async def delete_message( self, message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, **kwargs: Any diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 3db9aa986b91..675224a54cd8 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -22,7 +22,7 @@ from .._encryption import StorageEncryptionMixin from .._generated.aio import AzureQueueStorage from .._generated.models import StorageServiceProperties -from .._models import QueueProperties, service_properties_deserialize, service_stats_deserialize +from .._models import CorsRule, QueueProperties, service_properties_deserialize, service_stats_deserialize from .._queue_service_client_helpers import _parse_url from .._serialize import get_api_version from .._shared.base_client import StorageAccountHostsMixin @@ -34,7 +34,7 @@ if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential from azure.core.credentials_async import AsyncTokenCredential - from .._models import CorsRule, Metrics, QueueAnalyticsLogging + from .._models import Metrics, QueueAnalyticsLogging class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long @@ -146,7 +146,7 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, credential=credential, **kwargs) # type: ignore [arg-type] + return cls(account_url, credential=credential, **kwargs) @distributed_trace_async async def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: @@ -213,7 +213,7 @@ async def set_service_properties( self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, hour_metrics: Optional["Metrics"] = None, minute_metrics: Optional["Metrics"] = None, - cors: Optional[List["CorsRule"]] = None, + cors: Optional[List[CorsRule]] = None, **kwargs: Any ) -> None: """Sets the properties of a storage account's Queue service, including @@ -255,7 +255,7 @@ async def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=cors # type: ignore [arg-type] + cors=CorsRule.to_generated(cors) ) try: await self._client.service.set_properties(props, timeout=timeout, **kwargs) @@ -390,8 +390,7 @@ async def delete_queue( await queue_client.delete_queue(timeout=timeout, **kwargs) def get_queue_client( - self, - queue: Union["QueueProperties", str], + self, queue: Union["QueueProperties", str], **kwargs: Any ) -> QueueClient: """Get a client to interact with the specified queue. diff --git a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py index 2370cea048e5..a9f44c941f2e 100644 --- a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py +++ b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py @@ -63,6 +63,6 @@ messages = queue_client.peek_messages(max_messages=20, logging_enable=True) for message in messages: try: - print(' Message: {!r}'.format(base64.b64decode(message.content))) #type: ignore [arg-type] + print(' Message: {!r}'.format(base64.b64decode(message.content))) except binascii.Error: print(' Message: {}'.format(message.content)) From 574b50c6ee26643d77f34a97af9181c8386a2e96 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 13 Oct 2023 11:37:39 -0700 Subject: [PATCH 55/71] Fix CorsRule to_generated method --- .../azure/storage/queue/_models.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index a8b871fee7a2..067fb5b1b300 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -195,17 +195,22 @@ def __init__(self, allowed_origins: List[str], allowed_methods: List[str], **kwa self.max_age_in_seconds = kwargs.get('max_age_in_seconds', 0) @staticmethod - def to_generated(rules: Optional[List["CorsRule"]]) -> List[GeneratedCorsRule]: - if rules is not None: - cors_rules = rules[0] - generated_cors = [GeneratedCorsRule( - allowed_origins=cors_rules.allowed_origins, - allowed_methods=cors_rules.allowed_methods, - allowed_headers=cors_rules.allowed_headers, - exposed_headers=cors_rules.exposed_headers, - max_age_in_seconds=cors_rules.max_age_in_seconds - )] - return generated_cors + def to_generated(rules: Optional[List["CorsRule"]]) -> Optional[List[GeneratedCorsRule]]: + if rules is None: + return rules + + generated_cors_list = [] + for cors_rule in rules: + generated_cors = GeneratedCorsRule( + allowed_origins=cors_rule.allowed_origins, + allowed_methods=cors_rule.allowed_methods, + allowed_headers=cors_rule.allowed_headers, + exposed_headers=cors_rule.exposed_headers, + max_age_in_seconds=cors_rule.max_age_in_seconds + ) + generated_cors_list.append(generated_cors) + + return generated_cors_list @classmethod def _from_generated(cls, generated: Any) -> Self: From 7e90a97b91eee54ec051101e89581c65b5a4eb29 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 13 Oct 2023 13:20:41 -0700 Subject: [PATCH 56/71] Fix PyLint & test failures --- .../azure/storage/queue/_shared/base_client.py | 1 + .../azure/storage/queue/_shared/base_client_async.py | 1 + .../azure-storage-queue/azure/storage/queue/_shared/policies.py | 2 +- .../azure/storage/queue/_shared/policies_async.py | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index aa0acb0a9fda..00df9314d101 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -389,6 +389,7 @@ def parse_connection_str( raise ValueError("Connection string is either blank or malformed.") conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) endpoints = _SERVICE_PARAMS[service] + primary = None secondary = None if not credential: try: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 737d9691b928..cdb8932ecdc6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -221,6 +221,7 @@ def parse_connection_str( raise ValueError("Connection string is either blank or malformed.") conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) endpoints = _SERVICE_PARAMS[service] + primary = None secondary = None if not credential: try: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 68e7b7b5d557..fe4351668b4b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential - from azure.core.pipeline.transport import ( + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, PipelineResponse ) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index d3bf6106bd96..bf03d3690598 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential - from azure.core.pipeline.transport import ( + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, PipelineResponse ) From 1c21f590d68ee8369206d6784e6782ce6817b3de Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Tue, 17 Oct 2023 14:29:08 -0700 Subject: [PATCH 57/71] Feedback --- .../azure/storage/queue/_models.py | 56 +- .../azure/storage/queue/_queue_client.py | 6 +- .../storage/queue/_queue_client_helpers.py | 4 +- .../storage/queue/_shared/base_client.py | 14 +- .../queue/_shared/base_client_async.py | 16 +- .../azure/storage/queue/_shared/parser.py | 15 +- .../storage/queue/aio/_queue_client_async.py | 6 +- .../samples/queue_samples_authentication.py | 129 ++-- .../queue_samples_authentication_async.py | 145 +++-- .../samples/queue_samples_hello_world.py | 75 +-- .../queue_samples_hello_world_async.py | 83 +-- .../samples/queue_samples_message.py | 607 +++++++++--------- .../samples/queue_samples_message_async.py | 580 ++++++++--------- .../samples/queue_samples_service.py | 169 ++--- .../samples/queue_samples_service_async.py | 167 ++--- 15 files changed, 1027 insertions(+), 1045 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 067fb5b1b300..2922786cb433 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -363,7 +363,7 @@ class QueueMessage(DictMixin): dequeue_count: Optional[int] """Begins with a value of 1 the first time the message is received. This value is incremented each time the message is subsequently received.""" - content: Optional[Any] + content: Any """The message content. Type is determined by the decode_function set on the service. Default is str.""" pop_receipt: Optional[str] @@ -375,54 +375,14 @@ class QueueMessage(DictMixin): """A UTC date value representing the time the message will next be visible. Only returned by receive messages operations. Set to None for peek messages.""" - # @overload - # def __init__( - # self, - # *, - # id: str, - # inserted_on: Optional["datetime"] = None, - # expires_on: Optional["datetime"] = None, - # dequeue_count: Optional[int] = None, - # content: Optional[Any] = None, - # pop_receipt: Optional[str] = None, - # next_visible_on: Optional["datetime"] = None - # ) -> None: - # ... - - # @overload - # def __init__( - # self, - # content: Optional[Any] = None, - # *, - # id: Optional[str], - # inserted_on: Optional["datetime"] = None, - # expires_on: Optional["datetime"] = None, - # dequeue_count: Optional[int] = None, - # pop_receipt: Optional[str] = None, - # next_visible_on: Optional["datetime"] = None - # ) -> None: - # ... - - # def __init__(self, *args: Any, **kwargs: Any) -> None: - # self.id = kwargs.pop('id', None) - # self.inserted_on = kwargs.pop('inserted_on', None) - # self.expires_on = kwargs.pop('expires_on', None) - # self.dequeue_count = kwargs.pop('dequeue_count', None) - # if args: - # self.content = args[0] - # else: - # self.content = kwargs.pop('content', None) - # self.pop_receipt = kwargs.pop('pop_receipt', None) - # self.next_visible_on = kwargs.pop('next_visible_on', None) - - def __init__(self, content=None): - self.id = None - self.inserted_on = None - self.expires_on = None - self.dequeue_count = None + def __init__(self, content: Any = None, **kwargs: Any) -> None: + self.id = kwargs.pop('id', None) + self.inserted_on = kwargs.pop('inserted_on', None) + self.expires_on = kwargs.pop('expires_on', None) + self.dequeue_count = kwargs.pop('dequeue_count', None) self.content = content - self.pop_receipt = None - self.next_visible_on = None + self.pop_receipt = kwargs.pop('pop_receipt', None) + self.next_visible_on = kwargs.pop('next_visible_on', None) @classmethod def _from_generated(cls, generated: Any) -> Self: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index d8422edc30c9..d1924de13073 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -22,7 +22,7 @@ from ._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier from ._message_encoding import NoDecodePolicy, NoEncodePolicy from ._models import AccessPolicy, MessagesPaged, QueueMessage -from ._queue_client_helpers import _format_url_helper, _from_queue_url_helper, _parse_url +from ._queue_client_helpers import _format_url, _from_queue_url, _parse_url from ._serialize import get_api_version from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin from ._shared.request_handlers import add_metadata_headers, serialize_iso @@ -106,7 +106,7 @@ def _format_url(self, hostname: str) -> str: :returns: The formatted endpoint URL according to the specified location mode hostname. :rtype: str """ - return _format_url_helper( + return _format_url( queue_name=self.queue_name, hostname=hostname, scheme=self.scheme, @@ -137,7 +137,7 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - account_url, queue_name = _from_queue_url_helper(queue_url=queue_url) + account_url, queue_name = _from_queue_url(queue_url=queue_url) return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @classmethod diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 54ef0bd347dc..878a938fff65 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -53,7 +53,7 @@ def _parse_url( return parsed_url, sas_token -def _format_url_helper(queue_name: Union[bytes, str], hostname: str, scheme: str, query_str: str) -> str: +def _format_url(queue_name: Union[bytes, str], hostname: str, scheme: str, query_str: str) -> str: """Format the endpoint URL according to the current location mode hostname. :param Union[bytes, str] queue_name: The name of the queue. @@ -71,7 +71,7 @@ def _format_url_helper(queue_name: Union[bytes, str], hostname: str, scheme: str f"{scheme}://{hostname}" f"/{quote(queue_name)}{query_str}") -def _from_queue_url_helper(queue_url: str) -> Tuple[str, str]: +def _from_queue_url(queue_url: str) -> Tuple[str, str]: """A client to interact with a specific Queue. :param str queue_url: The full URI to the queue, including SAS token if used. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 00df9314d101..ea4ff2fb365a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -45,6 +45,7 @@ StorageRequestHook, StorageResponseHook, ) +from .parser import _is_credential_sastoken from .request_handlers import serialize_batch_body, _get_batch_request_delimiter from .response_handlers import PartialBatchErrorException, process_storage_error from .shared_access_signature import QueryStringConstants @@ -212,7 +213,7 @@ def _format_query_string( if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") - if is_credential_sastoken(credential): + if _is_credential_sastoken(credential): credential = cast(str, credential) query_str += credential.lstrip("?") credential = None @@ -454,14 +455,3 @@ def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]: snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot") return snapshot, sas_token - - -def is_credential_sastoken(credential: Any) -> bool: - if not credential or not isinstance(credential, str): - return False - - sas_values = QueryStringConstants.to_list() - parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query): - return True - return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index cdb8932ecdc6..242aa8fda877 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -7,7 +7,6 @@ import logging from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union -from urllib.parse import parse_qs from azure.core.async_paging import AsyncList from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential @@ -35,9 +34,9 @@ StorageHosts, StorageRequestHook, ) +from .parser import _is_credential_sastoken from .policies_async import AsyncStorageResponseHook from .response_handlers import PartialBatchErrorException, process_storage_error -from .shared_access_signature import QueryStringConstants if TYPE_CHECKING: from azure.core.pipeline.transport import HttpRequest, HttpResponse @@ -86,7 +85,7 @@ def _format_query_string( if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") - if is_credential_sastoken(credential): + if _is_credential_sastoken(credential): query_str += credential.lstrip("?") # type: ignore [union-attr] credential = None elif sas_token: @@ -261,17 +260,6 @@ def parse_connection_str( secondary = secondary.replace(".blob.", ".dfs.") return primary, secondary, credential -def is_credential_sastoken(credential: Any) -> bool: - if not credential or not isinstance(credential, str): - return False - - sas_values = QueryStringConstants.to_list() - parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query): - return True - return False - - class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created by a `get_client` method does not close the outer transport for the parent diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py index cd59cfe104ca..304635c4cf4e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py @@ -6,7 +6,10 @@ import sys from datetime import datetime, timezone -from typing import Optional +from typing import Any, Optional +from urllib.parse import parse_qs + +from .shared_access_signature import QueryStringConstants EPOCH_AS_FILETIME = 116444736000000000 # January 1, 1970 as MS filetime HUNDREDS_OF_NANOSECONDS = 10000000 @@ -59,3 +62,13 @@ def _filetime_to_datetime(filetime: str) -> Optional[datetime]: # Try RFC 1123 as backup return _rfc_1123_to_datetime(filetime) + +def _is_credential_sastoken(credential: Any) -> bool: + if not credential or not isinstance(credential, str): + return False + + sas_values = QueryStringConstants.to_list() + parsed_query = parse_qs(credential.lstrip("?")) + if parsed_query and all(k in sas_values for k in parsed_query): + return True + return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 8f22f1386824..e457ea742281 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -24,7 +24,7 @@ from .._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier from .._message_encoding import NoDecodePolicy, NoEncodePolicy from .._models import AccessPolicy, QueueMessage -from .._queue_client_helpers import _format_url_helper, _from_queue_url_helper, _parse_url +from .._queue_client_helpers import _format_url, _from_queue_url, _parse_url from .._serialize import get_api_version from .._shared.base_client import StorageAccountHostsMixin from .._shared.base_client_async import AsyncStorageAccountHostsMixin, parse_connection_str @@ -120,7 +120,7 @@ def _format_url(self, hostname: str) -> str: :returns: The formatted endpoint URL according to the specified location mode hostname. :rtype: str """ - return _format_url_helper( + return _format_url( queue_name=self.queue_name, hostname=hostname, scheme=self.scheme, @@ -148,7 +148,7 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - account_url, queue_name = _from_queue_url_helper(queue_url=queue_url) + account_url, queue_name = _from_queue_url(queue_url=queue_url) return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @classmethod diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py index d8ebea813d9c..9e1fd88d0c54 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py @@ -30,7 +30,8 @@ from datetime import datetime, timedelta -import os, sys +import os +import sys class QueueAuthSamples(object): @@ -46,89 +47,91 @@ class QueueAuthSamples(object): active_directory_tenant_id = os.getenv("ACTIVE_DIRECTORY_TENANT_ID") def authentication_by_connection_string(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_connection_string") + sys.exit(1) + # Instantiate a QueueServiceClient using a connection string # [START auth_from_connection_string] from azure.storage.queue import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [END auth_from_connection_string] - - # Get information for the Queue Service - properties = queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + # [END auth_from_connection_string] + + # Get information for the Queue Service + properties = queue_service.get_service_properties() def authentication_by_shared_key(self): + if self.account_url is None or self.access_key is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_key") + sys.exit(1) + # Instantiate a QueueServiceClient using a shared access key # [START create_queue_service_client] from azure.storage.queue import QueueServiceClient - if self.account_url is not None and self.access_key is not None: - queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) - # [END create_queue_service_client] - - # Get information for the Queue Service - properties = queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) + # [END create_queue_service_client] + + # Get information for the Queue Service + properties = queue_service.get_service_properties() def authentication_by_active_directory(self): + if (self.active_directory_tenant_id is None or + self.active_directory_application_id is None or + self.active_directory_application_secret is None or + self.account_url is None + ): + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_active_directory") + sys.exit(1) + # [START create_queue_service_client_token] # Get a token credential for authentication from azure.identity import ClientSecretCredential - if ( - self.active_directory_tenant_id is not None and - self.active_directory_application_id is not None and - self.active_directory_application_secret is not None and - self.account_url is not None + token_credential = ClientSecretCredential( + self.active_directory_tenant_id, + self.active_directory_application_id, + self.active_directory_application_secret + ) + + # Instantiate a QueueServiceClient using a token credential + from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) + # [END create_queue_service_client_token] + + # Get information for the Queue Service + properties = queue_service.get_service_properties() + + def authentication_by_shared_access_signature(self): + if (self.connection_string is None or + self.account_name is None or + self.access_key is None or + self.account_url is None ): - token_credential = ClientSecretCredential( - self.active_directory_tenant_id, - self.active_directory_application_id, - self.active_directory_application_secret - ) - - # Instantiate a QueueServiceClient using a token credential - from azure.storage.queue import QueueServiceClient - queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) - # [END create_queue_service_client_token] - - # Get information for the Queue Service - properties = queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_access_signature") sys.exit(1) - def authentication_by_shared_access_signature(self): # Instantiate a QueueServiceClient using a connection string from azure.storage.queue import QueueServiceClient - if ( - self.connection_string is not None and - self.account_name is not None and - self.access_key is not None and - self.account_url - ): - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # Create a SAS token to use for authentication of a client - from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions + # Create a SAS token to use for authentication of a client + from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions - sas_token = generate_account_sas( - self.account_name, - self.access_key, - resource_types=ResourceTypes(service=True), - permission=AccountSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1) - ) + sas_token = generate_account_sas( + self.account_name, + self.access_key, + resource_types=ResourceTypes(service=True), + permission=AccountSasPermissions(read=True), + expiry=datetime.utcnow() + timedelta(hours=1) + ) - token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) + token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) - # Get information for the Queue Service - properties = token_auth_queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + # Get information for the Queue Service + properties = token_auth_queue_service.get_service_properties() if __name__ == '__main__': @@ -136,4 +139,4 @@ def authentication_by_shared_access_signature(self): sample.authentication_by_connection_string() sample.authentication_by_shared_key() sample.authentication_by_active_directory() - sample.authentication_by_shared_access_signature() \ No newline at end of file + sample.authentication_by_shared_access_signature() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py index c2b8ea8d6303..ebcf80199410 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py @@ -31,7 +31,8 @@ from datetime import datetime, timedelta import asyncio -import os, sys +import os +import sys class QueueAuthSamplesAsync(object): @@ -47,93 +48,95 @@ class QueueAuthSamplesAsync(object): active_directory_tenant_id = os.getenv("ACTIVE_DIRECTORY_TENANT_ID") async def authentication_by_connection_string_async(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_connection_string_async") + sys.exit(1) + # Instantiate a QueueServiceClient using a connection string # [START async_auth_from_connection_string] from azure.storage.queue.aio import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # [END async_auth_from_connection_string] - - # Get information for the Queue Service - async with queue_service: - properties = await queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + # [END async_auth_from_connection_string] + + # Get information for the Queue Service + async with queue_service: + properties = await queue_service.get_service_properties() async def authentication_by_shared_key_async(self): + if self.account_url is None or self.access_key is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_key_async") + sys.exit(1) + # Instantiate a QueueServiceClient using a shared access key # [START async_create_queue_service_client] from azure.storage.queue.aio import QueueServiceClient - if self.account_url is not None and self.access_key is not None: - queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) - # [END async_create_queue_service_client] - - # Get information for the Queue Service - async with queue_service: - properties = await queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) + # [END async_create_queue_service_client] + + # Get information for the Queue Service + async with queue_service: + properties = await queue_service.get_service_properties() async def authentication_by_active_directory_async(self): + if (self.active_directory_tenant_id is None or + self.active_directory_application_id is None or + self.active_directory_application_secret is None or + self.account_url is None + ): + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_active_directory_async") + sys.exit(1) + # [START async_create_queue_service_client_token] # Get a token credential for authentication from azure.identity.aio import ClientSecretCredential - if ( - self.active_directory_tenant_id is not None and - self.active_directory_application_id is not None and - self.active_directory_application_secret is not None and - self.account_url is not None + token_credential = ClientSecretCredential( + self.active_directory_tenant_id, + self.active_directory_application_id, + self.active_directory_application_secret + ) + + # Instantiate a QueueServiceClient using a token credential + from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) + # [END async_create_queue_service_client_token] + + # Get information for the Queue Service + async with queue_service: + properties = await queue_service.get_service_properties() + + async def authentication_by_shared_access_signature_async(self): + if (self.connection_string is None or + self.account_name is None or + self.access_key is None or + self.account_url is None ): - token_credential = ClientSecretCredential( - self.active_directory_tenant_id, - self.active_directory_application_id, - self.active_directory_application_secret - ) - - # Instantiate a QueueServiceClient using a token credential - from azure.storage.queue.aio import QueueServiceClient - queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) - # [END async_create_queue_service_client_token] - - # Get information for the Queue Service - async with queue_service: - properties = await queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_access_signature_async") sys.exit(1) - async def authentication_by_shared_access_signature_async(self): # Instantiate a QueueServiceClient using a connection string from azure.storage.queue.aio import QueueServiceClient - if ( - self.connection_string is not None and - self.account_name is not None and - self.access_key is not None and - self.account_url - ): - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - # Create a SAS token to use for authentication of a client - from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions - - sas_token = generate_account_sas( - queue_service.account_name, - queue_service.credential.account_key, - resource_types=ResourceTypes(service=True), - permission=AccountSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1) - ) - - token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) - - # Get information for the Queue Service - async with token_auth_queue_service: - properties = await token_auth_queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + # Create a SAS token to use for authentication of a client + from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions + + sas_token = generate_account_sas( + queue_service.account_name, + queue_service.credential.account_key, + resource_types=ResourceTypes(service=True), + permission=AccountSasPermissions(read=True), + expiry=datetime.utcnow() + timedelta(hours=1) + ) + + token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) + + # Get information for the Queue Service + async with token_auth_queue_service: + properties = await token_auth_queue_service.get_service_properties() async def main(): @@ -144,4 +147,4 @@ async def main(): await sample.authentication_by_shared_access_signature_async() if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py index 250a854e9128..d798563b4e90 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py @@ -21,7 +21,8 @@ """ -import os, sys +import os +import sys class QueueHelloWorldSamples(object): @@ -29,50 +30,52 @@ class QueueHelloWorldSamples(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") def create_client_with_connection_string(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: create_client_with_connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - # Get queue service properties - properties = queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + # Get queue service properties + properties = queue_service.get_service_properties() def queue_and_messages_example(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: queue_and_messages_example") + sys.exit(1) + # Instantiate the QueueClient from a connection string from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") - - # Create the queue - # [START create_queue] - queue.create_queue() - # [END create_queue] - - try: - # Send messages - queue.send_message("I'm using queues!") - queue.send_message("This is my second message") - - # Receive the messages - response = queue.receive_messages(messages_per_page=2) - - # Print the content of the messages - for message in response: - print(message.content) - - finally: - # [START delete_queue] - queue.delete_queue() - # [END delete_queue] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") + + # Create the queue + # [START create_queue] + queue.create_queue() + # [END create_queue] + + try: + # Send messages + queue.send_message("I'm using queues!") + queue.send_message("This is my second message") + + # Receive the messages + response = queue.receive_messages(messages_per_page=2) + + # Print the content of the messages + for message in response: + print(message.content) + + finally: + # [START delete_queue] + queue.delete_queue() + # [END delete_queue] if __name__ == '__main__': sample = QueueHelloWorldSamples() sample.create_client_with_connection_string() - sample.queue_and_messages_example() \ No newline at end of file + sample.queue_and_messages_example() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py index bb864fe47f5b..34aa77894fb2 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py @@ -22,7 +22,8 @@ import asyncio -import os, sys +import os +import sys class QueueHelloWorldSamplesAsync(object): @@ -30,51 +31,53 @@ class QueueHelloWorldSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") async def create_client_with_connection_string_async(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: create_client_with_connection_string_async") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - # Get queue service properties - async with queue_service: - properties = await queue_service.get_service_properties() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + # Get queue service properties + async with queue_service: + properties = await queue_service.get_service_properties() async def queue_and_messages_example_async(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: queue_and_messages_example_async") + sys.exit(1) + # Instantiate the QueueClient from a connection string from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") - - async with queue: - # Create the queue - # [START async_create_queue] - await queue.create_queue() - # [END async_create_queue] - - try: - # Send messages - await asyncio.gather( - queue.send_message("I'm using queues!"), - queue.send_message("This is my second message") - ) - - # Receive the messages - response = queue.receive_messages(messages_per_page=2) - - # Print the content of the messages - async for message in response: - print(message.content) - - finally: - # [START async_delete_queue] - await queue.delete_queue() - # [END async_delete_queue] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") + + async with queue: + # Create the queue + # [START async_create_queue] + await queue.create_queue() + # [END async_create_queue] + + try: + # Send messages + await asyncio.gather( + queue.send_message("I'm using queues!"), + queue.send_message("This is my second message") + ) + + # Receive the messages + response = queue.receive_messages(messages_per_page=2) + + # Print the content of the messages + async for message in response: + print(message.content) + + finally: + # [START async_delete_queue] + await queue.delete_queue() + # [END async_delete_queue] async def main(): diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py index a64b4777aba7..e77df1a6bceb 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py @@ -24,357 +24,362 @@ from datetime import datetime, timedelta -import os, sys +import os +import sys class QueueMessageSamples(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + if connection_string is None: + print("Missing required environment variable: connection_string" + '\n' + + "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") + sys.exit(1) def set_access_policy(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # [START create_queue_client_from_connection_string] from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") - # [END create_queue_client_from_connection_string] + queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") + # [END create_queue_client_from_connection_string] + + # Create the queue + queue.create_queue() + + # Send a message + queue.send_message("hello world") + + try: + # [START set_access_policy] + # Create an access policy + from azure.storage.queue import AccessPolicy, QueueSasPermissions + access_policy = AccessPolicy() + access_policy.start = datetime.utcnow() - timedelta(hours=1) + access_policy.expiry = datetime.utcnow() + timedelta(hours=1) + access_policy.permission = QueueSasPermissions(read=True) + identifiers = {'my-access-policy-id': access_policy} + + # Set the access policy + queue.set_queue_access_policy(identifiers) + # [END set_access_policy] + + # Use the access policy to generate a SAS token + # [START queue_client_sas_token] + from azure.storage.queue import generate_queue_sas + sas_token = generate_queue_sas( + queue.account_name, + queue.queue_name, + queue.credential.account_key, + policy_id='my-access-policy-id' + ) + # [END queue_client_sas_token] + + # Authenticate with the sas token + # [START create_queue_client] + token_auth_queue = QueueClient.from_queue_url( + queue_url=queue.url, + credential=sas_token + ) + # [END create_queue_client] + + # Use the newly authenticated client to receive messages + my_message = token_auth_queue.receive_messages() + + finally: + # Delete the queue + queue.delete_queue() - # Create the queue - queue.create_queue() - - # Send a message - queue.send_message("hello world") - - try: - # [START set_access_policy] - # Create an access policy - from azure.storage.queue import AccessPolicy, QueueSasPermissions - access_policy = AccessPolicy() - access_policy.start = datetime.utcnow() - timedelta(hours=1) - access_policy.expiry = datetime.utcnow() + timedelta(hours=1) - access_policy.permission = QueueSasPermissions(read=True) - identifiers = {'my-access-policy-id': access_policy} - - # Set the access policy - queue.set_queue_access_policy(identifiers) - # [END set_access_policy] - - # Use the access policy to generate a SAS token - # [START queue_client_sas_token] - from azure.storage.queue import generate_queue_sas - sas_token = generate_queue_sas( - queue.account_name, - queue.queue_name, - queue.credential.account_key, - policy_id='my-access-policy-id' - ) - # [END queue_client_sas_token] - - # Authenticate with the sas token - # [START create_queue_client] - token_auth_queue = QueueClient.from_queue_url( - queue_url=queue.url, - credential=sas_token - ) - # [END create_queue_client] - - # Use the newly authenticated client to receive messages - my_message = token_auth_queue.receive_messages() - - finally: - # Delete the queue - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") + def queue_metadata(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") sys.exit(1) - def queue_metadata(self): # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") - - # Create the queue - queue.create_queue() - - try: - # [START set_queue_metadata] - metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} - queue.set_queue_metadata(metadata=metadata) - # [END set_queue_metadata] - - # [START get_queue_properties] - properties = queue.get_queue_properties().metadata - # [END get_queue_properties] - - finally: - # Delete the queue - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") + + # Create the queue + queue.create_queue() + + try: + # [START set_queue_metadata] + metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} + queue.set_queue_metadata(metadata=metadata) + # [END set_queue_metadata] + + # [START get_queue_properties] + properties = queue.get_queue_properties().metadata + # [END get_queue_properties] + + finally: + # Delete the queue + queue.delete_queue() def send_and_receive_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") - - # Create the queue - queue.create_queue() - - try: - # [START send_messages] - queue.send_message("message1") - queue.send_message("message2", visibility_timeout=30) # wait 30s before becoming visible - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") - # [END send_messages] - - # [START receive_messages] - # Receive messages one-by-one - messages = queue.receive_messages() - for msg in messages: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") + + # Create the queue + queue.create_queue() + + try: + # [START send_messages] + queue.send_message("message1") + queue.send_message("message2", visibility_timeout=30) # wait 30s before becoming visible + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") + # [END send_messages] + + # [START receive_messages] + # Receive messages one-by-one + messages = queue.receive_messages() + for msg in messages: + print(msg.content) + + # Receive messages by batch + messages = queue.receive_messages(messages_per_page=5) + for msg_batch in messages.by_page(): + for msg in msg_batch: print(msg.content) + queue.delete_message(msg) + # [END receive_messages] - # Receive messages by batch - messages = queue.receive_messages(messages_per_page=5) - for msg_batch in messages.by_page(): - for msg in msg_batch: - print(msg.content) - queue.delete_message(msg) - # [END receive_messages] - - # Only prints 4 messages because message 2 is not visible yet - # >>message1 - # >>message3 - # >>message4 - # >>message5 - - finally: - # Delete the queue - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + # Only prints 4 messages because message 2 is not visible yet + # >>message1 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + queue.delete_queue() def list_message_pages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") - - # Create the queue - queue.create_queue() - - try: - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") - queue.send_message("message6") - - # [START receive_messages_listing] - # Store two messages in each page - message_batches = queue.receive_messages(messages_per_page=2).by_page() - - # Iterate through the page lists - print(list(next(message_batches))) - print(list(next(message_batches))) - - # There are two iterations in the last page as well. - last_page = next(message_batches) - for message in last_page: - print(message) - # [END receive_messages_listing] - - finally: - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") + + # Create the queue + queue.create_queue() + + try: + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") + queue.send_message("message6") + + # [START receive_messages_listing] + # Store two messages in each page + message_batches = queue.receive_messages(messages_per_page=2).by_page() + + # Iterate through the page lists + print(list(next(message_batches))) + print(list(next(message_batches))) + + # There are two iterations in the last page as well. + last_page = next(message_batches) + for message in last_page: + print(message) + # [END receive_messages_listing] + + finally: + queue.delete_queue() def receive_one_message_from_queue(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") - - # Create the queue - queue.create_queue() - - try: - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - - # [START receive_one_message] - # Pop two messages from the front of the queue - message1 = queue.receive_message() - message2 = queue.receive_message() - # We should see message 3 if we peek - message3 = queue.peek_messages()[0] - - if not message1 or not message2 or not message3: - raise ValueError("One of the messages are None.") - print(message1.content) - print(message2.content) - print(message3.content) - # [END receive_one_message] - - finally: - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") + + # Create the queue + queue.create_queue() + + try: + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + + # [START receive_one_message] + # Pop two messages from the front of the queue + message1 = queue.receive_message() + message2 = queue.receive_message() + # We should see message 3 if we peek + message3 = queue.peek_messages()[0] + + if not message1 or not message2 or not message3: + raise ValueError("One of the messages are None.") + + print(message1.content) + print(message2.content) + print(message3.content) + # [END receive_one_message] + + finally: + queue.delete_queue() def delete_and_clear_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") + queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") - # Create the queue - queue.create_queue() + # Create the queue + queue.create_queue() - try: - # Send messages - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") + try: + # Send messages + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") - # [START delete_message] - # Get the message at the front of the queue - msg = next(queue.receive_messages()) + # [START delete_message] + # Get the message at the front of the queue + msg = next(queue.receive_messages()) - # Delete the specified message - queue.delete_message(msg) - # [END delete_message] + # Delete the specified message + queue.delete_message(msg) + # [END delete_message] - # [START clear_messages] - queue.clear_messages() - # [END clear_messages] + # [START clear_messages] + queue.clear_messages() + # [END clear_messages] - finally: - # Delete the queue - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + finally: + # Delete the queue + queue.delete_queue() def peek_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") - - # Create the queue - queue.create_queue() - - try: - # Send messages - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") - - # [START peek_message] - # Peek at one message at the front of the queue - msg = queue.peek_messages() - - # Peek at the last 5 messages - messages = queue.peek_messages(max_messages=5) - - # Print the last 5 messages - for message in messages: - print(message.content) - # [END peek_message] - - finally: - # Delete the queue - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") + + # Create the queue + queue.create_queue() + + try: + # Send messages + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") + + # [START peek_message] + # Peek at one message at the front of the queue + msg = queue.peek_messages() + + # Peek at the last 5 messages + messages = queue.peek_messages(max_messages=5) + + # Print the last 5 messages + for message in messages: + print(message.content) + # [END peek_message] + + finally: + # Delete the queue + queue.delete_queue() def update_message(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") - - # Create the queue - queue.create_queue() - - try: - # [START update_message] - # Send a message - queue.send_message("update me") - - # Receive the message - messages = queue.receive_messages() - - # Update the message - list_result = next(messages) - id = list_result.id - message = queue.update_message( - id, - pop_receipt=list_result.pop_receipt, - visibility_timeout=0, - content="updated") - # [END update_message] - - finally: - # Delete the queue - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") + + # Create the queue + queue.create_queue() + + try: + # [START update_message] + # Send a message + queue.send_message("update me") + + # Receive the message + messages = queue.receive_messages() + + # Update the message + list_result = next(messages) + message = queue.update_message( + list_result.id, + pop_receipt=list_result.pop_receipt, + visibility_timeout=0, + content="updated") + # [END update_message] + + finally: + # Delete the queue + queue.delete_queue() def receive_messages_with_max_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue9") - - # Create the queue - queue.create_queue() - - try: - queue.send_message("message1") - queue.send_message("message2") - queue.send_message("message3") - queue.send_message("message4") - queue.send_message("message5") - queue.send_message("message6") - queue.send_message("message7") - queue.send_message("message8") - queue.send_message("message9") - queue.send_message("message10") - - # Receive messages one-by-one - messages = queue.receive_messages(max_messages=5) - for msg in messages: - print(msg.content) - queue.delete_message(msg) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue9") + + # Create the queue + queue.create_queue() + + try: + queue.send_message("message1") + queue.send_message("message2") + queue.send_message("message3") + queue.send_message("message4") + queue.send_message("message5") + queue.send_message("message6") + queue.send_message("message7") + queue.send_message("message8") + queue.send_message("message9") + queue.send_message("message10") + + # Receive messages one-by-one + messages = queue.receive_messages(max_messages=5) + for msg in messages: + print(msg.content) + queue.delete_message(msg) - # Only prints 5 messages because 'max_messages'=5 - # >>message1 - # >>message2 - # >>message3 - # >>message4 - # >>message5 - - finally: - # Delete the queue - queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + # Only prints 5 messages because 'max_messages'=5 + # >>message1 + # >>message2 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + queue.delete_queue() if __name__ == '__main__': @@ -387,4 +392,4 @@ def receive_messages_with_max_messages(self): sample.delete_and_clear_messages() sample.peek_messages() sample.update_message() - sample.receive_messages_with_max_messages() \ No newline at end of file + sample.receive_messages_with_max_messages() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py index edb622157112..f9f7b3364263 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py @@ -24,333 +24,339 @@ from datetime import datetime, timedelta import asyncio -import os, sys +import os +import sys class QueueMessageSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + if connection_string is None: + print("Missing required environment variable: connection_string" + '\n' + + "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") + sys.exit(1) async def set_access_policy_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # [START async_create_queue_client_from_connection_string] from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") - # [END async_create_queue_client_from_connection_string] - - # Create the queue - async with queue: - await queue.create_queue() + queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") + # [END async_create_queue_client_from_connection_string] + + # Create the queue + async with queue: + await queue.create_queue() + + # Send a message + await queue.send_message("hello world") + + try: + # [START async_set_access_policy] + # Create an access policy + from azure.storage.queue import AccessPolicy, QueueSasPermissions + access_policy = AccessPolicy() + access_policy.start = datetime.utcnow() - timedelta(hours=1) + access_policy.expiry = datetime.utcnow() + timedelta(hours=1) + access_policy.permission = QueueSasPermissions(read=True) + identifiers = {'my-access-policy-id': access_policy} + + # Set the access policy + await queue.set_queue_access_policy(identifiers) + # [END async_set_access_policy] + + # Use the access policy to generate a SAS token + from azure.storage.queue import generate_queue_sas + sas_token = generate_queue_sas( + queue.account_name, + queue.queue_name, + queue.credential.account_key, + policy_id='my-access-policy-id' + ) + + # Authenticate with the sas token + # [START async_create_queue_client] + token_auth_queue = QueueClient.from_queue_url( + queue_url=queue.url, + credential=sas_token + ) + # [END async_create_queue_client] + + # Use the newly authenticated client to receive messages + my_messages = token_auth_queue.receive_messages() + + finally: + # Delete the queue + await queue.delete_queue() - # Send a message - await queue.send_message("hello world") - - try: - # [START async_set_access_policy] - # Create an access policy - from azure.storage.queue import AccessPolicy, QueueSasPermissions - access_policy = AccessPolicy() - access_policy.start = datetime.utcnow() - timedelta(hours=1) - access_policy.expiry = datetime.utcnow() + timedelta(hours=1) - access_policy.permission = QueueSasPermissions(read=True) - identifiers = {'my-access-policy-id': access_policy} - - # Set the access policy - await queue.set_queue_access_policy(identifiers) - # [END async_set_access_policy] - - # Use the access policy to generate a SAS token - from azure.storage.queue import generate_queue_sas - sas_token = generate_queue_sas( - queue.account_name, - queue.queue_name, - queue.credential.account_key, - policy_id='my-access-policy-id' - ) - - # Authenticate with the sas token - # [START async_create_queue_client] - token_auth_queue = QueueClient.from_queue_url( - queue_url=queue.url, - credential=sas_token - ) - # [END async_create_queue_client] - - # Use the newly authenticated client to receive messages - my_messages = token_auth_queue.receive_messages() - - finally: - # Delete the queue - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") + async def queue_metadata_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") sys.exit(1) - async def queue_metadata_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") - - # Create the queue - async with queue: - await queue.create_queue() - - try: - # [START async_set_queue_metadata] - metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} - await queue.set_queue_metadata(metadata=metadata) - # [END async_set_queue_metadata] - - # [START async_get_queue_properties] - properties = await queue.get_queue_properties() - # [END async_get_queue_properties] - - finally: - # Delete the queue - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") + + # Create the queue + async with queue: + await queue.create_queue() + + try: + # [START async_set_queue_metadata] + metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} + await queue.set_queue_metadata(metadata=metadata) + # [END async_set_queue_metadata] + + # [START async_get_queue_properties] + properties = await queue.get_queue_properties() + # [END async_get_queue_properties] + + finally: + # Delete the queue + await queue.delete_queue() async def send_and_receive_messages_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") - - # Create the queue - async with queue: - await queue.create_queue() - - try: - # [START async_send_messages] - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2", visibility_timeout=30), # wait 30s before becoming visible - queue.send_message("message3"), - queue.send_message("message4"), - queue.send_message("message5") - ) - # [END async_send_messages] - - # [START async_receive_messages] - # Receive messages one-by-one - messages = queue.receive_messages() - async for msg in messages: + queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") + + # Create the queue + async with queue: + await queue.create_queue() + + try: + # [START async_send_messages] + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2", visibility_timeout=30), # wait 30s before becoming visible + queue.send_message("message3"), + queue.send_message("message4"), + queue.send_message("message5") + ) + # [END async_send_messages] + + # [START async_receive_messages] + # Receive messages one-by-one + messages = queue.receive_messages() + async for msg in messages: + print(msg.content) + + # Receive messages by batch + messages = queue.receive_messages(messages_per_page=5) + async for msg_batch in messages.by_page(): + async for msg in msg_batch: print(msg.content) + await queue.delete_message(msg) + # [END async_receive_messages] - # Receive messages by batch - messages = queue.receive_messages(messages_per_page=5) - async for msg_batch in messages.by_page(): - async for msg in msg_batch: - print(msg.content) - await queue.delete_message(msg) - # [END async_receive_messages] - - # Only prints 4 messages because message 2 is not visible yet - # >>message1 - # >>message3 - # >>message4 - # >>message5 - - finally: - # Delete the queue - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + # Only prints 4 messages because message 2 is not visible yet + # >>message1 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + await queue.delete_queue() async def receive_one_message_from_queue(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") - - # Create the queue - async with queue: - await queue.create_queue() - - try: - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2"), - queue.send_message("message3")) - - # [START receive_one_message] - # Pop two messages from the front of the queue - message1 = await queue.receive_message() - message2 = await queue.receive_message() - # We should see message 3 if we peek - message3 = await queue.peek_messages() - - if not message1 or not message2 or not message3: - raise ValueError("One of the messages are None.") - print(message1.content) - print(message2.content) - print(message3[0].content) - # [END receive_one_message] - - finally: - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") + + # Create the queue + async with queue: + await queue.create_queue() + + try: + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2"), + queue.send_message("message3")) + + # [START receive_one_message] + # Pop two messages from the front of the queue + message1 = await queue.receive_message() + message2 = await queue.receive_message() + # We should see message 3 if we peek + message3 = await queue.peek_messages() + + if not message1 or not message2 or not message3: + raise ValueError("One of the messages are None.") + + print(message1.content) + print(message2.content) + print(message3[0].content) + # [END receive_one_message] + + finally: + await queue.delete_queue() async def delete_and_clear_messages_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") - - # Create the queue - async with queue: - await queue.create_queue() - - try: - # Send messages - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2"), - queue.send_message("message3"), - queue.send_message("message4"), - queue.send_message("message5") - ) - - # [START async_delete_message] - # Get the message at the front of the queue - messages = queue.receive_messages() - async for msg in messages: - # Delete the specified message - await queue.delete_message(msg) - # [END async_delete_message] - break - - # [START async_clear_messages] - await queue.clear_messages() - # [END async_clear_messages] - - finally: - # Delete the queue - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") + + # Create the queue + async with queue: + await queue.create_queue() + + try: + # Send messages + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2"), + queue.send_message("message3"), + queue.send_message("message4"), + queue.send_message("message5") + ) + + # [START async_delete_message] + # Get the message at the front of the queue + messages = queue.receive_messages() + async for msg in messages: + # Delete the specified message + await queue.delete_message(msg) + # [END async_delete_message] + break + + # [START async_clear_messages] + await queue.clear_messages() + # [END async_clear_messages] + + finally: + # Delete the queue + await queue.delete_queue() async def peek_messages_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") - - # Create the queue - async with queue: - await queue.create_queue() - - try: - # Send messages - await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2"), - queue.send_message("message3"), - queue.send_message("message4"), - queue.send_message("message5") - ) - - # [START async_peek_message] - # Peek at one message at the front of the queue - msg = await queue.peek_messages() - - # Peek at the last 5 messages - messages = await queue.peek_messages(max_messages=5) - - # Print the last 5 messages - for message in messages: - print(message.content) - # [END async_peek_message] - - finally: - # Delete the queue - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") + + # Create the queue + async with queue: + await queue.create_queue() + + try: + # Send messages + await asyncio.gather( + queue.send_message("message1"), + queue.send_message("message2"), + queue.send_message("message3"), + queue.send_message("message4"), + queue.send_message("message5") + ) + + # [START async_peek_message] + # Peek at one message at the front of the queue + msg = await queue.peek_messages() + + # Peek at the last 5 messages + messages = await queue.peek_messages(max_messages=5) + + # Print the last 5 messages + for message in messages: + print(message.content) + # [END async_peek_message] + + finally: + # Delete the queue + await queue.delete_queue() async def update_message_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") - - # Create the queue - async with queue: - await queue.create_queue() - - try: - # [START async_update_message] - # Send a message - await queue.send_message("update me") - - # Receive the message - messages = queue.receive_messages() - - # Update the message - async for message in messages: - message = await queue.update_message( - message, - visibility_timeout=0, - content="updated") - # [END async_update_message] - break - - finally: - # Delete the queue - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") + + # Create the queue + async with queue: + await queue.create_queue() + + try: + # [START async_update_message] + # Send a message + await queue.send_message("update me") + + # Receive the message + messages = queue.receive_messages() + + # Update the message + async for message in messages: + message = await queue.update_message( + message, + visibility_timeout=0, + content="updated") + # [END async_update_message] + break + + finally: + # Delete the queue + await queue.delete_queue() async def receive_messages_with_max_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient - if self.connection_string is not None: - queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") - - # Create the queue - async with queue: - await queue.create_queue() - - try: - await queue.send_message("message1") - await queue.send_message("message2") - await queue.send_message("message3") - await queue.send_message("message4") - await queue.send_message("message5") - await queue.send_message("message6") - await queue.send_message("message7") - await queue.send_message("message8") - await queue.send_message("message9") - await queue.send_message("message10") - - # Receive messages one-by-one - messages = queue.receive_messages(max_messages=5) - async for msg in messages: - print(msg.content) - await queue.delete_message(msg) - - # Only prints 5 messages because 'max_messages'=5 - # >>message1 - # >>message2 - # >>message3 - # >>message4 - # >>message5 - - finally: - # Delete the queue - await queue.delete_queue() - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") + + # Create the queue + async with queue: + await queue.create_queue() + + try: + await queue.send_message("message1") + await queue.send_message("message2") + await queue.send_message("message3") + await queue.send_message("message4") + await queue.send_message("message5") + await queue.send_message("message6") + await queue.send_message("message7") + await queue.send_message("message8") + await queue.send_message("message9") + await queue.send_message("message10") + + # Receive messages one-by-one + messages = queue.receive_messages(max_messages=5) + async for msg in messages: + print(msg.content) + await queue.delete_message(msg) + + # Only prints 5 messages because 'max_messages'=5 + # >>message1 + # >>message2 + # >>message3 + # >>message4 + # >>message5 + + finally: + # Delete the queue + await queue.delete_queue() async def main(): diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py index b2c46dbd8d5a..a8065b2e4043 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py @@ -20,108 +20,111 @@ 1) AZURE_STORAGE_CONNECTION_STRING - the connection string to your storage account """ -import os, sys +import os +import sys class QueueServiceSamples(object): - if os.getenv("AZURE_STORAGE_CONNECTION_STRING") is not None: - connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + if connection_string is None: + print("Missing required environment variable: connection_string" + '\n' + + "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") + sys.exit(1) def queue_service_properties(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - # [START set_queue_service_properties] - # Create service properties - from typing import List, Optional - from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy - - # Create logging settings - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create metrics for requests statistics - hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create CORS rules - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] - max_age_in_seconds = 500 - exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] - allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] - cors_rule2 = CorsRule( - allowed_origins, - allowed_methods, - max_age_in_seconds=max_age_in_seconds, - exposed_headers=exposed_headers, - allowed_headers=allowed_headers - ) - - cors = [cors_rule1, cors_rule2] - - # Set the service properties - queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) - # [END set_queue_service_properties] - - # [START get_queue_service_properties] - properties = queue_service.get_service_properties() - # [END get_queue_service_properties] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + # [START set_queue_service_properties] + # Create service properties + from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy + + # Create logging settings + logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create metrics for requests statistics + hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create CORS rules + cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) + allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] + allowed_methods = ['GET', 'PUT'] + max_age_in_seconds = 500 + exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] + allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] + cors_rule2 = CorsRule( + allowed_origins, + allowed_methods, + max_age_in_seconds=max_age_in_seconds, + exposed_headers=exposed_headers, + allowed_headers=allowed_headers + ) + + cors = [cors_rule1, cors_rule2] + + # Set the service properties + queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) + # [END set_queue_service_properties] + + # [START get_queue_service_properties] + properties = queue_service.get_service_properties() + # [END get_queue_service_properties] def queues_in_account(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - # [START qsc_create_queue] - queue_service.create_queue("myqueue1") - # [END qsc_create_queue] - - try: - # [START qsc_list_queues] - # List all the queues in the service - list_queues = queue_service.list_queues() - for queue in list_queues: - print(queue) - - # List the queues in the service that start with the name "my" - list_my_queues = queue_service.list_queues(name_starts_with="my") - for queue in list_my_queues: - print(queue) - # [END qsc_list_queues] - - finally: - # [START qsc_delete_queue] - queue_service.delete_queue("myqueue1") - # [END qsc_delete_queue] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + # [START qsc_create_queue] + queue_service.create_queue("myqueue1") + # [END qsc_create_queue] + + try: + # [START qsc_list_queues] + # List all the queues in the service + list_queues = queue_service.list_queues() + for queue in list_queues: + print(queue) + + # List the queues in the service that start with the name "my" + list_my_queues = queue_service.list_queues(name_starts_with="my") + for queue in list_my_queues: + print(queue) + # [END qsc_list_queues] + + finally: + # [START qsc_delete_queue] + queue_service.delete_queue("myqueue1") + # [END qsc_delete_queue] def get_queue_client(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient, QueueClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - # [START get_queue_client] - # Get the queue client to interact with a specific queue - queue = queue_service.get_queue_client(queue="myqueue2") - # [END get_queue_client] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + # [START get_queue_client] + # Get the queue client to interact with a specific queue + queue = queue_service.get_queue_client(queue="myqueue2") + # [END get_queue_client] if __name__ == '__main__': sample = QueueServiceSamples() sample.queue_service_properties() sample.queues_in_account() - sample.get_queue_client() \ No newline at end of file + sample.get_queue_client() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py index 9f823ea55510..54694d73637d 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py @@ -22,104 +22,109 @@ import asyncio -import os, sys +import os +import sys class QueueServiceSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") + if connection_string is None: + print("Missing required environment variable: connection_string" + '\n' + + "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") + sys.exit(1) async def queue_service_properties_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - async with queue_service: - # [START async_set_queue_service_properties] - # Create service properties - from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy - - # Create logging settings - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create metrics for requests statistics - hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - - # Create CORS rules - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] - max_age_in_seconds = 500 - exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] - allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] - cors_rule2 = CorsRule( - allowed_origins, - allowed_methods, - max_age_in_seconds=max_age_in_seconds, - exposed_headers=exposed_headers, - allowed_headers=allowed_headers - ) - - cors = [cors_rule1, cors_rule2] - - # Set the service properties - await queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) - # [END async_set_queue_service_properties] - - # [START async_get_queue_service_properties] - properties = await queue_service.get_service_properties() - # [END async_get_queue_service_properties] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + async with queue_service: + # [START async_set_queue_service_properties] + # Create service properties + from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy + + # Create logging settings + logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create metrics for requests statistics + hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + + # Create CORS rules + cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) + allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] + allowed_methods = ['GET', 'PUT'] + max_age_in_seconds = 500 + exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] + allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] + cors_rule2 = CorsRule( + allowed_origins, + allowed_methods, + max_age_in_seconds=max_age_in_seconds, + exposed_headers=exposed_headers, + allowed_headers=allowed_headers + ) + + cors = [cors_rule1, cors_rule2] + + # Set the service properties + await queue_service.set_service_properties(logging, hour_metrics, minute_metrics, cors) + # [END async_set_queue_service_properties] + + # [START async_get_queue_service_properties] + properties = await queue_service.get_service_properties() + # [END async_get_queue_service_properties] async def queues_in_account_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - async with queue_service: - # [START async_qsc_create_queue] - await queue_service.create_queue("myqueue1") - # [END async_qsc_create_queue] - - try: - # [START async_qsc_list_queues] - # List all the queues in the service - list_queues = queue_service.list_queues() - async for queue in list_queues: - print(queue) - - # List the queues in the service that start with the name "my_" - list_my_queues = queue_service.list_queues(name_starts_with="my_") - async for queue in list_my_queues: - print(queue) - # [END async_qsc_list_queues] - - finally: - # [START async_qsc_delete_queue] - await queue_service.delete_queue("myqueue1") - # [END async_qsc_delete_queue] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + async with queue_service: + # [START async_qsc_create_queue] + await queue_service.create_queue("myqueue1") + # [END async_qsc_create_queue] + + try: + # [START async_qsc_list_queues] + # List all the queues in the service + list_queues = queue_service.list_queues() + async for queue in list_queues: + print(queue) + + # List the queues in the service that start with the name "my_" + list_my_queues = queue_service.list_queues(name_starts_with="my_") + async for queue in list_my_queues: + print(queue) + # [END async_qsc_list_queues] + + finally: + # [START async_qsc_delete_queue] + await queue_service.delete_queue("myqueue1") + # [END async_qsc_delete_queue] async def get_queue_client_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient, QueueClient - if self.connection_string is not None: - queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) - - # [START async_get_queue_client] - # Get the queue client to interact with a specific queue - queue = queue_service.get_queue_client(queue="myqueue2") - # [END async_get_queue_client] - else: - print("Missing required enviornment variable(s). Please see specific test for more details.") - sys.exit(1) + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) + + # [START async_get_queue_client] + # Get the queue client to interact with a specific queue + queue = queue_service.get_queue_client(queue="myqueue2") + # [END async_get_queue_client] async def main(): From bb9b84abd16ead62a77e17e3ef7067855d4a7c24 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Tue, 17 Oct 2023 15:07:30 -0700 Subject: [PATCH 58/71] Revert moving out sas token helper, causes circular dependency --- .../azure/storage/queue/_shared/base_client.py | 11 ++++++++++- .../storage/queue/_shared/base_client_async.py | 13 ++++++++++++- .../azure/storage/queue/_shared/parser.py | 15 +-------------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index ea4ff2fb365a..798617b95b87 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -45,7 +45,6 @@ StorageRequestHook, StorageResponseHook, ) -from .parser import _is_credential_sastoken from .request_handlers import serialize_batch_body, _get_batch_request_delimiter from .response_handlers import PartialBatchErrorException, process_storage_error from .shared_access_signature import QueryStringConstants @@ -455,3 +454,13 @@ def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]: snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot") return snapshot, sas_token + +def _is_credential_sastoken(credential: Any) -> bool: + if not credential or not isinstance(credential, str): + return False + + sas_values = QueryStringConstants.to_list() + parsed_query = parse_qs(credential.lstrip("?")) + if parsed_query and all(k in sas_values for k in parsed_query): + return True + return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 242aa8fda877..ad05d94ab2d7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -7,6 +7,7 @@ import logging from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union +from urllib.parse import parse_qs from azure.core.async_paging import AsyncList from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential @@ -34,9 +35,9 @@ StorageHosts, StorageRequestHook, ) -from .parser import _is_credential_sastoken from .policies_async import AsyncStorageResponseHook from .response_handlers import PartialBatchErrorException, process_storage_error +from .shared_access_signature import QueryStringConstants if TYPE_CHECKING: from azure.core.pipeline.transport import HttpRequest, HttpResponse @@ -260,6 +261,16 @@ def parse_connection_str( secondary = secondary.replace(".blob.", ".dfs.") return primary, secondary, credential +def _is_credential_sastoken(credential: Any) -> bool: + if not credential or not isinstance(credential, str): + return False + + sas_values = QueryStringConstants.to_list() + parsed_query = parse_qs(credential.lstrip("?")) + if parsed_query and all(k in sas_values for k in parsed_query): + return True + return False + class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created by a `get_client` method does not close the outer transport for the parent diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py index 304635c4cf4e..cd59cfe104ca 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py @@ -6,10 +6,7 @@ import sys from datetime import datetime, timezone -from typing import Any, Optional -from urllib.parse import parse_qs - -from .shared_access_signature import QueryStringConstants +from typing import Optional EPOCH_AS_FILETIME = 116444736000000000 # January 1, 1970 as MS filetime HUNDREDS_OF_NANOSECONDS = 10000000 @@ -62,13 +59,3 @@ def _filetime_to_datetime(filetime: str) -> Optional[datetime]: # Try RFC 1123 as backup return _rfc_1123_to_datetime(filetime) - -def _is_credential_sastoken(credential: Any) -> bool: - if not credential or not isinstance(credential, str): - return False - - sas_values = QueryStringConstants.to_list() - parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query): - return True - return False From d84538e648e7a482c60d12d7a0d9be0f3ce6b987 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Wed, 18 Oct 2023 11:20:35 -0700 Subject: [PATCH 59/71] Pylint --- .../azure/storage/queue/_shared/base_client.py | 2 +- .../azure/storage/queue/_shared/base_client_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 798617b95b87..0ede562eec68 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -52,7 +52,7 @@ if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential - from azure.core.pipeline.transport import HttpRequest, HttpResponse + from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 _LOGGER = logging.getLogger(__name__) _SERVICE_PARAMS = { diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index ad05d94ab2d7..429eda14729b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -40,7 +40,7 @@ from .shared_access_signature import QueryStringConstants if TYPE_CHECKING: - from azure.core.pipeline.transport import HttpRequest, HttpResponse + from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 _LOGGER = logging.getLogger(__name__) _SERVICE_PARAMS = { From de5b4eb288296c5cba7f6084933f80174423b3ab Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Thu, 19 Oct 2023 18:57:27 -0700 Subject: [PATCH 60/71] Typing feedback --- .../azure/storage/queue/_models.py | 2 +- .../azure/storage/queue/_queue_client.py | 28 +++++++++---------- .../storage/queue/_shared/base_client.py | 11 +------- .../queue/_shared/base_client_async.py | 12 +------- .../azure/storage/queue/_shared/policies.py | 3 +- .../storage/queue/_shared_access_signature.py | 11 ++++++++ .../storage/queue/aio/_queue_client_async.py | 28 +++++++++---------- .../samples/network_activity_logging.py | 12 ++++---- .../samples/queue_samples_message.py | 4 --- .../samples/queue_samples_message_async.py | 4 --- .../samples/queue_samples_service.py | 4 --- .../samples/queue_samples_service_async.py | 4 --- 12 files changed, 50 insertions(+), 73 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 2922786cb433..febfb7158977 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -375,7 +375,7 @@ class QueueMessage(DictMixin): """A UTC date value representing the time the message will next be visible. Only returned by receive messages operations. Set to None for peek messages.""" - def __init__(self, content: Any = None, **kwargs: Any) -> None: + def __init__(self, content: Optional[Any] = None, **kwargs: Any) -> None: self.id = kwargs.pop('id', None) self.inserted_on = kwargs.pop('inserted_on', None) self.expires_on = kwargs.pop('expires_on', None) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index d1924de13073..e802476ec939 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -92,8 +92,8 @@ def __init__( self.queue_name = queue_name self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) - self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() - self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() + self._message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self._message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._configure_encryption(kwargs) @@ -489,7 +489,7 @@ def send_message( kwargs) try: - self.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, @@ -501,11 +501,11 @@ def send_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) - encoded_content = self.message_encode_policy(content) + encoded_content = self._message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) try: @@ -576,7 +576,7 @@ def receive_message( self.encryption_version, kwargs) - self.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -586,7 +586,7 @@ def receive_message( number_of_messages=1, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) wrapped_message = QueueMessage._from_generated( # pylint: disable=protected-access @@ -668,7 +668,7 @@ def receive_messages( self.encryption_version, kwargs) - self.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -678,7 +678,7 @@ def receive_messages( self._client.messages.dequeue, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) if max_messages is not None and messages_per_page is not None: @@ -775,7 +775,7 @@ def update_message( raise ValueError("pop_receipt must be present") if message_text is not None: try: - self.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function, @@ -787,11 +787,11 @@ def update_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function) - encoded_message_text = self.message_encode_policy(message_text) + encoded_message_text = self._message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: updated = None @@ -870,7 +870,7 @@ def peek_messages( self.encryption_version, kwargs) - self.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -879,7 +879,7 @@ def peek_messages( messages = self._client.messages.peek( number_of_messages=max_messages, timeout=timeout, - cls=self.message_decode_policy, + cls=self._message_decode_policy, **kwargs) wrapped_messages = [] for peeked in messages: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 0ede562eec68..6c4066b2792c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -49,6 +49,7 @@ from .response_handlers import PartialBatchErrorException, process_storage_error from .shared_access_signature import QueryStringConstants from .._version import VERSION +from .._shared_access_signature import _is_credential_sastoken if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -454,13 +455,3 @@ def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]: snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot") return snapshot, sas_token - -def _is_credential_sastoken(credential: Any) -> bool: - if not credential or not isinstance(credential, str): - return False - - sas_values = QueryStringConstants.to_list() - parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query): - return True - return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 429eda14729b..f336cca8add9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -37,7 +37,7 @@ ) from .policies_async import AsyncStorageResponseHook from .response_handlers import PartialBatchErrorException, process_storage_error -from .shared_access_signature import QueryStringConstants +from .._shared_access_signature import _is_credential_sastoken if TYPE_CHECKING: from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 @@ -261,16 +261,6 @@ def parse_connection_str( secondary = secondary.replace(".blob.", ".dfs.") return primary, secondary, credential -def _is_credential_sastoken(credential: Any) -> bool: - if not credential or not isinstance(credential, str): - return False - - sas_values = QueryStringConstants.to_list() - parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query): - return True - return False - class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created by a `get_client` method does not close the outer transport for the parent diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index fe4351668b4b..73105fe979ee 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -581,7 +581,8 @@ def __init__( increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, - random_jitter_range: int = 3, **kwargs + random_jitter_range: int = 3, + **kwargs: Any ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py index 6cfbca0f04d4..d25272fd8f04 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py @@ -5,6 +5,7 @@ # -------------------------------------------------------------------------- from typing import Any, Optional, TYPE_CHECKING, Union +from urllib.parse import parse_qs from azure.storage.queue._shared import sign_string from azure.storage.queue._shared.constants import X_MS_VERSION @@ -271,3 +272,13 @@ def generate_queue_sas( ip=ip, **kwargs ) + +def _is_credential_sastoken(credential: Any) -> bool: + if not credential or not isinstance(credential, str): + return False + + sas_values = QueryStringConstants.to_list() + parsed_query = parse_qs(credential.lstrip("?")) + if parsed_query and all(k in sas_values for k in parsed_query): + return True + return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index e457ea742281..ad0a39f17461 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -105,8 +105,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) - self.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() - self.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() + self._message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self._message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._loop = loop @@ -487,7 +487,7 @@ async def send_message( kwargs) try: - self.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, @@ -499,11 +499,11 @@ async def send_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) - encoded_content = self.message_encode_policy(content) + encoded_content = self._message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) try: @@ -575,7 +575,7 @@ async def receive_message( self.encryption_version, kwargs) - self.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -585,7 +585,7 @@ async def receive_message( number_of_messages=1, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) wrapped_message = QueueMessage._from_generated( # pylint: disable=protected-access @@ -658,7 +658,7 @@ def receive_messages( self.encryption_version, kwargs) - self.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -668,7 +668,7 @@ def receive_messages( self._client.messages.dequeue, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) if max_messages is not None and messages_per_page is not None: @@ -765,7 +765,7 @@ async def update_message( raise ValueError("pop_receipt must be present") if message_text is not None: try: - self.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function, @@ -778,12 +778,12 @@ async def update_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function ) - encoded_message_text = self.message_encode_policy(message_text) + encoded_message_text = self._message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: updated = None @@ -863,14 +863,14 @@ async def peek_messages( self.encryption_version, kwargs) - self.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function ) try: messages = await self._client.messages.peek( - number_of_messages=max_messages, timeout=timeout, cls=self.message_decode_policy, **kwargs + number_of_messages=max_messages, timeout=timeout, cls=self._message_decode_policy, **kwargs ) wrapped_messages = [] for peeked in messages: diff --git a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py index a9f44c941f2e..a6440ebdf02a 100644 --- a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py +++ b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py @@ -60,9 +60,9 @@ for queue in queues: print('Queue: {}'.format(queue.name)) queue_client = service_client.get_queue_client(queue.name) -messages = queue_client.peek_messages(max_messages=20, logging_enable=True) -for message in messages: - try: - print(' Message: {!r}'.format(base64.b64decode(message.content))) - except binascii.Error: - print(' Message: {}'.format(message.content)) + messages = queue_client.peek_messages(max_messages=20, logging_enable=True) + for message in messages: + try: + print(' Message: {!r}'.format(base64.b64decode(message.content))) + except binascii.Error: + print(' Message: {}'.format(message.content)) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py index e77df1a6bceb..44172f3d53c7 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py @@ -31,10 +31,6 @@ class QueueMessageSamples(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") - if connection_string is None: - print("Missing required environment variable: connection_string" + '\n' + - "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") - sys.exit(1) def set_access_policy(self): if self.connection_string is None: diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py index f9f7b3364263..c5aafcdfc80d 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py @@ -31,10 +31,6 @@ class QueueMessageSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") - if connection_string is None: - print("Missing required environment variable: connection_string" + '\n' + - "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") - sys.exit(1) async def set_access_policy_async(self): if self.connection_string is None: diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py index a8065b2e4043..d10c230db526 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py @@ -27,10 +27,6 @@ class QueueServiceSamples(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") - if connection_string is None: - print("Missing required environment variable: connection_string" + '\n' + - "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") - sys.exit(1) def queue_service_properties(self): if self.connection_string is None: diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py index 54694d73637d..b57816180892 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py @@ -29,10 +29,6 @@ class QueueServiceSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") - if connection_string is None: - print("Missing required environment variable: connection_string" + '\n' + - "Please ensure your environment variable for 'AZURE_STORAGE_CONNECTION_STRING' is set and try again.") - sys.exit(1) async def queue_service_properties_async(self): if self.connection_string is None: From 40678185322c652a5735557d482ccfd19bc027eb Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 20 Oct 2023 13:48:01 -0700 Subject: [PATCH 61/71] Fix CI --- .../azure/storage/queue/_shared/base_client_async.py | 1 - .../azure-storage-queue/tests/test_queue_encodings.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index f336cca8add9..8d917afa531a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -7,7 +7,6 @@ import logging from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union -from urllib.parse import parse_qs from azure.core.async_paging import AsyncList from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py index 6e7eea541b07..82d8a1532c59 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py @@ -70,8 +70,8 @@ def test_message_text_xml(self, **kwargs): queue = qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) # Asserts - assert isinstance(queue.message_encode_policy, NoEncodePolicy) - assert isinstance(queue.message_decode_policy, NoDecodePolicy) + assert isinstance(queue._message_encode_policy, NoEncodePolicy) + assert isinstance(queue._message_decode_policy, NoDecodePolicy) self._validate_encoding(queue, message) @QueuePreparer() @@ -225,8 +225,8 @@ def test_message_no_encoding(self): message_decode_policy=None) # Asserts - assert isinstance(queue.message_encode_policy, NoEncodePolicy) - assert isinstance(queue.message_decode_policy, NoDecodePolicy) + assert isinstance(queue._message_encode_policy, NoEncodePolicy) + assert isinstance(queue._message_decode_policy, NoDecodePolicy) # ------------------------------------------------------------------------------ From 519421aad47f4adf49eaecf950066c83bdafd4a1 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 20 Oct 2023 15:00:40 -0700 Subject: [PATCH 62/71] Use QueueMessage ctor --- .../azure/storage/queue/_queue_client.py | 29 ++++++++++--------- .../storage/queue/aio/_queue_client_async.py | 29 ++++++++++--------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index e802476ec939..6f8dd2d0d538 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -515,12 +515,14 @@ def send_message( message_time_to_live=time_to_live, timeout=timeout, **kwargs) - queue_message = QueueMessage(content=content) - queue_message.id = enqueued[0].message_id - queue_message.inserted_on = enqueued[0].insertion_time - queue_message.expires_on = enqueued[0].expiration_time - queue_message.pop_receipt = enqueued[0].pop_receipt - queue_message.next_visible_on = enqueued[0].time_next_visible + queue_message = QueueMessage( + content=content, + id=enqueued[0].message_id, + inserted_on=enqueued[0].insertion_time, + expires_on=enqueued[0].expiration_time, + pop_receipt = enqueued[0].pop_receipt, + next_visible_on = enqueued[0].time_next_visible + ) return queue_message except HttpResponseError as error: process_storage_error(error) @@ -804,13 +806,14 @@ def update_message( cls=return_response_headers, queue_message_id=message_id, **kwargs)) - new_message = QueueMessage(content=message_text) - new_message.id = message_id - new_message.inserted_on = inserted_on - new_message.expires_on = expires_on - new_message.dequeue_count = dequeue_count - new_message.pop_receipt = response['popreceipt'] - new_message.next_visible_on = response['time_next_visible'] + new_message = QueueMessage( + content=message_text, + id=message_id, + inserted_on=inserted_on, + expires_on=expires_on, + pop_receipt = response['popreceipt'], + next_visible_on = response['time_next_visible'] + ) return new_message except HttpResponseError as error: process_storage_error(error) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index ad0a39f17461..e281feb2034f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -514,12 +514,14 @@ async def send_message( timeout=timeout, **kwargs ) - queue_message = QueueMessage(content=content) - queue_message.id = enqueued[0].message_id - queue_message.inserted_on = enqueued[0].insertion_time - queue_message.expires_on = enqueued[0].expiration_time - queue_message.pop_receipt = enqueued[0].pop_receipt - queue_message.next_visible_on = enqueued[0].time_next_visible + queue_message = QueueMessage( + content=content, + id=enqueued[0].message_id, + inserted_on=enqueued[0].insertion_time, + expires_on=enqueued[0].expiration_time, + pop_receipt = enqueued[0].pop_receipt, + next_visible_on = enqueued[0].time_next_visible + ) return queue_message except HttpResponseError as error: process_storage_error(error) @@ -797,13 +799,14 @@ async def update_message( queue_message_id=message_id, **kwargs )) - new_message = QueueMessage(content=message_text) - new_message.id = message_id - new_message.inserted_on = inserted_on - new_message.expires_on = expires_on - new_message.dequeue_count = dequeue_count - new_message.pop_receipt = response["popreceipt"] - new_message.next_visible_on = response["time_next_visible"] + new_message = QueueMessage( + content=message_text, + id=message_id, + inserted_on=inserted_on, + expires_on=expires_on, + pop_receipt = response['popreceipt'], + next_visible_on = response['time_next_visible'] + ) return new_message except HttpResponseError as error: process_storage_error(error) From 98ac8ea4f283bef7641d430afac64279e0e4cfc1 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 23 Oct 2023 11:59:17 -0700 Subject: [PATCH 63/71] Pylint --- .../azure-storage-queue/azure/storage/queue/_queue_client.py | 1 + .../azure/storage/queue/aio/_queue_client_async.py | 1 + 2 files changed, 2 insertions(+) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 6f8dd2d0d538..df2e8975c21c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -810,6 +810,7 @@ def update_message( content=message_text, id=message_id, inserted_on=inserted_on, + dequeue_count=dequeue_count, expires_on=expires_on, pop_receipt = response['popreceipt'], next_visible_on = response['time_next_visible'] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index e281feb2034f..e1b554007561 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -803,6 +803,7 @@ async def update_message( content=message_text, id=message_id, inserted_on=inserted_on, + dequeue_count=dequeue_count, expires_on=expires_on, pop_receipt = response['popreceipt'], next_visible_on = response['time_next_visible'] From 542895e9d32286ec8a35a0b4375dee00eadeba6e Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 23 Oct 2023 12:18:05 -0700 Subject: [PATCH 64/71] Privatize CorsRule method --- sdk/storage/azure-storage-queue/azure/storage/queue/_models.py | 2 +- .../azure/storage/queue/_queue_service_client.py | 2 +- .../azure/storage/queue/aio/_queue_service_client_async.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index febfb7158977..bdda1c7b33ab 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -195,7 +195,7 @@ def __init__(self, allowed_origins: List[str], allowed_methods: List[str], **kwa self.max_age_in_seconds = kwargs.get('max_age_in_seconds', 0) @staticmethod - def to_generated(rules: Optional[List["CorsRule"]]) -> Optional[List[GeneratedCorsRule]]: + def _to_generated(rules: Optional[List["CorsRule"]]) -> Optional[List[GeneratedCorsRule]]: if rules is None: return rules diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index a80cc5caa23a..1f3aa691add7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -259,7 +259,7 @@ def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=CorsRule.to_generated(cors) + cors=CorsRule._to_generated(cors) ) try: self._client.service.set_properties(props, timeout=timeout, **kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 675224a54cd8..7e28a1b0cee1 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -255,7 +255,7 @@ async def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=CorsRule.to_generated(cors) + cors=CorsRule._to_generated(cors) ) try: await self._client.service.set_properties(props, timeout=timeout, **kwargs) From 7be5de592cf8abbf3098092b980d89c572823a7a Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Mon, 23 Oct 2023 15:15:02 -0700 Subject: [PATCH 65/71] Disable for protected access --- .../azure/storage/queue/_queue_service_client.py | 6 +++--- .../azure/storage/queue/aio/_queue_service_client_async.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 1f3aa691add7..6d95b59135aa 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -259,7 +259,7 @@ def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=CorsRule._to_generated(cors) + cors=CorsRule._to_generated(cors) # pylint: disable=protected-access ) try: self._client.service.set_properties(props, timeout=timeout, **kwargs) @@ -423,8 +423,8 @@ def get_queue_client( queue_name = queue _pipeline = Pipeline( - transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access - policies=self._pipeline._impl_policies # type: ignore # pylint: disable = protected-access + transport=TransportWrapper(self._pipeline._transport), # pylint: disable=protected-access + policies=self._pipeline._impl_policies # type: ignore # pylint: disable=protected-access ) return QueueClient( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 7e28a1b0cee1..39861cd9371b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -255,7 +255,7 @@ async def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=CorsRule._to_generated(cors) + cors=CorsRule._to_generated(cors) # pylint: disable=protected-access ) try: await self._client.service.set_properties(props, timeout=timeout, **kwargs) @@ -420,8 +420,8 @@ def get_queue_client( queue_name = queue _pipeline = AsyncPipeline( - transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access - policies=self._pipeline._impl_policies # type: ignore # pylint: disable = protected-access + transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable=protected-access + policies=self._pipeline._impl_policies # type: ignore # pylint: disable=protected-access ) return QueueClient( From b5b0e65101c1d1d5503cc14673e921a4cd5538b7 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Thu, 26 Oct 2023 16:11:51 -0700 Subject: [PATCH 66/71] Flip on mypy --- sdk/storage/azure-storage-queue/pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-queue/pyproject.toml b/sdk/storage/azure-storage-queue/pyproject.toml index ff779a518143..78755ba24174 100644 --- a/sdk/storage/azure-storage-queue/pyproject.toml +++ b/sdk/storage/azure-storage-queue/pyproject.toml @@ -1,5 +1,5 @@ [tool.azure-sdk-build] -mypy = false +mypy = true pyright = false -type_check_samples = false -verifytypes = false +type_check_samples = true +verifytypes = true From 7ba0e60c5614f4f4dc6a268bf6fbe2e32b31a1db Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Wed, 1 Nov 2023 13:57:38 -0700 Subject: [PATCH 67/71] Missed configuration value parsing --- .../azure/storage/queue/_shared/models.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index e9be33980b8f..552d0c632ca5 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -569,13 +569,13 @@ class StorageConfiguration(Configuration): def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) - self.max_single_put_size = 64 * 1024 * 1024 + self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024) self.copy_polling_interval = 15 - self.max_block_size = 4 * 1024 * 1024 - self.min_large_block_upload_threshold = 4 * 1024 * 1024 + 1 - self.use_byte_buffer = False - self.max_page_size = 4 * 1024 * 1024 - self.min_large_chunk_upload_threshold = 100 * 1024 * 1024 + 1 - self.max_single_get_size = 32 * 1024 * 1024 - self.max_chunk_get_size = 4 * 1024 * 1024 - self.max_range_size = 4 * 1024 * 1024 + self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024) + self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1) + self.use_byte_buffer = kwargs.pop('use_byte_buffer', False) + self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024) + self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) + self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) + self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) + self.max_range_size = kwargs.pop('max_get_range_size', 4 * 1024 * 1024) From 93cb85a240e3cbad75a10cbac483a7f4136aa2b0 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Wed, 1 Nov 2023 15:38:00 -0700 Subject: [PATCH 68/71] Double-spaced --- .../azure/storage/queue/_shared/base_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 6c4066b2792c..92b2a7aed224 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -431,7 +431,6 @@ def parse_connection_str( return primary, secondary, credential - def create_configuration(**kwargs: Any) -> StorageConfiguration: # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): From b6f4291801e36668a163af8f2e4d33d6563a52ed Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 3 Nov 2023 12:59:41 -0700 Subject: [PATCH 69/71] Typo --- .../azure-storage-queue/azure/storage/queue/_shared/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index 552d0c632ca5..34ce22b7a632 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -578,4 +578,4 @@ def __init__(self, **kwargs): self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) - self.max_range_size = kwargs.pop('max_get_range_size', 4 * 1024 * 1024) + self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024) From 7ae8e95482434ebb0422242573f459b65e3b1683 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 10 Nov 2023 13:35:35 -0800 Subject: [PATCH 70/71] Last minute nits --- .../azure/storage/queue/_shared/base_client.py | 18 +++++++++--------- .../storage/queue/_shared/base_client_async.py | 10 +++------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 92b2a7aed224..8746ae3195fd 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -285,15 +285,15 @@ def _batch_send( batch_id = str(uuid.uuid1()) request = self._client._client.post( # pylint: disable=protected-access - url=( - f'{self.scheme}://{self.primary_hostname}/' - f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" - f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" - ), - headers={ - 'x-ms-version': self.api_version, - "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) - } + url=( + f'{self.scheme}://{self.primary_hostname}/' + f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" + f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" + ), + headers={ + 'x-ms-version': self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) + } ) policies = [StorageHeadersPolicy()] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 8d917afa531a..8ddb5b390e11 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -160,18 +160,14 @@ async def _batch_send( """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) - client = self._client - scheme = self.scheme - primary_hostname = self.primary_hostname - api_version = self.api_version - request = client._client.post( # pylint: disable=protected-access + request = self._client._client.post( # pylint: disable=protected-access url=( - f'{scheme}://{primary_hostname}/' + f'{self.scheme}://{self.primary_hostname}/' f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), headers={ - 'x-ms-version': api_version + 'x-ms-version': self.api_version } ) From 9c0d2cbaafebeac3ac8184282c009f6bd8627317 Mon Sep 17 00:00:00 2001 From: Vincent Phat Tran Date: Fri, 10 Nov 2023 13:46:05 -0800 Subject: [PATCH 71/71] Extra newline in encryption --- .../azure-storage-queue/azure/storage/queue/_encryption.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 5f64d2549132..5ad7a2e9a2cc 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -294,7 +294,6 @@ def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: :rtype: bool """ # If encryption_data is None, assume no encryption - return bool(encryption_data and (encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2))