Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion autorest/codegen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .enum_schema import EnumSchema
from .base_schema import BaseSchema
from .constant_schema import ConstantSchema
from .imports import FileImport, ImportType
from .imports import FileImport, ImportType, TypingSection
from .lro_operation import LROOperation
from .paging_operation import PagingOperation
from .parameter import Parameter
Expand All @@ -36,6 +36,7 @@
"EnumSchema",
"FileImport",
"ImportType",
"TypingSection",
"PrimitiveSchema",
"LROOperation",
"Operation",
Expand Down
20 changes: 10 additions & 10 deletions autorest/codegen/models/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from enum import Enum, auto
from enum import Enum
from typing import Dict, Optional, Set


class ImportType(Enum):
STDLIB = auto()
THIRDPARTY = auto()
AZURECORE = auto()
LOCAL = auto()
class ImportType(str, Enum):
STDLIB = "stdlib"
THIRDPARTY = "thirdparty"
AZURECORE = "azurecore"
LOCAL = "local"

class TypingSection(Enum):
REGULAR = auto() # this import is always a typing import
CONDITIONAL = auto() # is a typing import when we're dealing with files that py2 will use, else regular
TYPING = auto() # never a typing import
class TypingSection(str, Enum):
REGULAR = "regular" # this import is always a typing import
CONDITIONAL = "conditional" # is a typing import when we're dealing with files that py2 will use, else regular
TYPING = "typing" # never a typing import


class FileImport:
Expand Down
37 changes: 32 additions & 5 deletions autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# license information.
# --------------------------------------------------------------------------
import copy
from typing import List, Optional, Set, Tuple
import json
from typing import List, Optional, Set, Tuple, Dict
from jinja2 import Environment
from ..models import (
CodeModel,
Expand All @@ -13,16 +14,42 @@
LROOperation,
PagingOperation,
CredentialSchema,
ParameterList
ParameterList,
TypingSection,
ImportType
)
from .import_serializer import FileImportSerializer

def _correct_credential_parameter(global_parameters: ParameterList, async_mode: bool) -> None:
credential_param = [
gp for gp in global_parameters.parameters if isinstance(gp.schema, CredentialSchema)
][0]
credential_param.schema = CredentialSchema(async_mode=async_mode)

def _json_serialize_imports(
imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]
):
if not imports:
return None

json_serialize_imports = {}
# need to make name_import set -> list to make the dictionary json serializable
# not using an OrderedDict since we're iterating through a set and the order there varies
# going to sort the list instead

for typing_section_key, typing_section_value in imports.items():
json_import_type_dictionary = {}
for import_type_key, import_type_value in typing_section_value.items():
json_package_name_dictionary = {}
for package_name, name_imports in import_type_value.items():
name_import_ordered_list = []
if name_imports:
name_import_ordered_list = list(name_imports)
name_import_ordered_list.sort()
json_package_name_dictionary[package_name] = name_import_ordered_list
json_import_type_dictionary[import_type_key] = json_package_name_dictionary
json_serialize_imports[typing_section_key] = json_import_type_dictionary
return json.dumps(json_serialize_imports)


class MetadataSerializer:
def __init__(self, code_model: CodeModel, env: Environment) -> None:
Expand Down Expand Up @@ -99,11 +126,11 @@ def _is_paging(operation):
is_paging=_is_paging,
str=str,
sync_mixin_imports=(
FileImportSerializer(sync_mixin_imports, is_python_3_file=False)
_json_serialize_imports(sync_mixin_imports.imports)
if sync_mixin_imports else None
),
async_mixin_imports=(
FileImportSerializer(async_mixin_imports, is_python_3_file=True)
_json_serialize_imports(async_mixin_imports.imports)
if async_mixin_imports else None
)
)
29 changes: 26 additions & 3 deletions autorest/multiapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Optional, cast, Any
from .multiapi_serializer import MultiAPISerializer
from .serializers import MultiAPISerializer, FileImportSerializer
from .models import FileImport
from ..jsonrpc import AutorestAPI

from .. import Plugin
Expand Down Expand Up @@ -247,6 +248,20 @@ def _parse_package_name_input(self) -> str:
self.output_package_name = self.input_package_name
return module_name

def _merge_mixin_imports_across_versions(
self, paths_to_versions: List[Path], async_mode: bool
) -> FileImport:
imports = FileImport()
imports_to_load = "async_imports" if async_mode else "sync_imports"
for version_path in paths_to_versions:
metadata_json = json.loads(self._autorestapi.read_file(version_path / "_metadata.json"))
if not metadata_json.get('operation_mixins'):
continue
current_version_imports = FileImport(json.loads(metadata_json[imports_to_load]))
imports.merge(current_version_imports)

return imports

def process(self) -> bool:
_LOGGER.info("Generating multiapi client")
# If True, means the auto-profile will consider preview versions.
Expand Down Expand Up @@ -324,6 +339,14 @@ def process(self) -> bool:
versioned_operations_dict, mixin_operations, last_api_version, preview_mode, async_mode=True
)

sync_imports = self._merge_mixin_imports_across_versions(
paths_to_versions, async_mode=False
)

async_imports = self._merge_mixin_imports_across_versions(
paths_to_versions, async_mode=True
)

conf = {
"client_name": metadata_json["client"]["name"],
"package_name": self.output_package_name,
Expand All @@ -340,8 +363,8 @@ def process(self) -> bool:
),
"config": metadata_json["config"],
"global_parameters": metadata_json["global_parameters"],
"sync_imports": metadata_json["sync_imports"],
"async_imports": metadata_json["async_imports"]
"sync_imports": str(FileImportSerializer(sync_imports, is_python_3_file=False)),
"async_imports": str(FileImportSerializer(async_imports, is_python_3_file=True))
}

multiapi_serializer = MultiAPISerializer(
Expand Down
13 changes: 13 additions & 0 deletions autorest/multiapi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from .imports import ImportType, FileImport, TypingSection

__all__ = [
"ImportType",
"FileImport",
"TypingSection"
]
78 changes: 78 additions & 0 deletions autorest/multiapi/models/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from enum import Enum
from typing import Dict, Optional, Set


class ImportType(str, Enum):
STDLIB = "stdlib"
THIRDPARTY = "thirdparty"
AZURECORE = "azurecore"
LOCAL = "local"

class TypingSection(str, Enum):
REGULAR = "regular" # this import is always a typing import
CONDITIONAL = "conditional" # is a typing import when we're dealing with files that py2 will use, else regular
TYPING = "typing" # never a typing import


class FileImport:
def __init__(
self, imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]] = None
) -> None:
# Basic implementation
# First level dict: TypingSection
# Second level dict: ImportType
# Third level dict: the package name.
# Fourth level set: None if this import is a "import", the name to import if it's a "from"
self._imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]] = imports or dict()

def _add_import(
self,
from_section: str,
import_type: ImportType,
name_import: Optional[str] = None,
typing_section: TypingSection = TypingSection.REGULAR
) -> None:
self._imports.setdefault(
typing_section, dict()
).setdefault(
import_type, dict()
).setdefault(
from_section, set()
).add(name_import)

def add_from_import(
self,
from_section: str,
name_import: str,
import_type: ImportType,
typing_section: TypingSection = TypingSection.REGULAR
) -> None:
"""Add an import to this import block.
"""
self._add_import(from_section, import_type, name_import, typing_section)

def add_import(
self,
name_import: str,
import_type: ImportType,
typing_section: TypingSection = TypingSection.REGULAR
) -> None:
# Implementation detail: a regular import is just a "from" with no from
self._add_import(name_import, import_type, None, typing_section)

@property
def imports(self) -> Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]:
return self._imports

def merge(self, file_import: "FileImport") -> None:
"""Merge the given file import format."""
for typing_section, import_type_dict in file_import.imports.items():
for import_type, package_list in import_type_dict.items():
for package_name, module_list in package_list.items():
for module_name in module_list:
self._add_import(package_name, import_type, module_name, typing_section)
13 changes: 13 additions & 0 deletions autorest/multiapi/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from .import_serializer import FileImportSerializer
from .multiapi_serializer import MultiAPISerializer

__all__ = [
"FileImportSerializer",
"MultiAPISerializer"
]
87 changes: 87 additions & 0 deletions autorest/multiapi/serializers/import_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from copy import deepcopy
from typing import Dict, Set, Optional, List
from ..models import ImportType, FileImport, TypingSection

def _serialize_package(package_name: str, module_list: Set[Optional[str]], delimiter: str) -> str:
buffer = []
if None in module_list:
buffer.append(f"import {package_name}")
if module_list != {None}:
buffer.append(
"from {} import {}".format(
package_name, ", ".join(sorted([mod for mod in module_list if mod is not None]))
)
)
return delimiter.join(buffer)

def _serialize_type(import_type_dict: Dict[str, Set[Optional[str]]], delimiter: str) -> str:
"""Serialize a given import type."""
import_list = []
for package_name in sorted(list(import_type_dict.keys())):
module_list = import_type_dict[package_name]
import_list.append(_serialize_package(package_name, module_list, delimiter))
return delimiter.join(import_list)

def _get_import_clauses(imports: Dict[ImportType, Dict[str, Set[Optional[str]]]], delimiter: str) -> List[str]:
import_clause = []
for import_type in ImportType:
if import_type in imports:
import_clause.append(_serialize_type(imports[import_type], delimiter))
return import_clause


class FileImportSerializer:
def __init__(self, file_import: FileImport, is_python_3_file: bool) -> None:
self._file_import = file_import
self.is_python_3_file = is_python_3_file

def _switch_typing_section_key(self, new_key: TypingSection):
switched_dictionary = {}
switched_dictionary[new_key] = self._file_import.imports[TypingSection.CONDITIONAL]
return switched_dictionary

def _get_imports_dict(self, baseline_typing_section: TypingSection, add_conditional_typing: bool):
# If this is a python 3 file, our regular imports include the CONDITIONAL category
# If this is not a python 3 file, our typing imports include the CONDITIONAL category
file_import_copy = deepcopy(self._file_import)
if add_conditional_typing and self._file_import.imports.get(TypingSection.CONDITIONAL):
# we switch the TypingSection key for the CONDITIONAL typing imports so we can merge
# the imports together
switched_imports_dictionary = self._switch_typing_section_key(baseline_typing_section)
switched_imports = FileImport(switched_imports_dictionary)
file_import_copy.merge(switched_imports)
return file_import_copy.imports.get(baseline_typing_section, {})

def _add_type_checking_import(self):
if (
self._file_import.imports.get(TypingSection.TYPING) or
(not self.is_python_3_file and self._file_import.imports.get(TypingSection.CONDITIONAL))
):
self._file_import.add_from_import("typing", "TYPE_CHECKING", ImportType.STDLIB)

def __str__(self) -> str:
self._add_type_checking_import()
regular_imports = ""
regular_imports_dict = self._get_imports_dict(
baseline_typing_section=TypingSection.REGULAR, add_conditional_typing=self.is_python_3_file
)

if regular_imports_dict:
regular_imports = "\n\n".join(
_get_import_clauses(regular_imports_dict, "\n")
)

typing_imports = ""
typing_imports_dict = self._get_imports_dict(
baseline_typing_section=TypingSection.TYPING, add_conditional_typing=not self.is_python_3_file
)
if typing_imports_dict:
typing_imports += "\n\nif TYPE_CHECKING:\n # pylint: disable=unused-import,ungrouped-imports\n "
typing_imports += "\n\n ".join(_get_import_clauses(typing_imports_dict, "\n "))

return regular_imports + typing_imports
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
from jinja2 import Environment, PackageLoader

from ..jsonrpc import AutorestAPI
from ...jsonrpc import AutorestAPI


class MultiAPISerializer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from azure.core.paging import ItemPaged
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.transport import HttpRequest, HttpResponse
from azure.core.polling import LROPoller, NoPolling, PollingMethod
from azure.core.polling.base_polling import LROBasePolling

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Callable, Dict, Generic, Iterable, Optional, TypeVar
from typing import Any, Callable, Dict, Generic, Iterable, Optional, TypeVar, Union


class MultiapiServiceClientOperationsMixin(object):
Expand Down
Loading