Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions sdk/core/azure-core/azure/core/_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@

_LOGGER = logging.getLogger(__name__)

def _prepare_request(request):
# returns the request ready to run through pipelines
# and a bool telling whether we ended up converting it
rest_request = False
try:
request_to_run = request._to_pipeline_transport_request() # pylint: disable=protected-access
rest_request = True
except AttributeError:
request_to_run = request
return rest_request, request_to_run

class PipelineClient(PipelineClientBase):
"""Service client core methods.

Expand Down Expand Up @@ -204,9 +193,9 @@ def send_request(self, request, **kwargs):
:return: The response of your network call. Does not do error handling on your response.
:rtype: ~azure.core.rest.HttpResponse
# """
rest_request, request_to_run = _prepare_request(request)
rest_request = hasattr(request, "content")
return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
pipeline_response = self._pipeline.run(request_to_run, **kwargs) # pylint: disable=protected-access
pipeline_response = self._pipeline.run(request, **kwargs) # pylint: disable=protected-access
response = pipeline_response.http_response
if rest_request:
response = _to_rest_response(response)
Expand Down
5 changes: 2 additions & 3 deletions sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
RequestIdPolicy,
AsyncRetryPolicy,
)
from ._pipeline_client import _prepare_request
from .pipeline._tools_async import to_rest_response as _to_rest_response

try:
Expand Down Expand Up @@ -175,10 +174,10 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
return AsyncPipeline(transport, policies)

async def _make_pipeline_call(self, request, **kwargs):
rest_request, request_to_run = _prepare_request(request)
rest_request = hasattr(request, "content")
return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
pipeline_response = await self._pipeline.run(
request_to_run, **kwargs # pylint: disable=protected-access
request, **kwargs # pylint: disable=protected-access
)
response = pipeline_response.http_response
if rest_request:
Expand Down
121 changes: 10 additions & 111 deletions sdk/core/azure-core/azure/core/pipeline/transport/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from io import BytesIO
import json
import logging
import os
import time
import copy

Expand All @@ -50,7 +49,6 @@
TYPE_CHECKING,
Generic,
TypeVar,
cast,
IO,
List,
Union,
Expand All @@ -63,7 +61,7 @@
Type
)

from six.moves.http_client import HTTPConnection, HTTPResponse as _HTTPResponse
from six.moves.http_client import HTTPResponse as _HTTPResponse

from azure.core.exceptions import HttpResponseError
from azure.core.pipeline import (
Expand All @@ -75,6 +73,12 @@
)
from .._tools import await_result as _await_result
from ...utils._utils import _case_insensitive_dict
from ...utils._pipeline_transport_rest_shared import (
_format_parameters_helper,
_prepare_multipart_body_helper,
_serialize_request,
_format_data_helper,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -127,36 +131,6 @@ def _urljoin(base_url, stub_url):
parsed = parsed._replace(path=parsed.path.rstrip("/") + "/" + stub_url)
return parsed.geturl()


class _HTTPSerializer(HTTPConnection, object):
"""Hacking the stdlib HTTPConnection to serialize HTTP request as strings.
"""

def __init__(self, *args, **kwargs):
self.buffer = b""
kwargs.setdefault("host", "fakehost")
super(_HTTPSerializer, self).__init__(*args, **kwargs)

def putheader(self, header, *values):
if header in ["Host", "Accept-Encoding"]:
return
super(_HTTPSerializer, self).putheader(header, *values)

def send(self, data):
self.buffer += data


def _serialize_request(http_request):
serializer = _HTTPSerializer()
serializer.request(
method=http_request.method,
url=http_request.url,
body=http_request.body,
headers=http_request.headers,
)
return serializer.buffer


class HttpTransport(
AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType]
): # type: ignore
Expand Down Expand Up @@ -253,16 +227,7 @@ def _format_data(data):
:param data: The request field data.
:type data: str or file-like object.
"""
if hasattr(data, "read"):
data = cast(IO, data)
data_name = None
try:
if data.name[0] != "<" and data.name[-1] != ">":
data_name = os.path.basename(data.name)
except (AttributeError, TypeError):
pass
return (data_name, data, "application/octet-stream")
return (None, cast(str, data))
return _format_data_helper(data)

def format_parameters(self, params):
# type: (Dict[str, str]) -> None
Expand All @@ -272,26 +237,7 @@ def format_parameters(self, params):

:param dict params: A dictionary of parameters.
"""
query = urlparse(self.url).query
if query:
self.url = self.url.partition("?")[0]
existing_params = {
p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]
}
params.update(existing_params)
query_params = []
for k, v in params.items():
if isinstance(v, list):
for w in v:
if w is None:
raise ValueError("Query parameter {} cannot be None".format(k))
query_params.append("{}={}".format(k, w))
else:
if v is None:
raise ValueError("Query parameter {} cannot be None".format(k))
query_params.append("{}={}".format(k, v))
query = "?" + "&".join(query_params)
self.url = self.url + query
return _format_parameters_helper(self, params)

def set_streamed_data_body(self, data):
"""Set a streamable data body.
Expand Down Expand Up @@ -416,54 +362,7 @@ def prepare_multipart_body(self, content_index=0):
:returns: The updated index after all parts in this request have been added.
:rtype: int
"""
if not self.multipart_mixed_info:
return 0

requests = self.multipart_mixed_info[0] # type: List[HttpRequest]
boundary = self.multipart_mixed_info[2] # type: Optional[str]

# Update the main request with the body
main_message = Message()
main_message.add_header("Content-Type", "multipart/mixed")
if boundary:
main_message.set_boundary(boundary)

for req in requests:
part_message = Message()
if req.multipart_mixed_info:
content_index = req.prepare_multipart_body(content_index=content_index)
part_message.add_header("Content-Type", req.headers['Content-Type'])
payload = req.serialize()
# We need to remove the ~HTTP/1.1 prefix along with the added content-length
payload = payload[payload.index(b'--'):]
else:
part_message.add_header("Content-Type", "application/http")
part_message.add_header("Content-Transfer-Encoding", "binary")
part_message.add_header("Content-ID", str(content_index))
payload = req.serialize()
content_index += 1
part_message.set_payload(payload)
main_message.attach(part_message)

try:
from email.policy import HTTP

full_message = main_message.as_bytes(policy=HTTP)
eol = b"\r\n"
except ImportError: # Python 2.7
# Right now we decide to not support Python 2.7 on serialization, since
# it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it)
raise NotImplementedError(
"Multipart request are not supported on Python 2.7"
)
# full_message = main_message.as_string()
# eol = b'\n'
_, _, body = full_message.split(eol, 2)
self.set_bytes_body(body)
self.headers["Content-Type"] = (
"multipart/mixed; boundary=" + main_message.get_boundary()
)
return content_index
return _prepare_multipart_body_helper(self, content_index)

def serialize(self):
# type: () -> bytes
Expand Down
Loading