Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
19 changes: 19 additions & 0 deletions autorest/codegen/models/code_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,25 @@ def generate_single_parameter_from_multiple_media_types_operation(self) -> None:
if operation.multiple_media_type_parameters:
operation.convert_multiple_media_type_parameters()

@property
def need_vendored_code(self) -> bool:
return self.need_request_converter or self.need_format_url

@property
def need_request_converter(self) -> bool:
if not self.options["show_operations"]:
return False
if not self.options["version_tolerant"]:
return True
for og in self.operation_groups:
if any(o for o in og.operations if o.use_pipeline_transport):
return True
return False

@property
def need_format_url(self) -> bool:
return any(rq for rq in self.rest.request_builders if rq.parameters.path)

@property
def has_lro_operations(self) -> bool:
return any([
Expand Down
7 changes: 6 additions & 1 deletion autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,12 @@ def imports(self, code_model, async_mode: bool) -> FileImport:
alias="rest"
)
if code_model.options["builders_visibility"] == "embedded" and not async_mode:
file_import.merge(self.request_builder.imports())
file_import.merge(self.request_builder.imports(code_model))
if code_model.need_request_converter:
relative_path = "..." if async_mode else ".."
file_import.add_from_import(
f"{relative_path}_vendor", "_convert_request", ImportType.LOCAL
)
return file_import

def convert_multiple_media_type_parameters(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/paging_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def _imports_shared(self, code_model, async_mode: bool) -> FileImport:
file_import.add_from_import("typing", "AsyncIterable", ImportType.STDLIB, TypingSection.CONDITIONAL)
else:
file_import.add_from_import("typing", "Iterable", ImportType.STDLIB, TypingSection.CONDITIONAL)
if self.next_request_builder:
file_import.merge(self.next_request_builder.imports())
if self.next_request_builder and code_model.options["builders_visibility"] == "embedded" and not async_mode:
file_import.merge(self.next_request_builder.imports(code_model))
return file_import

def imports_for_multiapi(self, code_model, async_mode: bool) -> FileImport:
Expand Down
7 changes: 5 additions & 2 deletions autorest/codegen/models/request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def operation_group_name(self) -> str:
def builder_group_name(self) -> str:
return self.yaml_data["language"]["python"]["builderGroupName"]

def imports(self) -> FileImport:
def imports(self, code_model) -> FileImport:
file_import = FileImport()
for parameter in self.parameters:
file_import.merge(parameter.imports())
Expand All @@ -67,8 +67,11 @@ def imports(self) -> FileImport:
ImportType.AZURECORE,
)
if self.parameters.path:
relative_path = ".."
if not code_model.options["builders_visibility"] == "embedded" and self.operation_group_name:
relative_path = "..." if self.operation_group_name else ".."
file_import.add_from_import(
"azure.core.pipeline.transport._base", "_format_url_section", ImportType.AZURECORE
f"{relative_path}_vendor", "_format_url_section", ImportType.LOCAL
)
file_import.add_from_import(
"typing", "Any", ImportType.STDLIB, typing_section=TypingSection.CONDITIONAL
Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def __init__(
super(Rest, self). __init__(yaml_data=yaml_data)
self.request_builders = request_builders

def imports(self) -> FileImport:
def imports(self, code_model) -> FileImport:
file_import = FileImport()
for request_builder in self.request_builders:
file_import.merge(request_builder.imports())
file_import.merge(request_builder.imports(code_model))
return file_import

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions autorest/codegen/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ def _serialize_and_write_top_level_folder(
general_serializer.serialize_service_client_file()
)

if code_model.need_vendored_code:
self._autorestapi.write_file(
namespace_path / Path("_vendor.py"),
general_serializer.serialize_vendor_file()
)

self._serialize_and_write_version_file(code_model, namespace_path, general_serializer)

# write the empty py.typed file
Expand Down
9 changes: 5 additions & 4 deletions autorest/codegen/serializers/builder_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,11 +775,12 @@ def _call_request_builder_helper(
retval.append(f" {kwarg}={kwarg},")
template_url = template_url or f"self.{builder.name}.metadata['url']"
retval.append(f" template_url={template_url},")

convert_to_legacy = ""
retval.append(f")")
if not self.code_model.options["version_tolerant"] or builder.use_pipeline_transport:
convert_to_legacy = "._to_pipeline_transport_request()"
retval.append(f"){convert_to_legacy}")
pass_files = ""
if "files" in builder.body_kwargs_to_pass_to_request_builder:
pass_files = ", files"
retval.append(f"request = _convert_request(request{pass_files})")
if builder.parameters.path:
retval.extend(self.serialize_path(builder))
retval.append(
Expand Down
21 changes: 21 additions & 0 deletions autorest/codegen/serializers/general_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ def serialize_service_client_file(self) -> str:
),
)

def serialize_vendor_file(self) -> str:
template = self.env.get_template("vendor.py.jinja2")

# configure imports
file_import = FileImport()
if self.code_model.need_request_converter:
file_import.add_from_import(
"azure.core.pipeline.transport",
"HttpRequest",
ImportType.AZURECORE,
)

return template.render(
code_model=self.code_model,
imports=FileImportSerializer(
file_import,
is_python_3_file=self.async_mode,
)
)


def serialize_config_file(self) -> str:

package_name = self.code_model.options['package_name']
Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/serializers/rest_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def serialize_request_builders(self) -> str:
return template.render(
code_model=self.code_model,
request_builders=self.request_builders,
imports=FileImportSerializer(self.code_model.rest.imports(), is_python_3_file=True),
imports=FileImportSerializer(self.code_model.rest.imports(self.code_model), is_python_3_file=True),
is_python_3_file=True,
request_builder_serializer=RequestBuilderPython3Serializer(self.code_model),
)
Expand All @@ -46,7 +46,7 @@ def serialize_request_builders(self) -> str:
return template.render(
code_model=self.code_model,
request_builders=self.request_builders,
imports=FileImportSerializer(self.code_model.rest.imports(), is_python_3_file=False),
imports=FileImportSerializer(self.code_model.rest.imports(self.code_model), is_python_3_file=False),
is_python_3_file=False,
request_builder_serializer=RequestBuilderGenericSerializer(self.code_model),
)
26 changes: 26 additions & 0 deletions autorest/codegen/templates/vendor.py.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{{ code_model.options['license_header'] }}

{{ imports }}

{% if code_model.need_request_converter %}
def _convert_request(request, files=None):
data = request.content if not files else None
request = HttpRequest(method=request.method, url=request.url, headers=request.headers, data=data)
if files:
request.set_formdata_body(files)
return request
{% endif %}
{% if code_model.need_format_url %}

def _format_url_section(template, **kwargs):
components = template.split("/")
while components:
try:
return template.format(**kwargs)
except KeyError as key:
formatted_components = template.split("/")
components = [
c for c in formatted_components if "{{{}}}".format(key.args[0]) not in c
]
template = "/".join(components)
{% endif %}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._paging_operations import PagingOperations
from azure.core.pipeline.transport import HttpRequest
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some weird diffs going on here. Basically, I realized I had a duplicate out-of-date CustomPollerPager in my version tolerant folder, so I deleted it, which for some reason has caused these weird diffs


__all__ = [
'PagingOperations',
]
def _convert_request(request, files=None):
data = request.content if not files else None
request = HttpRequest(method=request.method, url=request.url, headers=request.headers, data=data)
if files:
request.set_formdata_body(files)
return request
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from azure.core.rest import HttpRequest
from azure.core.tracing.decorator_async import distributed_trace_async

from ..._vendor import _convert_request
from ...operations._http_success_operations import build_head200_request, build_head204_request, build_head404_request

T = TypeVar('T')
Expand Down Expand Up @@ -59,7 +60,8 @@ async def head200(

request = build_head200_request(
template_url=self.head200.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = await self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -96,7 +98,8 @@ async def head204(

request = build_head204_request(
template_url=self.head204.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = await self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -133,7 +136,8 @@ async def head404(

request = build_head404_request(
template_url=self.head404.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = await self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from azure.core.tracing.decorator import distributed_trace
from msrest import Serializer

from .._vendor import _convert_request

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
Expand Down Expand Up @@ -108,7 +110,8 @@ def head200(

request = build_head200_request(
template_url=self.head200.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -146,7 +149,8 @@ def head204(

request = build_head204_request(
template_url=self.head204.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -184,7 +188,8 @@ def head404(

request = build_head404_request(
template_url=self.head404.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._paging_operations import PagingOperations
from azure.core.pipeline.transport import HttpRequest

__all__ = [
'PagingOperations',
]
def _convert_request(request, files=None):
data = request.content if not files else None
request = HttpRequest(method=request.method, url=request.url, headers=request.headers, data=data)
if files:
request.set_formdata_body(files)
return request
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from azure.core.rest import HttpRequest
from azure.core.tracing.decorator_async import distributed_trace_async

from ..._vendor import _convert_request
from ...operations._http_success_operations import build_head200_request, build_head204_request, build_head404_request

T = TypeVar('T')
Expand Down Expand Up @@ -59,7 +60,8 @@ async def head200(

request = build_head200_request(
template_url=self.head200.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = await self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -96,7 +98,8 @@ async def head204(

request = build_head204_request(
template_url=self.head204.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = await self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -133,7 +136,8 @@ async def head404(

request = build_head404_request(
template_url=self.head404.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = await self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from azure.core.tracing.decorator import distributed_trace
from msrest import Serializer

from .._vendor import _convert_request

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
Expand Down Expand Up @@ -108,7 +110,8 @@ def head200(

request = build_head200_request(
template_url=self.head200.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -146,7 +149,8 @@ def head204(

request = build_head204_request(
template_url=self.head204.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down Expand Up @@ -184,7 +188,8 @@ def head404(

request = build_head404_request(
template_url=self.head404.metadata['url'],
)._to_pipeline_transport_request()
)
request = _convert_request(request)
request.url = self._client.format_url(request.url)

pipeline_response = self._client.send_request(request, stream=False, _return_pipeline_response=True, **kwargs)
Expand Down
Loading