Skip to content

Commit

Permalink
[rest] switch base responses to ABCs (#20448)
Browse files Browse the repository at this point in the history
* switch to protocol

* update changelog

* add initial tests

* switch from protocol to abc

* improve HttpResponse docstrings

* lint

* HeadersType -> MutableMapping[str, str]

* remove iter_text and iter_lines

* update tests

* improve docstrings

* have base impls handle more code

* add set_read_checks

* commit to restart pipelines

* address xiang's comments

* lint

* clear json cache when encoding is updated

* make sure content type is empty string if doesn't exist

* update content_type to be None if there is no content type header

* fix passing encoding to text method error

* update is_stream_consumed docs

* remove erroneous committed code
  • Loading branch information
iscai-msft authored Sep 24, 2021
1 parent cc7e454 commit 1a9b633
Show file tree
Hide file tree
Showing 20 changed files with 919 additions and 567 deletions.
2 changes: 2 additions & 0 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
- The `text` property on `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` has changed to a method, which also takes
an `encoding` parameter.
- Removed `iter_text` and `iter_lines` from `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse`
- `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` are now abstract base classes. They should not be initialized directly, instead
your transport responses should inherit from them and implement them.

### Bugs Fixed

Expand Down
33 changes: 26 additions & 7 deletions sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# --------------------------------------------------------------------------

import logging
from collections.abc import Iterable
import collections.abc
from typing import Any, Awaitable
from .configuration import Configuration
from .pipeline import AsyncPipeline
Expand Down Expand Up @@ -62,6 +62,26 @@

_LOGGER = logging.getLogger(__name__)

class _AsyncContextManager(collections.abc.Awaitable):

def __init__(self, wrapped: collections.abc.Awaitable):
super().__init__()
self.wrapped = wrapped
self.response = None

def __await__(self):
return self.wrapped.__await__()

async def __aenter__(self):
self.response = await self
return self.response

async def __aexit__(self, *args):
await self.response.__aexit__(*args)

async def close(self):
await self.response.close()


class AsyncPipelineClient(PipelineClientBase):
"""Service client core methods.
Expand Down Expand Up @@ -125,7 +145,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
config.proxy_policy,
ContentDecodePolicy(**kwargs)
]
if isinstance(per_call_policies, Iterable):
if isinstance(per_call_policies, collections.abc.Iterable):
policies.extend(per_call_policies)
else:
policies.append(per_call_policies)
Expand All @@ -134,7 +154,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy])
if isinstance(per_retry_policies, Iterable):
if isinstance(per_retry_policies, collections.abc.Iterable):
policies.extend(per_retry_policies)
else:
policies.append(per_retry_policies)
Expand All @@ -143,13 +163,13 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or HttpLoggingPolicy(**kwargs)])
else:
if isinstance(per_call_policies, Iterable):
if isinstance(per_call_policies, collections.abc.Iterable):
per_call_policies_list = list(per_call_policies)
else:
per_call_policies_list = [per_call_policies]
per_call_policies_list.extend(policies)
policies = per_call_policies_list
if isinstance(per_retry_policies, Iterable):
if isinstance(per_retry_policies, collections.abc.Iterable):
per_retry_policies_list = list(per_retry_policies)
else:
per_retry_policies_list = [per_retry_policies]
Expand Down Expand Up @@ -188,7 +208,7 @@ async def _make_pipeline_call(self, request, **kwargs):
# the body is loaded. instead of doing response.read(), going to set the body
# to the internal content
rest_response._content = response.body() # pylint: disable=protected-access
await rest_response.close()
await rest_response._set_read_checks() # pylint: disable=protected-access
except Exception as exc:
await rest_response.close()
raise exc
Expand Down Expand Up @@ -222,6 +242,5 @@ def send_request(
:return: The response of your network call. Does not do error handling on your response.
:rtype: ~azure.core.rest.AsyncHttpResponse
"""
from .rest._rest_py3 import _AsyncContextManager
wrapped = self._make_pipeline_call(request, stream=stream, **kwargs)
return _AsyncContextManager(wrapped=wrapped)
7 changes: 3 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,20 @@ def to_rest_request(pipeline_transport_request):
def to_rest_response(pipeline_transport_response):
from .transport._requests_basic import RequestsTransportResponse
from ..rest._requests_basic import RestRequestsTransportResponse
from ..rest import HttpResponse
if isinstance(pipeline_transport_response, RequestsTransportResponse):
response_type = RestRequestsTransportResponse
else:
response_type = HttpResponse
raise ValueError("Unknown transport response")
response = response_type(
request=to_rest_request(pipeline_transport_response.request),
internal_response=pipeline_transport_response.internal_response,
block_size=pipeline_transport_response.block_size
)
response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access
return response

def get_block_size(response):
try:
return response._connection_data_block_size # pylint: disable=protected-access
return response._block_size # pylint: disable=protected-access
except AttributeError:
return response.block_size

Expand Down
5 changes: 2 additions & 3 deletions sdk/core/azure-core/azure/core/pipeline/_tools_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ def _get_response_type(pipeline_transport_response):
return RestTrioRequestsTransportResponse
except ImportError:
pass
from ..rest import AsyncHttpResponse
return AsyncHttpResponse
raise ValueError("Unknown transport response")

def to_rest_response(pipeline_transport_response):
response_type = _get_response_type(pipeline_transport_response)
response = response_type(
request=to_rest_request(pipeline_transport_response.request),
internal_response=pipeline_transport_response.internal_response,
block_size=pipeline_transport_response.block_size,
)
response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access
return response
56 changes: 20 additions & 36 deletions sdk/core/azure-core/azure/core/rest/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from itertools import groupby
from typing import AsyncIterator
from multidict import CIMultiDict
from . import HttpRequest, AsyncHttpResponse
from ._helpers_py3 import iter_raw_helper, iter_bytes_helper
from ._http_response_impl_async import AsyncHttpResponseImpl
from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator

class _ItemsView(collections.abc.ItemsView):
Expand Down Expand Up @@ -115,42 +114,26 @@ def get(self, key, default=None):
values = ", ".join(values)
return values or default

class RestAioHttpTransportResponse(AsyncHttpResponse):
class RestAioHttpTransportResponse(AsyncHttpResponseImpl):
def __init__(
self,
*,
request: HttpRequest,
internal_response,
decompress: bool = True,
**kwargs
):
super().__init__(request=request, internal_response=internal_response)
self.status_code = internal_response.status
self.headers = _CIMultiDict(internal_response.headers) # type: ignore
self.reason = internal_response.reason
self.content_type = internal_response.headers.get('content-type')

async def iter_raw(self) -> AsyncIterator[bytes]:
"""Asynchronously iterates over the response's bytes. Will not decompress in the process
:return: An async iterator of bytes from the response
:rtype: AsyncIterator[bytes]
"""
async for part in iter_raw_helper(AioHttpStreamDownloadGenerator, self):
yield part
await self.close()

async def iter_bytes(self) -> AsyncIterator[bytes]:
"""Asynchronously iterates over the response's bytes. Will decompress in the process
:return: An async iterator of bytes from the response
:rtype: AsyncIterator[bytes]
"""
async for part in iter_bytes_helper(
AioHttpStreamDownloadGenerator,
self,
content=self._content
):
yield part
await self.close()
headers = _CIMultiDict(internal_response.headers)
super().__init__(
internal_response=internal_response,
status_code=internal_response.status,
headers=headers,
content_type=headers.get('content-type'),
reason=internal_response.reason,
stream_download_generator=AioHttpStreamDownloadGenerator,
content=None,
**kwargs
)
self._decompress = decompress

def __getstate__(self):
state = self.__dict__.copy()
Expand All @@ -165,6 +148,7 @@ async def close(self) -> None:
:return: None
:rtype: None
"""
self.is_closed = True
self._internal_response.close()
await asyncio.sleep(0)
if not self.is_closed:
self._is_closed = True
self._internal_response.close()
await asyncio.sleep(0)
8 changes: 3 additions & 5 deletions sdk/core/azure-core/azure/core/rest/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
Union,
Mapping,
Sequence,
List,
Tuple,
IO,
Any,
Dict,
Iterable,
MutableMapping,
)
import xml.etree.ElementTree as ET
import six
Expand All @@ -66,8 +66,6 @@

ParamsType = Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]]

HeadersType = Mapping[str, str]

FileContent = Union[str, bytes, IO[str], IO[bytes]]
FileType = Union[
Tuple[Optional[str], FileContent],
Expand Down Expand Up @@ -129,8 +127,8 @@ def set_xml_body(content):
return headers, body

def _shared_set_content_body(content):
# type: (Any) -> Tuple[HeadersType, Optional[ContentTypeBase]]
headers = {} # type: HeadersType
# type: (Any) -> Tuple[MutableMapping[str, str], Optional[ContentTypeBase]]
headers = {} # type: MutableMapping[str, str]

if isinstance(content, ET.Element):
# XML body
Expand Down
57 changes: 3 additions & 54 deletions sdk/core/azure-core/azure/core/rest/_helpers_py3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,14 @@
Iterable,
Tuple,
Union,
Callable,
Optional,
AsyncIterator as AsyncIteratorType
MutableMapping,
)
from ..exceptions import StreamConsumedError, StreamClosedError

from ._helpers import (
_shared_set_content_body,
HeadersType
)
from ._helpers import _shared_set_content_body
ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]

def set_content_body(content: ContentType) -> Tuple[
HeadersType, ContentType
MutableMapping[str, str], ContentType
]:
headers, body = _shared_set_content_body(content)
if body is not None:
Expand All @@ -54,48 +48,3 @@ def set_content_body(content: ContentType) -> Tuple[
"Unexpected type for 'content': '{}'. ".format(type(content)) +
"We expect 'content' to either be str, bytes, or an Iterable / AsyncIterable"
)

def _stream_download_helper(
decompress: bool,
stream_download_generator: Callable,
response,
) -> AsyncIteratorType[bytes]:
if response.is_stream_consumed:
raise StreamConsumedError(response)
if response.is_closed:
raise StreamClosedError(response)

response.is_stream_consumed = True
return stream_download_generator(
pipeline=None,
response=response,
decompress=decompress,
)

async def iter_bytes_helper(
stream_download_generator: Callable,
response,
content: Optional[bytes],
) -> AsyncIteratorType[bytes]:
if content:
chunk_size = response._connection_data_block_size # pylint: disable=protected-access
for i in range(0, len(content), chunk_size):
yield content[i : i + chunk_size]
else:
async for part in _stream_download_helper(
decompress=True,
stream_download_generator=stream_download_generator,
response=response,
):
yield part

async def iter_raw_helper(
stream_download_generator: Callable,
response,
) -> AsyncIteratorType[bytes]:
async for part in _stream_download_helper(
decompress=False,
stream_download_generator=stream_download_generator,
response=response,
):
yield part
Loading

0 comments on commit 1a9b633

Please sign in to comment.