From 63ed9de8578cbaad053cb49bae8578264d1a268b Mon Sep 17 00:00:00 2001 From: iscai-msft Date: Fri, 5 Jun 2020 11:56:53 -0400 Subject: [PATCH] add type annotation for empty class model definitions in service client --- autorest/codegen/models/client.py | 8 ++++++-- autorest/codegen/templates/service_client.py.jinja2 | 2 +- .../Head/head/_auto_rest_head_test_service.py | 4 ++-- .../head/aio/_auto_rest_head_test_service_async.py | 4 +++- .../_auto_rest_head_exception_test_service.py | 4 ++-- .../_auto_rest_head_exception_test_service_async.py | 4 +++- .../aio/operations_async/_string_operations_async.py | 10 +++++----- .../bodystring/operations/_string_operations.py | 10 +++++----- .../nonstringenums/_non_string_enums_client.py | 4 ++-- .../aio/_non_string_enums_client_async.py | 8 ++++++-- .../ObjectType/objecttype/_object_type_client.py | 4 ++-- .../objecttype/aio/_object_type_client_async.py | 8 ++++++-- 12 files changed, 43 insertions(+), 27 deletions(-) diff --git a/autorest/codegen/models/client.py b/autorest/codegen/models/client.py index 242e3369926..392fb471ab8 100644 --- a/autorest/codegen/models/client.py +++ b/autorest/codegen/models/client.py @@ -28,8 +28,6 @@ def imports(code_model, async_mode: bool) -> FileImport: file_import.add_from_import("msrest", "Deserializer", ImportType.AZURECORE) file_import.add_from_import("typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL) - # if code_model.options["credential"]: - # file_import.add_from_import("azure.core.credentials", "TokenCredential", ImportType.AZURECORE) any_optional_gp = any(not gp.required for gp in code_model.global_parameters) if any_optional_gp or code_model.base_url: @@ -43,4 +41,10 @@ def imports(code_model, async_mode: bool) -> FileImport: file_import.add_from_import( "azure.core", Client.pipeline_class(code_model, async_mode), ImportType.AZURECORE ) + + if not code_model.sorted_schemas: + # in this case, we have client_models = {} in the service client, which needs a type annotation + # this import will always be commented, so will always add it to the typing section + file_import.add_from_import("typing", "Dict", ImportType.STDLIB, TypingSection.TYPING) + return file_import diff --git a/autorest/codegen/templates/service_client.py.jinja2 b/autorest/codegen/templates/service_client.py.jinja2 index 15b46145e16..b28d967cf5f 100644 --- a/autorest/codegen/templates/service_client.py.jinja2 +++ b/autorest/codegen/templates/service_client.py.jinja2 @@ -79,7 +79,7 @@ class {{ code_model.class_name }}({{ base_class }}): {% if code_model.sorted_schemas %} client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} {% else %} - client_models = {} + client_models = {} # type: Dict[str, Any] {% endif %} self._serialize = Serializer(client_models) {% if not code_model.options['client_side_validation'] %} diff --git a/test/azure/Expected/AcceptanceTests/Head/head/_auto_rest_head_test_service.py b/test/azure/Expected/AcceptanceTests/Head/head/_auto_rest_head_test_service.py index 4390e303543..9b4dc1a3162 100644 --- a/test/azure/Expected/AcceptanceTests/Head/head/_auto_rest_head_test_service.py +++ b/test/azure/Expected/AcceptanceTests/Head/head/_auto_rest_head_test_service.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from typing import Any, Optional + from typing import Any, Dict, Optional from azure.core.credentials import TokenCredential @@ -44,7 +44,7 @@ def __init__( self._config = AutoRestHeadTestServiceConfiguration(credential, **kwargs) self._client = ARMPipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) diff --git a/test/azure/Expected/AcceptanceTests/Head/head/aio/_auto_rest_head_test_service_async.py b/test/azure/Expected/AcceptanceTests/Head/head/aio/_auto_rest_head_test_service_async.py index 027ef7a504e..cb5b9c7e23a 100644 --- a/test/azure/Expected/AcceptanceTests/Head/head/aio/_auto_rest_head_test_service_async.py +++ b/test/azure/Expected/AcceptanceTests/Head/head/aio/_auto_rest_head_test_service_async.py @@ -13,6 +13,8 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports + from typing import Dict + from azure.core.credentials_async import AsyncTokenCredential from ._configuration_async import AutoRestHeadTestServiceConfiguration @@ -41,7 +43,7 @@ def __init__( self._config = AutoRestHeadTestServiceConfiguration(credential, **kwargs) self._client = AsyncARMPipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) diff --git a/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/_auto_rest_head_exception_test_service.py b/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/_auto_rest_head_exception_test_service.py index 0f632c2c717..3862a805a4e 100644 --- a/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/_auto_rest_head_exception_test_service.py +++ b/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/_auto_rest_head_exception_test_service.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from typing import Any, Optional + from typing import Any, Dict, Optional from azure.core.credentials import TokenCredential @@ -44,7 +44,7 @@ def __init__( self._config = AutoRestHeadExceptionTestServiceConfiguration(credential, **kwargs) self._client = ARMPipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) diff --git a/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/aio/_auto_rest_head_exception_test_service_async.py b/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/aio/_auto_rest_head_exception_test_service_async.py index 69aeb12f268..294469e1f07 100644 --- a/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/aio/_auto_rest_head_exception_test_service_async.py +++ b/test/azure/Expected/AcceptanceTests/HeadExceptions/headexceptions/aio/_auto_rest_head_exception_test_service_async.py @@ -13,6 +13,8 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports + from typing import Dict + from azure.core.credentials_async import AsyncTokenCredential from ._configuration_async import AutoRestHeadExceptionTestServiceConfiguration @@ -41,7 +43,7 @@ def __init__( self._config = AutoRestHeadExceptionTestServiceConfiguration(credential, **kwargs) self._client = AsyncARMPipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) diff --git a/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/aio/operations_async/_string_operations_async.py b/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/aio/operations_async/_string_operations_async.py index 4ac4736575b..c71130a62f3 100644 --- a/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/aio/operations_async/_string_operations_async.py +++ b/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/aio/operations_async/_string_operations_async.py @@ -452,15 +452,15 @@ async def get_not_provided( async def get_base64_encoded( self, **kwargs - ) -> bytes: + ) -> bytearray: """Get value that is base64 encoded. :keyword callable cls: A custom type or function that will be passed the direct response - :return: bytes, or the result of cls(response) - :rtype: bytes + :return: bytearray, or the result of cls(response) + :rtype: bytearray :raises: ~azure.core.exceptions.HttpResponseError """ - cls = kwargs.pop('cls', None) # type: ClsType[bytes] + cls = kwargs.pop('cls', None) # type: ClsType[bytearray] error_map = {404: ResourceNotFoundError, 409: ResourceExistsError} error_map.update(kwargs.pop('error_map', {})) @@ -483,7 +483,7 @@ async def get_base64_encoded( error = self._deserialize(models.Error, response) raise HttpResponseError(response=response, model=error) - deserialized = self._deserialize('base64', pipeline_response) + deserialized = self._deserialize('bytearray', pipeline_response) if cls: return cls(pipeline_response, deserialized, {}) diff --git a/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/operations/_string_operations.py b/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/operations/_string_operations.py index 486edc4ad09..c377ba629b9 100644 --- a/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/operations/_string_operations.py +++ b/test/vanilla/Expected/AcceptanceTests/BodyString/bodystring/operations/_string_operations.py @@ -466,15 +466,15 @@ def get_base64_encoded( self, **kwargs # type: Any ): - # type: (...) -> bytes + # type: (...) -> bytearray """Get value that is base64 encoded. :keyword callable cls: A custom type or function that will be passed the direct response - :return: bytes, or the result of cls(response) - :rtype: bytes + :return: bytearray, or the result of cls(response) + :rtype: bytearray :raises: ~azure.core.exceptions.HttpResponseError """ - cls = kwargs.pop('cls', None) # type: ClsType[bytes] + cls = kwargs.pop('cls', None) # type: ClsType[bytearray] error_map = {404: ResourceNotFoundError, 409: ResourceExistsError} error_map.update(kwargs.pop('error_map', {})) @@ -497,7 +497,7 @@ def get_base64_encoded( error = self._deserialize(models.Error, response) raise HttpResponseError(response=response, model=error) - deserialized = self._deserialize('base64', pipeline_response) + deserialized = self._deserialize('bytearray', pipeline_response) if cls: return cls(pipeline_response, deserialized, {}) diff --git a/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/_non_string_enums_client.py b/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/_non_string_enums_client.py index 31d1553f67a..e0080412189 100644 --- a/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/_non_string_enums_client.py +++ b/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/_non_string_enums_client.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from typing import Any, Optional + from typing import Any, Dict, Optional from ._configuration import NonStringEnumsClientConfiguration from .operations import IntOperations @@ -42,7 +42,7 @@ def __init__( self._config = NonStringEnumsClientConfiguration(**kwargs) self._client = PipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) diff --git a/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/aio/_non_string_enums_client_async.py b/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/aio/_non_string_enums_client_async.py index caaa1c8467a..838b8d8f09e 100644 --- a/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/aio/_non_string_enums_client_async.py +++ b/test/vanilla/Expected/AcceptanceTests/NonStringEnums/nonstringenums/aio/_non_string_enums_client_async.py @@ -6,11 +6,15 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING from azure.core import AsyncPipelineClient from msrest import Deserializer, Serializer +if TYPE_CHECKING: + # pylint: disable=unused-import,ungrouped-imports + from typing import Dict + from ._configuration_async import NonStringEnumsClientConfiguration from .operations_async import IntOperations from .operations_async import FloatOperations @@ -37,7 +41,7 @@ def __init__( self._config = NonStringEnumsClientConfiguration(**kwargs) self._client = AsyncPipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) diff --git a/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/_object_type_client.py b/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/_object_type_client.py index de17cfd4c6d..edcb78d3ddb 100644 --- a/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/_object_type_client.py +++ b/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/_object_type_client.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from typing import Any, Optional + from typing import Any, Dict, Optional from ._configuration import ObjectTypeClientConfiguration from .operations import ObjectTypeClientOperationsMixin @@ -37,7 +37,7 @@ def __init__( self._config = ObjectTypeClientConfiguration(**kwargs) self._client = PipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) diff --git a/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/aio/_object_type_client_async.py b/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/aio/_object_type_client_async.py index 411714f166c..5acb3664725 100644 --- a/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/aio/_object_type_client_async.py +++ b/test/vanilla/Expected/AcceptanceTests/ObjectType/objecttype/aio/_object_type_client_async.py @@ -6,11 +6,15 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING from azure.core import AsyncPipelineClient from msrest import Deserializer, Serializer +if TYPE_CHECKING: + # pylint: disable=unused-import,ungrouped-imports + from typing import Dict + from ._configuration_async import ObjectTypeClientConfiguration from .operations_async import ObjectTypeClientOperationsMixin @@ -32,7 +36,7 @@ def __init__( self._config = ObjectTypeClientConfiguration(**kwargs) self._client = AsyncPipelineClient(base_url=base_url, config=self._config, **kwargs) - client_models = {} + client_models = {} # type: Dict[str, Any] self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models)