diff --git a/sdk/storage/azure-storage-queue/azure/__init__.py b/sdk/storage/azure-storage-queue/azure/__init__.py index 59cb70146572..0d1f7edf5dc6 100644 --- a/sdk/storage/azure-storage-queue/azure/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: str +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-queue/azure/storage/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/__init__.py index 59cb70146572..0d1f7edf5dc6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/storage/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: str +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py index 50f79edb6076..f2016049827e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py @@ -5,14 +5,22 @@ # -------------------------------------------------------------------------- # pylint: disable=unused-argument -from azure.core.exceptions import ResourceExistsError +from typing import Any, Dict, TYPE_CHECKING -from ._shared.models import StorageErrorCode +from azure.core.exceptions import ResourceExistsError from ._models import QueueProperties +from ._shared.models import StorageErrorCode from ._shared.response_handlers import deserialize_metadata +if TYPE_CHECKING: + from azure.core.pipeline import PipelineResponse + -def deserialize_queue_properties(response, obj, headers): +def deserialize_queue_properties( + response: "PipelineResponse", + obj: Any, + headers: Dict[str, Any] +) -> QueueProperties: metadata = deserialize_metadata(response, obj, headers) queue_properties = QueueProperties( metadata=metadata, @@ -21,9 +29,13 @@ def deserialize_queue_properties(response, obj, headers): return queue_properties -def deserialize_queue_creation(response, obj, headers): +def deserialize_queue_creation( + response: "PipelineResponse", + obj: Any, + headers: Dict[str, Any] +) -> Dict[str, Any]: response = response.http_response - if response.status_code == 204: + if response.status_code == 204: # type: ignore error_code = StorageErrorCode.queue_already_exists error = ResourceExistsError( message=( @@ -31,8 +43,8 @@ def deserialize_queue_creation(response, obj, headers): f"RequestId:{headers['x-ms-request-id']}\n" f"Time:{headers['Date']}\n" f"ErrorCode:{error_code}"), - response=response) - error.error_code = error_code - error.additional_info = {} + response=response) # type: ignore + error.error_code = error_code # type: ignore + error.additional_info = {} # type: ignore raise error return headers diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index c4d2e66fbebc..5ad7a2e9a2cc 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -5,8 +5,8 @@ # license information. # -------------------------------------------------------------------------- -import os import math +import os import sys import warnings from collections import OrderedDict @@ -15,7 +15,9 @@ dumps, loads, ) -from typing import Any, BinaryIO, Dict, Optional, Tuple +from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TYPE_CHECKING +from typing import OrderedDict as TypedOrderedDict +from typing_extensions import Protocol from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher @@ -28,7 +30,12 @@ from azure.core.utils import CaseInsensitiveDict from ._version import VERSION -from ._shared import encode_base64, decode_base64_to_bytes +from ._shared import decode_base64_to_bytes, encode_base64 + +if TYPE_CHECKING: + from azure.core.pipeline import PipelineResponse + from cryptography.hazmat.primitives.ciphers import AEADEncryptionContext + from cryptography.hazmat.primitives.padding import PaddingContext _ENCRYPTION_PROTOCOL_V1 = '1.0' @@ -41,12 +48,27 @@ '{0} does not define a complete interface. Value of {1} is either missing or invalid.' -def _validate_not_none(param_name, param): +class KeyEncryptionKey(Protocol): + + def wrap_key(self, key: bytes) -> bytes: + ... + + def unwrap_key(self, key: bytes, algorithm: str) -> bytes: + ... + + def get_kid(self) -> str: + ... + + def get_key_wrap_algorithm(self) -> str: + ... + + +def _validate_not_none(param_name: str, param: Any): if param is None: raise ValueError(f'{param_name} should not be None.') -def _validate_key_encryption_key_wrap(kek): +def _validate_key_encryption_key_wrap(kek: KeyEncryptionKey): # Note that None is not callable and so will fail the second clause of each check. if not hasattr(kek, 'wrap_key') or not callable(kek.wrap_key): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) @@ -57,7 +79,7 @@ def _validate_key_encryption_key_wrap(kek): class StorageEncryptionMixin(object): - def _configure_encryption(self, kwargs): + def _configure_encryption(self, kwargs: Dict[str, Any]): self.require_encryption = kwargs.get("require_encryption", False) self.encryption_version = kwargs.get("encryption_version", "1.0") self.key_encryption_key = kwargs.get("key_encryption_key") @@ -80,15 +102,17 @@ class _EncryptionAlgorithm(object): class _WrappedContentKey: """ Represents the envelope key details stored on the service. - - :param str algorithm: - The algorithm used for wrapping. - :param bytes encrypted_key: - The encrypted content-encryption-key. - :param str key_id: - The key-encryption-key identifier string. """ - def __init__(self, algorithm, encrypted_key, key_id): + + def __init__(self, algorithm: str, encrypted_key: bytes, key_id: str) -> None: + """ + :param str algorithm: + The algorithm used for wrapping. + :param bytes encrypted_key: + The encrypted content-encryption-key. + :param str key_id: + The key-encryption-key identifier string. + """ _validate_not_none('algorithm', algorithm) _validate_not_none('encrypted_key', encrypted_key) _validate_not_none('key_id', key_id) @@ -102,16 +126,17 @@ class _EncryptedRegionInfo: """ Represents the length of encryption elements. This is only used for Encryption V2. - - :param int data_length: - The length of the encryption region data (not including nonce + tag). - :param str nonce_length: - The length of nonce used when encrypting. - :param int tag_length: - The length of the encryption tag. """ - def __init__(self, data_length, nonce_length, tag_length): + def __init__(self, data_length: int, nonce_length: int, tag_length: int) -> None: + """ + :param int data_length: + The length of the encryption region data (not including nonce + tag). + :param int nonce_length: + The length of nonce used when encrypting. + :param int tag_length: + The length of the encryption tag. + """ _validate_not_none('data_length', data_length) _validate_not_none('nonce_length', nonce_length) _validate_not_none('tag_length', tag_length) @@ -125,14 +150,15 @@ class _EncryptionAgent: """ Represents the encryption agent stored on the service. It consists of the encryption protocol version and encryption algorithm used. - - :param _EncryptionAlgorithm encryption_algorithm: - The algorithm used for encrypting the message contents. - :param str protocol: - The protocol version used for encryption. """ - def __init__(self, encryption_algorithm, protocol): + def __init__(self, encryption_algorithm: _EncryptionAlgorithm, protocol: str) -> None: + """ + :param _EncryptionAlgorithm encryption_algorithm: + The algorithm used for encrypting the message contents. + :param str protocol: + The protocol version used for encryption. + """ _validate_not_none('encryption_algorithm', encryption_algorithm) _validate_not_none('protocol', protocol) @@ -143,30 +169,30 @@ def __init__(self, encryption_algorithm, protocol): class _EncryptionData: """ Represents the encryption data that is stored on the service. - - :param Optional[bytes] content_encryption_IV: - The content encryption initialization vector. - Required for AES-CBC (V1). - :param Optional[_EncryptedRegionInfo] encrypted_region_info: - The info about the authenticated block sizes. - Required for AES-GCM (V2). - :param _EncryptionAgent encryption_agent: - The encryption agent. - :param _WrappedContentKey wrapped_content_key: - An object that stores the wrapping algorithm, the key identifier, - and the encrypted key bytes. - :param dict key_wrapping_metadata: - A dict containing metadata related to the key wrapping. """ def __init__( - self, - content_encryption_IV, - encrypted_region_info, - encryption_agent, - wrapped_content_key, - key_wrapping_metadata - ): + self, content_encryption_IV: Optional[bytes], + encrypted_region_info: Optional[_EncryptedRegionInfo], + encryption_agent: _EncryptionAgent, + wrapped_content_key: _WrappedContentKey, + key_wrapping_metadata: Dict[str, Any] + ) -> None: + """ + :param Optional[bytes] content_encryption_IV: + The content encryption initialization vector. + Required for AES-CBC (V1). + :param Optional[_EncryptedRegionInfo] encrypted_region_info: + The info about the autenticated block sizes. + Required for AES-GCM (V2). + :param _EncryptionAgent encryption_agent: + The encryption agent. + :param _WrappedContentKey wrapped_content_key: + An object that stores the wrapping algorithm, the key identifier, + and the encrypted key bytes. + :param Dict[str, Any] key_wrapping_metadata: + A dict containing metadata related to the key wrapping. + """ _validate_not_none('encryption_agent', encryption_agent) _validate_not_none('wrapped_content_key', wrapped_content_key) @@ -191,15 +217,15 @@ class GCMBlobEncryptionStream: it's streamed. Data is read and encrypted in regions. The stream will use the same encryption key and will generate a guaranteed unique nonce for each encryption region. - - :param bytes content_encryption_key: The encryption key to use. - :param BinaryIO data_stream: The data stream to read data from. """ def __init__( - self, - content_encryption_key: bytes, + self, content_encryption_key: bytes, data_stream: BinaryIO, - ): + ) -> None: + """ + :param bytes content_encryption_key: The encryption key to use. + :param BinaryIO data_stream: The data stream to read data from. + """ self.content_encryption_key = content_encryption_key self.data_stream = data_stream @@ -268,7 +294,7 @@ def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: :rtype: bool """ # If encryption_data is None, assume no encryption - return encryption_data and encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2 + return bool(encryption_data and (encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2)) def modify_user_agent_for_encryption( @@ -376,6 +402,9 @@ def get_adjusted_download_range_and_offset( elif encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V2: start_offset, end_offset = 0, end + if encryption_data.encrypted_region_info is None: + raise ValueError("Missing required metadata for Encryption V2") + nonce_length = encryption_data.encrypted_region_info.nonce_length data_length = encryption_data.encrypted_region_info.data_length tag_length = encryption_data.encrypted_region_info.tag_length @@ -420,17 +449,17 @@ def parse_encryption_data(metadata: Dict[str, Any]) -> Optional[_EncryptionData] return None -def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_EncryptionData]) -> int: +def adjust_blob_size_for_encryption(size: int, encryption_data: _EncryptionData) -> int: """ Adjusts the given blob size for encryption by subtracting the size of the encryption data (nonce + tag). This only has an affect for encryption V2. :param int size: The original blob size. - :param Optional[_EncryptionData] encryption_data: The encryption data to determine version and sizes. + :param _EncryptionData encryption_data: The encryption data to determine version and sizes. :return: The new blob size. :rtype: int """ - if is_encryption_v2(encryption_data): + if is_encryption_v2(encryption_data) and encryption_data.encrypted_region_info is not None: nonce_length = encryption_data.encrypted_region_info.nonce_length data_length = encryption_data.encrypted_region_info.data_length tag_length = encryption_data.encrypted_region_info.tag_length @@ -443,17 +472,22 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp return size -def _generate_encryption_data_dict(kek, cek, iv, version): - ''' +def _generate_encryption_data_dict( + kek: KeyEncryptionKey, + cek: bytes, + iv: Optional[bytes], + version: str + ) -> TypedOrderedDict[str, Any]: + """ Generates and returns the encryption metadata as a dict. - :param object kek: The key encryption key. See calling functions for more information. + :param KeyEncryptionKey kek: The key encryption key. See calling functions for more information. :param bytes cek: The content encryption key. :param Optional[bytes] iv: The initialization vector. Only required for AES-CBC. :param str version: The client encryption version used. :return: A dict containing all the encryption metadata. - :rtype: dict - ''' + :rtype: Dict[str, Any] + """ # Encrypt the cek. if version == _ENCRYPTION_PROTOCOL_V1: wrapped_cek = kek.wrap_key(cek) @@ -483,20 +517,20 @@ def _generate_encryption_data_dict(kek, cek, iv, version): encrypted_region_info['DataLength'] = _GCM_REGION_DATA_LENGTH encrypted_region_info['NonceLength'] = _GCM_NONCE_LENGTH - encryption_data_dict = OrderedDict() + encryption_data_dict: TypedOrderedDict[str, Any] = OrderedDict() encryption_data_dict['WrappedContentKey'] = wrapped_content_key encryption_data_dict['EncryptionAgent'] = encryption_agent if version == _ENCRYPTION_PROTOCOL_V1: encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info - encryption_data_dict['KeyWrappingMetadata'] = {'EncryptionLibrary': 'Python ' + VERSION} + encryption_data_dict['KeyWrappingMetadata'] = OrderedDict({'EncryptionLibrary': 'Python ' + VERSION}) return encryption_data_dict -def _dict_to_encryption_data(encryption_data_dict): - ''' +def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _EncryptionData: + """ Converts the specified dictionary to an EncryptionData object for eventual use in decryption. @@ -504,7 +538,7 @@ def _dict_to_encryption_data(encryption_data_dict): The dictionary containing the encryption data. :return: an _EncryptionData object built from the dictionary. :rtype: _EncryptionData - ''' + """ try: protocol = encryption_data_dict['EncryptionAgent']['Protocol'] if protocol not in [_ENCRYPTION_PROTOCOL_V1, _ENCRYPTION_PROTOCOL_V2]: @@ -547,15 +581,15 @@ def _dict_to_encryption_data(encryption_data_dict): return encryption_data -def _generate_AES_CBC_cipher(cek, iv): - ''' +def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: + """ Generates and returns an encryption cipher for AES CBC using the given cek and iv. :param bytes[] cek: The content encryption key for the cipher. :param bytes[] iv: The initialization vector for the cipher. :return: A cipher for encrypting in AES256 CBC. :rtype: ~cryptography.hazmat.primitives.ciphers.Cipher - ''' + """ backend = default_backend() algorithm = AES(cek) @@ -563,21 +597,30 @@ def _generate_AES_CBC_cipher(cek, iv): return Cipher(algorithm, mode, backend) -def _validate_and_unwrap_cek(encryption_data, key_encryption_key=None, key_resolver=None): - ''' +def _validate_and_unwrap_cek( + encryption_data: _EncryptionData, + key_encryption_key: Optional[KeyEncryptionKey] = None, + key_resolver: Optional[Callable[[str], KeyEncryptionKey]] = None +) -> bytes: + """ Extracts and returns the content_encryption_key stored in the encryption_data object and performs necessary validation on all parameters. :param _EncryptionData encryption_data: The encryption metadata of the retrieved value. - :param obj key_encryption_key: - The key_encryption_key used to unwrap the cek. Please refer to high-level service object - instance variables for more details. - :param func key_resolver: + :param Optional[KeyEncryptionKey] key_encryption_key: + The user-provided key-encryption-key. Must implement the following methods: + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. + :param Optional[Callable[[str], KeyEncryptionKey]] key_resolver: A function used that, given a key_id, will return a key_encryption_key. Please refer to high-level service object instance variables for more details. - :return: the content_encryption_key stored in the encryption_data object. - :rtype: bytes[] - ''' + :return: The content_encryption_key stored in the encryption_data object. + :rtype: bytes + """ _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) @@ -589,13 +632,14 @@ def _validate_and_unwrap_cek(encryption_data, key_encryption_key=None, key_resol else: raise ValueError('Specified encryption version is not supported.') - content_encryption_key = None + content_encryption_key: Optional[bytes] = None # If the resolver exists, give priority to the key it finds. if key_resolver is not None: key_encryption_key = key_resolver(encryption_data.wrapped_content_key.key_id) - _validate_not_none('key_encryption_key', key_encryption_key) + if key_encryption_key is None: + raise ValueError("Unable to decrypt. key_resolver and key_encryption_key cannot both be None.") if not hasattr(key_encryption_key, 'get_kid') or not callable(key_encryption_key.get_kid): raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) if not hasattr(key_encryption_key, 'unwrap_key') or not callable(key_encryption_key.unwrap_key): @@ -603,8 +647,9 @@ def _validate_and_unwrap_cek(encryption_data, key_encryption_key=None, key_resol if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') # Will throw an exception if the specified algorithm is not supported. - content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, - encryption_data.wrapped_content_key.algorithm) + content_encryption_key = key_encryption_key.unwrap_key( + encryption_data.wrapped_content_key.encrypted_key, + encryption_data.wrapped_content_key.algorithm) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. @@ -622,27 +667,34 @@ def _validate_and_unwrap_cek(encryption_data, key_encryption_key=None, key_resol return content_encryption_key -def _decrypt_message(message, encryption_data, key_encryption_key=None, resolver=None): +def _decrypt_message( + message: bytes, + encryption_data: _EncryptionData, + key_encryption_key: Optional[KeyEncryptionKey] = None, + resolver: Optional[Callable[[str], KeyEncryptionKey]] = None +) -> bytes: """ Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. Unwraps the content-encryption-key using the user-provided or resolved key-encryption-key (kek). - Returns the original plaintex. + Returns the original plaintext. - :param str message: + :param bytes message: The ciphertext to be decrypted. :param _EncryptionData encryption_data: The metadata associated with this ciphertext. - :param object key_encryption_key: + :param Optional[KeyEncryptionKey] key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - unwrap_key(key, algorithm) - - returns the unwrapped form of the specified symmetric key using the string-specified algorithm. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. get_kid() - - returns a string key id for this key-encryption-key. - :param Callable resolver: + - Returns a string key id for this key-encryption-key. + :param Optional[Callable[[str], KeyEncryptionKey]] resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The decrypted plaintext. - :rtype: str + :rtype: bytes """ _validate_not_none('message', message) content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, resolver) @@ -654,9 +706,8 @@ def _decrypt_message(message, encryption_data, key_encryption_key=None, resolver cipher = _generate_AES_CBC_cipher(content_encryption_key, encryption_data.content_encryption_IV) # decrypt data - decrypted_data = message decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) + decrypted_data = (decryptor.update(message) + decryptor.finalize()) # unpad data unpadder = PKCS7(128).unpadder() @@ -667,7 +718,10 @@ def _decrypt_message(message, encryption_data, key_encryption_key=None, resolver if not block_info or not block_info.nonce_length: raise ValueError("Missing required metadata for decryption.") - nonce_length = encryption_data.encrypted_region_info.nonce_length + if encryption_data.encrypted_region_info is None: + raise ValueError("Missing required metadata for Encryption V2") + + nonce_length = int(encryption_data.encrypted_region_info.nonce_length) # First bytes are the nonce nonce = message[:nonce_length] @@ -682,8 +736,8 @@ def _decrypt_message(message, encryption_data, key_encryption_key=None, resolver return decrypted_data -def encrypt_blob(blob, key_encryption_key, version): - ''' +def encrypt_blob(blob: bytes, key_encryption_key: KeyEncryptionKey, version: str) -> Tuple[str, bytes]: + """ Encrypts the given blob using the given encryption protocol version. Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encryption metadata. This method should @@ -692,15 +746,18 @@ def encrypt_blob(blob, key_encryption_key, version): :param bytes blob: The blob to be encrypted. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - wrap_key(key)--wraps the specified key using an algorithm of the user's choice. - get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. - get_kid()--returns a string key id for this key-encryption-key. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. :param str version: The client encryption version to use. :return: A tuple of json-formatted string containing the encryption metadata and the encrypted blob data. :rtype: (str, bytes) - ''' + """ _validate_not_none('blob', blob) _validate_not_none('key_encryption_key', key_encryption_key) @@ -741,17 +798,21 @@ def encrypt_blob(blob, key_encryption_key, version): return dumps(encryption_data), encrypted_data -def generate_blob_encryption_data(key_encryption_key, version): - ''' +def generate_blob_encryption_data( + key_encryption_key: Optional[KeyEncryptionKey], + version: str +) -> Tuple[Optional[bytes], Optional[bytes], Optional[str]]: + """ Generates the encryption_metadata for the blob. - :param object key_encryption_key: + :param Optional[KeyEncryptionKey] key_encryption_key: The key-encryption-key used to wrap the cek associate with this blob. :param str version: The client encryption version to use. :return: A tuple containing the cek and iv for this blob as well as the serialized encryption metadata for the blob. - :rtype: (bytes, Optional[bytes], str) - ''' + :rtype: (Optional[bytes], Optional[bytes], Optional[str]) + """ + encryption_data = None content_encryption_key = None initialization_vector = None @@ -761,37 +822,42 @@ def generate_blob_encryption_data(key_encryption_key, version): # Initialization vector only needed for V1 if version == _ENCRYPTION_PROTOCOL_V1: initialization_vector = os.urandom(16) - encryption_data = _generate_encryption_data_dict(key_encryption_key, + encryption_data_dict = _generate_encryption_data_dict(key_encryption_key, content_encryption_key, initialization_vector, version) - encryption_data['EncryptionMode'] = 'FullBlob' - encryption_data = dumps(encryption_data) + encryption_data_dict['EncryptionMode'] = 'FullBlob' + encryption_data = dumps(encryption_data_dict) return content_encryption_key, initialization_vector, encryption_data def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements - require_encryption, - key_encryption_key, - key_resolver, - content, - start_offset, - end_offset, - response_headers): + require_encryption: bool, + key_encryption_key: KeyEncryptionKey, + key_resolver: Optional[Callable[[str], KeyEncryptionKey]], + content: bytes, + start_offset: int, + end_offset: int, + response_headers: Dict[str, Any] +) -> bytes: """ Decrypts the given blob contents and returns only the requested range. :param bool require_encryption: Whether the calling blob service requires objects to be decrypted. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - wrap_key(key)--wraps the specified key using an algorithm of the user's choice. - get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. - get_kid()--returns a string key id for this key-encryption-key. - :param object key_resolver: + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. + :param key_resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. + :paramtype key_resolver: Optional[Callable[[str], KeyEncryptionKey]] :param bytes content: The encrypted blob content. :param int start_offset: @@ -827,7 +893,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements if version == _ENCRYPTION_PROTOCOL_V1: blob_type = response_headers['x-ms-blob-type'] - iv = None + iv: Optional[bytes] = None unpad = False if 'content-range' in response_headers: content_range = response_headers['content-range'] @@ -857,6 +923,9 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements if blob_type == 'PageBlob': unpad = False + if iv is None: + raise ValueError("Missing required metadata for Encryption V1") + cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) decryptor = cipher.decryptor() @@ -872,6 +941,9 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements total_size = len(content) offset = 0 + if encryption_data.encrypted_region_info is None: + raise ValueError("Missing required metadata for Encryption V2") + nonce_length = encryption_data.encrypted_region_info.nonce_length data_length = encryption_data.encrypted_region_info.data_length tag_length = encryption_data.encrypted_region_info.tag_length @@ -899,7 +971,11 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements raise ValueError('Specified encryption version is not supported.') -def get_blob_encryptor_and_padder(cek, iv, should_pad): +def get_blob_encryptor_and_padder( + cek: Optional[bytes], + iv: Optional[bytes], + should_pad: bool +) -> Tuple[Optional["AEADEncryptionContext"], Optional["PaddingContext"]]: encryptor = None padder = None @@ -911,23 +987,26 @@ def get_blob_encryptor_and_padder(cek, iv, should_pad): return encryptor, padder -def encrypt_queue_message(message, key_encryption_key, version): - ''' +def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, version: str) -> str: + """ Encrypts the given plain text message using the given protocol version. Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). Returns a json-formatted string containing the encrypted message and the encryption metadata. - :param object message: + :param str message: The plain text message to be encrypted. - :param object key_encryption_key: + :param KeyEncryptionKey key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - wrap_key(key)--wraps the specified key using an algorithm of the user's choice. - get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. - get_kid()--returns a string key id for this key-encryption-key. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. + get_kid() + - Returns a string key id for this key-encryption-key. :param str version: The client encryption version to use. :return: A json-formatted string containing the encrypted message and the encryption metadata. :rtype: str - ''' + """ _validate_not_none('message', message) _validate_not_none('key_encryption_key', key_encryption_key) @@ -935,7 +1014,7 @@ def encrypt_queue_message(message, key_encryption_key, version): # Queue encoding functions all return unicode strings, and encryption should # operate on binary strings. - message = message.encode('utf-8') + message_as_bytes: bytes = message.encode('utf-8') if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks @@ -946,7 +1025,7 @@ def encrypt_queue_message(message, key_encryption_key, version): # PKCS7 with 16 byte blocks ensures compatibility with AES. padder = PKCS7(128).padder() - padded_data = padder.update(message) + padder.finalize() + padded_data = padder.update(message_as_bytes) + padder.finalize() # Encrypt the data. encryptor = cipher.encryptor() @@ -962,8 +1041,8 @@ def encrypt_queue_message(message, key_encryption_key, version): aesgcm = AESGCM(content_encryption_key) # Returns ciphertext + tag - ciphertext_with_tag = aesgcm.encrypt(nonce, message, None) - encrypted_data = nonce + ciphertext_with_tag + cipertext_with_tag = aesgcm.encrypt(nonce, message_as_bytes, None) + encrypted_data = nonce + cipertext_with_tag else: raise ValueError("Invalid encryption version specified.") @@ -978,7 +1057,13 @@ def encrypt_queue_message(message, key_encryption_key, version): return dumps(queue_message) -def decrypt_queue_message(message, response, require_encryption, key_encryption_key, resolver): +def decrypt_queue_message( + message: str, + response: "PipelineResponse", + require_encryption: bool, + key_encryption_key: Optional[KeyEncryptionKey], + resolver: Optional[Callable[[str], KeyEncryptionKey]] +) -> str: """ Returns the decrypted message contents from an EncryptedQueueMessage. If no encryption metadata is present, will return the unaltered message. @@ -988,13 +1073,15 @@ def decrypt_queue_message(message, response, require_encryption, key_encryption_ The pipeline response used to generate an error with. :param bool require_encryption: If set, will enforce that the retrieved messages are encrypted and decrypt them. - :param object key_encryption_key: + :param Optional[KeyEncryptionKey] key_encryption_key: The user-provided key-encryption-key. Must implement the following methods: - unwrap_key(key, algorithm) - - returns the unwrapped form of the specified symmetric key usingthe string-specified algorithm. + wrap_key(key) + - Wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm() + - Returns the algorithm used to wrap the specified symmetric key. get_kid() - - returns a string key id for this key-encryption-key. - :param Callable resolver: + - Returns a string key id for this key-encryption-key. + :param Optional[Callable[[str], KeyEncryptionKey]] resolver: The user-provided key resolver. Uses the kid string to return a key-encryption-key implementing the interface defined above. :return: The plain text message from the queue message. @@ -1003,10 +1090,10 @@ def decrypt_queue_message(message, response, require_encryption, key_encryption_ response = response.http_response try: - message = loads(message) + deserialized_message: Dict[str, Any] = loads(message) - encryption_data = _dict_to_encryption_data(message['EncryptionData']) - decoded_data = decode_base64_to_bytes(message['EncryptedMessageContents']) + encryption_data = _dict_to_encryption_data(deserialized_message['EncryptionData']) + decoded_data = decode_base64_to_bytes(deserialized_message['EncryptedMessageContents']) except (KeyError, ValueError) as exc: # Message was not json formatted and so was not encrypted # or the user provided a json formatted message @@ -1022,5 +1109,5 @@ def decrypt_queue_message(message, response, require_encryption, key_encryption_ except Exception as error: raise HttpResponseError( message="Decryption failed.", - response=response, + response=response, #type: ignore [arg-type] error=error) from error diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index 9b3d3209d192..ce490e354b68 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -3,32 +3,48 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -# pylint: disable=unused-argument -import sys -from base64 import b64encode, b64decode +from base64 import b64decode, b64encode +from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union from azure.core.exceptions import DecodeError -from ._encryption import decrypt_queue_message, encrypt_queue_message, _ENCRYPTION_PROTOCOL_V1 +from ._encryption import decrypt_queue_message, encrypt_queue_message, KeyEncryptionKey, _ENCRYPTION_PROTOCOL_V1 + +if TYPE_CHECKING: + from azure.core.pipeline import PipelineResponse class MessageEncodePolicy(object): - def __init__(self): + require_encryption: bool + """Indicates whether encryption is required or not.""" + encryption_version: str + """Indicates the version of encryption being used.""" + key_encryption_key: Optional[KeyEncryptionKey] + """The user-provided key-encryption-key.""" + resolver: Optional[Callable[[str], KeyEncryptionKey]] + """The user-provided key resolver.""" + + def __init__(self) -> None: self.require_encryption = False - self.encryption_version = None + self.encryption_version = _ENCRYPTION_PROTOCOL_V1 self.key_encryption_key = None self.resolver = None - def __call__(self, content): + def __call__(self, content: Any) -> str: if content: content = self.encode(content) if self.key_encryption_key is not None: content = encrypt_queue_message(content, self.key_encryption_key, self.encryption_version) return content - def configure(self, require_encryption, key_encryption_key, resolver, encryption_version=_ENCRYPTION_PROTOCOL_V1): + def configure( + self, require_encryption: bool, + key_encryption_key: Optional[KeyEncryptionKey], + resolver: Optional[Callable[[str], KeyEncryptionKey]], + encryption_version: str = _ENCRYPTION_PROTOCOL_V1 + ) -> None: self.require_encryption = require_encryption self.encryption_version = encryption_version self.key_encryption_key = key_encryption_key @@ -36,18 +52,25 @@ def configure(self, require_encryption, key_encryption_key, resolver, encryption if self.require_encryption and not self.key_encryption_key: raise ValueError("Encryption required but no key was provided.") - def encode(self, content): + def encode(self, content: Any) -> str: raise NotImplementedError("Must be implemented by child class.") class MessageDecodePolicy(object): - def __init__(self): + require_encryption: bool = False + """Indicates whether encryption is required or not.""" + key_encryption_key: Optional[KeyEncryptionKey] = None + """The user-provided key-encryption-key.""" + resolver: Optional[Callable[[str], KeyEncryptionKey]] = None + """The user-provided key resolver.""" + + def __init__(self) -> None: self.require_encryption = False self.key_encryption_key = None self.resolver = None - def __call__(self, response, obj, headers): + def __call__(self, response: "PipelineResponse", obj: Iterable, headers: Dict[str, Any]) -> object: for message in obj: if message.message_text in [None, "", b""]: continue @@ -61,12 +84,16 @@ def __call__(self, response, obj, headers): message.message_text = self.decode(content, response) return obj - def configure(self, require_encryption, key_encryption_key, resolver): + def configure( + self, require_encryption: bool, + key_encryption_key: Optional[KeyEncryptionKey], + resolver: Optional[Callable[[str], KeyEncryptionKey]] + ) -> None: self.require_encryption = require_encryption self.key_encryption_key = key_encryption_key self.resolver = resolver - def decode(self, content, response): + def decode(self, content: Any, response: "PipelineResponse") -> Union[bytes, str]: raise NotImplementedError("Must be implemented by child class.") @@ -77,7 +104,7 @@ class TextBase64EncodePolicy(MessageEncodePolicy): is not text, a TypeError will be raised. Input text must support UTF-8. """ - def encode(self, content): + def encode(self, content: str) -> str: if not isinstance(content, str): raise TypeError("Message content must be text for base 64 encoding.") return b64encode(content.encode('utf-8')).decode('utf-8') @@ -91,14 +118,14 @@ class TextBase64DecodePolicy(MessageDecodePolicy): support UTF-8. """ - def decode(self, content, response): + def decode(self, content: str, response: "PipelineResponse") -> str: try: return b64decode(content.encode('utf-8')).decode('utf-8') except (ValueError, TypeError) as error: # ValueError for Python 3, TypeError for Python 2 raise DecodeError( message="Message content is not valid base 64.", - response=response, + response=response, #type: ignore error=error) from error @@ -109,7 +136,7 @@ class BinaryBase64EncodePolicy(MessageEncodePolicy): is not bytes, a TypeError will be raised. """ - def encode(self, content): + def encode(self, content: bytes) -> str: if not isinstance(content, bytes): raise TypeError("Message content must be bytes for base 64 encoding.") return b64encode(content).decode('utf-8') @@ -122,7 +149,7 @@ class BinaryBase64DecodePolicy(MessageDecodePolicy): is not valid base 64, a DecodeError will be raised. """ - def decode(self, content, response): + def decode(self, content: str, response: "PipelineResponse") -> bytes: response = response.http_response try: return b64decode(content.encode('utf-8')) @@ -130,23 +157,21 @@ def decode(self, content, response): # ValueError for Python 3, TypeError for Python 2 raise DecodeError( message="Message content is not valid base 64.", - response=response, + response=response, #type: ignore error=error) from error class NoEncodePolicy(MessageEncodePolicy): """Bypass any message content encoding.""" - def encode(self, content): - if isinstance(content, bytes) and sys.version_info > (3,): - raise TypeError( - "Message content must not be bytes. Use the BinaryBase64EncodePolicy to send bytes." - ) + def encode(self, content: str) -> str: + if isinstance(content, bytes): + raise TypeError("Message content must not be bytes. Use the BinaryBase64EncodePolicy to send bytes.") return content class NoDecodePolicy(MessageDecodePolicy): """Bypass any message content decoding.""" - def decode(self, content, response): + def decode(self, content: str, response: "PipelineResponse") -> str: return content diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 47f832cd6327..bdda1c7b33ab 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -6,16 +6,59 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -from typing import List # pylint: disable=unused-import +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator -from ._shared.response_handlers import return_context_and_deserialized, process_storage_error +from ._shared.response_handlers import process_storage_error, return_context_and_deserialized from ._shared.models import DictMixin from ._generated.models import AccessPolicy as GenAccessPolicy +from ._generated.models import CorsRule as GeneratedCorsRule from ._generated.models import Logging as GeneratedLogging from ._generated.models import Metrics as GeneratedMetrics from ._generated.models import RetentionPolicy as GeneratedRetentionPolicy -from ._generated.models import CorsRule as GeneratedCorsRule + +if sys.version_info >= (3, 11): + from typing import Self # pylint: disable=no-name-in-module, ungrouped-imports +else: + from typing_extensions import Self # pylint: disable=ungrouped-imports + +if TYPE_CHECKING: + from datetime import datetime + + +class RetentionPolicy(GeneratedRetentionPolicy): + """The retention policy which determines how long the associated data should + persist. + + All required parameters must be populated in order to send to Azure. + + :param bool enabled: Required. Indicates whether a retention policy is enabled + for the storage service. + :param int days: Indicates the number of days that metrics or logging or + soft-deleted data should be retained. All data older than this value will + be deleted. + """ + + enabled: bool = False + """Indicates whether a retention policy is enabled for the storage service.""" + days: Optional[int] = None + """Indicates the number of days that metrics or logging or soft-deleted data should be retained.""" + + def __init__(self, enabled: bool = False, days: Optional[int] = None) -> None: + self.enabled = enabled + self.days = days + if self.enabled and (self.days is None): + raise ValueError("If policy is enabled, 'days' must be specified.") + + @classmethod + def _from_generated(cls, generated: Any) -> Self: + if not generated: + return cls() + return cls( + enabled=generated.enabled, + days=generated.days, + ) class QueueAnalyticsLogging(GeneratedLogging): @@ -27,11 +70,21 @@ class QueueAnalyticsLogging(GeneratedLogging): :keyword bool delete: Required. Indicates whether all delete requests should be logged. :keyword bool read: Required. Indicates whether all read requests should be logged. :keyword bool write: Required. Indicates whether all write requests should be logged. - :keyword ~azure.storage.queue.RetentionPolicy retention_policy: Required. - The retention policy for the metrics. + :keyword ~azure.storage.queue.RetentionPolicy retention_policy: The retention policy for the metrics. """ - def __init__(self, **kwargs): + version: str = '1.0' + """The version of Storage Analytics to configure.""" + delete: bool = False + """Indicates whether all delete requests should be logged.""" + read: bool = False + """Indicates whether all read requests should be logged.""" + write: bool = False + """Indicates whether all write requests should be logged.""" + retention_policy: RetentionPolicy = RetentionPolicy() + """The retention policy for the metrics.""" + + def __init__(self, **kwargs: Any) -> None: self.version = kwargs.get('version', '1.0') self.delete = kwargs.get('delete', False) self.read = kwargs.get('read', False) @@ -39,7 +92,7 @@ def __init__(self, **kwargs): self.retention_policy = kwargs.get('retention_policy') or RetentionPolicy() @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: if not generated: return cls() return cls( @@ -58,20 +111,28 @@ class Metrics(GeneratedMetrics): :keyword str version: The version of Storage Analytics to configure. :keyword bool enabled: Required. Indicates whether metrics are enabled for the service. - :keyword bool include_ap_is: Indicates whether metrics should generate summary + :keyword bool include_apis: Indicates whether metrics should generate summary statistics for called API operations. - :keyword ~azure.storage.queue.RetentionPolicy retention_policy: Required. - The retention policy for the metrics. + :keyword ~azure.storage.queue.RetentionPolicy retention_policy: The retention policy for the metrics. """ - def __init__(self, **kwargs): + version: str = '1.0' + """The version of Storage Analytics to configure.""" + enabled: bool = False + """Indicates whether metrics are enabled for the service.""" + include_apis: Optional[bool] + """Indicates whether metrics should generate summary statistics for called API operations.""" + retention_policy: RetentionPolicy = RetentionPolicy() + """The retention policy for the metrics.""" + + def __init__(self, **kwargs: Any) -> None: self.version = kwargs.get('version', '1.0') self.enabled = kwargs.get('enabled', False) self.include_apis = kwargs.get('include_apis') self.retention_policy = kwargs.get('retention_policy') or RetentionPolicy() @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: if not generated: return cls() return cls( @@ -82,35 +143,6 @@ def _from_generated(cls, generated): ) -class RetentionPolicy(GeneratedRetentionPolicy): - """The retention policy which determines how long the associated data should - persist. - - All required parameters must be populated in order to send to Azure. - - :param bool enabled: Required. Indicates whether a retention policy is enabled - for the storage service. - :param int days: Indicates the number of days that metrics or logging or - soft-deleted data should be retained. All data older than this value will - be deleted. - """ - - def __init__(self, enabled=False, days=None): - self.enabled = enabled - self.days = days - if self.enabled and (self.days is None): - raise ValueError("If policy is enabled, 'days' must be specified.") - - @classmethod - def _from_generated(cls, generated): - if not generated: - return cls() - return cls( - enabled=generated.enabled, - days=generated.days, - ) - - class CorsRule(GeneratedCorsRule): """CORS is an HTTP feature that enables a web application running under one domain to access resources in another domain. Web browsers implement a @@ -120,36 +152,68 @@ class CorsRule(GeneratedCorsRule): All required parameters must be populated in order to send to Azure. - :param list(str) allowed_origins: + :param List[str] allowed_origins: A list of origin domains that will be allowed via CORS, or "*" to allow - all domains. The list of must contain at least one entry. Limited to 64 + all domains. The list must contain at least one entry. Limited to 64 origin domains. Each allowed origin can have up to 256 characters. - :param list(str) allowed_methods: + :param List[str] allowed_methods: A list of HTTP methods that are allowed to be executed by the origin. - The list of must contain at least one entry. For Azure Storage, + The list must contain at least one entry. For Azure Storage, permitted methods are DELETE, GET, HEAD, MERGE, POST, OPTIONS or PUT. :keyword int max_age_in_seconds: The number of seconds that the client/browser should cache a pre-flight response. - :keyword list(str) exposed_headers: + :keyword List[str] exposed_headers: Defaults to an empty list. A list of response headers to expose to CORS clients. Limited to 64 defined headers and two prefixed headers. Each header can be up to 256 characters. - :keyword list(str) allowed_headers: + :keyword List[str] allowed_headers: Defaults to an empty list. A list of headers allowed to be part of the cross-origin request. Limited to 64 defined headers and 2 prefixed headers. Each header can be up to 256 characters. """ - def __init__(self, allowed_origins, allowed_methods, **kwargs): + allowed_origins: str + """The comma-delimited string representation of the list of origin domains that will be allowed via + CORS, or "*" to allow all domains.""" + allowed_methods: str + """The comma-delimited string representation of the list HTTP methods that are allowed to be executed + by the origin.""" + max_age_in_seconds: int + """The number of seconds that the client/browser should cache a pre-flight response.""" + exposed_headers: str + """The comma-delimited string representation of the list of response headers to expose to CORS clients.""" + allowed_headers: str + """The comma-delimited string representation of the list of headers allowed to be part of the cross-origin + request.""" + + def __init__(self, allowed_origins: List[str], allowed_methods: List[str], **kwargs: Any) -> None: self.allowed_origins = ','.join(allowed_origins) self.allowed_methods = ','.join(allowed_methods) self.allowed_headers = ','.join(kwargs.get('allowed_headers', [])) self.exposed_headers = ','.join(kwargs.get('exposed_headers', [])) self.max_age_in_seconds = kwargs.get('max_age_in_seconds', 0) + @staticmethod + def _to_generated(rules: Optional[List["CorsRule"]]) -> Optional[List[GeneratedCorsRule]]: + if rules is None: + return rules + + generated_cors_list = [] + for cors_rule in rules: + generated_cors = GeneratedCorsRule( + allowed_origins=cors_rule.allowed_origins, + allowed_methods=cors_rule.allowed_methods, + allowed_headers=cors_rule.allowed_headers, + exposed_headers=cors_rule.exposed_headers, + max_age_in_seconds=cors_rule.max_age_in_seconds + ) + generated_cors_list.append(generated_cors) + + return generated_cors_list + @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: return cls( [generated.allowed_origins], [generated.allowed_methods], @@ -159,6 +223,72 @@ def _from_generated(cls, generated): ) +class QueueSasPermissions(object): + """QueueSasPermissions class to be used with the + :func:`~azure.storage.queue.generate_queue_sas` function and for the AccessPolicies used with + :func:`~azure.storage.queue.QueueClient.set_queue_access_policy`. + + :param bool read: + Read metadata and properties, including message count. Peek at messages. + :param bool add: + Add messages to the queue. + :param bool update: + Update messages in the queue. Note: Use the Process permission with + Update so you can first get the message you want to update. + :param bool process: + Get and delete messages from the queue. + """ + + read: bool = False + """Read metadata and properties, including message count.""" + add: bool = False + """Add messages to the queue.""" + update: bool = False + """Update messages in the queue.""" + process: bool = False + """Get and delete messages from the queue.""" + + def __init__( + self, read: bool = False, + add: bool = False, + update: bool = False, + process: bool = False + ) -> None: + self.read = read + self.add = add + self.update = update + self.process = process + self._str = (('r' if self.read else '') + + ('a' if self.add else '') + + ('u' if self.update else '') + + ('p' if self.process else '')) + + def __str__(self): + return self._str + + @classmethod + def from_string(cls, permission: str) -> Self: + """Create a QueueSasPermissions from a string. + + To specify read, add, update, or process permissions you need only to + include the first letter of the word in the string. E.g. For read and + update permissions, you would provide a string "ru". + + :param str permission: The string which dictates the + read, add, update, or process permissions. + :return: A QueueSasPermissions object + :rtype: ~azure.storage.queue.QueueSasPermissions + """ + p_read = 'r' in permission + p_add = 'a' in permission + p_update = 'u' in permission + p_process = 'p' in permission + + parsed = cls(p_read, p_add, p_update, p_process) + + return parsed + + class AccessPolicy(GenAccessPolicy): """Access Policy class used by the set and get access policy methods. @@ -179,74 +309,83 @@ class AccessPolicy(GenAccessPolicy): both in the Shared Access Signature URL and in the stored access policy, the request will fail with status code 400 (Bad Request). - :param str permission: + :param Optional[QueueSasPermissions] permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. Required unless an id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. - :param expiry: + :param Optional[Union["datetime", str]] expiry: The time at which the shared access signature becomes invalid. Required unless an id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type expiry: ~datetime.datetime or str - :param start: + :param Optional[Union["datetime", str]] start: The time at which the shared access signature becomes valid. If omitted, start time for this call is assumed to be the time when the storage service receives the request. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type start: ~datetime.datetime or str """ - def __init__(self, permission=None, expiry=None, start=None): + permission: Optional[QueueSasPermissions] #type: ignore [assignment] + """The permissions associated with the shared access signature. The user is restricted to + operations allowed by the permissions.""" + expiry: Optional[Union["datetime", str]] #type: ignore [assignment] + """The time at which the shared access signature becomes invalid.""" + start: Optional[Union["datetime", str]] #type: ignore [assignment] + """The time at which the shared access signature becomes valid.""" + + def __init__( + self, permission: Optional[QueueSasPermissions] = None, + expiry: Optional[Union["datetime", str]] = None, + start: Optional[Union["datetime", str]] = None + ) -> None: self.start = start self.expiry = expiry self.permission = permission class QueueMessage(DictMixin): - """Represents a queue message. + """Represents a queue message.""" - :ivar str id: - A GUID value assigned to the message by the Queue service that + id: str + """A GUID value assigned to the message by the Queue service that identifies the message in the queue. This value may be used together with the value of pop_receipt to delete a message from the queue after - it has been retrieved with the receive messages operation. - :ivar date inserted_on: - A UTC date value representing the time the messages was inserted. - :ivar date expires_on: - A UTC date value representing the time the message expires. - :ivar int dequeue_count: - Begins with a value of 1 the first time the message is received. This - value is incremented each time the message is subsequently received. - :param obj content: - The message content. Type is determined by the decode_function set on - the service. Default is str. - :ivar str pop_receipt: - A receipt str which can be used together with the message_id element to + it has been retrieved with the receive messages operation.""" + inserted_on: Optional["datetime"] + """A UTC date value representing the time the messages was inserted.""" + expires_on: Optional["datetime"] + """A UTC date value representing the time the message expires.""" + dequeue_count: Optional[int] + """Begins with a value of 1 the first time the message is received. This + value is incremented each time the message is subsequently received.""" + content: Any + """The message content. Type is determined by the decode_function set on + the service. Default is str.""" + pop_receipt: Optional[str] + """A receipt str which can be used together with the message_id element to delete a message from the queue after it has been retrieved with the receive messages operation. Only returned by receive messages operations. Set to - None for peek messages. - :ivar date next_visible_on: - A UTC date value representing the time the message will next be visible. - Only returned by receive messages operations. Set to None for peek messages. - """ - - def __init__(self, content=None): - self.id = None - self.inserted_on = None - self.expires_on = None - self.dequeue_count = None + None for peek messages.""" + next_visible_on: Optional["datetime"] + """A UTC date value representing the time the message will next be visible. + Only returned by receive messages operations. Set to None for peek messages.""" + + def __init__(self, content: Optional[Any] = None, **kwargs: Any) -> None: + self.id = kwargs.pop('id', None) + self.inserted_on = kwargs.pop('inserted_on', None) + self.expires_on = kwargs.pop('expires_on', None) + self.dequeue_count = kwargs.pop('dequeue_count', None) self.content = content - self.pop_receipt = None - self.next_visible_on = None + self.pop_receipt = kwargs.pop('pop_receipt', None) + self.next_visible_on = kwargs.pop('next_visible_on', None) @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: message = cls(content=generated.message_text) message.id = generated.message_id message.inserted_on = generated.insertion_time @@ -261,13 +400,26 @@ def _from_generated(cls, generated): class MessagesPaged(PageIterator): """An iterable of Queue Messages. - :param callable command: Function to retrieve the next page of items. - :param int results_per_page: The maximum number of messages to retrieve per + :param Callable command: Function to retrieve the next page of items. + :param Optional[int] results_per_page: The maximum number of messages to retrieve per call. - :param int max_messages: The maximum number of messages to retrieve from + :param Optional[int] max_messages: The maximum number of messages to retrieve from the queue. """ - def __init__(self, command, results_per_page=None, continuation_token=None, max_messages=None): + + command: Callable + """Function to retrieve the next page of items.""" + results_per_page: Optional[int] = None + """The maximum number of messages to retrieve per call.""" + max_messages: Optional[int] = None + """The maximum number of messages to retrieve from the queue.""" + + def __init__( + self, command: Callable, + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None, + max_messages: Optional[int] = None + ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") @@ -279,7 +431,7 @@ def __init__(self, command, results_per_page=None, continuation_token=None, max_ self.results_per_page = results_per_page self._max_messages = max_messages - def _get_next_cb(self, continuation_token): + def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: if self._max_messages is not None: if self.results_per_page is None: @@ -291,7 +443,7 @@ def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - def _extract_data_cb(self, messages): + def _extract_data_cb(self, messages: Any) -> Tuple[str, List[QueueMessage]]: # There is no concept of continuation token, so raising on my own condition if not messages: raise StopIteration("End of paging") @@ -303,20 +455,27 @@ def _extract_data_cb(self, messages): class QueueProperties(DictMixin): """Queue Properties. - :ivar str name: The name of the queue. - :keyword dict(str,str) metadata: + :keyword str name: + The name of the queue. + :keyword Optional[Dict[str, str]] metadata: A dict containing name-value pairs associated with the queue as metadata. This var is set to None unless the include=metadata param was included - for the list queues operation. If this parameter was specified but the + for the list queues operation. """ - def __init__(self, **kwargs): - self.name = None + name: str + """The name of the queue.""" + metadata: Optional[Dict[str, str]] + """A dict containing name-value pairs associated with the queue as metadata.""" + approximate_message_count: Optional[int] + """The approximate number of messages contained in the queue.""" + + def __init__(self, **kwargs: Any) -> None: self.metadata = kwargs.get('metadata') self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') @classmethod - def _from_generated(cls, generated): + def _from_generated(cls, generated: Any) -> Self: props = cls() props.name = generated.name props.metadata = generated.metadata @@ -326,24 +485,40 @@ def _from_generated(cls, generated): class QueuePropertiesPaged(PageIterator): """An iterable of Queue properties. - :ivar str service_endpoint: The service URL. - :ivar str prefix: A queue name prefix being used to filter the list. - :ivar str marker: The continuation token of the current page of results. - :ivar int results_per_page: The maximum number of results retrieved per API call. - :ivar str next_marker: The continuation token to retrieve the next page of results. - :ivar str location_mode: The location mode being used to list results. The available - options include "primary" and "secondary". - :param callable command: Function to retrieve the next page of items. - :param str prefix: Filters the results to return only queues whose names + :param Callable command: Function to retrieve the next page of items. + :param Optional[str] prefix: Filters the results to return only queues whose names begin with the specified prefix. - :param int results_per_page: The maximum number of queue names to retrieve per + :param Optional[int] results_per_page: The maximum number of queue names to retrieve per call. :param str continuation_token: An opaque continuation token. """ - def __init__(self, command, prefix=None, results_per_page=None, continuation_token=None): + + service_endpoint: Optional[str] + """The service URL.""" + prefix: Optional[str] + """A queue name prefix being used to filter the list.""" + marker: Optional[str] + """The continuation token of the current page of results.""" + results_per_page: Optional[int] = None + """The maximum number of results retrieved per API call.""" + next_marker: str + """The continuation token to retrieve the next page of results.""" + location_mode: Optional[str] + """The location mode being used to list results. The available options include "primary" and "secondary".""" + command: Callable + """Function to retrieve the next page of items.""" + _response: Any + """Function to retrieve the next page of items.""" + + def __init__( + self, command: Callable, + prefix: Optional[str] = None, + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None + ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, + self._extract_data_cb, #type: ignore continuation_token=continuation_token or "" ) self._command = command @@ -353,7 +528,7 @@ def __init__(self, command, prefix=None, results_per_page=None, continuation_tok self.results_per_page = results_per_page self.location_mode = None - def _get_next_cb(self, continuation_token): + def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: return self._command( marker=continuation_token or None, @@ -363,7 +538,7 @@ def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - def _extract_data_cb(self, get_next_return): + def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return self.service_endpoint = self._response.service_endpoint self.prefix = self._response.prefix @@ -373,59 +548,13 @@ def _extract_data_cb(self, get_next_return): return self._response.next_marker or None, props_list -class QueueSasPermissions(object): - """QueueSasPermissions class to be used with the - :func:`~azure.storage.queue.generate_queue_sas` function and for the AccessPolicies used with - :func:`~azure.storage.queue.QueueClient.set_queue_access_policy`. +def service_stats_deserialize(generated: Any) -> Dict[str, Any]: + """Deserialize a ServiceStats objects into a dict. - :param bool read: - Read metadata and properties, including message count. Peek at messages. - :param bool add: - Add messages to the queue. - :param bool update: - Update messages in the queue. Note: Use the Process permission with - Update so you can first get the message you want to update. - :param bool process: - Get and delete messages from the queue. + :param Any generated: The service stats returned from the generated code. + :returns: The deserialized ServiceStats as a Dict. + :rtype: Dict[str, Any] """ - def __init__(self, read=False, add=False, update=False, process=False): - self.read = read - self.add = add - self.update = update - self.process = process - self._str = (('r' if self.read else '') + - ('a' if self.add else '') + - ('u' if self.update else '') + - ('p' if self.process else '')) - - def __str__(self): - return self._str - - @classmethod - def from_string(cls, permission): - """Create a QueueSasPermissions from a string. - - To specify read, add, update, or process permissions you need only to - include the first letter of the word in the string. E.g. For read and - update permissions, you would provide a string "ru". - - :param str permission: The string which dictates the - read, add, update, or process permissions. - :return: A QueueSasPermissions object - :rtype: ~azure.storage.queue.QueueSasPermissions - """ - p_read = 'r' in permission - p_add = 'a' in permission - p_update = 'u' in permission - p_process = 'p' in permission - - parsed = cls(p_read, p_add, p_update, p_process) - - return parsed - - -# Deserialize a ServiceStats objects into a dict. -def service_stats_deserialize(generated): return { 'geo_replication': { 'status': generated.geo_replication.status, @@ -434,8 +563,13 @@ def service_stats_deserialize(generated): } -# Deserialize a ServiceProperties objects into a dict. -def service_properties_deserialize(generated): +def service_properties_deserialize(generated: Any) -> Dict[str, Any]: + """Deserialize a ServiceProperties objects into a dict. + + :param Any generated: The service properties returned from the generated code. + :returns: The deserialized ServiceProperties as a Dict. + :rtype: Dict[str, Any] + """ return { 'analytics_logging': QueueAnalyticsLogging._from_generated(generated.logging), # pylint: disable=protected-access 'hour_metrics': Metrics._from_generated(generated.hour_metrics), # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index 9245efa18a73..df2e8975c21c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -7,29 +7,29 @@ import functools import warnings -from typing import ( # pylint: disable=unused-import - Any, Dict, List, Optional, Union, - TYPE_CHECKING) -from urllib.parse import urlparse, quote, unquote - +from typing import ( + Any, cast, Dict, List, Optional, + Tuple, TYPE_CHECKING, Union +) from typing_extensions import Self from azure.core.exceptions import HttpResponseError from azure.core.paging import ItemPaged from azure.core.tracing.decorator import distributed_trace -from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query +from ._deserialize import deserialize_queue_creation, deserialize_queue_properties +from ._encryption import modify_user_agent_for_encryption, StorageEncryptionMixin +from ._generated import AzureQueueStorage +from ._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier +from ._message_encoding import NoDecodePolicy, NoEncodePolicy +from ._models import AccessPolicy, MessagesPaged, QueueMessage +from ._queue_client_helpers import _format_url, _from_queue_url, _parse_url +from ._serialize import get_api_version +from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin from ._shared.request_handlers import add_metadata_headers, serialize_iso from ._shared.response_handlers import ( process_storage_error, - return_response_headers, - return_headers_and_deserialized) -from ._generated import AzureQueueStorage -from ._generated.models import SignedIdentifier, QueueMessage as GenQueueMessage -from ._deserialize import deserialize_queue_properties, deserialize_queue_creation -from ._encryption import modify_user_agent_for_encryption, StorageEncryptionMixin -from ._message_encoding import NoEncodePolicy, NoDecodePolicy -from ._models import QueueMessage, AccessPolicy, MessagesPaged -from ._serialize import get_api_version + return_headers_and_deserialized, + return_response_headers) if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -57,6 +57,7 @@ class QueueClient(StorageAccountHostsMixin, StorageEncryptionMixin): - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -82,50 +83,41 @@ class QueueClient(StorageAccountHostsMixin, StorageEncryptionMixin): :caption: Create the queue client with url and credential. """ def __init__( - self, account_url: str, - queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> None: - try: - if not account_url.lower().startswith('http'): - account_url = "https://" + account_url - except AttributeError as exc: - raise ValueError("Account URL must be a string.") from exc - parsed_url = urlparse(account_url.rstrip('/')) - if not queue_name: - raise ValueError("Please specify a queue name.") - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {parsed_url}") - - _, sas_token = parse_query(parsed_url.query) - if not sas_token and not credential: - raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") - + self, account_url: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: + parsed_url, sas_token = _parse_url(account_url=account_url, queue_name=queue_name, credential=credential) self.queue_name = queue_name self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) - - self._config.message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() - self._config.message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() + self._message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self._message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._configure_encryption(kwargs) - def _format_url(self, hostname): - queue_name = self.queue_name - if isinstance(queue_name, str): - queue_name = queue_name.encode('UTF-8') - return ( - f"{self.scheme}://{hostname}" - f"/{quote(queue_name)}{self._query_str}") + def _format_url(self, hostname: str) -> str: + """Format the endpoint URL according to the current location + mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str + """ + return _format_url( + queue_name=self.queue_name, + hostname=hostname, + scheme=self.scheme, + query_str=self._query_str) @classmethod def from_queue_url( - cls, queue_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> Self: + cls, queue_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: """A client to interact with a specific Queue. :param str queue_url: The full URI to the queue, including SAS token if used. @@ -145,35 +137,16 @@ def from_queue_url( :returns: A queue client. :rtype: ~azure.storage.queue.QueueClient """ - try: - if not queue_url.lower().startswith('http'): - queue_url = "https://" + queue_url - except AttributeError as exc: - raise ValueError("Queue URL must be a string.") from exc - parsed_url = urlparse(queue_url.rstrip('/')) - - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {queue_url}") - - queue_path = parsed_url.path.lstrip('/').split('/') - account_path = "" - if len(queue_path) > 1: - account_path = "/" + "/".join(queue_path[:-1]) - account_url = ( - f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" - f"{account_path}?{parsed_url.query}") - queue_name = unquote(queue_path[-1]) - if not queue_name: - raise ValueError("Invalid URL. Please provide a URL with a valid queue name") + account_url, queue_name = _from_queue_url(queue_url=queue_url) return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @classmethod def from_connection_string( - cls, conn_str: str, - queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> Self: + cls, conn_str: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: """Create QueueClient from a Connection String. :param str conn_str: @@ -209,17 +182,20 @@ def from_connection_string( conn_str, credential, 'queue') if 'secondary_hostname' not in kwargs: kwargs['secondary_hostname'] = secondary - return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) # type: ignore + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) @distributed_trace - def create_queue(self, **kwargs): - # type: (Any) -> None + def create_queue( + self, *, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Creates a new queue in the storage account. If a queue with the same name already exists, the operation fails with a `ResourceExistsError`. - :keyword dict(str,str) metadata: + :keyword Dict[str,str] metadata: A dict containing name-value pairs to associate with the queue as metadata. Note that metadata names preserve the case with which they were created, but are case-insensitive when set or read. @@ -243,11 +219,10 @@ def create_queue(self, **kwargs): :caption: Create a queue. """ headers = kwargs.pop('headers', {}) - metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - return self._client.queue.create( # type: ignore + return self._client.queue.create( metadata=metadata, timeout=timeout, headers=headers, @@ -257,8 +232,7 @@ def create_queue(self, **kwargs): process_storage_error(error) @distributed_trace - def delete_queue(self, **kwargs): - # type: (Any) -> None + def delete_queue(self, **kwargs: Any) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -293,8 +267,7 @@ def delete_queue(self, **kwargs): process_storage_error(error) @distributed_trace - def get_queue_properties(self, **kwargs): - # type: (Any) -> QueueProperties + def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """Returns all user-defined metadata for the specified queue. The data returned does not include the queue's list of messages. @@ -315,20 +288,20 @@ def get_queue_properties(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - response = self._client.queue.get_properties( + response = cast("QueueProperties", self._client.queue.get_properties( timeout=timeout, cls=deserialize_queue_properties, - **kwargs) + **kwargs)) except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name - return response # type: ignore + return response @distributed_trace def set_queue_metadata( - self, metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any - ) -> Dict[str, Any]: + self, metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Sets user-defined metadata on the specified queue. Metadata is associated with the queue as name-value pairs. @@ -336,6 +309,7 @@ def set_queue_metadata( :param Optional[Dict[str, Any]] metadata: A dict containing name-value pairs to associate with the queue as metadata. + :type metadata: Optional[Dict[str, str]] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. @@ -356,9 +330,9 @@ def set_queue_metadata( """ timeout = kwargs.pop('timeout', None) headers = kwargs.pop('headers', {}) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - return self._client.queue.set_metadata( # type: ignore + return self._client.queue.set_metadata( timeout=timeout, headers=headers, cls=return_response_headers, @@ -367,8 +341,7 @@ def set_queue_metadata( process_storage_error(error) @distributed_trace - def get_queue_access_policy(self, **kwargs): - # type: (Any) -> Dict[str, AccessPolicy] + def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: """Returns details about any stored access policies specified on the queue that may be used with Shared Access Signatures. @@ -379,24 +352,23 @@ def get_queue_access_policy(self, **kwargs): see `here `_. :return: A dictionary of access policies associated with the queue. - :rtype: dict(str, ~azure.storage.queue.AccessPolicy) + :rtype: Dict[str, ~azure.storage.queue.AccessPolicy] """ timeout = kwargs.pop('timeout', None) try: - _, identifiers = self._client.queue.get_access_policy( + _, identifiers = cast(Tuple[Dict, List], self._client.queue.get_access_policy( timeout=timeout, cls=return_headers_and_deserialized, - **kwargs) + **kwargs)) except HttpResponseError as error: process_storage_error(error) return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace - def set_queue_access_policy(self, - signed_identifiers, # type: Dict[str, AccessPolicy] - **kwargs # type: Any - ): - # type: (...) -> None + def set_queue_access_policy( + self, signed_identifiers: Dict[str, AccessPolicy], + **kwargs: Any + ) -> None: """Sets stored access policies for the queue that may be used with Shared Access Signatures. @@ -415,7 +387,7 @@ def set_queue_access_policy(self, SignedIdentifier access policies to associate with the queue. This may contain up to 5 elements. An empty dict will clear the access policies set on the service. - :type signed_identifiers: dict(str, ~azure.storage.queue.AccessPolicy) + :type signed_identifiers: Dict[str, ~azure.storage.queue.AccessPolicy] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. @@ -443,10 +415,9 @@ def set_queue_access_policy(self, value.start = serialize_iso(value.start) value.expiry = serialize_iso(value.expiry) identifiers.append(SignedIdentifier(id=key, access_policy=value)) - signed_identifiers = identifiers # type: ignore try: self._client.queue.set_access_policy( - queue_acl=signed_identifiers or None, + queue_acl=identifiers or None, timeout=timeout, **kwargs) except HttpResponseError as error: @@ -454,11 +425,12 @@ def set_queue_access_policy(self, @distributed_trace def send_message( - self, - content, # type: Any - **kwargs # type: Any - ): - # type: (...) -> QueueMessage + self, content: Any, + *, + visibility_timeout: Optional[int] = None, + time_to_live: Optional[int] = None, + **kwargs: Any + ) -> QueueMessage: """Adds a new message to the back of the message queue. The visibility timeout specifies the time that the message will be @@ -472,7 +444,7 @@ def send_message( If the key-encryption-key field is set on the local service object, this method will encrypt the content before uploading. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. The encoded message can be up to 64KB in size. @@ -508,8 +480,6 @@ def send_message( :dedent: 12 :caption: Send messages. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) - time_to_live = kwargs.pop('time_to_live', None) timeout = kwargs.pop('timeout', None) if self.key_encryption_key: modify_user_agent_for_encryption( @@ -519,7 +489,7 @@ def send_message( kwargs) try: - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, @@ -531,11 +501,11 @@ def send_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) - encoded_content = self._config.message_encode_policy(content) + encoded_content = self._message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) try: @@ -545,18 +515,24 @@ def send_message( message_time_to_live=time_to_live, timeout=timeout, **kwargs) - queue_message = QueueMessage(content=content) - queue_message.id = enqueued[0].message_id - queue_message.inserted_on = enqueued[0].insertion_time - queue_message.expires_on = enqueued[0].expiration_time - queue_message.pop_receipt = enqueued[0].pop_receipt - queue_message.next_visible_on = enqueued[0].time_next_visible + queue_message = QueueMessage( + content=content, + id=enqueued[0].message_id, + inserted_on=enqueued[0].insertion_time, + expires_on=enqueued[0].expiration_time, + pop_receipt = enqueued[0].pop_receipt, + next_visible_on = enqueued[0].time_next_visible + ) return queue_message except HttpResponseError as error: process_storage_error(error) @distributed_trace - def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: + def receive_message( + self, *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> Optional[QueueMessage]: """Removes one message from the front of the queue. When the message is retrieved from the queue, the response includes the message @@ -594,7 +570,6 @@ def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: :dedent: 12 :caption: Receive one message from the queue. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( @@ -603,16 +578,17 @@ def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: message = self._client.messages.dequeue( number_of_messages=1, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) wrapped_message = QueueMessage._from_generated( # pylint: disable=protected-access @@ -622,8 +598,13 @@ def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: process_storage_error(error) @distributed_trace - def receive_messages(self, **kwargs): - # type: (Any) -> ItemPaged[QueueMessage] + def receive_messages( + self, *, + messages_per_page: Optional[int] = None, + visibility_timeout: Optional[int] = None, + max_messages: Optional[int] = None, + **kwargs: Any + ) -> ItemPaged[QueueMessage]: """Removes one or more messages from the front of the queue. When a message is retrieved from the queue, the response includes the message @@ -660,14 +641,14 @@ def receive_messages(self, **kwargs): larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. visibility_timeout should be set to a value smaller than the time-to-live value. + :keyword int max_messages: + An integer that specifies the maximum number of messages to retrieve from the queue. :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. This value is not tracked or validated on the client. To configure client-side network timesouts see `here `_. - :keyword int max_messages: - An integer that specifies the maximum number of messages to retrieve from the queue. :return: Returns a message iterator of dict-like Message objects. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.queue.QueueMessage] @@ -681,10 +662,7 @@ def receive_messages(self, **kwargs): :dedent: 12 :caption: Receive messages from the queue. """ - messages_per_page = kwargs.pop('messages_per_page', None) - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) - max_messages = kwargs.pop('max_messages', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( self._config.user_agent_policy.user_agent, @@ -692,16 +670,17 @@ def receive_messages(self, **kwargs): self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: command = functools.partial( self._client.messages.dequeue, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) if max_messages is not None and messages_per_page is not None: @@ -713,13 +692,14 @@ def receive_messages(self, **kwargs): process_storage_error(error) @distributed_trace - def update_message(self, - message, # type: Any - pop_receipt=None, # type: Optional[str] - content=None, # type: Optional[Any] - **kwargs # type: Any - ): - # type: (...) -> QueueMessage + def update_message( + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + content: Optional[Any] = None, + *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> QueueMessage: """Updates the visibility timeout of a message. You can also use this operation to update the contents of a message. @@ -740,7 +720,7 @@ def update_message(self, :param str pop_receipt: A valid pop receipt value returned from an earlier call to the :func:`~receive_messages` or :func:`~update_message` operation. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. :keyword int visibility_timeout: @@ -770,7 +750,6 @@ def update_message(self, :dedent: 12 :caption: Update a message. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( @@ -779,14 +758,14 @@ def update_message(self, self.encryption_version, kwargs) - try: + if isinstance(message, QueueMessage): message_id = message.id message_text = content or message.content receipt = pop_receipt or message.pop_receipt inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - except AttributeError: + else: message_id = message message_text = content receipt = pop_receipt @@ -798,7 +777,7 @@ def update_message(self, raise ValueError("pop_receipt must be present") if message_text is not None: try: - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function, @@ -810,40 +789,41 @@ def update_message(self, Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function) - encoded_message_text = self._config.message_encode_policy(message_text) + encoded_message_text = self._message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: - updated = None # type: ignore + updated = None try: - response = self._client.message_id.update( + response = cast(QueueMessage, self._client.message_id.update( queue_message=updated, visibilitytimeout=visibility_timeout or 0, timeout=timeout, pop_receipt=receipt, cls=return_response_headers, queue_message_id=message_id, - **kwargs) - new_message = QueueMessage(content=message_text) - new_message.id = message_id - new_message.inserted_on = inserted_on - new_message.expires_on = expires_on - new_message.dequeue_count = dequeue_count - new_message.pop_receipt = response['popreceipt'] - new_message.next_visible_on = response['time_next_visible'] + **kwargs)) + new_message = QueueMessage( + content=message_text, + id=message_id, + inserted_on=inserted_on, + dequeue_count=dequeue_count, + expires_on=expires_on, + pop_receipt = response['popreceipt'], + next_visible_on = response['time_next_visible'] + ) return new_message except HttpResponseError as error: process_storage_error(error) @distributed_trace - def peek_messages(self, - max_messages=None, # type: Optional[int] - **kwargs # type: Any - ): - # type: (...) -> List[QueueMessage] + def peek_messages( + self, max_messages: Optional[int] = None, + **kwargs: Any + ) -> List[QueueMessage]: """Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. @@ -894,15 +874,16 @@ def peek_messages(self, self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: messages = self._client.messages.peek( number_of_messages=max_messages, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self._message_decode_policy, **kwargs) wrapped_messages = [] for peeked in messages: @@ -912,8 +893,7 @@ def peek_messages(self, process_storage_error(error) @distributed_trace - def clear_messages(self, **kwargs): - # type: (Any) -> None + def clear_messages(self, **kwargs: Any) -> None: """Deletes all messages from the specified queue. :keyword int timeout: @@ -939,12 +919,11 @@ def clear_messages(self, **kwargs): process_storage_error(error) @distributed_trace - def delete_message(self, - message, # type: Any - pop_receipt=None, # type: Optional[str] - **kwargs # type: Any - ): - # type: (...) -> None + def delete_message( + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + **kwargs: Any + ) -> None: """Deletes the specified message. Normally after a client retrieves a message with the receive messages operation, @@ -980,10 +959,12 @@ def delete_message(self, :caption: Delete a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id receipt = pop_receipt or message.pop_receipt - except AttributeError: + else: message_id = message receipt = pop_receipt diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py new file mode 100644 index 000000000000..878a938fff65 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -0,0 +1,101 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from urllib.parse import quote, unquote, urlparse +from ._shared.base_client import parse_query + +if TYPE_CHECKING: + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials_async import AsyncTokenCredential + from urllib.parse import ParseResult + + +def _parse_url( + account_url: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long +) -> Tuple["ParseResult", Any]: + """Performs initial input validation and returns the parsed URL and SAS token. + + :param str account_url: The URL to the storage account. + :param str queue_name: The name of the queue. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential + - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential, TokenCredential]] # pylint: disable=line-too-long + :returns: The parsed URL and SAS token. + :rtype: Tuple[ParseResult, Any] + """ + try: + if not account_url.lower().startswith('http'): + account_url = "https://" + account_url + except AttributeError as exc: + raise ValueError("Account URL must be a string.") from exc + parsed_url = urlparse(account_url.rstrip('/')) + if not queue_name: + raise ValueError("Please specify a queue name.") + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {parsed_url}") + + _, sas_token = parse_query(parsed_url.query) + if not sas_token and not credential: + raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + + return parsed_url, sas_token + +def _format_url(queue_name: Union[bytes, str], hostname: str, scheme: str, query_str: str) -> str: + """Format the endpoint URL according to the current location mode hostname. + + :param Union[bytes, str] queue_name: The name of the queue. + :param str hostname: The current location mode hostname. + :param str scheme: The scheme for the current location mode hostname. + :param str query_str: The query string of the endpoint URL being formatted. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str + """ + if isinstance(queue_name, str): + queue_name = queue_name.encode('UTF-8') + else: + pass + return ( + f"{scheme}://{hostname}" + f"/{quote(queue_name)}{query_str}") + +def _from_queue_url(queue_url: str) -> Tuple[str, str]: + """A client to interact with a specific Queue. + + :param str queue_url: The full URI to the queue, including SAS token if used. + :returns: The parsed out account_url and queue name. + :rtype: Tuple[str, str] + """ + try: + if not queue_url.lower().startswith('http'): + queue_url = "https://" + queue_url + except AttributeError as exc: + raise ValueError("Queue URL must be a string.") from exc + parsed_url = urlparse(queue_url.rstrip('/')) + + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {queue_url}") + + queue_path = parsed_url.path.lstrip('/').split('/') + account_path = "" + if len(queue_path) > 1: + account_path = "/" + "/".join(queue_path[:-1]) + account_url = ( + f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" + f"{account_path}?{parsed_url.query}") + queue_name = unquote(queue_path[-1]) + if not queue_name: + raise ValueError("Invalid URL. Please provide a URL with a valid queue name") + return(account_url, queue_name) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index 8abf3420add4..6d95b59135aa 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -6,38 +6,35 @@ import functools from typing import ( - Any, Dict, List, Optional, Union, - TYPE_CHECKING) -from urllib.parse import urlparse - + Any, Dict, List, Optional, + TYPE_CHECKING, Union +) from typing_extensions import Self from azure.core.exceptions import HttpResponseError from azure.core.paging import ItemPaged from azure.core.pipeline import Pipeline from azure.core.tracing.decorator import distributed_trace -from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query -from ._shared.models import LocationMode -from ._shared.response_handlers import process_storage_error +from ._encryption import StorageEncryptionMixin from ._generated import AzureQueueStorage from ._generated.models import StorageServiceProperties -from ._encryption import StorageEncryptionMixin from ._models import ( + CorsRule, + QueueProperties, QueuePropertiesPaged, - service_stats_deserialize, service_properties_deserialize, + service_stats_deserialize, ) -from ._serialize import get_api_version from ._queue_client import QueueClient +from ._queue_service_client_helpers import _parse_url +from ._serialize import get_api_version +from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin, TransportWrapper +from ._shared.models import LocationMode +from ._shared.response_handlers import process_storage_error if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential - from ._models import ( - CorsRule, - Metrics, - QueueProperties, - QueueAnalyticsLogging, - ) + from ._models import Metrics, QueueAnalyticsLogging class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): @@ -65,6 +62,7 @@ class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -92,37 +90,33 @@ class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): """ def __init__( - self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> None: - try: - if not account_url.lower().startswith('http'): - account_url = "https://" + account_url - except AttributeError as exc: - raise ValueError("Account URL must be a string.") from exc - parsed_url = urlparse(account_url.rstrip('/')) - if not parsed_url.netloc: - raise ValueError(f"Invalid URL: {account_url}") - - _, sas_token = parse_query(parsed_url.query) - if not sas_token and not credential: - raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + self, account_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: + parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueServiceClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._configure_encryption(kwargs) - def _format_url(self, hostname): + def _format_url(self, hostname: str) -> str: + """Format the endpoint URL according to the current location + mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str + """ return f"{self.scheme}://{hostname}/{self._query_str}" @classmethod def from_connection_string( - cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> Self: + cls, conn_str: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: """Create QueueServiceClient from a Connection String. :param str conn_str: @@ -159,8 +153,7 @@ def from_connection_string( return cls(account_url, credential=credential, **kwargs) @distributed_trace - def get_service_stats(self, **kwargs): - # type: (Any) -> Dict[str, Any] + def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: """Retrieves statistics related to replication for the Queue service. It is only available when read-access geo-redundant replication is enabled for @@ -186,15 +179,14 @@ def get_service_stats(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - stats = self._client.service.get_statistics( # type: ignore + stats = self._client.service.get_statistics( timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) return service_stats_deserialize(stats) except HttpResponseError as error: process_storage_error(error) @distributed_trace - def get_service_properties(self, **kwargs): - # type: (Any) -> Dict[str, Any] + def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: """Gets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -215,20 +207,19 @@ def get_service_properties(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - service_props = self._client.service.get_properties(timeout=timeout, **kwargs) # type: ignore + service_props = self._client.service.get_properties(timeout=timeout, **kwargs) return service_properties_deserialize(service_props) except HttpResponseError as error: process_storage_error(error) @distributed_trace - def set_service_properties( # type: ignore - self, analytics_logging=None, # type: Optional[QueueAnalyticsLogging] - hour_metrics=None, # type: Optional[Metrics] - minute_metrics=None, # type: Optional[Metrics] - cors=None, # type: Optional[List[CorsRule]] - **kwargs # type: Any - ): - # type: (...) -> None + def set_service_properties( + self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + hour_metrics: Optional["Metrics"] = None, + minute_metrics: Optional["Metrics"] = None, + cors: Optional[List[CorsRule]] = None, + **kwargs: Any + ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -250,7 +241,7 @@ def set_service_properties( # type: ignore You can include up to five CorsRule elements in the list. If an empty list is specified, all CORS rules will be deleted, and CORS will be disabled for the service. - :type cors: list(~azure.storage.queue.CorsRule) + :type cors: Optional[List[~azure.storage.queue.CorsRule]] :keyword int timeout: The timeout parameter is expressed in seconds. @@ -268,20 +259,19 @@ def set_service_properties( # type: ignore logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=cors + cors=CorsRule._to_generated(cors) # pylint: disable=protected-access ) try: - self._client.service.set_properties(props, timeout=timeout, **kwargs) # type: ignore + self._client.service.set_properties(props, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) @distributed_trace def list_queues( - self, name_starts_with=None, # type: Optional[str] - include_metadata=False, # type: Optional[bool] - **kwargs # type: Any - ): - # type: (...) -> ItemPaged[QueueProperties] + self, name_starts_with: Optional[str] = None, + include_metadata: Optional[bool] = False, + **kwargs: Any + ) -> ItemPaged["QueueProperties"]: """Returns a generator to list the queues under the specified account. The generator will lazily follow the continuation tokens returned by @@ -331,11 +321,10 @@ def list_queues( @distributed_trace def create_queue( - self, name, # type: str - metadata=None, # type: Optional[Dict[str, str]] - **kwargs # type: Any - ): - # type: (...) -> QueueClient + self, name: str, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> QueueClient: """Creates a new queue under the specified account. If a queue with the same name already exists, the operation fails. @@ -345,7 +334,7 @@ def create_queue( :param metadata: A dict with name_value pairs to associate with the queue as metadata. Example: {'Category': 'test'} - :type metadata: dict(str, str) + :type metadata: Dict[str, str] :keyword int timeout: The timeout parameter is expressed in seconds. :return: A QueueClient for the newly created Queue. @@ -369,11 +358,9 @@ def create_queue( @distributed_trace def delete_queue( - self, - queue, # type: Union[QueueProperties, str] - **kwargs # type: Any - ): - # type: (...) -> None + self, queue: Union["QueueProperties", str], + **kwargs: Any + ) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -406,11 +393,10 @@ def delete_queue( kwargs.setdefault('merge_span', True) queue_client.delete_queue(timeout=timeout, **kwargs) - def get_queue_client(self, - queue, # type: Union[QueueProperties, str] - **kwargs # type: Any - ): - # type: (...) -> QueueClient + def get_queue_client( + self, queue: Union["QueueProperties", str], + **kwargs: Any + ) -> QueueClient: """Get a client to interact with the specified queue. The queue need not already exist. @@ -431,14 +417,14 @@ def get_queue_client(self, :dedent: 8 :caption: Get the queue client. """ - try: + if isinstance(queue, QueueProperties): queue_name = queue.name - except AttributeError: + else: queue_name = queue _pipeline = Pipeline( - transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access - policies=self._pipeline._impl_policies # pylint: disable = protected-access + transport=TransportWrapper(self._pipeline._transport), # pylint: disable=protected-access + policies=self._pipeline._impl_policies # type: ignore # pylint: disable=protected-access ) return QueueClient( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py new file mode 100644 index 000000000000..881d3d6929cc --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py @@ -0,0 +1,50 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from urllib.parse import urlparse +from ._shared.base_client import parse_query + +if TYPE_CHECKING: + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials_async import AsyncTokenCredential + from urllib.parse import ParseResult + + +def _parse_url( + account_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long +) -> Tuple["ParseResult", Any]: + """Performs initial input validation and returns the parsed URL and SAS token. + + :param str account_url: The URL to the storage account. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential + - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential, TokenCredential]] # pylint: disable=line-too-long + :returns: The parsed URL and SAS token. + :rtype: Tuple[ParseResult, Any] + """ + try: + if not account_url.lower().startswith('http'): + account_url = "https://" + account_url + except AttributeError as exc: + raise ValueError("Account URL must be a string.") from exc + parsed_url = urlparse(account_url.rstrip('/')) + if not parsed_url.netloc: + raise ValueError(f"Invalid URL: {account_url}") + + _, sas_token = parse_query(parsed_url.query) + if not sas_token and not credential: + raise ValueError("You need to provide either a SAS token or an account shared key to authenticate.") + + return parsed_url, sas_token diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py index 804e0b78e597..abbbfe88127f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py @@ -55,7 +55,7 @@ def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str # Build list of sorted tuples sorted_headers = [] for key in header_keys: - sorted_headers.append((key, header_dict.get(key))) + sorted_headers.append((key, header_dict.pop(key))) return sorted_headers diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index ab5136ae0a7b..8746ae3195fd 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -5,26 +5,22 @@ # -------------------------------------------------------------------------- import logging import uuid -from typing import ( # pylint: disable=unused-import +from typing import ( Any, + cast, Dict, + Iterator, Optional, Tuple, TYPE_CHECKING, Union, ) +from urllib.parse import parse_qs, quote -try: - from urllib.parse import parse_qs, quote -except ImportError: - from urlparse import parse_qs # type: ignore - from urllib2 import quote # type: ignore - -from azure.core.configuration import Configuration -from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline -from azure.core.pipeline.transport import RequestsTransport, HttpTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module from azure.core.pipeline.policies import ( AzureSasCredentialPolicy, BearerTokenCredentialPolicy, @@ -36,11 +32,9 @@ UserAgentPolicy, ) -from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE -from .models import LocationMode from .authentication import SharedKeyCredentialPolicy -from .shared_access_signature import QueryStringConstants -from .request_handlers import serialize_batch_body, _get_batch_request_delimiter +from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE +from .models import LocationMode, StorageConfiguration from .policies import ( ExponentialRetry, QueueMessagePolicy, @@ -51,11 +45,15 @@ StorageRequestHook, StorageResponseHook, ) +from .request_handlers import serialize_batch_body, _get_batch_request_delimiter +from .response_handlers import PartialBatchErrorException, process_storage_error +from .shared_access_signature import QueryStringConstants from .._version import VERSION -from .response_handlers import process_storage_error, PartialBatchErrorException +from .._shared_access_signature import _is_credential_sastoken if TYPE_CHECKING: - from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 _LOGGER = logging.getLogger(__name__) _SERVICE_PARAMS = { @@ -67,14 +65,14 @@ class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-attributes + _client: Any def __init__( self, - parsed_url, # type: Any - service, # type: str - credential=None, # type: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "TokenCredential"]] # pylint: disable=line-too-long - **kwargs # type: Any - ): - # type: (...) -> None + parsed_url: Any, + service: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts") self.scheme = parsed_url.scheme @@ -137,7 +135,7 @@ def url(self): def primary_endpoint(self): """The full primary endpoint URL. - :type: str + :rtype: str """ return self._format_url(self._hosts[LocationMode.PRIMARY]) @@ -145,7 +143,7 @@ def primary_endpoint(self): def primary_hostname(self): """The hostname of the primary endpoint. - :type: str + :rtype: str """ return self._hosts[LocationMode.PRIMARY] @@ -156,7 +154,7 @@ def secondary_endpoint(self): If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. - :type: str + :rtype: str :raise ValueError: """ if not self._hosts[LocationMode.SECONDARY]: @@ -170,7 +168,7 @@ def secondary_hostname(self): If not available this will be None. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. - :type: str or None + :rtype: Optional[str] """ return self._hosts[LocationMode.SECONDARY] @@ -180,7 +178,7 @@ def location_mode(self): By default this will be "primary". Options include "primary" and "secondary". - :type: str + :rtype: str """ return self._location_mode @@ -197,35 +195,43 @@ def location_mode(self, value): def api_version(self): """The version of the Storage API used for requests. - :type: str + :rtype: str """ return self._client._config.version # pylint: disable=protected-access - def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None): + def _format_query_string( + self, sas_token: Optional[str], + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long + snapshot: Optional[str] = None, + share_snapshot: Optional[str] = None + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long query_str = "?" if snapshot: - query_str += f"snapshot={self.snapshot}&" + query_str += f"snapshot={snapshot}&" if share_snapshot: - query_str += f"sharesnapshot={self.snapshot}&" + query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") - if is_credential_sastoken(credential): + if _is_credential_sastoken(credential): + credential = cast(str, credential) query_str += credential.lstrip("?") credential = None elif sas_token: query_str += sas_token return query_str.rstrip("?&"), credential - def _create_pipeline(self, credential, **kwargs): - # type: (Any, **Any) -> Tuple[Configuration, Pipeline] - self._credential_policy = None + def _create_pipeline( + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Tuple[StorageConfiguration, Pipeline]: + self._credential_policy: Any = None if hasattr(credential, "get_token"): if kwargs.get('audience'): audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE - self._credential_policy = BearerTokenCredentialPolicy(credential, audience) + self._credential_policy = BearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -236,11 +242,11 @@ def _create_pipeline(self, credential, **kwargs): config = kwargs.get("_configuration") or create_configuration(**kwargs) if kwargs.get("_pipeline"): return config, kwargs["_pipeline"] - config.transport = kwargs.get("transport") # type: ignore + transport = kwargs.get("transport") kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) - if not config.transport: - config.transport = RequestsTransport(**kwargs) + if not transport: + transport = RequestsTransport(**kwargs) policies = [ QueueMessagePolicy(), config.proxy_policy, @@ -259,15 +265,21 @@ def _create_pipeline(self, credential, **kwargs): HttpLoggingPolicy(**kwargs) ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") - return config, Pipeline(config.transport, policies=policies) + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore + config.transport = transport # type: ignore + return config, Pipeline(transport, policies=policies) - # Given a series of request, do a Storage batch call. def _batch_send( self, - *reqs, # type: HttpRequest - **kwargs - ): + *reqs: "HttpRequest", + **kwargs: Any + ) -> Iterator["HttpResponse"]: + """Given a series of request, do a Storage batch call. + + :param HttpRequest reqs: A collection of HttpRequest objects. + :returns: An iterator of HttpResponse objects. + :rtype: Iterator[HttpResponse] + """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) batch_id = str(uuid.uuid1()) @@ -348,7 +360,10 @@ def __exit__(self, *args): # pylint: disable=arguments-differ pass -def _format_shared_key_credential(account_name, credential): +def _format_shared_key_credential( + account_name: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long +) -> Any: if isinstance(credential, str): if not account_name: raise ValueError("Unable to determine account name for shared key credential.") @@ -364,12 +379,16 @@ def _format_shared_key_credential(account_name, credential): return credential -def parse_connection_str(conn_str, credential, service): +def parse_connection_str( + conn_str: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], # pylint: disable=line-too-long + service: str +) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long conn_str = conn_str.rstrip(";") - conn_settings = [s.split("=", 1) for s in conn_str.split(";")] - if any(len(tup) != 2 for tup in conn_settings): + conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings_list): raise ValueError("Connection string is either blank or malformed.") - conn_settings = dict((key.upper(), val) for key, val in conn_settings) + conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) endpoints = _SERVICE_PARAMS[service] primary = None secondary = None @@ -412,43 +431,20 @@ def parse_connection_str(conn_str, credential, service): return primary, secondary, credential -def create_configuration(**kwargs): - # type: (**Any) -> Configuration +def create_configuration(**kwargs: Any) -> StorageConfiguration: # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" - config = Configuration(**kwargs) + config = StorageConfiguration(**kwargs) config.headers_policy = StorageHeadersPolicy(**kwargs) config.user_agent_policy = UserAgentPolicy(**kwargs) config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) config.logging_policy = StorageLoggingPolicy(**kwargs) config.proxy_policy = ProxyPolicy(**kwargs) - - # Storage settings - config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024) - config.copy_polling_interval = 15 - - # Block blob uploads - config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024) - config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) - config.use_byte_buffer = kwargs.get("use_byte_buffer", False) - - # Page blob uploads - config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024) - - # Datalake file uploads - config.min_large_chunk_upload_threshold = kwargs.get("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) - - # Blob downloads - config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024) - config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024) - - # File uploads - config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024) return config -def parse_query(query_str): +def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]: sas_values = QueryStringConstants.to_list() parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()} sas_params = [f"{k}={quote(v, safe='')}" for k, v in parsed_query.items() if k in sas_values] @@ -458,14 +454,3 @@ def parse_query(query_str): snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot") return snapshot, sas_token - - -def is_credential_sastoken(credential): - if not credential or not isinstance(credential, str): - return False - - sas_values = QueryStringConstants.to_list() - parsed_query = parse_qs(credential.lstrip("?")) - if parsed_query and all(k in sas_values for k in parsed_query.keys()): - return True - return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 1ca23b131a81..8ddb5b390e11 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -3,17 +3,16 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +# mypy: disable-error-code="attr-defined" -from typing import ( # pylint: disable=unused-import - Union, Optional, Any, Iterable, Dict, List, Type, Tuple, - TYPE_CHECKING -) import logging +from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union -from azure.core.credentials import AzureSasCredential -from azure.core.pipeline import AsyncPipeline from azure.core.async_paging import AsyncList +from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential +from azure.core.credentials_async import AsyncTokenCredential from azure.core.exceptions import HttpResponseError +from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import ( AsyncBearerTokenCredentialPolicy, AsyncRedirectPolicy, @@ -24,9 +23,10 @@ ) from azure.core.pipeline.transport import AsyncHttpTransport -from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, STORAGE_OAUTH_SCOPE from .authentication import SharedKeyCredentialPolicy from .base_client import create_configuration +from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE +from .models import StorageConfiguration from .policies import ( QueueMessagePolicy, StorageContentValidation, @@ -35,15 +35,20 @@ StorageRequestHook, ) from .policies_async import AsyncStorageResponseHook - -from .response_handlers import process_storage_error, PartialBatchErrorException +from .response_handlers import PartialBatchErrorException, process_storage_error +from .._shared_access_signature import _is_credential_sastoken if TYPE_CHECKING: - from azure.core.pipeline import Pipeline - from azure.core.pipeline.transport import HttpRequest - from azure.core.configuration import Configuration + from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 _LOGGER = logging.getLogger(__name__) +_SERVICE_PARAMS = { + "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"}, + "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"}, + "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"}, + "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"}, +} + class AsyncStorageAccountHostsMixin(object): @@ -66,15 +71,41 @@ async def close(self): """ await self._client.close() - def _create_pipeline(self, credential, **kwargs): - # type: (Any, **Any) -> Tuple[Configuration, Pipeline] - self._credential_policy = None + def _format_query_string( + self, sas_token: Optional[str], + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long + snapshot: Optional[str] = None, + share_snapshot: Optional[str] = None + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long + query_str = "?" + if snapshot: + query_str += f"snapshot={snapshot}&" + if share_snapshot: + query_str += f"sharesnapshot={share_snapshot}&" + if sas_token and isinstance(credential, AzureSasCredential): + raise ValueError( + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + if _is_credential_sastoken(credential): + query_str += credential.lstrip("?") # type: ignore [union-attr] + credential = None + elif sas_token: + query_str += sas_token + return query_str.rstrip("?&"), credential + + def _create_pipeline( + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Tuple[StorageConfiguration, AsyncPipeline]: + self._credential_policy: Optional[ + Union[AsyncBearerTokenCredentialPolicy, + SharedKeyCredentialPolicy, + AzureSasCredentialPolicy]] = None if hasattr(credential, 'get_token'): if kwargs.get('audience'): audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE - self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, audience) + self._credential_policy = AsyncBearerTokenCredentialPolicy(cast(AsyncTokenCredential, credential), audience) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): @@ -84,15 +115,16 @@ def _create_pipeline(self, credential, **kwargs): config = kwargs.get('_configuration') or create_configuration(**kwargs) if kwargs.get('_pipeline'): return config, kwargs['_pipeline'] - config.transport = kwargs.get('transport') # type: ignore + transport = kwargs.get('transport') kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) - if not config.transport: + if not transport: try: from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import except ImportError as exc: raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc - config.transport = AioHttpTransport(**kwargs) + transport = AioHttpTransport(**kwargs) + hosts = self._hosts policies = [ QueueMessagePolicy(), config.headers_policy, @@ -103,7 +135,7 @@ def _create_pipeline(self, credential, **kwargs): self._credential_policy, ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), - StorageHosts(hosts=self._hosts, **kwargs), # type: ignore + StorageHosts(hosts=hosts, **kwargs), config.retry_policy, config.logging_policy, AsyncStorageResponseHook(**kwargs), @@ -111,15 +143,21 @@ def _create_pipeline(self, credential, **kwargs): HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") - return config, AsyncPipeline(config.transport, policies=policies) + policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore + config.transport = transport #type: ignore + return config, AsyncPipeline(transport, policies=policies) #type: ignore - # Given a series of request, do a Storage batch call. async def _batch_send( self, - *reqs, # type: HttpRequest - **kwargs - ): + *reqs: "HttpRequest", + **kwargs: Any + ) -> AsyncList["HttpResponse"]: + """Given a series of request, do a Storage batch call. + + :param HttpRequest reqs: A collection of HttpRequest objects. + :returns: An AsyncList of HttpResponse objects. + :rtype: AsyncList[HttpResponse] + """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) request = self._client._client.post( # pylint: disable=protected-access @@ -135,7 +173,7 @@ async def _batch_send( policies = [StorageHeadersPolicy()] if self._credential_policy: - policies.append(self._credential_policy) + policies.append(self._credential_policy) # type: ignore request.set_multipart_mixed( *reqs, @@ -167,6 +205,56 @@ async def _batch_send( except HttpResponseError as error: process_storage_error(error) +def parse_connection_str( + conn_str: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], # pylint: disable=line-too-long + service: str +) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long + conn_str = conn_str.rstrip(";") + conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings_list): + raise ValueError("Connection string is either blank or malformed.") + conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) + endpoints = _SERVICE_PARAMS[service] + primary = None + secondary = None + if not credential: + try: + credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]} + except KeyError: + credential = conn_settings.get("SHAREDACCESSSIGNATURE") + if endpoints["primary"] in conn_settings: + primary = conn_settings[endpoints["primary"]] + if endpoints["secondary"] in conn_settings: + secondary = conn_settings[endpoints["secondary"]] + else: + if endpoints["secondary"] in conn_settings: + raise ValueError("Connection string specifies only secondary endpoint.") + try: + primary =( + f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" + f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + secondary = ( + f"{conn_settings['ACCOUNTNAME']}-secondary." + f"{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + except KeyError: + pass + + if not primary: + try: + primary = ( + f"https://{conn_settings['ACCOUNTNAME']}." + f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}" + ) + except KeyError as exc: + raise ValueError("Connection string missing required connection details.") from exc + if service == "dfs": + primary = primary.replace(".blob.", ".dfs.") + if secondary: + secondary = secondary.replace(".blob.", ".dfs.") + return primary, secondary, credential class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index b6519ebe9f81..34ce22b7a632 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -5,8 +5,11 @@ # -------------------------------------------------------------------------- # pylint: disable=too-many-instance-attributes from enum import Enum +from typing import Optional from azure.core import CaseInsensitiveEnumMeta +from azure.core.configuration import Configuration +from azure.core.pipeline.policies import UserAgentPolicy def get_enum_value(value): @@ -269,7 +272,17 @@ class ResourceTypes(object): files(e.g. Put Blob, Query Entity, Get Messages, Create File, etc.) """ - def __init__(self, service=False, container=False, object=False): # pylint: disable=redefined-builtin + service: bool = False + container: bool = False + object: bool = False + _str: str + + def __init__( + self, + service: bool = False, + container: bool = False, + object: bool = False # pylint: disable=redefined-builtin + ) -> None: self.service = service self.container = container self.object = object @@ -344,9 +357,34 @@ class AccountSasPermissions(object): To enable permanent delete on the blob is permitted. Valid for Object resource type of Blob only. """ - def __init__(self, read=False, write=False, delete=False, - list=False, # pylint: disable=redefined-builtin - add=False, create=False, update=False, process=False, delete_previous_version=False, **kwargs): + + read: bool = False + write: bool = False + delete: bool = False + delete_previous_version: bool = False + list: bool = False + add: bool = False + create: bool = False + update: bool = False + process: bool = False + tag: bool = False + filter_by_tags: bool = False + set_immutability_policy: bool = False + permanent_delete: bool = False + + def __init__( + self, + read: bool = False, + write: bool = False, + delete: bool = False, + list: bool = False, # pylint: disable=redefined-builtin + add: bool = False, + create: bool = False, + update: bool = False, + process: bool = False, + delete_previous_version: bool = False, + **kwargs + ) -> None: self.read = read self.write = write self.delete = delete @@ -423,7 +461,11 @@ class Services(object): Access for the `~azure.storage.fileshare.ShareServiceClient` """ - def __init__(self, blob=False, queue=False, fileshare=False): + blob: bool = False + queue: bool = False + fileshare: bool = False + + def __init__(self, blob: bool = False, queue: bool = False, fileshare: bool = False): self.blob = blob self.queue = queue self.fileshare = fileshare @@ -463,22 +505,23 @@ class UserDelegationKey(object): The fields are saved as simple strings since the user does not have to interact with this object; to generate an identify SAS, the user can simply pass it to the right API. - - :ivar str signed_oid: - Object ID of this token. - :ivar str signed_tid: - Tenant ID of the tenant that issued this token. - :ivar str signed_start: - The datetime this token becomes valid. - :ivar str signed_expiry: - The datetime this token expires. - :ivar str signed_service: - What service this key is valid for. - :ivar str signed_version: - The version identifier of the REST service that created this token. - :ivar str value: - The user delegation key. """ + + signed_oid: Optional[str] = None + """Object ID of this token.""" + signed_tid: Optional[str] = None + """Tenant ID of the tenant that issued this token.""" + signed_start: Optional[str] = None + """The datetime this token becomes valid.""" + signed_expiry: Optional[str] = None + """The datetime this token expires.""" + signed_service: Optional[str] = None + """What service this key is valid for.""" + signed_version: Optional[str] = None + """The version identifier of the REST service that created this token.""" + value: Optional[str] = None + """The user delegation key.""" + def __init__(self): self.signed_oid = None self.signed_tid = None @@ -487,3 +530,52 @@ def __init__(self): self.signed_service = None self.signed_version = None self.value = None + + +class StorageConfiguration(Configuration): + """ + Specifies the configurable values used in Azure Storage. + + :param int max_single_put_size: If the blob size is less than or equal max_single_put_size, then the blob will be + uploaded with only one http PUT request. If the blob size is larger than max_single_put_size, + the blob will be uploaded in chunks. Defaults to 64*1024*1024, or 64MB. + :param int copy_polling_interval: The interval in seconds for polling copy operations. + :param int max_block_size: The maximum chunk size for uploading a block blob in chunks. + Defaults to 4*1024*1024, or 4MB. + :param int min_large_block_upload_threshold: The minimum chunk size required to use the memory efficient + algorithm when uploading a block blob. + :param bool use_byte_buffer: Use a byte buffer for block blob uploads. Defaults to False. + :param int max_page_size: The maximum chunk size for uploading a page blob. Defaults to 4*1024*1024, or 4MB. + :param int min_large_chunk_upload_threshold: The max size for a single put operation. + :param int max_single_get_size: The maximum size for a blob to be downloaded in a single call, + the exceeded part will be downloaded in chunks (could be parallel). Defaults to 32*1024*1024, or 32MB. + :param int max_chunk_get_size: The maximum chunk size used for downloading a blob. Defaults to 4*1024*1024, + or 4MB. + :param int max_range_size: The max range size for file upload. + + """ + + max_single_put_size: int + copy_polling_interval: int + max_block_size: int + min_large_block_upload_threshold: int + use_byte_buffer: bool + max_page_size: int + min_large_chunk_upload_threshold: int + max_single_get_size: int + max_chunk_get_size: int + max_range_size: int + user_agent_policy: UserAgentPolicy + + def __init__(self, **kwargs): + super(StorageConfiguration, self).__init__(**kwargs) + self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024) + self.copy_polling_interval = 15 + self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024) + self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1) + self.use_byte_buffer = kwargs.pop('use_byte_buffer', False) + self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024) + self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) + self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) + self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) + self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py index c76c376f30e0..cd59cfe104ca 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py @@ -6,6 +6,7 @@ import sys from datetime import datetime, timezone +from typing import Optional EPOCH_AS_FILETIME = 116444736000000000 # January 1, 1970 as MS filetime HUNDREDS_OF_NANOSECONDS = 10000000 @@ -20,10 +21,10 @@ def _str(value): _str = str -def _to_utc_datetime(value): +def _to_utc_datetime(value: datetime) -> str: return value.strftime('%Y-%m-%dT%H:%M:%SZ') -def _rfc_1123_to_datetime(rfc_1123: str) -> datetime: +def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: """Converts an RFC 1123 date string to a UTC datetime. :param str rfc_1123: The time and date in RFC 1123 format. @@ -35,7 +36,7 @@ def _rfc_1123_to_datetime(rfc_1123: str) -> datetime: return datetime.strptime(rfc_1123, "%a, %d %b %Y %H:%M:%S %Z") -def _filetime_to_datetime(filetime: str) -> datetime: +def _filetime_to_datetime(filetime: str) -> Optional[datetime]: """Converts an MS filetime string to a UTC datetime. "0" indicates None. If parsing MS Filetime fails, tries RFC 1123 as backup. @@ -48,11 +49,11 @@ def _filetime_to_datetime(filetime: str) -> datetime: # Try to convert to MS Filetime try: - filetime = int(filetime) - if filetime == 0: + temp_filetime = int(filetime) + if temp_filetime == 0: return None - return datetime.fromtimestamp((filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc) + return datetime.fromtimestamp((temp_filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc) except ValueError: pass diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index a0c060aaaad9..73105fe979ee 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -6,29 +6,22 @@ import base64 import hashlib -import re -import random -from time import time -from io import SEEK_SET, UnsupportedOperation import logging +import random +import re import uuid -from typing import Any, TYPE_CHECKING -from wsgiref.handlers import format_date_time -try: - from urllib.parse import ( - urlparse, +from io import SEEK_SET, UnsupportedOperation +from time import time +from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from urllib.parse import ( parse_qsl, - urlunparse, urlencode, - ) -except ImportError: - from urllib import urlencode # type: ignore - from urlparse import ( # type: ignore urlparse, - parse_qsl, urlunparse, - ) +) +from wsgiref.handlers import format_date_time +from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import ( BearerTokenCredentialPolicy, HeadersPolicy, @@ -37,7 +30,6 @@ RequestHistory, SansIOHTTPPolicy, ) -from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError from .authentication import StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE @@ -50,7 +42,10 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential - from azure.core.pipeline import PipelineRequest, PipelineResponse + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + PipelineRequest, + PipelineResponse + ) _LOGGER = logging.getLogger(__name__) @@ -134,8 +129,7 @@ def on_request(self, request): class StorageHeadersPolicy(HeadersPolicy): request_id_header_name = 'x-ms-client-request-id' - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: "PipelineRequest") -> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) request.http_request.headers['x-ms-date'] = current_time @@ -165,8 +159,7 @@ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument self.hosts = hosts super(StorageHosts, self).__init__() - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: "PipelineRequest") -> None: request.context.options['hosts'] = self.hosts parsed_url = urlparse(request.http_request.url) @@ -197,12 +190,12 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): This accepts both global configuration, and per-request level with "enable_http_logger" """ - def __init__(self, logging_enable=False, **kwargs): + + def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request options = request.context.options self.logging_body = self.logging_body or options.pop("logging_body", False) @@ -242,8 +235,7 @@ def on_request(self, request): except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse, Any) -> None + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return @@ -286,8 +278,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument self._request_callback = kwargs.get('raw_request_hook') super(StorageRequestHook, self).__init__() - def on_request(self, request): - # type: (PipelineRequest, **Any) -> PipelineResponse + def on_request(self, request: "PipelineRequest") -> None: request_callback = request.context.options.pop('raw_request_hook', self._request_callback) if request_callback: request_callback(request) @@ -299,8 +290,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument self._response_callback = kwargs.get('raw_response_hook') super(StorageResponseHook, self).__init__() - def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 data_stream_total = request.context.get('data_stream_total') if data_stream_total is None: @@ -333,9 +323,10 @@ def send(self, request): elif should_update_counts and upload_stream_current is not None: upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) for pipeline_obj in [request, response]: - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, 'context'): + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current if response_callback: response_callback(response) request.context['response_callback'] = response_callback @@ -350,7 +341,7 @@ class StorageContentValidation(SansIOHTTPPolicy): """ header_name = 'Content-MD5' - def __init__(self, **kwargs): # pylint: disable=unused-argument + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @staticmethod @@ -378,8 +369,7 @@ def get_content_md5(data): return md5.digest() - def on_request(self, request): - # type: (PipelineRequest, Any) -> None + def on_request(self, request: "PipelineRequest") -> None: validate_content = request.context.options.pop('validate_content', False) if validate_content and request.http_request.method != 'GET': computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) @@ -387,7 +377,7 @@ def on_request(self, request): request.context['validate_content_md5'] = computed_md5 request.context['validate_content'] = validate_content - def on_response(self, request, response): + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = request.context.get('validate_content_md5') or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) @@ -404,7 +394,18 @@ class StorageRetryPolicy(HTTPPolicy): The base class for Exponential and Linear retries containing shared code. """ - def __init__(self, **kwargs): + total_retries: int + """The max number of retries.""" + connect_retries: int + """The max number of connect retries.""" + retry_read: int + """The max number of read retries.""" + retry_status: int + """The max number of status retries.""" + retry_to_secondary: bool + """Whether the secondary endpoint should be retried.""" + + def __init__(self, **kwargs: Any) -> None: self.total_retries = kwargs.pop('retry_total', 10) self.connect_retries = kwargs.pop('retry_connect', 3) self.read_retries = kwargs.pop('retry_read', 3) @@ -412,11 +413,11 @@ def __init__(self, **kwargs): self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings, request): + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. :param PipelineRequest request: A pipeline request object. """ if settings['hosts'] and all(settings['hosts'].values()): @@ -429,7 +430,7 @@ def _set_next_host_location(self, settings, request): updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) request.url = updated.geturl() - def configure_retries(self, request): + def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: body_position = None if hasattr(request.http_request.body, 'read'): try: @@ -452,11 +453,11 @@ def configure_retries(self, request): 'history': [] } - def get_backoff_time(self, settings): # pylint: disable=unused-argument + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """ Formula for computing the current backoff. Should be calculated by child class. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: The backoff time. :rtype: float """ @@ -468,15 +469,20 @@ def sleep(self, settings, transport): return transport.sleep(backoff) - def increment(self, settings, request, response=None, error=None): + def increment( + self, settings: Dict[str, Any], + request: "PipelineRequest", + response: Optional["PipelineResponse"] = None, + error: Optional[Union[ServiceRequestError, ServiceResponseError]] = None + ) -> bool: """Increment the retry counters. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the increment operation. - :param "PipelineRequest" request: A pipeline request object. - :param "PipelineResponse": A pipeline response object. + Dict[str, Any]] settings: The configurable values pertaining to the increment operation. + :param PipelineRequest request: A pipeline request object. + :param Optional[PipelineResponse] response: A pipeline response object. :param error: An error encountered during the request, or None if the response was received successfully. - :paramtype error: Union[ServiceRequestError, ServiceResponseError] + :paramtype error: Optional[Union[ServiceRequestError, ServiceResponseError]] :returns: Whether the retry attempts are exhausted. :rtype: bool """ @@ -562,9 +568,23 @@ def send(self, request): class ExponentialRetry(StorageRetryPolicy): """Exponential retry.""" - def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, - retry_to_secondary=False, random_jitter_range=3, **kwargs): - ''' + initial_backoff: int + """The initial backoff interval, in seconds, for the first retry.""" + increment_base: int + """The base, in seconds, to increment the initial_backoff by after the + first retry.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, initial_backoff: int = 15, + increment_base: int = 3, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: + """ Constructs an Exponential retry object. The initial_backoff is used for the first retry. Subsequent retries are retried after initial_backoff + increment_power^retry_count seconds. @@ -574,7 +594,7 @@ def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, :param int increment_base: The base, in seconds, to increment the initial_backoff by after the first retry. - :param int max_attempts: + :param int retry_total: The maximum number of retry attempts. :param bool retry_to_secondary: Whether the request should be retried to secondary, if able. This should @@ -583,22 +603,22 @@ def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, :param int random_jitter_range: A number in seconds which indicates a range to jitter/randomize for the back-off interval. For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. - ''' + """ self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range super(ExponentialRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to get backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. :returns: - An integer indicating how long to wait before retrying the request, + A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. - :rtype: int or None + :rtype: float """ random_generator = random.Random() backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) @@ -610,13 +630,24 @@ def get_backoff_time(self, settings): class LinearRetry(StorageRetryPolicy): """Linear retry.""" - def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_jitter_range=3, **kwargs): + initial_backoff: int + """The backoff interval, in seconds, between retries.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, backoff: int = 15, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: """ Constructs a Linear retry object. :param int backoff: The backoff interval, in seconds, between retries. - :param int max_attempts: + :param int retry_total: The maximum number of retry attempts. :param bool retry_to_secondary: Whether the request should be retried to secondary, if able. This should @@ -631,15 +662,15 @@ def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_j super(LinearRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: - An integer indicating how long to wait before retrying the request, + A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. - :rtype: int or None + :rtype: float """ random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility @@ -656,8 +687,7 @@ class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - def on_challenge(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> bool + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 8dc38f626357..bf03d3690598 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -6,12 +6,12 @@ # pylint: disable=invalid-overridden-method import asyncio -import random import logging -from typing import Any, TYPE_CHECKING +import random +from typing import Any, Dict, TYPE_CHECKING -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy from azure.core.exceptions import AzureError +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy from .authentication import StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE @@ -19,7 +19,10 @@ if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential - from azure.core.pipeline import PipelineRequest, PipelineResponse + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + PipelineRequest, + PipelineResponse + ) _LOGGER = logging.getLogger(__name__) @@ -45,8 +48,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument self._response_callback = kwargs.get('raw_response_hook') super(AsyncStorageResponseHook, self).__init__() - async def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + async def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 data_stream_total = request.context.get('data_stream_total') if data_stream_total is None: @@ -80,12 +82,13 @@ async def send(self, request): elif should_update_counts and upload_stream_current is not None: upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) for pipeline_obj in [request, response]: - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, 'context'): + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): - await response_callback(response) + await response_callback(response) # type: ignore else: response_callback(response) request.context['response_callback'] = response_callback @@ -144,9 +147,23 @@ async def send(self, request): class ExponentialRetry(AsyncStorageRetryPolicy): """Exponential retry.""" - def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, - retry_to_secondary=False, random_jitter_range=3, **kwargs): - ''' + initial_backoff: int + """The initial backoff interval, in seconds, for the first retry.""" + increment_base: int + """The base, in seconds, to increment the initial_backoff by after the + first retry.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, + initial_backoff: int = 15, + increment_base: int = 3, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, **kwargs + ) -> None: + """ Constructs an Exponential retry object. The initial_backoff is used for the first retry. Subsequent retries are retried after initial_backoff + increment_power^retry_count seconds. For example, by default the first retry @@ -167,18 +184,18 @@ def __init__(self, initial_backoff=15, increment_base=3, retry_total=3, :param int random_jitter_range: A number in seconds which indicates a range to jitter/randomize for the back-off interval. For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. - ''' + """ self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range super(ExponentialRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. @@ -194,7 +211,18 @@ def get_backoff_time(self, settings): class LinearRetry(AsyncStorageRetryPolicy): """Linear retry.""" - def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_jitter_range=3, **kwargs): + initial_backoff: int + """The backoff interval, in seconds, between retries.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, backoff: int = 15, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: """ Constructs a Linear retry object. @@ -215,11 +243,11 @@ def __init__(self, backoff=15, retry_total=3, retry_to_secondary=False, random_j super(LinearRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings): + def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Optional[Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. @@ -240,8 +268,7 @@ class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - async def on_challenge(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> bool + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py index 1f67cb8c72f4..313009997a8e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py @@ -3,26 +3,22 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import NoReturn, TYPE_CHECKING import logging +from typing import NoReturn from xml.etree.ElementTree import Element -from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.exceptions import ( + ClientAuthenticationError, + DecodeError, HttpResponseError, - ResourceNotFoundError, - ResourceModifiedError, ResourceExistsError, - ClientAuthenticationError, - DecodeError) + ResourceModifiedError, + ResourceNotFoundError, +) +from azure.core.pipeline.policies import ContentDecodePolicy -from .parser import _to_utc_datetime from .models import StorageErrorCode, UserDelegationKey, get_enum_value - - -if TYPE_CHECKING: - from datetime import datetime - from azure.core.exceptions import AzureError +from .parser import _to_utc_datetime _LOGGER = logging.getLogger(__name__) @@ -85,7 +81,7 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error) -> NoReturn: # pylint:disable=too-many-statements +def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements raise_error = HttpResponseError serialized = False if not storage_error.response or storage_error.response.status_code in [200, 204]: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py index 67b4f0716d05..708211a00ba9 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py @@ -4,24 +4,25 @@ # license information. # -------------------------------------------------------------------------- -from typing import ( # pylint: disable=unused-import - Union, Optional, TYPE_CHECKING -) +from typing import Any, Optional, TYPE_CHECKING, Union +from urllib.parse import parse_qs from azure.storage.queue._shared import sign_string from azure.storage.queue._shared.constants import X_MS_VERSION from azure.storage.queue._shared.models import Services -from azure.storage.queue._shared.shared_access_signature import SharedAccessSignature, _SharedAccessHelper, \ - QueryStringConstants +from azure.storage.queue._shared.shared_access_signature import ( + QueryStringConstants, + SharedAccessSignature, + _SharedAccessHelper +) if TYPE_CHECKING: - from datetime import datetime from azure.storage.queue import ( - ResourceTypes, AccountSasPermissions, - QueueSasPermissions + QueueSasPermissions, + ResourceTypes ) - from typing import Any + from datetime import datetime class QueueSharedAccessSignature(SharedAccessSignature): ''' @@ -31,7 +32,7 @@ class QueueSharedAccessSignature(SharedAccessSignature): generate_*_shared_access_signature method directly. ''' - def __init__(self, account_name, account_key): + def __init__(self, account_name: str, account_key: str) -> None: ''' :param str account_name: The storage account name used to generate the shared access signatures. @@ -40,21 +41,28 @@ def __init__(self, account_name, account_key): ''' super(QueueSharedAccessSignature, self).__init__(account_name, account_key, x_ms_version=X_MS_VERSION) - def generate_queue(self, queue_name, permission=None, - expiry=None, start=None, policy_id=None, - ip=None, protocol=None): + def generate_queue( + self, queue_name: str, + permission: Optional[Union["QueueSasPermissions", str]] = None, + expiry: Optional[Union["datetime", str]] = None, + start: Optional[Union["datetime", str]] = None, + policy_id: Optional[str] = None, + ip: Optional[str] = None, + protocol: Optional[str] = None + ) -> str: ''' Generates a shared access signature for the queue. Use the returned signature with the sas_token parameter of QueueService. :param str queue_name: Name of queue. - :param QueueSasPermissions permission: + :param permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. Permissions must be ordered read, add, update, process. Required unless an id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. + :type permission: ~azure.storage.queue.QueueSasPermissions or str :param expiry: The time at which the shared access signature becomes invalid. Required unless an id is given referencing a stored access policy @@ -62,14 +70,14 @@ def generate_queue(self, queue_name, permission=None, been specified in an associated stored access policy. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type expiry: datetime or str + :type expiry: ~datetime.datetime or str :param start: The time at which the shared access signature becomes valid. If omitted, start time for this call is assumed to be the time when the storage service receives the request. Azure will always convert values to UTC. If a date is passed in without timezone info, it is assumed to be UTC. - :type start: datetime or str + :type start: ~datetime.datetime or str :param str policy_id: A unique value up to 64 characters in length that correlates to a stored access policy. @@ -95,7 +103,7 @@ def generate_queue(self, queue_name, permission=None, class _QueueSharedAccessHelper(_SharedAccessHelper): - def add_resource_signature(self, account_name, account_key, path): # pylint: disable=arguments-differ + def add_resource_signature(self, account_name: str, account_key: str, path: str): # pylint: disable=arguments-differ def get_value_to_append(query): return_value = self.query_dict.get(query) or '' return return_value + '\n' @@ -126,15 +134,15 @@ def get_value_to_append(query): def generate_account_sas( - account_name, # type: str - account_key, # type: str - resource_types, # type: Union[ResourceTypes, str] - permission, # type: Union[AccountSasPermissions, str] - expiry, # type: Union[datetime, str] - start=None, # type: Optional[Union[datetime, str]] - ip=None, # type: Optional[str] - **kwargs # type: "Any" - ): # type: (...) -> str + account_name: str, + account_key: str, + resource_types: Union["ResourceTypes", str], + permission: Union["AccountSasPermissions", str], + expiry: Optional[Union["datetime", str]], + start: Optional[Union["datetime", str]] = None, + ip: Optional[str] = None, + **kwargs: Any +) -> str: """Generates a shared access signature for the queue service. Use the returned signature with the credential parameter of any Queue Service. @@ -145,9 +153,10 @@ def generate_account_sas( The account key, also called shared key or access key, to generate the shared access signature. :param ~azure.storage.queue.ResourceTypes resource_types: Specifies the resource types that are accessible with the account SAS. - :param ~azure.storage.queue.AccountSasPermissions permission: + :param permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. + :type permission: ~azure.storage.queue.AccountSasPermissions or str :param expiry: The time at which the shared access signature becomes invalid. Azure will always convert values to UTC. If a date is passed in @@ -180,20 +189,20 @@ def generate_account_sas( start=start, ip=ip, **kwargs - ) # type: ignore + ) def generate_queue_sas( - account_name, # type: str - queue_name, # type: str - account_key, # type: str - permission=None, # type: Optional[Union[QueueSasPermissions, str]] - expiry=None, # type: Optional[Union[datetime, str]] - start=None, # type: Optional[Union[datetime, str]] - policy_id=None, # type: Optional[str] - ip=None, # type: Optional[str] - **kwargs # type: "Any" - ): # type: (...) -> str + account_name: str, + queue_name: str, + account_key: str, + permission: Optional[Union["QueueSasPermissions", str]] = None, + expiry: Optional[Union["datetime", str]] = None, + start: Optional[Union["datetime", str]] = None, + policy_id: Optional[str] = None, + ip: Optional[str] = None, + **kwargs: Any +) -> str: """Generates a shared access signature for a queue. Use the returned signature with the credential parameter of any Queue Service. @@ -204,12 +213,13 @@ def generate_queue_sas( The name of the queue. :param str account_key: The account key, also called shared key or access key, to generate the shared access signature. - :param ~azure.storage.queue.QueueSasPermissions permission: + :param permission: The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions. Required unless a policy_id is given referencing a stored access policy which contains this field. This field must be omitted if it has been specified in an associated stored access policy. + :type permission: ~azure.storage.queue.QueueSasPermissions or str :param expiry: The time at which the shared access signature becomes invalid. Required unless a policy_id is given referencing a stored access policy @@ -264,3 +274,13 @@ def generate_queue_sas( ip=ip, **kwargs ) + +def _is_credential_sastoken(credential: Any) -> bool: + if not credential or not isinstance(credential, str): + return False + + sas_values = QueryStringConstants.to_list() + parsed_query = parse_qs(credential.lstrip("?")) + if parsed_query and all(k in sas_values for k in parsed_query): + return True + return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index dcb8d32277f4..564558eec50b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -6,37 +6,49 @@ # pylint: disable=too-few-public-methods, too-many-instance-attributes # pylint: disable=super-init-not-called -from typing import List # pylint: disable=unused-import +from typing import Any, Callable, List, Optional, Tuple + from azure.core.async_paging import AsyncPageIterator from azure.core.exceptions import HttpResponseError -from .._shared.response_handlers import ( - process_storage_error, - return_context_and_deserialized) from .._models import QueueMessage, QueueProperties +from .._shared.response_handlers import process_storage_error, return_context_and_deserialized class MessagesPaged(AsyncPageIterator): """An iterable of Queue Messages. - :param callable command: Function to retrieve the next page of items. - :param int results_per_page: The maximum number of messages to retrieve per + :param Callable command: Function to retrieve the next page of items. + :param Optional[int] results_per_page: The maximum number of messages to retrieve per call. - :param int max_messages: The maximum number of messages to retrieve from + :param Optional[int] max_messages: The maximum number of messages to retrieve from the queue. """ - def __init__(self, command, results_per_page=None, continuation_token=None, max_messages=None): + + command: Callable + """Function to retrieve the next page of items.""" + results_per_page: Optional[int] = None + """A UTC date value representing the time the message expires.""" + max_messages: Optional[int] = None + """The maximum number of messages to retrieve from the queue.""" + + def __init__( + self, command: Callable, + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None, + max_messages: Optional[int] = None + ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") super(MessagesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, + self._extract_data_cb, # type: ignore [arg-type] ) self._command = command self.results_per_page = results_per_page self._max_messages = max_messages - async def _get_next_cb(self, continuation_token): + async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: if self._max_messages is not None: if self.results_per_page is None: @@ -48,7 +60,7 @@ async def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, messages): + async def _extract_data_cb(self, messages: Any) -> Tuple[str, List[QueueMessage]]: # There is no concept of continuation token, so raising on my own condition if not messages: raise StopAsyncIteration("End of paging") @@ -60,24 +72,40 @@ async def _extract_data_cb(self, messages): class QueuePropertiesPaged(AsyncPageIterator): """An iterable of Queue properties. - :ivar str service_endpoint: The service URL. - :ivar str prefix: A queue name prefix being used to filter the list. - :ivar str marker: The continuation token of the current page of results. - :ivar int results_per_page: The maximum number of results retrieved per API call. - :ivar str next_marker: The continuation token to retrieve the next page of results. - :ivar str location_mode: The location mode being used to list results. The available - options include "primary" and "secondary". - :param callable command: Function to retrieve the next page of items. + :param Callable command: Function to retrieve the next page of items. :param str prefix: Filters the results to return only queues whose names begin with the specified prefix. - :param int results_per_page: The maximum number of queue names to retrieve per + :param Optional[int] results_per_page: The maximum number of queue names to retrieve per call. :param str continuation_token: An opaque continuation token. """ - def __init__(self, command, prefix=None, results_per_page=None, continuation_token=None): + + service_endpoint: Optional[str] + """The service URL.""" + prefix: Optional[str] + """A queue name prefix being used to filter the list.""" + marker: Optional[str] + """The continuation token of the current page of results.""" + results_per_page: Optional[int] = None + """The maximum number of results retrieved per API call.""" + next_marker: str + """The continuation token to retrieve the next page of results.""" + location_mode: Optional[str] + """The location mode being used to list results. The available options include "primary" and "secondary".""" + command: Callable + """Function to retrieve the next page of items.""" + _response: Any + """Function to retrieve the next page of items.""" + + def __init__( + self, command: Callable, + prefix: Optional[str] = None, + results_per_page: Optional[int] = None, + continuation_token: Optional[str] = None + ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, + self._extract_data_cb, # type: ignore [arg-type] continuation_token=continuation_token or "" ) self._command = command @@ -87,7 +115,7 @@ def __init__(self, command, prefix=None, results_per_page=None, continuation_tok self.results_per_page = results_per_page self.location_mode = None - async def _get_next_cb(self, continuation_token): + async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: try: return await self._command( marker=continuation_token or None, @@ -97,11 +125,12 @@ async def _get_next_cb(self, continuation_token): except HttpResponseError as error: process_storage_error(error) - async def _extract_data_cb(self, get_next_return): + async def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[QueueProperties]]: self.location_mode, self._response = get_next_return self.service_endpoint = self._response.service_endpoint self.prefix = self._response.prefix self.marker = self._response.marker self.results_per_page = self._response.max_results props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access - return self._response.next_marker or None, props_list + next_marker = self._response.next_marker + return next_marker or None, props_list diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 248c1dca5d4f..e1b554007561 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -8,36 +8,41 @@ import functools import warnings from typing import ( - Any, Dict, List, Optional, Union, - TYPE_CHECKING) + Any, cast, Dict, List, + Optional, Tuple, TYPE_CHECKING, Union +) +from typing_extensions import Self from azure.core.async_paging import AsyncItemPaged from azure.core.exceptions import HttpResponseError from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async +from ._models import MessagesPaged +from .._deserialize import deserialize_queue_creation, deserialize_queue_properties +from .._encryption import modify_user_agent_for_encryption, StorageEncryptionMixin +from .._generated.aio import AzureQueueStorage +from .._generated.models import QueueMessage as GenQueueMessage, SignedIdentifier +from .._message_encoding import NoDecodePolicy, NoEncodePolicy +from .._models import AccessPolicy, QueueMessage +from .._queue_client_helpers import _format_url, _from_queue_url, _parse_url from .._serialize import get_api_version -from .._shared.base_client_async import AsyncStorageAccountHostsMixin +from .._shared.base_client import StorageAccountHostsMixin +from .._shared.base_client_async import AsyncStorageAccountHostsMixin, parse_connection_str from .._shared.policies_async import ExponentialRetry from .._shared.request_handlers import add_metadata_headers, serialize_iso from .._shared.response_handlers import ( - return_response_headers, process_storage_error, return_headers_and_deserialized, + return_response_headers ) -from .._generated.aio import AzureQueueStorage -from .._generated.models import SignedIdentifier, QueueMessage as GenQueueMessage -from .._deserialize import deserialize_queue_properties, deserialize_queue_creation -from .._encryption import modify_user_agent_for_encryption, StorageEncryptionMixin -from .._models import QueueMessage, AccessPolicy -from .._queue_client import QueueClient as QueueClientBase -from ._models import MessagesPaged if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential + from azure.core.credentials_async import AsyncTokenCredential from .._models import QueueProperties -class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncryptionMixin): +class QueueClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long """A client to interact with a specific Queue. :param str account_url: @@ -54,6 +59,7 @@ class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncrypt - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -87,25 +93,111 @@ class QueueClient(AsyncStorageAccountHostsMixin, QueueClientBase, StorageEncrypt """ def __init__( - self, account_url: str, - queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> None: + self, account_url: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) - super(QueueClient, self).__init__( - account_url, queue_name=queue_name, credential=credential, loop=loop, **kwargs - ) - self._client = AzureQueueStorage(self.url, base_url=self.url, - pipeline=self._pipeline, loop=loop) # type: ignore - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + parsed_url, sas_token = _parse_url(account_url=account_url, queue_name=queue_name, credential=credential) + self.queue_name = queue_name + self._query_str, credential = self._format_query_string(sas_token, credential) + super(QueueClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) + + self._message_encode_policy = kwargs.get('message_encode_policy', None) or NoEncodePolicy() + self._message_decode_policy = kwargs.get('message_decode_policy', None) or NoDecodePolicy() + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) + def _format_url(self, hostname: str) -> str: + """Format the endpoint URL according to the current location + mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str + """ + return _format_url( + queue_name=self.queue_name, + hostname=hostname, + scheme=self.scheme, + query_str=self._query_str) + + @classmethod + def from_queue_url( + cls, queue_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: + """A client to interact with a specific Queue. + + :param str queue_url: The full URI to the queue, including SAS token if used. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + If the resource URI already contains a SAS token, this will be ignored in favor of an explicit credential + - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long + :returns: A queue client. + :rtype: ~azure.storage.queue.QueueClient + """ + account_url, queue_name = _from_queue_url(queue_url=queue_url) + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + + @classmethod + def from_connection_string( + cls, conn_str: str, + queue_name: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: + """Create QueueClient from a Connection String. + + :param str conn_str: + A connection string to an Azure Storage account. + :param queue_name: The queue name. + :type queue_name: str + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token, or the connection string already has shared + access key values. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + Credentials provided here will take precedence over those in the connection string. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long + :returns: A queue client. + :rtype: ~azure.storage.queue.QueueClient + + .. admonition:: Example: + + .. literalinclude:: ../samples/queue_samples_message.py + :start-after: [START create_queue_client_from_connection_string] + :end-before: [END create_queue_client_from_connection_string] + :language: python + :dedent: 8 + :caption: Create the queue client from connection string. + """ + account_url, secondary, credential = parse_connection_str( + conn_str, credential, 'queue') + if 'secondary_hostname' not in kwargs: + kwargs['secondary_hostname'] = secondary + return cls(account_url, queue_name=queue_name, credential=credential, **kwargs) + @distributed_trace_async - async def create_queue(self, **kwargs): - # type: (Optional[Any]) -> None + async def create_queue( + self, *, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Creates a new queue in the storage account. If a queue with the same name already exists, the operation fails with @@ -134,20 +226,18 @@ async def create_queue(self, **kwargs): :dedent: 12 :caption: Create a queue. """ - metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) headers = kwargs.pop("headers", {}) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - return await self._client.queue.create( # type: ignore + return await self._client.queue.create( metadata=metadata, timeout=timeout, headers=headers, cls=deserialize_queue_creation, **kwargs ) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def delete_queue(self, **kwargs): - # type: (Optional[Any]) -> None + async def delete_queue(self, **kwargs: Any) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -182,8 +272,7 @@ async def delete_queue(self, **kwargs): process_storage_error(error) @distributed_trace_async - async def get_queue_properties(self, **kwargs): - # type: (Optional[Any]) -> QueueProperties + async def get_queue_properties(self, **kwargs: Any) -> "QueueProperties": """Returns all user-defined metadata for the specified queue. The data returned does not include the queue's list of messages. @@ -204,19 +293,19 @@ async def get_queue_properties(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - response = await self._client.queue.get_properties( + response = cast("QueueProperties", await (self._client.queue.get_properties( timeout=timeout, cls=deserialize_queue_properties, **kwargs - ) + ))) except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name - return response # type: ignore + return response @distributed_trace_async async def set_queue_metadata( - self, metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any - ) -> Dict[str, Any]: + self, metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: """Sets user-defined metadata on the specified queue. Metadata is associated with the queue as name-value pairs. @@ -244,17 +333,16 @@ async def set_queue_metadata( """ timeout = kwargs.pop('timeout', None) headers = kwargs.pop("headers", {}) - headers.update(add_metadata_headers(metadata)) # type: ignore + headers.update(add_metadata_headers(metadata)) try: - await self._client.queue.set_metadata( # type: ignore + await self._client.queue.set_metadata( timeout=timeout, headers=headers, cls=return_response_headers, **kwargs ) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def get_queue_access_policy(self, **kwargs): - # type: (Optional[Any]) -> Dict[str, Any] + async def get_queue_access_policy(self, **kwargs: Any) -> Dict[str, AccessPolicy]: """Returns details about any stored access policies specified on the queue that may be used with Shared Access Signatures. @@ -269,16 +357,18 @@ async def get_queue_access_policy(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - _, identifiers = await self._client.queue.get_access_policy( + _, identifiers = cast(Tuple[Dict, List], await self._client.queue.get_access_policy( timeout=timeout, cls=return_headers_and_deserialized, **kwargs - ) + )) except HttpResponseError as error: process_storage_error(error) return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace_async - async def set_queue_access_policy(self, signed_identifiers, **kwargs): - # type: (Dict[str, AccessPolicy], Optional[Any]) -> None + async def set_queue_access_policy( + self, signed_identifiers: Dict[str, AccessPolicy], + **kwargs: Any + ) -> None: """Sets stored access policies for the queue that may be used with Shared Access Signatures. @@ -297,7 +387,7 @@ async def set_queue_access_policy(self, signed_identifiers, **kwargs): SignedIdentifier access policies to associate with the queue. This may contain up to 5 elements. An empty dict will clear the access policies set on the service. - :type signed_identifiers: dict(str, ~azure.storage.queue.AccessPolicy) + :type signed_identifiers: Dict[str, ~azure.storage.queue.AccessPolicy] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. @@ -326,19 +416,19 @@ async def set_queue_access_policy(self, signed_identifiers, **kwargs): value.start = serialize_iso(value.start) value.expiry = serialize_iso(value.expiry) identifiers.append(SignedIdentifier(id=key, access_policy=value)) - signed_identifiers = identifiers # type: ignore try: - await self._client.queue.set_access_policy(queue_acl=signed_identifiers or None, timeout=timeout, **kwargs) + await self._client.queue.set_access_policy(queue_acl=identifiers or None, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def send_message( # type: ignore - self, - content, # type: Any - **kwargs # type: Optional[Any] - ): - # type: (...) -> QueueMessage + async def send_message( + self, content: Any, + *, + visibility_timeout: Optional[int] = None, + time_to_live: Optional[int] = None, + **kwargs: Any + ) -> "QueueMessage": """Adds a new message to the back of the message queue. The visibility timeout specifies the time that the message will be @@ -352,7 +442,7 @@ async def send_message( # type: ignore If the key-encryption-key field is set on the local service object, this method will encrypt the content before uploading. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. The encoded message can be up to 64KB in size. @@ -388,8 +478,6 @@ async def send_message( # type: ignore :dedent: 16 :caption: Send messages. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) - time_to_live = kwargs.pop('time_to_live', None) timeout = kwargs.pop('timeout', None) if self.key_encryption_key: modify_user_agent_for_encryption( @@ -399,7 +487,7 @@ async def send_message( # type: ignore kwargs) try: - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, @@ -411,11 +499,11 @@ async def send_message( # type: ignore Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function) - encoded_content = self._config.message_encode_policy(content) + encoded_content = self._message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) try: @@ -426,18 +514,24 @@ async def send_message( # type: ignore timeout=timeout, **kwargs ) - queue_message = QueueMessage(content=content) - queue_message.id = enqueued[0].message_id - queue_message.inserted_on = enqueued[0].insertion_time - queue_message.expires_on = enqueued[0].expiration_time - queue_message.pop_receipt = enqueued[0].pop_receipt - queue_message.next_visible_on = enqueued[0].time_next_visible + queue_message = QueueMessage( + content=content, + id=enqueued[0].message_id, + inserted_on=enqueued[0].insertion_time, + expires_on=enqueued[0].expiration_time, + pop_receipt = enqueued[0].pop_receipt, + next_visible_on = enqueued[0].time_next_visible + ) return queue_message except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: + async def receive_message( + self, *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> Optional[QueueMessage]: """Removes one message from the front of the queue. When the message is retrieved from the queue, the response includes the message @@ -475,7 +569,6 @@ async def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: :dedent: 12 :caption: Receive one message from the queue. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( @@ -484,16 +577,17 @@ async def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function) + resolver=self.key_resolver_function + ) try: message = await self._client.messages.dequeue( number_of_messages=1, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) wrapped_message = QueueMessage._from_generated( # pylint: disable=protected-access @@ -503,8 +597,13 @@ async def receive_message(self, **kwargs: Any) -> Optional[QueueMessage]: process_storage_error(error) @distributed_trace - def receive_messages(self, **kwargs): - # type: (Optional[Any]) -> AsyncItemPaged[QueueMessage] + def receive_messages( + self, *, + messages_per_page: Optional[int] = None, + visibility_timeout: Optional[int] = None, + max_messages: Optional[int] = None, + **kwargs: Any + ) -> AsyncItemPaged[QueueMessage]: """Removes one or more messages from the front of the queue. When a message is retrieved from the queue, the response includes the message @@ -532,14 +631,14 @@ def receive_messages(self, **kwargs): larger than 7 days. The visibility timeout of a message cannot be set to a value later than the expiry time. visibility_timeout should be set to a value smaller than the time-to-live value. + :keyword int max_messages: + An integer that specifies the maximum number of messages to retrieve from the queue. :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-queue-service-operations. This value is not tracked or validated on the client. To configure client-side network timesouts see `here `_. - :keyword int max_messages: - An integer that specifies the maximum number of messages to retrieve from the queue. :return: Returns a message iterator of dict-like Message objects. :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.storage.queue.QueueMessage] @@ -553,10 +652,7 @@ def receive_messages(self, **kwargs): :dedent: 16 :caption: Receive messages from the queue. """ - messages_per_page = kwargs.pop('messages_per_page', None) - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) - max_messages = kwargs.pop('max_messages', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( self._config.user_agent_policy.user_agent, @@ -564,7 +660,7 @@ def receive_messages(self, **kwargs): self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function @@ -574,7 +670,7 @@ def receive_messages(self, **kwargs): self._client.messages.dequeue, visibilitytimeout=visibility_timeout, timeout=timeout, - cls=self._config.message_decode_policy, + cls=self._message_decode_policy, **kwargs ) if max_messages is not None and messages_per_page is not None: @@ -587,13 +683,13 @@ def receive_messages(self, **kwargs): @distributed_trace_async async def update_message( - self, - message, - pop_receipt=None, - content=None, - **kwargs - ): - # type: (Any, int, Optional[str], Optional[Any], Any) -> QueueMessage + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + content: Optional[Any] = None, + *, + visibility_timeout: Optional[int] = None, + **kwargs: Any + ) -> QueueMessage: """Updates the visibility timeout of a message. You can also use this operation to update the contents of a message. @@ -614,7 +710,7 @@ async def update_message( :param str pop_receipt: A valid pop receipt value returned from an earlier call to the :func:`~receive_messages` or :func:`~update_message` operation. - :param obj content: + :param Any content: Message content. Allowed type is determined by the encode_function set on the service. Default is str. :keyword int visibility_timeout: @@ -644,7 +740,6 @@ async def update_message( :dedent: 16 :caption: Update a message. """ - visibility_timeout = kwargs.pop('visibility_timeout', None) timeout = kwargs.pop('timeout', None) if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( @@ -653,14 +748,14 @@ async def update_message( self.encryption_version, kwargs) - try: + if isinstance(message, QueueMessage): message_id = message.id message_text = content or message.content receipt = pop_receipt or message.pop_receipt inserted_on = message.inserted_on expires_on = message.expires_on dequeue_count = message.dequeue_count - except AttributeError: + else: message_id = message message_text = content receipt = pop_receipt @@ -672,7 +767,7 @@ async def update_message( raise ValueError("pop_receipt must be present") if message_text is not None: try: - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function, @@ -685,17 +780,17 @@ async def update_message( Consider updating your encryption information/implementation. \ Retrying without encryption_version." ) - self._config.message_encode_policy.configure( + self._message_encode_policy.configure( self.require_encryption, self.key_encryption_key, self.key_resolver_function ) - encoded_message_text = self._config.message_encode_policy(message_text) + encoded_message_text = self._message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: - updated = None # type: ignore + updated = None try: - response = await self._client.message_id.update( + response = cast(QueueMessage, await self._client.message_id.update( queue_message=updated, visibilitytimeout=visibility_timeout or 0, timeout=timeout, @@ -703,21 +798,25 @@ async def update_message( cls=return_response_headers, queue_message_id=message_id, **kwargs + )) + new_message = QueueMessage( + content=message_text, + id=message_id, + inserted_on=inserted_on, + dequeue_count=dequeue_count, + expires_on=expires_on, + pop_receipt = response['popreceipt'], + next_visible_on = response['time_next_visible'] ) - new_message = QueueMessage(content=message_text) - new_message.id = message_id - new_message.inserted_on = inserted_on - new_message.expires_on = expires_on - new_message.dequeue_count = dequeue_count - new_message.pop_receipt = response["popreceipt"] - new_message.next_visible_on = response["time_next_visible"] return new_message except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def peek_messages(self, max_messages=None, **kwargs): - # type: (Optional[int], Optional[Any]) -> List[QueueMessage] + async def peek_messages( + self, max_messages: Optional[int] = None, + **kwargs: Any + ) -> List[QueueMessage]: """Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. @@ -768,14 +867,14 @@ async def peek_messages(self, max_messages=None, **kwargs): self.encryption_version, kwargs) - self._config.message_decode_policy.configure( + self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function ) try: messages = await self._client.messages.peek( - number_of_messages=max_messages, timeout=timeout, cls=self._config.message_decode_policy, **kwargs + number_of_messages=max_messages, timeout=timeout, cls=self._message_decode_policy, **kwargs ) wrapped_messages = [] for peeked in messages: @@ -785,8 +884,7 @@ async def peek_messages(self, max_messages=None, **kwargs): process_storage_error(error) @distributed_trace_async - async def clear_messages(self, **kwargs): - # type: (Optional[Any]) -> None + async def clear_messages(self, **kwargs: Any) -> None: """Deletes all messages from the specified queue. :keyword int timeout: @@ -812,8 +910,11 @@ async def clear_messages(self, **kwargs): process_storage_error(error) @distributed_trace_async - async def delete_message(self, message, pop_receipt=None, **kwargs): - # type: (Any, Optional[str], Any) -> None + async def delete_message( + self, message: Union[str, QueueMessage], + pop_receipt: Optional[str] = None, + **kwargs: Any + ) -> None: """Deletes the specified message. Normally after a client retrieves a message with the receive messages operation, @@ -849,10 +950,12 @@ async def delete_message(self, message, pop_receipt=None, **kwargs): :caption: Delete a message. """ timeout = kwargs.pop('timeout', None) - try: + + receipt: Optional[str] + if isinstance(message, QueueMessage): message_id = message.id receipt = pop_receipt or message.pop_receipt - except AttributeError: + else: message_id = message receipt = pop_receipt diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index ae4e9f4a07b4..39861cd9371b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -7,41 +7,37 @@ import functools from typing import ( - Any, Dict, List, Optional, Union, - TYPE_CHECKING) + Any, Dict, List, Optional, + TYPE_CHECKING, Union +) +from typing_extensions import Self from azure.core.async_paging import AsyncItemPaged from azure.core.exceptions import HttpResponseError from azure.core.pipeline import AsyncPipeline from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async +from ._models import QueuePropertiesPaged +from ._queue_client_async import QueueClient +from .._encryption import StorageEncryptionMixin +from .._generated.aio import AzureQueueStorage +from .._generated.models import StorageServiceProperties +from .._models import CorsRule, QueueProperties, service_properties_deserialize, service_stats_deserialize +from .._queue_service_client_helpers import _parse_url from .._serialize import get_api_version -from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper -from .._shared.policies_async import ExponentialRetry +from .._shared.base_client import StorageAccountHostsMixin +from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper, parse_connection_str from .._shared.models import LocationMode +from .._shared.policies_async import ExponentialRetry from .._shared.response_handlers import process_storage_error -from .._generated.aio import AzureQueueStorage -from .._generated.models import StorageServiceProperties -from .._encryption import StorageEncryptionMixin -from .._models import ( - service_stats_deserialize, - service_properties_deserialize, -) -from .._queue_service_client import QueueServiceClient as QueueServiceClientBase -from ._models import QueuePropertiesPaged -from ._queue_client_async import QueueClient if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential - from .._models import ( - CorsRule, - Metrics, - QueueProperties, - QueueAnalyticsLogging, - ) + from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential + from azure.core.credentials_async import AsyncTokenCredential + from .._models import Metrics, QueueAnalyticsLogging -class QueueServiceClient(AsyncStorageAccountHostsMixin, QueueServiceClientBase, StorageEncryptionMixin): +class QueueServiceClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin): # type: ignore [misc] # pylint: disable=line-too-long """A client to interact with the Queue Service at the account level. This client provides operations to retrieve and configure the account properties @@ -62,6 +58,7 @@ class QueueServiceClient(AsyncStorageAccountHostsMixin, QueueServiceClientBase, - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long :keyword str api_version: The Storage API version to use for requests. Default value is the most recent service version that is compatible with the current SDK. Setting to an older version may result in reduced feature compatibility. @@ -89,25 +86,70 @@ class QueueServiceClient(AsyncStorageAccountHostsMixin, QueueServiceClientBase, """ def __init__( - self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> None: + self, account_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) loop = kwargs.pop('loop', None) - super(QueueServiceClient, self).__init__( # type: ignore - account_url, - credential=credential, - loop=loop, - **kwargs) - self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) # type: ignore - self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access + parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) + self._query_str, credential = self._format_query_string(sas_token, credential) + super(QueueServiceClient, self).__init__(parsed_url, service='queue', credential=credential, **kwargs) + self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) + self._client._config.version = get_api_version(kwargs) # type: ignore [assignment] # pylint: disable=protected-access self._loop = loop self._configure_encryption(kwargs) + def _format_url(self, hostname: str) -> str: + """Format the endpoint URL according to the current location + mode hostname. + + :param str hostname: The current location mode hostname. + :returns: The formatted endpoint URL according to the specified location mode hostname. + :rtype: str + """ + return f"{self.scheme}://{hostname}/{self._query_str}" + + @classmethod + def from_connection_string( + cls, conn_str: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: + """Create QueueServiceClient from a Connection String. + + :param str conn_str: + A connection string to an Azure Storage account. + :param credential: + The credentials with which to authenticate. This is optional if the + account URL already has a SAS token, or the connection string already has shared + access key values. The value can be a SAS token string, + an instance of a AzureSasCredential or AzureNamedKeyCredential from azure.core.credentials, + an account shared access key, or an instance of a TokenCredentials class from azure.identity. + Credentials provided here will take precedence over those in the connection string. + If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" + should be the storage account key. + :paramtype credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] # pylint: disable=line-too-long + :returns: A Queue service client. + :rtype: ~azure.storage.queue.QueueClient + + .. admonition:: Example: + + .. literalinclude:: ../samples/queue_samples_authentication.py + :start-after: [START auth_from_connection_string] + :end-before: [END auth_from_connection_string] + :language: python + :dedent: 8 + :caption: Creating the QueueServiceClient with a connection string. + """ + account_url, secondary, credential = parse_connection_str( + conn_str, credential, 'queue') + if 'secondary_hostname' not in kwargs: + kwargs['secondary_hostname'] = secondary + return cls(account_url, credential=credential, **kwargs) + @distributed_trace_async - async def get_service_stats(self, **kwargs): - # type: (Optional[Any]) -> Dict[str, Any] + async def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: """Retrieves statistics related to replication for the Queue service. It is only available when read-access geo-redundant replication is enabled for @@ -133,15 +175,14 @@ async def get_service_stats(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - stats = await self._client.service.get_statistics( # type: ignore + stats = await self._client.service.get_statistics( timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) return service_stats_deserialize(stats) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def get_service_properties(self, **kwargs): - # type: (Optional[Any]) -> Dict[str, Any] + async def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: """Gets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -162,20 +203,19 @@ async def get_service_properties(self, **kwargs): """ timeout = kwargs.pop('timeout', None) try: - service_props = await self._client.service.get_properties(timeout=timeout, **kwargs) # type: ignore + service_props = await self._client.service.get_properties(timeout=timeout, **kwargs) return service_properties_deserialize(service_props) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async - async def set_service_properties( # type: ignore - self, analytics_logging=None, # type: Optional[QueueAnalyticsLogging] - hour_metrics=None, # type: Optional[Metrics] - minute_metrics=None, # type: Optional[Metrics] - cors=None, # type: Optional[List[CorsRule]] - **kwargs - ): - # type: (...) -> None + async def set_service_properties( + self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + hour_metrics: Optional["Metrics"] = None, + minute_metrics: Optional["Metrics"] = None, + cors: Optional[List[CorsRule]] = None, + **kwargs: Any + ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -197,7 +237,7 @@ async def set_service_properties( # type: ignore You can include up to five CorsRule elements in the list. If an empty list is specified, all CORS rules will be deleted, and CORS will be disabled for the service. - :type cors: list(~azure.storage.queue.CorsRule) + :type cors: Optional[List(~azure.storage.queue.CorsRule)] :keyword int timeout: The timeout parameter is expressed in seconds. @@ -215,19 +255,19 @@ async def set_service_properties( # type: ignore logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=cors + cors=CorsRule._to_generated(cors) # pylint: disable=protected-access ) try: - await self._client.service.set_properties(props, timeout=timeout, **kwargs) # type: ignore + await self._client.service.set_properties(props, timeout=timeout, **kwargs) except HttpResponseError as error: process_storage_error(error) @distributed_trace def list_queues( - self, name_starts_with=None, # type: Optional[str] - include_metadata=False, # type: Optional[bool] - **kwargs - ): # type: (...) -> AsyncItemPaged + self, name_starts_with: Optional[str] = None, + include_metadata: Optional[bool] = False, + **kwargs: Any + ) -> AsyncItemPaged: """Returns a generator to list the queues under the specified account. The generator will lazily follow the continuation tokens returned by @@ -276,12 +316,11 @@ def list_queues( ) @distributed_trace_async - async def create_queue( # type: ignore - self, name, # type: str - metadata=None, # type: Optional[Dict[str, str]] - **kwargs - ): - # type: (...) -> QueueClient + async def create_queue( + self, name: str, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> QueueClient: """Creates a new queue under the specified account. If a queue with the same name already exists, the operation fails. @@ -291,7 +330,7 @@ async def create_queue( # type: ignore :param metadata: A dict with name_value pairs to associate with the queue as metadata. Example: {'Category': 'test'} - :type metadata: dict(str, str) + :type metadata: Dict[str, str] :keyword int timeout: The timeout parameter is expressed in seconds. :return: A QueueClient for the newly created Queue. @@ -314,11 +353,10 @@ async def create_queue( # type: ignore return queue @distributed_trace_async - async def delete_queue( # type: ignore - self, queue, # type: Union[QueueProperties, str] - **kwargs - ): - # type: (...) -> None + async def delete_queue( + self, queue: Union["QueueProperties", str], + **kwargs: Any + ) -> None: """Deletes the specified queue and any messages it contains. When a queue is successfully deleted, it is immediately marked for deletion @@ -351,8 +389,10 @@ async def delete_queue( # type: ignore kwargs.setdefault('merge_span', True) await queue_client.delete_queue(timeout=timeout, **kwargs) - def get_queue_client(self, queue, **kwargs): - # type: (Union[QueueProperties, str], Optional[Any]) -> QueueClient + def get_queue_client( + self, queue: Union["QueueProperties", str], + **kwargs: Any + ) -> QueueClient: """Get a client to interact with the specified queue. The queue need not already exist. @@ -373,14 +413,15 @@ def get_queue_client(self, queue, **kwargs): :dedent: 8 :caption: Get the queue client. """ - try: - queue_name = queue.name - except AttributeError: + if isinstance(queue, QueueProperties): + if queue.name is not None: + queue_name = queue.name + else: queue_name = queue _pipeline = AsyncPipeline( - transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access - policies=self._pipeline._impl_policies # pylint: disable = protected-access + transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable=protected-access + policies=self._pipeline._impl_policies # type: ignore # pylint: disable=protected-access ) return QueueClient( diff --git a/sdk/storage/azure-storage-queue/mypy.ini b/sdk/storage/azure-storage-queue/mypy.ini index f8e299bc34f3..0d5fa8139e1d 100644 --- a/sdk/storage/azure-storage-queue/mypy.ini +++ b/sdk/storage/azure-storage-queue/mypy.ini @@ -1,13 +1,2 @@ -[mypy] -python_version = 3.7 -warn_return_any = True -warn_unused_configs = True -ignore_missing_imports = True - -# Per-module options: - [mypy-azure.storage.queue._generated.*] ignore_errors = True - -[mypy-azure.core.*] -ignore_errors = True diff --git a/sdk/storage/azure-storage-queue/pyproject.toml b/sdk/storage/azure-storage-queue/pyproject.toml index ff779a518143..78755ba24174 100644 --- a/sdk/storage/azure-storage-queue/pyproject.toml +++ b/sdk/storage/azure-storage-queue/pyproject.toml @@ -1,5 +1,5 @@ [tool.azure-sdk-build] -mypy = false +mypy = true pyright = false -type_check_samples = false -verifytypes = false +type_check_samples = true +verifytypes = true diff --git a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py index 3be395e90c39..a6440ebdf02a 100644 --- a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py +++ b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py @@ -63,6 +63,6 @@ messages = queue_client.peek_messages(max_messages=20, logging_enable=True) for message in messages: try: - print(' Message: {}'.format(base64.b64decode(message.content))) + print(' Message: {!r}'.format(base64.b64decode(message.content))) except binascii.Error: print(' Message: {}'.format(message.content)) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py index 76cc089c1e20..9e1fd88d0c54 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py @@ -31,6 +31,7 @@ from datetime import datetime, timedelta import os +import sys class QueueAuthSamples(object): @@ -46,6 +47,11 @@ class QueueAuthSamples(object): active_directory_tenant_id = os.getenv("ACTIVE_DIRECTORY_TENANT_ID") def authentication_by_connection_string(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_connection_string") + sys.exit(1) + # Instantiate a QueueServiceClient using a connection string # [START auth_from_connection_string] from azure.storage.queue import QueueServiceClient @@ -56,6 +62,11 @@ def authentication_by_connection_string(self): properties = queue_service.get_service_properties() def authentication_by_shared_key(self): + if self.account_url is None or self.access_key is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_key") + sys.exit(1) + # Instantiate a QueueServiceClient using a shared access key # [START create_queue_service_client] from azure.storage.queue import QueueServiceClient @@ -66,6 +77,15 @@ def authentication_by_shared_key(self): properties = queue_service.get_service_properties() def authentication_by_active_directory(self): + if (self.active_directory_tenant_id is None or + self.active_directory_application_id is None or + self.active_directory_application_secret is None or + self.account_url is None + ): + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_active_directory") + sys.exit(1) + # [START create_queue_service_client_token] # Get a token credential for authentication from azure.identity import ClientSecretCredential @@ -84,6 +104,15 @@ def authentication_by_active_directory(self): properties = queue_service.get_service_properties() def authentication_by_shared_access_signature(self): + if (self.connection_string is None or + self.account_name is None or + self.access_key is None or + self.account_url is None + ): + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_access_signature") + sys.exit(1) + # Instantiate a QueueServiceClient using a connection string from azure.storage.queue import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py index 5b7f97818cc7..ebcf80199410 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py @@ -32,6 +32,7 @@ from datetime import datetime, timedelta import asyncio import os +import sys class QueueAuthSamplesAsync(object): @@ -47,6 +48,11 @@ class QueueAuthSamplesAsync(object): active_directory_tenant_id = os.getenv("ACTIVE_DIRECTORY_TENANT_ID") async def authentication_by_connection_string_async(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_connection_string_async") + sys.exit(1) + # Instantiate a QueueServiceClient using a connection string # [START async_auth_from_connection_string] from azure.storage.queue.aio import QueueServiceClient @@ -58,6 +64,11 @@ async def authentication_by_connection_string_async(self): properties = await queue_service.get_service_properties() async def authentication_by_shared_key_async(self): + if self.account_url is None or self.access_key is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_key_async") + sys.exit(1) + # Instantiate a QueueServiceClient using a shared access key # [START async_create_queue_service_client] from azure.storage.queue.aio import QueueServiceClient @@ -69,6 +80,15 @@ async def authentication_by_shared_key_async(self): properties = await queue_service.get_service_properties() async def authentication_by_active_directory_async(self): + if (self.active_directory_tenant_id is None or + self.active_directory_application_id is None or + self.active_directory_application_secret is None or + self.account_url is None + ): + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_active_directory_async") + sys.exit(1) + # [START async_create_queue_service_client_token] # Get a token credential for authentication from azure.identity.aio import ClientSecretCredential @@ -88,6 +108,15 @@ async def authentication_by_active_directory_async(self): properties = await queue_service.get_service_properties() async def authentication_by_shared_access_signature_async(self): + if (self.connection_string is None or + self.account_name is None or + self.access_key is None or + self.account_url is None + ): + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: authentication_by_shared_access_signature_async") + sys.exit(1) + # Instantiate a QueueServiceClient using a connection string from azure.storage.queue.aio import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py index 5bc6cf41f6b1..d798563b4e90 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py @@ -22,6 +22,7 @@ import os +import sys class QueueHelloWorldSamples(object): @@ -29,6 +30,11 @@ class QueueHelloWorldSamples(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") def create_client_with_connection_string(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: create_client_with_connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) @@ -37,6 +43,11 @@ def create_client_with_connection_string(self): properties = queue_service.get_service_properties() def queue_and_messages_example(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: queue_and_messages_example") + sys.exit(1) + # Instantiate the QueueClient from a connection string from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py index bc33a36a29a9..34aa77894fb2 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py @@ -23,6 +23,7 @@ import asyncio import os +import sys class QueueHelloWorldSamplesAsync(object): @@ -30,6 +31,11 @@ class QueueHelloWorldSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") async def create_client_with_connection_string_async(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: create_client_with_connection_string_async") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) @@ -39,6 +45,11 @@ async def create_client_with_connection_string_async(self): properties = await queue_service.get_service_properties() async def queue_and_messages_example_async(self): + if self.connection_string is None: + print("Missing required environment variable(s). Please see specific test for more details." + '\n' + + "Test: queue_and_messages_example_async") + sys.exit(1) + # Instantiate the QueueClient from a connection string from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py index 04f5139b3fdd..44172f3d53c7 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py @@ -25,6 +25,7 @@ from datetime import datetime, timedelta import os +import sys class QueueMessageSamples(object): @@ -32,6 +33,10 @@ class QueueMessageSamples(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") def set_access_policy(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # [START create_queue_client_from_connection_string] from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") @@ -84,6 +89,10 @@ def set_access_policy(self): queue.delete_queue() def queue_metadata(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") @@ -106,6 +115,10 @@ def queue_metadata(self): queue.delete_queue() def send_and_receive_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") @@ -147,6 +160,10 @@ def send_and_receive_messages(self): queue.delete_queue() def list_message_pages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") @@ -180,6 +197,10 @@ def list_message_pages(self): queue.delete_queue() def receive_one_message_from_queue(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") @@ -199,6 +220,9 @@ def receive_one_message_from_queue(self): # We should see message 3 if we peek message3 = queue.peek_messages()[0] + if not message1 or not message2 or not message3: + raise ValueError("One of the messages are None.") + print(message1.content) print(message2.content) print(message3.content) @@ -208,6 +232,10 @@ def receive_one_message_from_queue(self): queue.delete_queue() def delete_and_clear_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") @@ -240,6 +268,10 @@ def delete_and_clear_messages(self): queue.delete_queue() def peek_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") @@ -272,6 +304,10 @@ def peek_messages(self): queue.delete_queue() def update_message(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") @@ -301,6 +337,10 @@ def update_message(self): queue.delete_queue() def receive_messages_with_max_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue9") @@ -348,4 +388,4 @@ def receive_messages_with_max_messages(self): sample.delete_and_clear_messages() sample.peek_messages() sample.update_message() - sample.receive_messages_with_max_messages() \ No newline at end of file + sample.receive_messages_with_max_messages() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py index f0d8affca3f1..c5aafcdfc80d 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py @@ -25,6 +25,7 @@ from datetime import datetime, timedelta import asyncio import os +import sys class QueueMessageSamplesAsync(object): @@ -32,6 +33,10 @@ class QueueMessageSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") async def set_access_policy_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # [START async_create_queue_client_from_connection_string] from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") @@ -83,6 +88,10 @@ async def set_access_policy_async(self): await queue.delete_queue() async def queue_metadata_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") @@ -106,6 +115,10 @@ async def queue_metadata_async(self): await queue.delete_queue() async def send_and_receive_messages_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") @@ -134,7 +147,7 @@ async def send_and_receive_messages_async(self): # Receive messages by batch messages = queue.receive_messages(messages_per_page=5) async for msg_batch in messages.by_page(): - for msg in msg_batch: + async for msg in msg_batch: print(msg.content) await queue.delete_message(msg) # [END async_receive_messages] @@ -150,6 +163,10 @@ async def send_and_receive_messages_async(self): await queue.delete_queue() async def receive_one_message_from_queue(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") @@ -171,6 +188,9 @@ async def receive_one_message_from_queue(self): # We should see message 3 if we peek message3 = await queue.peek_messages() + if not message1 or not message2 or not message3: + raise ValueError("One of the messages are None.") + print(message1.content) print(message2.content) print(message3[0].content) @@ -180,6 +200,10 @@ async def receive_one_message_from_queue(self): await queue.delete_queue() async def delete_and_clear_messages_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") @@ -216,6 +240,10 @@ async def delete_and_clear_messages_async(self): await queue.delete_queue() async def peek_messages_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") @@ -251,6 +279,10 @@ async def peek_messages_async(self): await queue.delete_queue() async def update_message_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") @@ -281,6 +313,10 @@ async def update_message_async(self): await queue.delete_queue() async def receive_messages_with_max_messages(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate a queue client from azure.storage.queue.aio import QueueClient queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py index 5b6d24eb5cbd..d10c230db526 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py @@ -21,6 +21,7 @@ """ import os +import sys class QueueServiceSamples(object): @@ -28,6 +29,10 @@ class QueueServiceSamples(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") def queue_service_properties(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) @@ -69,6 +74,10 @@ def queue_service_properties(self): # [END get_queue_service_properties] def queues_in_account(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) @@ -96,6 +105,10 @@ def queues_in_account(self): # [END qsc_delete_queue] def get_queue_client(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient, QueueClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) @@ -110,4 +123,4 @@ def get_queue_client(self): sample = QueueServiceSamples() sample.queue_service_properties() sample.queues_in_account() - sample.get_queue_client() \ No newline at end of file + sample.get_queue_client() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py index 01d79059ae2f..b57816180892 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py @@ -23,6 +23,7 @@ import asyncio import os +import sys class QueueServiceSamplesAsync(object): @@ -30,6 +31,10 @@ class QueueServiceSamplesAsync(object): connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") async def queue_service_properties_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) @@ -72,6 +77,10 @@ async def queue_service_properties_async(self): # [END async_get_queue_service_properties] async def queues_in_account_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) @@ -100,6 +109,10 @@ async def queues_in_account_async(self): # [END async_qsc_delete_queue] async def get_queue_client_async(self): + if self.connection_string is None: + print("Missing required environment variable: connection_string") + sys.exit(1) + # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient, QueueClient queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py index 86c47c7dce1f..82d8a1532c59 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py @@ -70,8 +70,8 @@ def test_message_text_xml(self, **kwargs): queue = qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) # Asserts - assert isinstance(queue._config.message_encode_policy, NoEncodePolicy) - assert isinstance(queue._config.message_decode_policy, NoDecodePolicy) + assert isinstance(queue._message_encode_policy, NoEncodePolicy) + assert isinstance(queue._message_decode_policy, NoDecodePolicy) self._validate_encoding(queue, message) @QueuePreparer() @@ -225,8 +225,8 @@ def test_message_no_encoding(self): message_decode_policy=None) # Asserts - assert isinstance(queue._config.message_encode_policy, NoEncodePolicy) - assert isinstance(queue._config.message_decode_policy, NoDecodePolicy) + assert isinstance(queue._message_encode_policy, NoEncodePolicy) + assert isinstance(queue._message_decode_policy, NoDecodePolicy) # ------------------------------------------------------------------------------