Skip to content

Commit

Permalink
[rest] add backcompat mixin to rest requests (Azure#20599)
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft authored Sep 22, 2021
1 parent 4b3397d commit 30b196e
Show file tree
Hide file tree
Showing 40 changed files with 1,873 additions and 854 deletions.
15 changes: 2 additions & 13 deletions sdk/core/azure-core/azure/core/_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@

_LOGGER = logging.getLogger(__name__)

def _prepare_request(request):
# returns the request ready to run through pipelines
# and a bool telling whether we ended up converting it
rest_request = False
try:
request_to_run = request._to_pipeline_transport_request() # pylint: disable=protected-access
rest_request = True
except AttributeError:
request_to_run = request
return rest_request, request_to_run

class PipelineClient(PipelineClientBase):
"""Service client core methods.
Expand Down Expand Up @@ -204,9 +193,9 @@ def send_request(self, request, **kwargs):
:return: The response of your network call. Does not do error handling on your response.
:rtype: ~azure.core.rest.HttpResponse
# """
rest_request, request_to_run = _prepare_request(request)
rest_request = hasattr(request, "content")
return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
pipeline_response = self._pipeline.run(request_to_run, **kwargs) # pylint: disable=protected-access
pipeline_response = self._pipeline.run(request, **kwargs) # pylint: disable=protected-access
response = pipeline_response.http_response
if rest_request:
response = _to_rest_response(response)
Expand Down
5 changes: 2 additions & 3 deletions sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
RequestIdPolicy,
AsyncRetryPolicy,
)
from ._pipeline_client import _prepare_request
from .pipeline._tools_async import to_rest_response as _to_rest_response

try:
Expand Down Expand Up @@ -175,10 +174,10 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
return AsyncPipeline(transport, policies)

async def _make_pipeline_call(self, request, **kwargs):
rest_request, request_to_run = _prepare_request(request)
rest_request = hasattr(request, "content")
return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
pipeline_response = await self._pipeline.run(
request_to_run, **kwargs # pylint: disable=protected-access
request, **kwargs # pylint: disable=protected-access
)
response = pipeline_response.http_response
if rest_request:
Expand Down
121 changes: 10 additions & 111 deletions sdk/core/azure-core/azure/core/pipeline/transport/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from io import BytesIO
import json
import logging
import os
import time
import copy

Expand All @@ -50,7 +49,6 @@
TYPE_CHECKING,
Generic,
TypeVar,
cast,
IO,
List,
Union,
Expand All @@ -63,7 +61,7 @@
Type
)

from six.moves.http_client import HTTPConnection, HTTPResponse as _HTTPResponse
from six.moves.http_client import HTTPResponse as _HTTPResponse

from azure.core.exceptions import HttpResponseError
from azure.core.pipeline import (
Expand All @@ -75,6 +73,12 @@
)
from .._tools import await_result as _await_result
from ...utils._utils import _case_insensitive_dict
from ...utils._pipeline_transport_rest_shared import (
_format_parameters_helper,
_prepare_multipart_body_helper,
_serialize_request,
_format_data_helper,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -127,36 +131,6 @@ def _urljoin(base_url, stub_url):
parsed = parsed._replace(path=parsed.path.rstrip("/") + "/" + stub_url)
return parsed.geturl()


class _HTTPSerializer(HTTPConnection, object):
"""Hacking the stdlib HTTPConnection to serialize HTTP request as strings.
"""

def __init__(self, *args, **kwargs):
self.buffer = b""
kwargs.setdefault("host", "fakehost")
super(_HTTPSerializer, self).__init__(*args, **kwargs)

def putheader(self, header, *values):
if header in ["Host", "Accept-Encoding"]:
return
super(_HTTPSerializer, self).putheader(header, *values)

def send(self, data):
self.buffer += data


def _serialize_request(http_request):
serializer = _HTTPSerializer()
serializer.request(
method=http_request.method,
url=http_request.url,
body=http_request.body,
headers=http_request.headers,
)
return serializer.buffer


class HttpTransport(
AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType]
): # type: ignore
Expand Down Expand Up @@ -253,16 +227,7 @@ def _format_data(data):
:param data: The request field data.
:type data: str or file-like object.
"""
if hasattr(data, "read"):
data = cast(IO, data)
data_name = None
try:
if data.name[0] != "<" and data.name[-1] != ">":
data_name = os.path.basename(data.name)
except (AttributeError, TypeError):
pass
return (data_name, data, "application/octet-stream")
return (None, cast(str, data))
return _format_data_helper(data)

def format_parameters(self, params):
# type: (Dict[str, str]) -> None
Expand All @@ -272,26 +237,7 @@ def format_parameters(self, params):
:param dict params: A dictionary of parameters.
"""
query = urlparse(self.url).query
if query:
self.url = self.url.partition("?")[0]
existing_params = {
p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]
}
params.update(existing_params)
query_params = []
for k, v in params.items():
if isinstance(v, list):
for w in v:
if w is None:
raise ValueError("Query parameter {} cannot be None".format(k))
query_params.append("{}={}".format(k, w))
else:
if v is None:
raise ValueError("Query parameter {} cannot be None".format(k))
query_params.append("{}={}".format(k, v))
query = "?" + "&".join(query_params)
self.url = self.url + query
return _format_parameters_helper(self, params)

def set_streamed_data_body(self, data):
"""Set a streamable data body.
Expand Down Expand Up @@ -416,54 +362,7 @@ def prepare_multipart_body(self, content_index=0):
:returns: The updated index after all parts in this request have been added.
:rtype: int
"""
if not self.multipart_mixed_info:
return 0

requests = self.multipart_mixed_info[0] # type: List[HttpRequest]
boundary = self.multipart_mixed_info[2] # type: Optional[str]

# Update the main request with the body
main_message = Message()
main_message.add_header("Content-Type", "multipart/mixed")
if boundary:
main_message.set_boundary(boundary)

for req in requests:
part_message = Message()
if req.multipart_mixed_info:
content_index = req.prepare_multipart_body(content_index=content_index)
part_message.add_header("Content-Type", req.headers['Content-Type'])
payload = req.serialize()
# We need to remove the ~HTTP/1.1 prefix along with the added content-length
payload = payload[payload.index(b'--'):]
else:
part_message.add_header("Content-Type", "application/http")
part_message.add_header("Content-Transfer-Encoding", "binary")
part_message.add_header("Content-ID", str(content_index))
payload = req.serialize()
content_index += 1
part_message.set_payload(payload)
main_message.attach(part_message)

try:
from email.policy import HTTP

full_message = main_message.as_bytes(policy=HTTP)
eol = b"\r\n"
except ImportError: # Python 2.7
# Right now we decide to not support Python 2.7 on serialization, since
# it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it)
raise NotImplementedError(
"Multipart request are not supported on Python 2.7"
)
# full_message = main_message.as_string()
# eol = b'\n'
_, _, body = full_message.split(eol, 2)
self.set_bytes_body(body)
self.headers["Content-Type"] = (
"multipart/mixed; boundary=" + main_message.get_boundary()
)
return content_index
return _prepare_multipart_body_helper(self, content_index)

def serialize(self):
# type: () -> bytes
Expand Down
Loading

0 comments on commit 30b196e

Please sign in to comment.