From 662dedf59382ce3e12e597fbae1a99d078900497 Mon Sep 17 00:00:00 2001 From: iscai-msft <43154838+iscai-msft@users.noreply.github.com> Date: Tue, 29 Jun 2021 14:29:02 -0400 Subject: [PATCH] [core] add provisional azure.core.rest (#19502) --- sdk/core/azure-core/CHANGELOG.md | 10 +- sdk/core/azure-core/README.md | 15 + .../azure-core/azure/core/_pipeline_client.py | 52 ++ .../azure/core/_pipeline_client_async.py | 62 ++- sdk/core/azure-core/azure/core/_version.py | 2 +- sdk/core/azure-core/azure/core/exceptions.py | 47 ++ .../azure-core/azure/core/pipeline/_tools.py | 37 ++ .../azure/core/pipeline/_tools_async.py | 35 ++ .../azure/core/pipeline/transport/_aiohttp.py | 15 +- .../azure/core/pipeline/transport/_base.py | 1 - .../pipeline/transport/_requests_asyncio.py | 13 +- .../pipeline/transport/_requests_basic.py | 17 +- .../core/pipeline/transport/_requests_trio.py | 13 +- .../azure-core/azure/core/rest/__init__.py | 51 ++ .../azure-core/azure/core/rest/_aiohttp.py | 87 +++ .../azure-core/azure/core/rest/_helpers.py | 306 +++++++++++ .../azure/core/rest/_helpers_py3.py | 101 ++++ .../azure/core/rest/_requests_asyncio.py | 83 +++ .../azure/core/rest/_requests_basic.py | 151 ++++++ .../azure/core/rest/_requests_trio.py | 77 +++ sdk/core/azure-core/azure/core/rest/_rest.py | 367 +++++++++++++ .../azure-core/azure/core/rest/_rest_py3.py | 504 ++++++++++++++++++ sdk/core/azure-core/doc/azure.core.rst | 12 + .../tests/async_tests/test_request_asyncio.py | 5 +- .../testserver_tests/async_tests/conftest.py | 100 ++++ .../async_tests/rest_client_async.py | 69 +++ .../test_rest_asyncio_transport.py | 43 ++ .../test_rest_context_manager_async.py | 82 +++ .../test_rest_http_request_async.py | 90 ++++ .../test_rest_http_response_async.py | 280 ++++++++++ .../test_rest_stream_responses_async.py | 206 +++++++ .../async_tests/test_rest_trio_transport.py | 41 ++ .../test_testserver_async.py | 0 .../tests/testserver_tests/conftest.py | 16 +- .../coretestserver/test_routes/basic.py | 6 +- .../coretestserver/test_routes/encoding.py | 12 +- .../tests/testserver_tests/rest_client.py | 83 +++ .../test_rest_context_manager.py | 78 +++ .../testserver_tests/test_rest_headers.py | 104 ++++ .../test_rest_http_request.py | 305 +++++++++++ .../test_rest_http_response.py | 298 +++++++++++ .../tests/testserver_tests/test_rest_query.py | 31 ++ .../test_rest_stream_responses.py | 230 ++++++++ 43 files changed, 4094 insertions(+), 43 deletions(-) create mode 100644 sdk/core/azure-core/azure/core/rest/__init__.py create mode 100644 sdk/core/azure-core/azure/core/rest/_aiohttp.py create mode 100644 sdk/core/azure-core/azure/core/rest/_helpers.py create mode 100644 sdk/core/azure-core/azure/core/rest/_helpers_py3.py create mode 100644 sdk/core/azure-core/azure/core/rest/_requests_asyncio.py create mode 100644 sdk/core/azure-core/azure/core/rest/_requests_basic.py create mode 100644 sdk/core/azure-core/azure/core/rest/_requests_trio.py create mode 100644 sdk/core/azure-core/azure/core/rest/_rest.py create mode 100644 sdk/core/azure-core/azure/core/rest/_rest_py3.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/conftest.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/rest_client_async.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_asyncio_transport.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_context_manager_async.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_request_async.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_response_async.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_stream_responses_async.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_trio_transport.py rename sdk/core/azure-core/tests/testserver_tests/{ => async_tests}/test_testserver_async.py (100%) create mode 100644 sdk/core/azure-core/tests/testserver_tests/rest_client.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/test_rest_context_manager.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/test_rest_headers.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/test_rest_http_request.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/test_rest_http_response.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/test_rest_query.py create mode 100644 sdk/core/azure-core/tests/testserver_tests/test_rest_stream_responses.py diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index ff3163d6f9b4..40b84e70df39 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -1,6 +1,14 @@ # Release History -## 1.15.1 (Unreleased) +## 1.16.0 (Unreleased) + +### New Features + +- Add new ***provisional*** methods `send_request` onto the `azure.core.PipelineClient` and `azure.core.AsyncPipelineClient`. This method takes in +requests and sends them through our pipelines. +- Add new ***provisional*** module `azure.core.rest`. `azure.core.rest` is our new public simple HTTP library in `azure.core` that users will use to create requests, and consume responses. +- Add new ***provisional*** errors `StreamConsumedError`, `StreamClosedError`, and `ResponseNotReadError` to `azure.core.exceptions`. These errors +are thrown if you mishandle streamed responses from the provisional `azure.core.rest` module ### Bug Fixes diff --git a/sdk/core/azure-core/README.md b/sdk/core/azure-core/README.md index 67c611ae596d..d098d1bf63dc 100644 --- a/sdk/core/azure-core/README.md +++ b/sdk/core/azure-core/README.md @@ -112,6 +112,21 @@ class TooManyRedirectsError(HttpResponseError): *kwargs* are keyword arguments to include with the exception. +#### **Provisional** StreamConsumedError +A **provisional** error thrown if you try to access the stream of the **provisional** +responses `azure.core.rest.HttpResponse` or `azure.core.rest.AsyncHttpResponse` once +the response stream has been consumed. + +#### **Provisional** StreamClosedError +A **provisional** error thrown if you try to access the stream of the **provisional** +responses `azure.core.rest.HttpResponse` or `azure.core.rest.AsyncHttpResponse` once +the response stream has been closed. + +#### **Provisional** ResponseNotReadError +A **provisional** error thrown if you try to access the `content` of the **provisional** +responses `azure.core.rest.HttpResponse` or `azure.core.rest.AsyncHttpResponse` before +reading in the response's bytes first. + ### Configurations When calling the methods, some properties can be configured by passing in as kwargs arguments. diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index 6f2376d36956..ec32e2b20964 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -40,6 +40,7 @@ RetryPolicy, ) from .pipeline.transport import RequestsTransport +from .pipeline._tools import to_rest_response as _to_rest_response try: from typing import TYPE_CHECKING @@ -58,10 +59,23 @@ Callable, Iterator, cast, + TypeVar ) # pylint: disable=unused-import + HTTPResponseType = TypeVar("HTTPResponseType") + HTTPRequestType = TypeVar("HTTPRequestType") _LOGGER = logging.getLogger(__name__) +def _prepare_request(request): + # returns the request ready to run through pipelines + # and a bool telling whether we ended up converting it + rest_request = False + try: + request_to_run = request._to_pipeline_transport_request() # pylint: disable=protected-access + rest_request = True + except AttributeError: + request_to_run = request + return rest_request, request_to_run class PipelineClient(PipelineClientBase): """Service client core methods. @@ -170,3 +184,41 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use transport = RequestsTransport(**kwargs) return Pipeline(transport, policies) + + + def send_request(self, request, **kwargs): + # type: (HTTPRequestType, Any) -> HTTPResponseType + """**Provisional** method that runs the network request through the client's chained policies. + + This method is marked as **provisional**, meaning it may be changed in a future release. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + + >>> response = client.send_request(request) + + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.HttpResponse + # """ + rest_request, request_to_run = _prepare_request(request) + return_pipeline_response = kwargs.pop("_return_pipeline_response", False) + pipeline_response = self._pipeline.run(request_to_run, **kwargs) # pylint: disable=protected-access + response = pipeline_response.http_response + if rest_request: + response = _to_rest_response(response) + try: + if not kwargs.get("stream", False): + response.read() + response.close() + except Exception as exc: + response.close() + raise exc + if return_pipeline_response: + pipeline_response.http_response = response + pipeline_response.http_request = request + return pipeline_response + return response diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index 1e17480ba4f9..357b3d9b917d 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -26,6 +26,7 @@ import logging from collections.abc import Iterable +from typing import Any, Awaitable from .configuration import Configuration from .pipeline import AsyncPipeline from .pipeline.transport._base import PipelineClientBase @@ -36,16 +37,20 @@ RequestIdPolicy, AsyncRetryPolicy, ) +from ._pipeline_client import _prepare_request +from .pipeline._tools_async import to_rest_response as _to_rest_response try: - from typing import TYPE_CHECKING + from typing import TYPE_CHECKING, TypeVar except ImportError: TYPE_CHECKING = False +HTTPRequestType = TypeVar("HTTPRequestType") +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") + if TYPE_CHECKING: from typing import ( List, - Any, Dict, Union, IO, @@ -168,3 +173,56 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use transport = AioHttpTransport(**kwargs) return AsyncPipeline(transport, policies) + + async def _make_pipeline_call(self, request, **kwargs): + rest_request, request_to_run = _prepare_request(request) + return_pipeline_response = kwargs.pop("_return_pipeline_response", False) + pipeline_response = await self._pipeline.run( + request_to_run, **kwargs # pylint: disable=protected-access + ) + response = pipeline_response.http_response + if rest_request: + rest_response = _to_rest_response(response) + if not kwargs.get("stream"): + try: + # in this case, the pipeline transport response already called .load_body(), so + # the body is loaded. instead of doing response.read(), going to set the body + # to the internal content + rest_response._content = response.body() # pylint: disable=protected-access + await rest_response.close() + except Exception as exc: + await rest_response.close() + raise exc + response = rest_response + if return_pipeline_response: + pipeline_response.http_response = response + pipeline_response.http_request = request + return pipeline_response + return response + + def send_request( + self, + request: HTTPRequestType, + *, + stream: bool = False, + **kwargs: Any + ) -> Awaitable[AsyncHTTPResponseType]: + """**Provisional** method that runs the network request through the client's chained policies. + + This method is marked as **provisional**, meaning it may be changed in a future release. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + + >>> response = await client.send_request(request) + + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.AsyncHttpResponse + """ + from .rest._rest_py3 import _AsyncContextManager + wrapped = self._make_pipeline_call(request, stream=stream, **kwargs) + return _AsyncContextManager(wrapped=wrapped) diff --git a/sdk/core/azure-core/azure/core/_version.py b/sdk/core/azure-core/azure/core/_version.py index d7a104234f8f..48bb9d819b66 100644 --- a/sdk/core/azure-core/azure/core/_version.py +++ b/sdk/core/azure-core/azure/core/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.15.1" +VERSION = "1.16.0" diff --git a/sdk/core/azure-core/azure/core/exceptions.py b/sdk/core/azure-core/azure/core/exceptions.py index 4af83e9b3683..0f827008a717 100644 --- a/sdk/core/azure-core/azure/core/exceptions.py +++ b/sdk/core/azure-core/azure/core/exceptions.py @@ -433,3 +433,50 @@ def __str__(self): if self._error_format: return str(self._error_format) return super(ODataV4Error, self).__str__() + +class StreamConsumedError(AzureError): + """**Provisional** error thrown if you try to access the stream of a response once consumed. + + This error is marked as **provisional**, meaning it may be changed in a future release. It is + thrown if you try to read / stream an ~azure.core.rest.HttpResponse or + ~azure.core.rest.AsyncHttpResponse once the response's stream has been consumed. + """ + def __init__(self, response): + message = ( + "You are attempting to read or stream the content from request {}. "\ + "You have likely already consumed this stream, so it can not be accessed anymore.".format( + response.request + ) + ) + super(StreamConsumedError, self).__init__(message) + +class StreamClosedError(AzureError): + """**Provisional** error thrown if you try to access the stream of a response once closed. + + This error is marked as **provisional**, meaning it may be changed in a future release. It is + thrown if you try to read / stream an ~azure.core.rest.HttpResponse or + ~azure.core.rest.AsyncHttpResponse once the response's stream has been closed. + """ + def __init__(self, response): + message = ( + "The content for response from request {} can no longer be read or streamed, since the "\ + "response has already been closed.".format(response.request) + ) + super(StreamClosedError, self).__init__(message) + +class ResponseNotReadError(AzureError): + """**Provisional** error thrown if you try to access a response's content without reading first. + + This error is marked as **provisional**, meaning it may be changed in a future release. It is + thrown if you try to access an ~azure.core.rest.HttpResponse or + ~azure.core.rest.AsyncHttpResponse's content without first reading the response's bytes in first. + """ + + def __init__(self, response): + message = ( + "You have not read in the bytes for the response from request {}. "\ + "Call .read() on the response first.".format( + response.request + ) + ) + super(ResponseNotReadError, self).__init__(message) diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index 47453ad55721..a8beebd75b99 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -32,3 +32,40 @@ def await_result(func, *args, **kwargs): "Policy {} returned awaitable object in non-async pipeline.".format(func) ) return result + +def to_rest_request(pipeline_transport_request): + from ..rest import HttpRequest as RestHttpRequest + return RestHttpRequest( + method=pipeline_transport_request.method, + url=pipeline_transport_request.url, + headers=pipeline_transport_request.headers, + files=pipeline_transport_request.files, + data=pipeline_transport_request.data + ) + +def to_rest_response(pipeline_transport_response): + from .transport._requests_basic import RequestsTransportResponse + from ..rest._requests_basic import RestRequestsTransportResponse + from ..rest import HttpResponse + if isinstance(pipeline_transport_response, RequestsTransportResponse): + response_type = RestRequestsTransportResponse + else: + response_type = HttpResponse + response = response_type( + request=to_rest_request(pipeline_transport_response.request), + internal_response=pipeline_transport_response.internal_response, + ) + response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access + return response + +def get_block_size(response): + try: + return response._connection_data_block_size # pylint: disable=protected-access + except AttributeError: + return response.block_size + +def get_internal_response(response): + try: + return response._internal_response # pylint: disable=protected-access + except AttributeError: + return response.internal_response diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index d29988bd41ee..de59dfdd86ed 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -23,6 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +from ._tools import to_rest_request async def await_result(func, *args, **kwargs): """If func returns an awaitable, await it.""" @@ -31,3 +32,37 @@ async def await_result(func, *args, **kwargs): # type ignore on await: https://github.com/python/mypy/issues/7587 return await result # type: ignore return result + +def _get_response_type(pipeline_transport_response): + try: + from .transport import AioHttpTransportResponse + from ..rest._aiohttp import RestAioHttpTransportResponse + if isinstance(pipeline_transport_response, AioHttpTransportResponse): + return RestAioHttpTransportResponse + except ImportError: + pass + try: + from .transport import AsyncioRequestsTransportResponse + from ..rest._requests_asyncio import RestAsyncioRequestsTransportResponse + if isinstance(pipeline_transport_response, AsyncioRequestsTransportResponse): + return RestAsyncioRequestsTransportResponse + except ImportError: + pass + try: + from .transport import TrioRequestsTransportResponse + from ..rest._requests_trio import RestTrioRequestsTransportResponse + if isinstance(pipeline_transport_response, TrioRequestsTransportResponse): + return RestTrioRequestsTransportResponse + except ImportError: + pass + from ..rest import AsyncHttpResponse + return AsyncHttpResponse + +def to_rest_response(pipeline_transport_response): + response_type = _get_response_type(pipeline_transport_response) + response = response_type( + request=to_rest_request(pipeline_transport_response.request), + internal_response=pipeline_transport_response.internal_response, + ) + response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access + return response diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 9ae3f96434cb..e32d0d1c0aec 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -46,6 +46,7 @@ AsyncHttpTransport, AsyncHttpResponse, _ResponseStopIteration) +from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response # Matching requests, because why not? CONTENT_CHUNK_SIZE = 10 * 1024 @@ -215,22 +216,24 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompres self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = response.block_size + self.block_size = _get_block_size(response) self._decompress = decompress - self.content_length = int(response.internal_response.headers.get('Content-Length', 0)) + internal_response = _get_internal_response(response) + self.content_length = int(internal_response.headers.get('Content-Length', 0)) self._decompressor = None def __len__(self): return self.content_length async def __anext__(self): + internal_response = _get_internal_response(self.response) try: - chunk = await self.response.internal_response.content.read(self.block_size) + chunk = await internal_response.content.read(self.block_size) if not chunk: raise _ResponseStopIteration() if not self._decompress: return chunk - enc = self.response.internal_response.headers.get('Content-Encoding') + enc = internal_response.headers.get('Content-Encoding') if not enc: return chunk enc = enc.lower() @@ -242,13 +245,13 @@ async def __anext__(self): chunk = self._decompressor.decompress(chunk) return chunk except _ResponseStopIteration: - self.response.internal_response.close() + internal_response.close() raise StopAsyncIteration() except StreamConsumedError: raise except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() + internal_response.close() raise class AioHttpTransportResponse(AsyncHttpResponse): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 589d5549c584..c807e02d841a 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -473,7 +473,6 @@ def serialize(self): """ return _serialize_request(self) - class _HttpResponseBase(object): """Represent a HTTP response. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index aab184cb3d8b..e41e4de91325 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -44,6 +44,7 @@ _iterate_response_content) from ._requests_basic import RequestsTransportResponse, _read_raw_stream from ._base_requests_async import RequestsAsyncTransportBase +from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response _LOGGER = logging.getLogger(__name__) @@ -145,14 +146,15 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = response.block_size + self.block_size = _get_block_size(response) decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = _get_internal_response(response) if decompress: - self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + self.iter_content_func = internal_response.iter_content(self.block_size) else: - self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size) + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) self.content_length = int(response.headers.get('Content-Length', 0)) def __len__(self): @@ -160,6 +162,7 @@ def __len__(self): async def __anext__(self): loop = _get_running_loop() + internal_response = _get_internal_response(self.response) try: chunk = await loop.run_in_executor( None, @@ -170,13 +173,13 @@ async def __anext__(self): raise _ResponseStopIteration() return chunk except _ResponseStopIteration: - self.response.internal_response.close() + internal_response.close() raise StopAsyncIteration() except requests.exceptions.StreamConsumedError: raise except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() + internal_response.close() raise diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index b1b827424cdd..28b81d705c16 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -46,6 +46,7 @@ _HttpResponseBase ) from ._bigger_block_size_http_adapters import BiggerBlockSizeHTTPAdapter +from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response PipelineType = TypeVar("PipelineType") @@ -71,6 +72,10 @@ def _read_raw_stream(response, chunk_size=1): break yield chunk + # following behavior from requests iter_content, we set content consumed to True + # https://github.com/psf/requests/blob/master/requests/models.py#L774 + response._content_consumed = True # pylint: disable=protected-access + class _RequestsTransportResponseBase(_HttpResponseBase): """Base class for accessing response data. @@ -127,14 +132,15 @@ def __init__(self, pipeline, response, **kwargs): self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = response.block_size + self.block_size = _get_block_size(response) decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = _get_internal_response(response) if decompress: - self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + self.iter_content_func = internal_response.iter_content(self.block_size) else: - self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size) + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) self.content_length = int(response.headers.get('Content-Length', 0)) def __len__(self): @@ -144,19 +150,20 @@ def __iter__(self): return self def __next__(self): + internal_response = _get_internal_response(self.response) try: chunk = next(self.iter_content_func) if not chunk: raise StopIteration() return chunk except StopIteration: - self.response.internal_response.close() + internal_response.close() raise StopIteration() except requests.exceptions.StreamConsumedError: raise except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() + internal_response.close() raise next = __next__ # Python 2 compatibility. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index 7be76336979f..e21ee5115327 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -44,6 +44,7 @@ _iterate_response_content) from ._requests_basic import RequestsTransportResponse, _read_raw_stream from ._base_requests_async import RequestsAsyncTransportBase +from .._tools import get_block_size as _get_block_size, get_internal_response as _get_internal_response _LOGGER = logging.getLogger(__name__) @@ -61,20 +62,22 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> self.pipeline = pipeline self.request = response.request self.response = response - self.block_size = response.block_size + self.block_size = _get_block_size(response) decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + internal_response = _get_internal_response(response) if decompress: - self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + self.iter_content_func = internal_response.iter_content(self.block_size) else: - self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size) + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) self.content_length = int(response.headers.get('Content-Length', 0)) def __len__(self): return self.content_length async def __anext__(self): + internal_response = _get_internal_response(self.response) try: try: chunk = await trio.to_thread.run_sync( @@ -90,13 +93,13 @@ async def __anext__(self): raise _ResponseStopIteration() return chunk except _ResponseStopIteration: - self.response.internal_response.close() + internal_response.close() raise StopAsyncIteration() except requests.exceptions.StreamConsumedError: raise except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() + internal_response.close() raise class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore diff --git a/sdk/core/azure-core/azure/core/rest/__init__.py b/sdk/core/azure-core/azure/core/rest/__init__.py new file mode 100644 index 000000000000..2fc73b837f14 --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/__init__.py @@ -0,0 +1,51 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +try: + from ._rest_py3 import ( + HttpRequest, + HttpResponse, + ) +except (SyntaxError, ImportError): + from ._rest import ( # type: ignore + HttpRequest, + HttpResponse, + ) + +__all__ = [ + "HttpRequest", + "HttpResponse", +] + +try: + from ._rest_py3 import ( # pylint: disable=unused-import + AsyncHttpResponse, + ) + __all__.extend([ + "AsyncHttpResponse", + ]) + +except (SyntaxError, ImportError): + pass diff --git a/sdk/core/azure-core/azure/core/rest/_aiohttp.py b/sdk/core/azure-core/azure/core/rest/_aiohttp.py new file mode 100644 index 000000000000..f25d9f7679b0 --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_aiohttp.py @@ -0,0 +1,87 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +import asyncio +from typing import AsyncIterator +from multidict import CIMultiDict +from . import HttpRequest, AsyncHttpResponse +from ._helpers_py3 import iter_raw_helper, iter_bytes_helper +from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator + + +class RestAioHttpTransportResponse(AsyncHttpResponse): + def __init__( + self, + *, + request: HttpRequest, + internal_response, + ): + super().__init__(request=request, internal_response=internal_response) + self.status_code = internal_response.status + self.headers = CIMultiDict(internal_response.headers) # type: ignore + self.reason = internal_response.reason + self.content_type = internal_response.headers.get('content-type') + + async def iter_raw(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will not decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + async for part in iter_raw_helper(AioHttpStreamDownloadGenerator, self): + yield part + await self.close() + + async def iter_bytes(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + async for part in iter_bytes_helper( + AioHttpStreamDownloadGenerator, + self, + content=self._content + ): + yield part + await self.close() + + def __getstate__(self): + state = self.__dict__.copy() + # Remove the unpicklable entries. + state['internal_response'] = None # aiohttp response are not pickable (see headers comments) + state['headers'] = CIMultiDict(self.headers) # MultiDictProxy is not pickable + return state + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + self.is_closed = True + self._internal_response.close() + await asyncio.sleep(0) diff --git a/sdk/core/azure-core/azure/core/rest/_helpers.py b/sdk/core/azure-core/azure/core/rest/_helpers.py new file mode 100644 index 000000000000..27e11299fdb5 --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_helpers.py @@ -0,0 +1,306 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import os +import codecs +import cgi +from enum import Enum +from json import dumps +import collections +from typing import ( + Optional, + Union, + Mapping, + Sequence, + List, + Tuple, + IO, + Any, + Dict, + Iterable, + Iterator, + cast, + Callable, +) +import xml.etree.ElementTree as ET +import six +try: + from urlparse import urlparse # type: ignore +except ImportError: + from urllib.parse import urlparse +try: + import cchardet as chardet +except ImportError: # pragma: no cover + import chardet # type: ignore +from ..exceptions import ResponseNotReadError + +################################### TYPES SECTION ######################### + +PrimitiveData = Optional[Union[str, int, float, bool]] + + +ParamsType = Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]] + +HeadersType = Mapping[str, str] + +FileContent = Union[str, bytes, IO[str], IO[bytes]] +FileType = Union[ + Tuple[Optional[str], FileContent], +] + +FilesType = Union[ + Mapping[str, FileType], + Sequence[Tuple[str, FileType]] +] + +ContentTypeBase = Union[str, bytes, Iterable[bytes]] + +class HttpVerbs(str, Enum): + GET = "GET" + PUT = "PUT" + POST = "POST" + HEAD = "HEAD" + PATCH = "PATCH" + DELETE = "DELETE" + MERGE = "MERGE" + +########################### ERRORS SECTION ################################# + + + +########################### HELPER SECTION ################################# + +def _verify_data_object(name, value): + if not isinstance(name, str): + raise TypeError( + "Invalid type for data name. Expected str, got {}: {}".format( + type(name), name + ) + ) + if value is not None and not isinstance(value, (str, bytes, int, float)): + raise TypeError( + "Invalid type for data value. Expected primitive type, got {}: {}".format( + type(name), name + ) + ) + +def _format_data(data): + # type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]] + """Format field data according to whether it is a stream or + a string for a form-data request. + + :param data: The request field data. + :type data: str or file-like object. + """ + if hasattr(data, "read"): + data = cast(IO, data) + data_name = None + try: + if data.name[0] != "<" and data.name[-1] != ">": + data_name = os.path.basename(data.name) + except (AttributeError, TypeError): + pass + return (data_name, data, "application/octet-stream") + return (None, cast(str, data)) + +def set_urlencoded_body(data, has_files): + body = {} + default_headers = {} + for f, d in data.items(): + if not d: + continue + if isinstance(d, list): + for item in d: + _verify_data_object(f, item) + else: + _verify_data_object(f, d) + body[f] = d + if not has_files: + # little hacky, but for files we don't send a content type with + # boundary so requests / aiohttp etc deal with it + default_headers["Content-Type"] = "application/x-www-form-urlencoded" + return default_headers, body + +def set_multipart_body(files): + formatted_files = { + f: _format_data(d) for f, d in files.items() if d is not None + } + return {}, formatted_files + +def set_xml_body(content): + headers = {} + bytes_content = ET.tostring(content, encoding="utf8") + body = bytes_content.replace(b"encoding='utf8'", b"encoding='utf-8'") + if body: + headers["Content-Length"] = str(len(body)) + return headers, body + +def _shared_set_content_body(content): + # type: (Any) -> Tuple[HeadersType, Optional[ContentTypeBase]] + headers = {} # type: HeadersType + + if isinstance(content, ET.Element): + # XML body + return set_xml_body(content) + if isinstance(content, (str, bytes)): + headers = {} + body = content + if isinstance(content, six.string_types): + headers["Content-Type"] = "text/plain" + if body: + headers["Content-Length"] = str(len(body)) + return headers, body + if isinstance(content, collections.Iterable): + return {}, content + return headers, None + +def set_content_body(content): + headers, body = _shared_set_content_body(content) + if body is not None: + return headers, body + raise TypeError( + "Unexpected type for 'content': '{}'. ".format(type(content)) + + "We expect 'content' to either be str, bytes, or an Iterable" + ) + +def set_json_body(json): + # type: (Any) -> Tuple[Dict[str, str], Any] + body = dumps(json) + return { + "Content-Type": "application/json", + "Content-Length": str(len(body)) + }, body + +def format_parameters(url, params): + """Format parameters into a valid query string. + It's assumed all parameters have already been quoted as + valid URL strings. + + :param dict params: A dictionary of parameters. + """ + query = urlparse(url).query + if query: + url = url.partition("?")[0] + existing_params = { + p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] + } + params.update(existing_params) + query_params = [] + for k, v in params.items(): + if isinstance(v, list): + for w in v: + if w is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, w)) + else: + if v is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, v)) + query = "?" + "&".join(query_params) + url += query + return url + +def lookup_encoding(encoding): + # type: (str) -> bool + # including check for whether encoding is known taken from httpx + try: + codecs.lookup(encoding) + return True + except LookupError: + return False + +def parse_lines_from_text(text): + # largely taken from httpx's LineDecoder code + lines = [] + last_chunk_of_text = "" + while text: + text_length = len(text) + for idx in range(text_length): + curr_char = text[idx] + next_char = None if idx == len(text) - 1 else text[idx + 1] + if curr_char == "\n": + lines.append(text[: idx + 1]) + text = text[idx + 1: ] + break + if curr_char == "\r" and next_char == "\n": + # if it ends with \r\n, we only do \n + lines.append(text[:idx] + "\n") + text = text[idx + 2:] + break + if curr_char == "\r" and next_char is not None: + # if it's \r then a normal character, we switch \r to \n + lines.append(text[:idx] + "\n") + text = text[idx + 1:] + break + if next_char is None: + last_chunk_of_text += text + text = "" + break + if last_chunk_of_text.endswith("\r"): + # if ends with \r, we switch \r to \n + lines.append(last_chunk_of_text[:-1] + "\n") + elif last_chunk_of_text: + lines.append(last_chunk_of_text) + return lines + +def to_pipeline_transport_request_helper(rest_request): + from ..pipeline.transport import HttpRequest as PipelineTransportHttpRequest + return PipelineTransportHttpRequest( + method=rest_request.method, + url=rest_request.url, + headers=rest_request.headers, + files=rest_request._files, # pylint: disable=protected-access + data=rest_request._data # pylint: disable=protected-access + ) + +def from_pipeline_transport_request_helper(request_class, pipeline_transport_request): + return request_class( + method=pipeline_transport_request.method, + url=pipeline_transport_request.url, + headers=pipeline_transport_request.headers, + files=pipeline_transport_request.files, + data=pipeline_transport_request.data + ) + +def get_charset_encoding(response): + content_type = response.headers.get("Content-Type") + + if not content_type: + return None + _, params = cgi.parse_header(content_type) + encoding = params.get('charset') # -> utf-8 + if encoding is None: + if content_type in ("application/json", "application/rdap+json"): + # RFC 7159 states that the default encoding is UTF-8. + # RFC 7483 defines application/rdap+json + encoding = "utf-8" + else: + try: + encoding = chardet.detect(response.content)["encoding"] + except ResponseNotReadError: + pass + if encoding is None or not lookup_encoding(encoding): + return None + return encoding diff --git a/sdk/core/azure-core/azure/core/rest/_helpers_py3.py b/sdk/core/azure-core/azure/core/rest/_helpers_py3.py new file mode 100644 index 000000000000..90948012db2a --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_helpers_py3.py @@ -0,0 +1,101 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import collections.abc +from typing import ( + AsyncIterable, + Dict, + Iterable, + Tuple, + Union, + Callable, + Optional, + AsyncIterator as AsyncIteratorType +) +from ..exceptions import StreamConsumedError, StreamClosedError + +from ._helpers import ( + _shared_set_content_body, + HeadersType +) +ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] + +def set_content_body(content: ContentType) -> Tuple[ + HeadersType, ContentType +]: + headers, body = _shared_set_content_body(content) + if body is not None: + return headers, body + if isinstance(content, collections.abc.AsyncIterable): + return {}, content + raise TypeError( + "Unexpected type for 'content': '{}'. ".format(type(content)) + + "We expect 'content' to either be str, bytes, or an Iterable / AsyncIterable" + ) + +def _stream_download_helper( + decompress: bool, + stream_download_generator: Callable, + response, +) -> AsyncIteratorType[bytes]: + if response.is_stream_consumed: + raise StreamConsumedError(response) + if response.is_closed: + raise StreamClosedError(response) + + response.is_stream_consumed = True + return stream_download_generator( + pipeline=None, + response=response, + decompress=decompress, + ) + +async def iter_bytes_helper( + stream_download_generator: Callable, + response, + content: Optional[bytes], +) -> AsyncIteratorType[bytes]: + if content: + chunk_size = response._connection_data_block_size # pylint: disable=protected-access + for i in range(0, len(content), chunk_size): + yield content[i : i + chunk_size] + else: + async for part in _stream_download_helper( + decompress=True, + stream_download_generator=stream_download_generator, + response=response, + ): + yield part + +async def iter_raw_helper( + stream_download_generator: Callable, + response, +) -> AsyncIteratorType[bytes]: + async for part in _stream_download_helper( + decompress=False, + stream_download_generator=stream_download_generator, + response=response, + ): + yield part diff --git a/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py b/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py new file mode 100644 index 000000000000..b21545a79804 --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py @@ -0,0 +1,83 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import AsyncIterator +import asyncio +from ._helpers_py3 import iter_bytes_helper, iter_raw_helper +from . import AsyncHttpResponse +from ._requests_basic import _RestRequestsTransportResponseBase, _has_content +from ..pipeline.transport._requests_asyncio import AsyncioStreamDownloadGenerator + +class RestAsyncioRequestsTransportResponse(AsyncHttpResponse, _RestRequestsTransportResponseBase): # type: ignore + """Asynchronous streaming of data from the response. + """ + + async def iter_raw(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will not decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + + async for part in iter_raw_helper(AsyncioStreamDownloadGenerator, self): + yield part + await self.close() + + async def iter_bytes(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + async for part in iter_bytes_helper( + AsyncioStreamDownloadGenerator, + self, + content=self.content if _has_content(self) else None + ): + yield part + await self.close() + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + self.is_closed = True + self._internal_response.close() + await asyncio.sleep(0) + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + if not _has_content(self): + parts = [] + async for part in self.iter_bytes(): # type: ignore + parts.append(part) + self._internal_response._content = b"".join(parts) # pylint: disable=protected-access + return self.content diff --git a/sdk/core/azure-core/azure/core/rest/_requests_basic.py b/sdk/core/azure-core/azure/core/rest/_requests_basic.py new file mode 100644 index 000000000000..e8ef734e1275 --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_requests_basic.py @@ -0,0 +1,151 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import TYPE_CHECKING, cast + +from ..exceptions import ResponseNotReadError, StreamConsumedError, StreamClosedError +from ._rest import _HttpResponseBase, HttpResponse +from ..pipeline.transport._requests_basic import StreamDownloadGenerator + +if TYPE_CHECKING: + from typing import Iterator, Optional + +def _has_content(response): + try: + response.content # pylint: disable=pointless-statement + return True + except ResponseNotReadError: + return False + +class _RestRequestsTransportResponseBase(_HttpResponseBase): + def __init__(self, **kwargs): + super(_RestRequestsTransportResponseBase, self).__init__(**kwargs) + self.status_code = self._internal_response.status_code + self.headers = self._internal_response.headers + self.reason = self._internal_response.reason + self.content_type = self._internal_response.headers.get('content-type') + + @property + def content(self): + # type: () -> bytes + if not self._internal_response._content_consumed: # pylint: disable=protected-access + # if we just call .content, requests will read in the content. + # we want to read it in our own way + raise ResponseNotReadError(self) + + try: + return self._internal_response.content + except RuntimeError: + # requests throws a RuntimeError if the content for a response is already consumed + raise ResponseNotReadError(self) + + @property + def encoding(self): + # type: () -> Optional[str] + retval = super(_RestRequestsTransportResponseBase, self).encoding + if not retval: + # There is a few situation where "requests" magic doesn't fit us: + # - https://github.com/psf/requests/issues/654 + # - https://github.com/psf/requests/issues/1737 + # - https://github.com/psf/requests/issues/2086 + from codecs import BOM_UTF8 + if self._internal_response.content[:3] == BOM_UTF8: + retval = "utf-8-sig" + if retval: + if retval == "utf-8": + retval = "utf-8-sig" + return retval + + @encoding.setter # type: ignore + def encoding(self, value): + # type: (str) -> None + # ignoring setter bc of known mypy issue https://github.com/python/mypy/issues/1465 + self._encoding = value + self._internal_response.encoding = value + + @property + def text(self): + # this will trigger errors if response is not read in + self.content # pylint: disable=pointless-statement + return self._internal_response.text + +def _stream_download_helper(decompress, response): + if response.is_stream_consumed: + raise StreamConsumedError(response) + if response.is_closed: + raise StreamClosedError(response) + + response.is_stream_consumed = True + stream_download = StreamDownloadGenerator( + pipeline=None, + response=response, + decompress=decompress, + ) + for part in stream_download: + yield part + +class RestRequestsTransportResponse(HttpResponse, _RestRequestsTransportResponseBase): + + def iter_bytes(self): + # type: () -> Iterator[bytes] + """Iterates over the response's bytes. Will decompress in the process + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + if _has_content(self): + chunk_size = cast(int, self._connection_data_block_size) + for i in range(0, len(self.content), chunk_size): + yield self.content[i : i + chunk_size] + else: + for part in _stream_download_helper( + decompress=True, + response=self, + ): + yield part + self.close() + + def iter_raw(self): + # type: () -> Iterator[bytes] + """Iterates over the response's bytes. Will not decompress in the process + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + for raw_bytes in _stream_download_helper( + decompress=False, + response=self, + ): + yield raw_bytes + self.close() + + def read(self): + # type: () -> bytes + """Read the response's bytes. + + :return: The read in bytes + :rtype: bytes + """ + if not _has_content(self): + self._internal_response._content = b"".join(self.iter_bytes()) # pylint: disable=protected-access + return self.content diff --git a/sdk/core/azure-core/azure/core/rest/_requests_trio.py b/sdk/core/azure-core/azure/core/rest/_requests_trio.py new file mode 100644 index 000000000000..9806380ef04f --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_requests_trio.py @@ -0,0 +1,77 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from typing import AsyncIterator +import trio +from . import AsyncHttpResponse +from ._requests_basic import _RestRequestsTransportResponseBase, _has_content +from ._helpers_py3 import iter_bytes_helper, iter_raw_helper +from ..pipeline.transport._requests_trio import TrioStreamDownloadGenerator + +class RestTrioRequestsTransportResponse(AsyncHttpResponse, _RestRequestsTransportResponseBase): # type: ignore + """Asynchronous streaming of data from the response. + """ + async def iter_raw(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will not decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + async for part in iter_raw_helper(TrioStreamDownloadGenerator, self): + yield part + await self.close() + + async def iter_bytes(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + + async for part in iter_bytes_helper( + TrioStreamDownloadGenerator, + self, + content=self.content if _has_content(self) else None + ): + yield part + await self.close() + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + if not _has_content(self): + parts = [] + async for part in self.iter_bytes(): # type: ignore + parts.append(part) + self._internal_response._content = b"".join(parts) # pylint: disable=protected-access + return self.content + + async def close(self) -> None: + self.is_closed = True + self._internal_response.close() + await trio.sleep(0) diff --git a/sdk/core/azure-core/azure/core/rest/_rest.py b/sdk/core/azure-core/azure/core/rest/_rest.py new file mode 100644 index 000000000000..24897c6f61dd --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_rest.py @@ -0,0 +1,367 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import copy +from json import loads + +from typing import TYPE_CHECKING, cast + +from azure.core.exceptions import HttpResponseError + +from .._utils import _case_insensitive_dict +from ._helpers import ( + FilesType, + parse_lines_from_text, + set_content_body, + set_json_body, + set_multipart_body, + set_urlencoded_body, + format_parameters, + to_pipeline_transport_request_helper, + from_pipeline_transport_request_helper, + get_charset_encoding, +) +from ..exceptions import ResponseNotReadError +if TYPE_CHECKING: + from typing import ( + Iterable, + Optional, + Any, + Iterator, + Union, + Dict, + ) + from ._helpers import HeadersType + ByteStream = Iterable[bytes] + ContentType = Union[str, bytes, ByteStream] + + from ._helpers import HeadersType, ContentTypeBase as ContentType + + + +################################## CLASSES ###################################### + +class HttpRequest(object): + """Provisional object that represents an HTTP request. + + **This object is provisional**, meaning it may be changed in a future release. + + It should be passed to your client's `send_request` method. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + + >>> response = client.send_request(request) + + + :param str method: HTTP method (GET, HEAD, etc.) + :param str url: The url for your request + :keyword mapping params: Query parameters to be mapped into your URL. Your input + should be a mapping of query name to query value(s). + :keyword mapping headers: HTTP headers you want in your request. Your input should + be a mapping of header name to header value. + :keyword any json: A JSON serializable object. We handle JSON-serialization for your + object, so use this for more complicated data structures than `data`. + :keyword content: Content you want in your request body. Think of it as the kwarg you should input + if your data doesn't fit into `json`, `data`, or `files`. Accepts a bytes type, or a generator + that yields bytes. + :paramtype content: str or bytes or iterable[bytes] or asynciterable[bytes] + :keyword dict data: Form data you want in your request body. Use for form-encoded data, i.e. + HTML forms. + :keyword mapping files: Files you want to in your request body. Use for uploading files with + multipart encoding. Your input should be a mapping of file name to file content. + Use the `data` kwarg in addition if you want to include non-file data files as part of your request. + :ivar str url: The URL this request is against. + :ivar str method: The method type of this request. + :ivar mapping headers: The HTTP headers you passed in to your request + :ivar bytes content: The content passed in for the request + """ + + def __init__(self, method, url, **kwargs): + # type: (str, str, Any) -> None + + self.url = url + self.method = method + + params = kwargs.pop("params", None) + if params: + self.url = format_parameters(self.url, params) + self._files = None + self._data = None + + default_headers = self._set_body( + content=kwargs.pop("content", None), + data=kwargs.pop("data", None), + files=kwargs.pop("files", None), + json=kwargs.pop("json", None), + ) + self.headers = _case_insensitive_dict(default_headers) + self.headers.update(kwargs.pop("headers", {})) + + if kwargs: + raise TypeError( + "You have passed in kwargs '{}' that are not valid kwargs.".format( + "', '".join(list(kwargs.keys())) + ) + ) + + def _set_body(self, content, data, files, json): + # type: (Optional[ContentType], Optional[dict], Optional[FilesType], Any) -> HeadersType + """Sets the body of the request, and returns the default headers + """ + default_headers = {} + if data is not None and not isinstance(data, dict): + # should we warn? + content = data + if content is not None: + default_headers, self._data = set_content_body(content) + return default_headers + if json is not None: + default_headers, self._data = set_json_body(json) + return default_headers + if files: + default_headers, self._files = set_multipart_body(files) + if data: + default_headers, self._data = set_urlencoded_body(data, bool(files)) + return default_headers + + def _update_headers(self, default_headers): + # type: (Dict[str, str]) -> None + for name, value in default_headers.items(): + if name == "Transfer-Encoding" and "Content-Length" in self.headers: + continue + self.headers.setdefault(name, value) + + @property + def content(self): + # type: (...) -> Any + """Get's the request's content + + :return: The request's content + :rtype: any + """ + return self._data or self._files + + def __repr__(self): + # type: (...) -> str + return "".format( + self.method, self.url + ) + + def __deepcopy__(self, memo=None): + try: + request = HttpRequest( + method=self.method, + url=self.url, + headers=self.headers, + ) + request._data = copy.deepcopy(self._data, memo) + request._files = copy.deepcopy(self._files, memo) + return request + except (ValueError, TypeError): + return copy.copy(self) + + def _to_pipeline_transport_request(self): + return to_pipeline_transport_request_helper(self) + + @classmethod + def _from_pipeline_transport_request(cls, pipeline_transport_request): + return from_pipeline_transport_request_helper(cls, pipeline_transport_request) + +class _HttpResponseBase(object): # pylint: disable=too-many-instance-attributes + + def __init__(self, **kwargs): + # type: (Any) -> None + self.request = kwargs.pop("request") + self._internal_response = kwargs.pop("internal_response") + self.status_code = None + self.headers = {} # type: HeadersType + self.reason = None + self.is_closed = False + self.is_stream_consumed = False + self.content_type = None + self._json = None # this is filled in ContentDecodePolicy, when we deserialize + self._connection_data_block_size = None # type: Optional[int] + self._content = None # type: Optional[bytes] + + @property + def url(self): + # type: (...) -> str + """Returns the URL that resulted in this response""" + return self.request.url + + @property + def encoding(self): + # type: (...) -> Optional[str] + """Returns the response encoding. By default, is specified + by the response Content-Type header. + """ + try: + return self._encoding + except AttributeError: + return get_charset_encoding(self) + + @encoding.setter + def encoding(self, value): + # type: (str) -> None + """Sets the response encoding""" + self._encoding = value + + @property + def text(self): + # type: (...) -> str + """Returns the response body as a string""" + encoding = self.encoding + if encoding == "utf-8" or encoding is None: + encoding = "utf-8-sig" + return self.content.decode(encoding) + + def json(self): + # type: (...) -> Any + """Returns the whole body as a json object. + + :return: The JSON deserialized response body + :rtype: any + :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: + """ + # this will trigger errors if response is not read in + self.content # pylint: disable=pointless-statement + if not self._json: + self._json = loads(self.text) + return self._json + + def raise_for_status(self): + # type: (...) -> None + """Raises an HttpResponseError if the response has an error status code. + + If response is good, does nothing. + """ + if cast(int, self.status_code) >= 400: + raise HttpResponseError(response=self) + + @property + def content(self): + # type: (...) -> bytes + """Return the response's content in bytes.""" + if self._content is None: + raise ResponseNotReadError(self) + return self._content + + def __repr__(self): + # type: (...) -> str + content_type_str = ( + ", Content-Type: {}".format(self.content_type) if self.content_type else "" + ) + return "".format( + self.status_code, self.reason, content_type_str + ) + +class HttpResponse(_HttpResponseBase): # pylint: disable=too-many-instance-attributes + """**Provisional** object that represents an HTTP response. + + **This object is provisional**, meaning it may be changed in a future release. + + It is returned from your client's `send_request` method if you pass in + an :class:`~azure.core.rest.HttpRequest` + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + + >>> response = client.send_request(request) + + + :keyword request: The request that resulted in this response. + :paramtype request: ~azure.core.rest.HttpRequest + :ivar int status_code: The status code of this response + :ivar mapping headers: The response headers + :ivar str reason: The reason phrase for this response + :ivar bytes content: The response content in bytes. + :ivar str url: The URL that resulted in this response + :ivar str encoding: The response encoding. Is settable, by default + is the response Content-Type header + :ivar str text: The response body as a string. + :ivar request: The request that resulted in this response. + :vartype request: ~azure.core.rest.HttpRequest + :ivar internal_response: The object returned from the HTTP library. + :ivar str content_type: The content type of the response + :ivar bool is_closed: Whether the network connection has been closed yet + :ivar bool is_stream_consumed: When getting a stream response, checks + whether the stream has been fully consumed + """ + + def __enter__(self): + # type: (...) -> HttpResponse + return self + + def close(self): + # type: (...) -> None + self.is_closed = True + self._internal_response.close() + + def __exit__(self, *args): + # type: (...) -> None + self.close() + + def read(self): + # type: (...) -> bytes + """ + Read the response's bytes. + + """ + if self._content is None: + self._content = b"".join(self.iter_bytes()) + return self.content + + def iter_raw(self): + # type: () -> Iterator[bytes] + """Iterate over the raw response bytes + """ + raise NotImplementedError() + + def iter_bytes(self): + # type: () -> Iterator[bytes] + """Iterate over the response bytes + """ + raise NotImplementedError() + + def iter_text(self): + # type: () -> Iterator[str] + """Iterate over the response text + """ + for byte in self.iter_bytes(): + text = byte.decode(self.encoding or "utf-8") + yield text + + def iter_lines(self): + # type: () -> Iterator[str] + for text in self.iter_text(): + lines = parse_lines_from_text(text) + for line in lines: + yield line + + def _close_stream(self): + # type: (...) -> None + self.is_stream_consumed = True + self.close() diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py new file mode 100644 index 000000000000..27128c66a94e --- /dev/null +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -0,0 +1,504 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import copy +import collections +import collections.abc +from json import loads +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Dict, + Iterable, Iterator, + Optional, + Type, + Union, +) + + +from azure.core.exceptions import HttpResponseError + +from .._utils import _case_insensitive_dict + +from ._helpers import ( + ParamsType, + FilesType, + HeadersType, + cast, + parse_lines_from_text, + set_json_body, + set_multipart_body, + set_urlencoded_body, + format_parameters, + to_pipeline_transport_request_helper, + from_pipeline_transport_request_helper, + get_charset_encoding +) +from ._helpers_py3 import set_content_body +from ..exceptions import ResponseNotReadError + +ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] + +class _AsyncContextManager(collections.abc.Awaitable): + + def __init__(self, wrapped: collections.abc.Awaitable): + super().__init__() + self.wrapped = wrapped + self.response = None + + def __await__(self): + return self.wrapped.__await__() + + async def __aenter__(self): + self.response = await self + return self.response + + async def __aexit__(self, *args): + await self.response.__aexit__(*args) + + async def close(self): + await self.response.close() + +################################## CLASSES ###################################### + +class HttpRequest: + """**Provisional** object that represents an HTTP request. + + **This object is provisional**, meaning it may be changed in a future release. + + It should be passed to your client's `send_request` method. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + + >>> response = client.send_request(request) + + + :param str method: HTTP method (GET, HEAD, etc.) + :param str url: The url for your request + :keyword mapping params: Query parameters to be mapped into your URL. Your input + should be a mapping of query name to query value(s). + :keyword mapping headers: HTTP headers you want in your request. Your input should + be a mapping of header name to header value. + :keyword any json: A JSON serializable object. We handle JSON-serialization for your + object, so use this for more complicated data structures than `data`. + :keyword content: Content you want in your request body. Think of it as the kwarg you should input + if your data doesn't fit into `json`, `data`, or `files`. Accepts a bytes type, or a generator + that yields bytes. + :paramtype content: str or bytes or iterable[bytes] or asynciterable[bytes] + :keyword dict data: Form data you want in your request body. Use for form-encoded data, i.e. + HTML forms. + :keyword mapping files: Files you want to in your request body. Use for uploading files with + multipart encoding. Your input should be a mapping of file name to file content. + Use the `data` kwarg in addition if you want to include non-file data files as part of your request. + :ivar str url: The URL this request is against. + :ivar str method: The method type of this request. + :ivar mapping headers: The HTTP headers you passed in to your request + :ivar any content: The content passed in for the request + """ + + def __init__( + self, + method: str, + url: str, + *, + params: Optional[ParamsType] = None, + headers: Optional[HeadersType] = None, + json: Any = None, + content: Optional[ContentType] = None, + data: Optional[dict] = None, + files: Optional[FilesType] = None, + **kwargs + ): + self.url = url + self.method = method + + if params: + self.url = format_parameters(self.url, params) + self._files = None + self._data = None # type: Any + + default_headers = self._set_body( + content=content, + data=data, + files=files, + json=json, + ) + self.headers = _case_insensitive_dict(default_headers) + self.headers.update(headers or {}) + + if kwargs: + raise TypeError( + "You have passed in kwargs '{}' that are not valid kwargs.".format( + "', '".join(list(kwargs.keys())) + ) + ) + + def _set_body( + self, + content: Optional[ContentType], + data: Optional[dict], + files: Optional[FilesType], + json: Any, + ) -> HeadersType: + """Sets the body of the request, and returns the default headers + """ + default_headers = {} # type: HeadersType + if data is not None and not isinstance(data, dict): + # should we warn? + content = data + if content is not None: + default_headers, self._data = set_content_body(content) + return default_headers + if json is not None: + default_headers, self._data = set_json_body(json) + return default_headers + if files: + default_headers, self._files = set_multipart_body(files) + if data: + default_headers, self._data = set_urlencoded_body(data, has_files=bool(files)) + return default_headers + + @property + def content(self) -> Any: + """Get's the request's content + + :return: The request's content + :rtype: any + """ + return self._data or self._files + + def __repr__(self) -> str: + return "".format( + self.method, self.url + ) + + def __deepcopy__(self, memo=None) -> "HttpRequest": + try: + request = HttpRequest( + method=self.method, + url=self.url, + headers=self.headers, + ) + request._data = copy.deepcopy(self._data, memo) + request._files = copy.deepcopy(self._files, memo) + return request + except (ValueError, TypeError): + return copy.copy(self) + + def _to_pipeline_transport_request(self): + return to_pipeline_transport_request_helper(self) + + @classmethod + def _from_pipeline_transport_request(cls, pipeline_transport_request): + return from_pipeline_transport_request_helper(cls, pipeline_transport_request) + +class _HttpResponseBase: # pylint: disable=too-many-instance-attributes + + def __init__( + self, + *, + request: HttpRequest, + **kwargs + ): + self.request = request + self._internal_response = kwargs.pop("internal_response") + self.status_code = None + self.headers = {} # type: HeadersType + self.reason = None + self.is_closed = False + self.is_stream_consumed = False + self.content_type = None + self._connection_data_block_size = None + self._json = None # this is filled in ContentDecodePolicy, when we deserialize + self._content = None # type: Optional[bytes] + + @property + def url(self) -> str: + """Returns the URL that resulted in this response""" + return self.request.url + + @property + def encoding(self) -> Optional[str]: + """Returns the response encoding. By default, is specified + by the response Content-Type header. + """ + try: + return self._encoding + except AttributeError: + return get_charset_encoding(self) + + @encoding.setter + def encoding(self, value: str) -> None: + """Sets the response encoding""" + self._encoding = value + + @property + def text(self) -> str: + """Returns the response body as a string""" + encoding = self.encoding + if encoding == "utf-8" or encoding is None: + encoding = "utf-8-sig" + return self.content.decode(encoding) + + def json(self) -> Any: + """Returns the whole body as a json object. + + :return: The JSON deserialized response body + :rtype: any + :raises json.decoder.JSONDecodeError or ValueError (in python 2.7) if object is not JSON decodable: + """ + # this will trigger errors if response is not read in + self.content # pylint: disable=pointless-statement + if not self._json: + self._json = loads(self.text) + return self._json + + def raise_for_status(self) -> None: + """Raises an HttpResponseError if the response has an error status code. + + If response is good, does nothing. + """ + if cast(int, self.status_code) >= 400: + raise HttpResponseError(response=self) + + @property + def content(self) -> bytes: + """Return the response's content in bytes.""" + if self._content is None: + raise ResponseNotReadError(self) + return self._content + +class HttpResponse(_HttpResponseBase): + """**Provisional** object that represents an HTTP response. + + **This object is provisional**, meaning it may be changed in a future release. + + It is returned from your client's `send_request` method if you pass in + an :class:`~azure.core.rest.HttpRequest` + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + + >>> response = client.send_request(request) + + + :keyword request: The request that resulted in this response. + :paramtype request: ~azure.core.rest.HttpRequest + :ivar int status_code: The status code of this response + :ivar mapping headers: The response headers + :ivar str reason: The reason phrase for this response + :ivar bytes content: The response content in bytes. + :ivar str url: The URL that resulted in this response + :ivar str encoding: The response encoding. Is settable, by default + is the response Content-Type header + :ivar str text: The response body as a string. + :ivar request: The request that resulted in this response. + :vartype request: ~azure.core.rest.HttpRequest + :ivar internal_response: The object returned from the HTTP library. + :ivar str content_type: The content type of the response + :ivar bool is_closed: Whether the network connection has been closed yet + :ivar bool is_stream_consumed: When getting a stream response, checks + whether the stream has been fully consumed + """ + + def __enter__(self) -> "HttpResponse": + return self + + def close(self) -> None: + """Close the response + + :return: None + :rtype: None + """ + self.is_closed = True + self._internal_response.close() + + def __exit__(self, *args) -> None: + self.close() + + def read(self) -> bytes: + """Read the response's bytes. + + :return: The read in bytes + :rtype: bytes + """ + if self._content is None: + self._content = b"".join(self.iter_bytes()) + return self.content + + def iter_raw(self) -> Iterator[bytes]: + """Iterates over the response's bytes. Will not decompress in the process + + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + raise NotImplementedError() + + def iter_bytes(self) -> Iterator[bytes]: + """Iterates over the response's bytes. Will decompress in the process + + :return: An iterator of bytes from the response + :rtype: Iterator[str] + """ + raise NotImplementedError() + + def iter_text(self) -> Iterator[str]: + """Iterates over the text in the response. + + :return: An iterator of string. Each string chunk will be a text from the response + :rtype: Iterator[str] + """ + for byte in self.iter_bytes(): + text = byte.decode(self.encoding or "utf-8") + yield text + + def iter_lines(self) -> Iterator[str]: + """Iterates over the lines in the response. + + :return: An iterator of string. Each string chunk will be a line from the response + :rtype: Iterator[str] + """ + for text in self.iter_text(): + lines = parse_lines_from_text(text) + for line in lines: + yield line + + def __repr__(self) -> str: + content_type_str = ( + ", Content-Type: {}".format(self.content_type) if self.content_type else "" + ) + return "".format( + self.status_code, self.reason, content_type_str + ) + +class AsyncHttpResponse(_HttpResponseBase): + """**Provisional** object that represents an Async HTTP response. + + **This object is provisional**, meaning it may be changed in a future release. + + It is returned from your async client's `send_request` method if you pass in + an :class:`~azure.core.rest.HttpRequest` + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest('GET', 'http://www.example.com') + + >>> response = await client.send_request(request) + + + :keyword request: The request that resulted in this response. + :paramtype request: ~azure.core.rest.HttpRequest + :keyword internal_response: The object returned from the HTTP library. + :ivar int status_code: The status code of this response + :ivar mapping headers: The response headers + :ivar str reason: The reason phrase for this response + :ivar bytes content: The response content in bytes. + :ivar str url: The URL that resulted in this response + :ivar str encoding: The response encoding. Is settable, by default + is the response Content-Type header + :ivar str text: The response body as a string. + :ivar request: The request that resulted in this response. + :vartype request: ~azure.core.rest.HttpRequest + :ivar internal_response: The object returned from the HTTP library. + :ivar str content_type: The content type of the response + :ivar bool is_closed: Whether the network connection has been closed yet + :ivar bool is_stream_consumed: When getting a stream response, checks + whether the stream has been fully consumed + """ + + async def read(self) -> bytes: + """Read the response's bytes into memory. + + :return: The response's bytes + :rtype: bytes + """ + if self._content is None: + parts = [] + async for part in self.iter_bytes(): + parts.append(part) + self._content = b"".join(parts) + return self._content + + async def iter_raw(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will not decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + raise NotImplementedError() + # getting around mypy behavior, see https://github.com/python/mypy/issues/10732 + yield # pylint: disable=unreachable + + async def iter_bytes(self) -> AsyncIterator[bytes]: + """Asynchronously iterates over the response's bytes. Will decompress in the process + + :return: An async iterator of bytes from the response + :rtype: AsyncIterator[bytes] + """ + raise NotImplementedError() + # getting around mypy behavior, see https://github.com/python/mypy/issues/10732 + yield # pylint: disable=unreachable + + async def iter_text(self) -> AsyncIterator[str]: + """Asynchronously iterates over the text in the response. + + :return: An async iterator of string. Each string chunk will be a text from the response + :rtype: AsyncIterator[str] + """ + async for byte in self.iter_bytes(): # type: ignore + text = byte.decode(self.encoding or "utf-8") + yield text + + async def iter_lines(self) -> AsyncIterator[str]: + """Asynchronously iterates over the lines in the response. + + :return: An async iterator of string. Each string chunk will be a line from the response + :rtype: AsyncIterator[str] + """ + async for text in self.iter_text(): + lines = parse_lines_from_text(text) + for line in lines: + yield line + + async def close(self) -> None: + """Close the response. + + :return: None + :rtype: None + """ + self.is_closed = True + await self._internal_response.close() + + async def __aexit__(self, *args) -> None: + await self.close() + + def __repr__(self) -> str: + content_type_str = ( + ", Content-Type: {}".format(self.content_type) if self.content_type else "" + ) + return "".format( + self.status_code, self.reason, content_type_str + ) diff --git a/sdk/core/azure-core/doc/azure.core.rst b/sdk/core/azure-core/doc/azure.core.rst index 36716ad49c62..c6f0ed92463b 100644 --- a/sdk/core/azure-core/doc/azure.core.rst +++ b/sdk/core/azure-core/doc/azure.core.rst @@ -72,3 +72,15 @@ azure.core.serialization :members: :undoc-members: :inherited-members: + +azure.core.rest +------------------- +***THIS MODULE IS PROVISIONAL*** + +This module is ***provisional***, meaning any of the objects and methods in this module may be changed. + +.. automodule:: azure.core.rest + :members: + :undoc-members: + :inherited-members: + diff --git a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py index 773a320804a4..92097a2265d8 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py @@ -26,8 +26,9 @@ async def __anext__(self): raise StopAsyncIteration async with AsyncioRequestsTransport() as transport: - req = HttpRequest('GET', 'http://httpbin.org/post', data=AsyncGen()) - await transport.send(req) + req = HttpRequest('GET', 'http://httpbin.org/anything', data=AsyncGen()) + response = await transport.send(req) + assert json.loads(response.text())['data'] == "azerty" @pytest.mark.asyncio async def test_send_data(): diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/conftest.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/conftest.py new file mode 100644 index 000000000000..17c93009a373 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/conftest.py @@ -0,0 +1,100 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import time +import pytest +import signal +import os +import subprocess +import sys +import random +from six.moves import urllib +from rest_client_async import AsyncTestRestClient + +def is_port_available(port_num): + req = urllib.request.Request("http://localhost:{}/health".format(port_num)) + try: + return urllib.request.urlopen(req).code != 200 + except Exception as e: + return True + +def get_port(): + count = 3 + for _ in range(count): + port_num = random.randrange(3000, 5000) + if is_port_available(port_num): + return port_num + raise TypeError("Tried {} times, can't find an open port".format(count)) + +@pytest.fixture +def port(): + return os.environ["FLASK_PORT"] + +def start_testserver(): + port = get_port() + os.environ["FLASK_APP"] = "coretestserver" + os.environ["FLASK_PORT"] = str(port) + cmd = "flask run -p {}".format(port) + if os.name == 'nt': #On windows, subprocess creation works without being in the shell + child_process = subprocess.Popen(cmd, env=dict(os.environ)) + else: + #On linux, have to set shell=True + child_process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setsid, env=dict(os.environ)) + count = 5 + for _ in range(count): + if not is_port_available(port): + return child_process + time.sleep(1) + raise ValueError("Didn't start!") + +def terminate_testserver(process): + if os.name == 'nt': + process.kill() + else: + os.killpg(os.getpgid(process.pid), signal.SIGTERM) # Send the signal to all the process groups + +@pytest.fixture(autouse=True, scope="package") +def testserver(): + """Start the Autorest testserver.""" + server = start_testserver() + yield + terminate_testserver(server) + + +# Ignore collection of async tests for Python 2 +collect_ignore_glob = [] +if sys.version_info < (3, 5): + collect_ignore_glob.append("*_async.py") + +@pytest.fixture +def client(port): + return AsyncTestRestClient(port) + +import sys + +# Ignore collection of async tests for Python 2 +collect_ignore = [] +if sys.version_info < (3, 5): + collect_ignore.append("async_tests") diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/rest_client_async.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/rest_client_async.py new file mode 100644 index 000000000000..1f2e3568bb02 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/rest_client_async.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from copy import deepcopy +from azure.core import AsyncPipelineClient +from azure.core.pipeline import policies +from azure.core.configuration import Configuration + +class TestRestClientConfiguration(Configuration): + def __init__( + self, **kwargs + ): + # type: (...) -> None + super(TestRestClientConfiguration, self).__init__(**kwargs) + + kwargs.setdefault("sdk_moniker", "autorestswaggerbatfileservice/1.0.0b1") + self._configure(**kwargs) + + def _configure(self, **kwargs) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + +class AsyncTestRestClient(object): + + def __init__(self, port, **kwargs): + self._config = TestRestClientConfiguration(**kwargs) + + self._client = AsyncPipelineClient( + base_url="http://localhost:{}".format(port), + config=self._config, + **kwargs + ) + + def send_request(self, request, **kwargs): + """Runs the network request through the client's chained policies. + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "http://localhost:3000/helloWorld") + + >>> response = await client.send_request(request) + + For more information on this code flow, see https://aka.ms/azsdk/python/protocol/quickstart + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.AsyncHttpResponse + """ + request_copy = deepcopy(request) + request_copy.url = self._client.format_url(request_copy.url) + return self._client.send_request(request_copy, **kwargs) + + async def close(self) -> None: + await self._client.close() + + async def __aenter__(self): + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details) -> None: + await self._client.__aexit__(*exc_details) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_asyncio_transport.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_asyncio_transport.py new file mode 100644 index 000000000000..126487f92aa7 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_asyncio_transport.py @@ -0,0 +1,43 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import json + +from azure.core.pipeline.transport import AsyncioRequestsTransport +from azure.core.rest import HttpRequest +from rest_client_async import AsyncTestRestClient + +import pytest + + +@pytest.mark.asyncio +async def test_async_gen_data(port): + class AsyncGen: + def __init__(self): + self._range = iter([b"azerty"]) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._range) + except StopIteration: + raise StopAsyncIteration + + async with AsyncioRequestsTransport() as transport: + client = AsyncTestRestClient(port, transport=transport) + request = HttpRequest('GET', 'http://httpbin.org/anything', content=AsyncGen()) + response = await client.send_request(request) + assert response.json()['data'] == "azerty" + +@pytest.mark.asyncio +async def test_send_data(port): + async with AsyncioRequestsTransport() as transport: + client = AsyncTestRestClient(port, transport=transport) + request = HttpRequest('PUT', 'http://httpbin.org/anything', content=b"azerty") + response = await client.send_request(request) + + assert response.json()['data'] == "azerty" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_context_manager_async.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_context_manager_async.py new file mode 100644 index 000000000000..4afcac42172f --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_context_manager_async.py @@ -0,0 +1,82 @@ +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.core.exceptions import HttpResponseError, ResponseNotReadError +import pytest +from azure.core.rest import HttpRequest +from rest_client_async import AsyncTestRestClient + +@pytest.mark.asyncio +async def test_normal_call(client): + async def _raise_and_get_text(response): + response.raise_for_status() + assert response.text == "Hello, world!" + assert response.is_closed + request = HttpRequest("GET", url="/basic/string") + response = await client.send_request(request) + await _raise_and_get_text(response) + assert response.is_closed + + async with client.send_request(request) as response: + await _raise_and_get_text(response) + + response = client.send_request(request) + async with response as response: + await _raise_and_get_text(response) + +@pytest.mark.asyncio +async def test_stream_call(client): + async def _raise_and_get_text(response): + response.raise_for_status() + assert not response.is_closed + with pytest.raises(ResponseNotReadError): + response.text + await response.read() + assert response.text == "Hello, world!" + assert response.is_closed + request = HttpRequest("GET", url="/streams/basic") + response = await client.send_request(request, stream=True) + await _raise_and_get_text(response) + assert response.is_closed + + async with client.send_request(request, stream=True) as response: + await _raise_and_get_text(response) + assert response.is_closed + + response = client.send_request(request, stream=True) + async with response as response: + await _raise_and_get_text(response) + +# TODO: commenting until https://github.com/Azure/azure-sdk-for-python/issues/18086 is fixed + +# @pytest.mark.asyncio +# async def test_stream_with_error(client): +# request = HttpRequest("GET", url="/streams/error") +# async with client.send_request(request, stream=True) as response: +# assert not response.is_closed +# with pytest.raises(HttpResponseError) as e: +# response.raise_for_status() +# error = e.value +# assert error.status_code == 400 +# assert error.reason == "BAD REQUEST" +# assert "Operation returned an invalid status 'BAD REQUEST'" in str(error) +# with pytest.raises(ResponseNotReadError): +# error.error +# with pytest.raises(ResponseNotReadError): +# error.model +# with pytest.raises(ResponseNotReadError): +# response.json() +# with pytest.raises(ResponseNotReadError): +# response.content + +# # NOW WE READ THE RESPONSE +# await response.read() +# assert error.status_code == 400 +# assert error.reason == "BAD REQUEST" +# assert error.error.code == "BadRequest" +# assert error.error.message == "You made a bad request" +# assert error.model.code == "BadRequest" +# assert error.error.message == "You made a bad request" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_request_async.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_request_async.py new file mode 100644 index 000000000000..67f9d419fb31 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_request_async.py @@ -0,0 +1,90 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +# NOTE: These tests are heavily inspired from the httpx test suite: https://github.com/encode/httpx/tree/master/tests +# Thank you httpx for your wonderful tests! +import pytest +from azure.core.rest import HttpRequest +from typing import AsyncGenerator +import collections.abc + +@pytest.fixture +def assert_aiterator_body(): + async def _comparer(request, final_value): + parts = [] + async for part in request.content: + parts.append(part) + content = b"".join(parts) + assert content == final_value + return _comparer + +def test_transfer_encoding_header(): + async def streaming_body(data): + yield data # pragma: nocover + + data = streaming_body(b"test 123") + + request = HttpRequest("POST", "http://example.org", data=data) + assert "Content-Length" not in request.headers + +def test_override_content_length_header(): + async def streaming_body(data): + yield data # pragma: nocover + + data = streaming_body(b"test 123") + headers = {"Content-Length": "0"} + + request = HttpRequest("POST", "http://example.org", data=data, headers=headers) + assert request.headers["Content-Length"] == "0" + +@pytest.mark.asyncio +async def test_aiterbale_content(assert_aiterator_body): + class Content: + async def __aiter__(self): + yield b"test 123" + + request = HttpRequest("POST", "http://example.org", content=Content()) + assert request.headers == {} + await assert_aiterator_body(request, b"test 123") + +@pytest.mark.asyncio +async def test_aiterator_content(assert_aiterator_body): + async def hello_world(): + yield b"Hello, " + yield b"world!" + + request = HttpRequest("POST", url="http://example.org", content=hello_world()) + assert not isinstance(request._data, collections.abc.Iterable) + assert isinstance(request._data, collections.abc.AsyncIterable) + + assert request.headers == {} + await assert_aiterator_body(request, b"Hello, world!") + + # Support 'data' for compat with requests. + request = HttpRequest("POST", url="http://example.org", data=hello_world()) + assert not isinstance(request._data, collections.abc.Iterable) + assert isinstance(request._data, collections.abc.AsyncIterable) + + assert request.headers == {} + await assert_aiterator_body(request, b"Hello, world!") + + # transfer encoding should not be set for GET requests + request = HttpRequest("GET", url="http://example.org", data=hello_world()) + assert not isinstance(request._data, collections.abc.Iterable) + assert isinstance(request._data, collections.abc.AsyncIterable) + + assert request.headers == {} + await assert_aiterator_body(request, b"Hello, world!") + +@pytest.mark.asyncio +async def test_read_content(assert_aiterator_body): + async def content(): + yield b"test 123" + + request = HttpRequest("POST", "http://example.org", content=content()) + await assert_aiterator_body(request, b"test 123") + # in this case, request._data is what we end up passing to the requests transport + assert isinstance(request._data, collections.abc.AsyncIterable) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_response_async.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_response_async.py new file mode 100644 index 000000000000..317b74c3bac0 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_http_response_async.py @@ -0,0 +1,280 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +# NOTE: These tests are heavily inspired from the httpx test suite: https://github.com/encode/httpx/tree/master/tests +# Thank you httpx for your wonderful tests! +import io +import pytest +from azure.core.rest import HttpRequest +from azure.core.exceptions import HttpResponseError + +@pytest.fixture +def send_request(client): + async def _send_request(request): + response = await client.send_request(request, stream=False) + response.raise_for_status() + return response + return _send_request + +@pytest.mark.asyncio +async def test_response(send_request, port): + response = await send_request( + HttpRequest("GET", "/basic/string"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + assert response.content == b"Hello, world!" + assert response.text == "Hello, world!" + assert response.request.method == "GET" + assert response.request.url == "http://localhost:{}/basic/string".format(port) + +@pytest.mark.asyncio +async def test_response_content(send_request): + response = await send_request( + request=HttpRequest("GET", "/basic/bytes"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + content = await response.read() + assert content == b"Hello, world!" + assert response.text == "Hello, world!" + +@pytest.mark.asyncio +async def test_response_text(send_request): + response = await send_request( + request=HttpRequest("GET", "/basic/string"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + content = await response.read() + assert content == b"Hello, world!" + assert response.text == "Hello, world!" + assert response.headers["Content-Length"] == '13' + assert response.headers['Content-Type'] == "text/plain; charset=utf-8" + +@pytest.mark.asyncio +async def test_response_html(send_request): + response = await send_request( + request=HttpRequest("GET", "/basic/html"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + content = await response.read() + assert content == b"Hello, world!" + assert response.text == "Hello, world!" + +@pytest.mark.asyncio +async def test_raise_for_status(client): + # response = await client.send_request( + # HttpRequest("GET", "/basic/string"), + # ) + # response.raise_for_status() + + response = await client.send_request( + HttpRequest("GET", "/errors/403"), + ) + assert response.status_code == 403 + with pytest.raises(HttpResponseError): + response.raise_for_status() + + response = await client.send_request( + HttpRequest("GET", "/errors/500"), + retry_total=0, # takes too long with retires on 500 + ) + assert response.status_code == 500 + with pytest.raises(HttpResponseError): + response.raise_for_status() + +@pytest.mark.asyncio +async def test_response_repr(send_request): + response = await send_request( + HttpRequest("GET", "/basic/string") + ) + assert repr(response) == "" + +@pytest.mark.asyncio +async def test_response_content_type_encoding(send_request): + """ + Use the charset encoding in the Content-Type header if possible. + """ + response = await send_request( + request=HttpRequest("GET", "/encoding/latin-1") + ) + await response.read() + assert response.content_type == "text/plain; charset=latin-1" + assert response.content == b'Latin 1: \xff' + assert response.text == "Latin 1: รฟ" + assert response.encoding == "latin-1" + + +@pytest.mark.asyncio +async def test_response_autodetect_encoding(send_request): + """ + Autodetect encoding if there is no Content-Type header. + """ + response = await send_request( + request=HttpRequest("GET", "/encoding/latin-1") + ) + await response.read() + assert response.text == u'Latin 1: รฟ' + assert response.encoding == "latin-1" + + +@pytest.mark.asyncio +async def test_response_fallback_to_autodetect(send_request): + """ + Fallback to autodetection if we get an invalid charset in the Content-Type header. + """ + response = await send_request( + request=HttpRequest("GET", "/encoding/invalid-codec-name") + ) + await response.read() + assert response.headers["Content-Type"] == "text/plain; charset=invalid-codec-name" + assert response.text == "ใŠใฏใ‚ˆใ†ใ”ใ–ใ„ใพใ™ใ€‚" + assert response.encoding is None + + +@pytest.mark.asyncio +async def test_response_no_charset_with_ascii_content(send_request): + """ + A response with ascii encoded content should decode correctly, + even with no charset specified. + """ + response = await send_request( + request=HttpRequest("GET", "/encoding/no-charset"), + ) + + assert response.headers["Content-Type"] == "text/plain" + assert response.status_code == 200 + assert response.encoding == 'ascii' + content = await response.read() + assert content == b"Hello, world!" + assert response.text == "Hello, world!" + + +@pytest.mark.asyncio +async def test_response_no_charset_with_iso_8859_1_content(send_request): + """ + A response with ISO 8859-1 encoded content should decode correctly, + even with no charset specified. + """ + response = await send_request( + request=HttpRequest("GET", "/encoding/iso-8859-1"), + ) + await response.read() + assert response.text == u"Accented: ร–sterreich" + assert response.encoding == 'ISO-8859-1' + +# NOTE: aiohttp isn't liking this +# @pytest.mark.asyncio +# async def test_response_set_explicit_encoding(send_request): +# response = await send_request( +# request=HttpRequest("GET", "/encoding/latin-1-with-utf-8"), +# ) +# assert response.headers["Content-Type"] == "text/plain; charset=utf-8" +# response.encoding = "latin-1" +# await response.read() +# assert response.text == "Latin 1: รฟ" +# assert response.encoding == "latin-1" + +@pytest.mark.asyncio +async def test_json(send_request): + response = await send_request( + request=HttpRequest("GET", "/basic/json"), + ) + await response.read() + assert response.json() == {"greeting": "hello", "recipient": "world"} + assert response.encoding == 'utf-8' + +@pytest.mark.asyncio +async def test_json_with_specified_encoding(send_request): + response = await send_request( + request=HttpRequest("GET", "/encoding/json"), + ) + await response.read() + assert response.json() == {"greeting": "hello", "recipient": "world"} + assert response.encoding == "utf-16" + +@pytest.mark.asyncio +async def test_emoji(send_request): + response = await send_request( + request=HttpRequest("GET", "/encoding/emoji"), + ) + await response.read() + assert response.text == "๐Ÿ‘ฉ" + +@pytest.mark.asyncio +async def test_emoji_family_with_skin_tone_modifier(send_request): + response = await send_request( + request=HttpRequest("GET", "/encoding/emoji-family-skin-tone-modifier"), + ) + await response.read() + assert response.text == "๐Ÿ‘ฉ๐Ÿปโ€๐Ÿ‘ฉ๐Ÿฝโ€๐Ÿ‘ง๐Ÿพโ€๐Ÿ‘ฆ๐Ÿฟ SSN: 859-98-0987" + +@pytest.mark.asyncio +async def test_korean_nfc(send_request): + response = await send_request( + request=HttpRequest("GET", "/encoding/korean"), + ) + await response.read() + assert response.text == "์•„๊ฐ€" + +@pytest.mark.asyncio +async def test_urlencoded_content(send_request): + await send_request( + request=HttpRequest( + "POST", + "/urlencoded/pet/add/1", + data={ "pet_type": "dog", "pet_food": "meat", "name": "Fido", "pet_age": 42 } + ), + ) + +@pytest.mark.asyncio +async def test_multipart_files_content(send_request): + request = HttpRequest( + "POST", + "/multipart/basic", + files={"fileContent": io.BytesIO(b"")}, + ) + await send_request(request) + +@pytest.mark.asyncio +async def test_send_request_return_pipeline_response(client): + # we use return_pipeline_response for some cases in autorest + request = HttpRequest("GET", "/basic/string") + response = await client.send_request(request, _return_pipeline_response=True) + assert hasattr(response, "http_request") + assert hasattr(response, "http_response") + assert hasattr(response, "context") + assert response.http_response.text == "Hello, world!" + assert hasattr(response.http_request, "content") + +# @pytest.mark.asyncio +# async def test_multipart_encode_non_seekable_filelike(send_request): +# """ +# Test that special readable but non-seekable filelike objects are supported, +# at the cost of reading them into memory at most once. +# """ + +# class IteratorIO(io.IOBase): +# def __init__(self, iterator): +# self._iterator = iterator + +# def read(self, *args): +# return b"".join(self._iterator) + +# def data(): +# yield b"Hello" +# yield b"World" + +# fileobj = IteratorIO(data()) +# files = {"file": fileobj} +# request = HttpRequest( +# "POST", +# "/multipart/non-seekable-filelike", +# files=files, +# ) +# await send_request(request) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_stream_responses_async.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_stream_responses_async.py new file mode 100644 index 000000000000..673148749719 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_stream_responses_async.py @@ -0,0 +1,206 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.core.exceptions import HttpResponseError, ServiceRequestError +import functools +import os +import json +import pytest +from azure.core.rest import HttpRequest +from azure.core.exceptions import StreamClosedError, StreamConsumedError, ResponseNotReadError + +@pytest.mark.asyncio +async def test_iter_raw(client): + request = HttpRequest("GET", "/streams/basic") + async with client.send_request(request, stream=True) as response: + raw = b"" + async for part in response.iter_raw(): + raw += part + assert raw == b"Hello, world!" + +@pytest.mark.asyncio +async def test_iter_raw_on_iterable(client): + request = HttpRequest("GET", "/streams/iterable") + + async with client.send_request(request, stream=True) as response: + raw = b"" + async for part in response.iter_raw(): + raw += part + assert raw == b"Hello, world!" + +@pytest.mark.asyncio +async def test_iter_with_error(client): + request = HttpRequest("GET", "/errors/403") + + async with client.send_request(request, stream=True) as response: + try: + response.raise_for_status() + except HttpResponseError as e: + pass + assert response.is_closed + + try: + async with client.send_request(request, stream=True) as response: + response.raise_for_status() + except HttpResponseError as e: + pass + + assert response.is_closed + + request = HttpRequest("GET", "http://doesNotExist") + with pytest.raises(ServiceRequestError): + async with (await client.send_request(request, stream=True)): + raise ValueError("Should error before entering") + assert response.is_closed + +@pytest.mark.asyncio +async def test_iter_bytes(client): + request = HttpRequest("GET", "/streams/basic") + + async with client.send_request(request, stream=True) as response: + raw = b"" + async for chunk in response.iter_bytes(): + assert response.is_stream_consumed + assert not response.is_closed + raw += chunk + assert response.is_stream_consumed + assert response.is_closed + assert raw == b"Hello, world!" + +@pytest.mark.asyncio +async def test_iter_text(client): + request = HttpRequest("GET", "/basic/string") + + async with client.send_request(request, stream=True) as response: + content = "" + async for part in response.iter_text(): + content += part + assert content == "Hello, world!" + +@pytest.mark.asyncio +async def test_iter_lines(client): + request = HttpRequest("GET", "/basic/lines") + + async with client.send_request(request, stream=True) as response: + content = [] + async for line in response.iter_lines(): + content.append(line) + assert content == ["Hello,\n", "world!"] + + +@pytest.mark.asyncio +async def test_streaming_response(client): + request = HttpRequest("GET", "/streams/basic") + + async with client.send_request(request, stream=True) as response: + assert response.status_code == 200 + assert not response.is_closed + + content = await response.read() + + assert content == b"Hello, world!" + assert response.content == b"Hello, world!" + assert response.is_closed + +@pytest.mark.asyncio +async def test_cannot_read_after_stream_consumed(port, client): + request = HttpRequest("GET", "/streams/basic") + async with client.send_request(request, stream=True) as response: + content = b"" + async for chunk in response.iter_bytes(): + content += chunk + + with pytest.raises(StreamConsumedError) as ex: + await response.read() + assert "".format(port) in str(ex.value) + assert "You have likely already consumed this stream, so it can not be accessed anymore" in str(ex.value) + + +@pytest.mark.asyncio +async def test_cannot_read_after_response_closed(port, client): + request = HttpRequest("GET", "/streams/basic") + async with client.send_request(request, stream=True) as response: + pass + + with pytest.raises(StreamClosedError) as ex: + await response.read() + assert "".format(port) in str(ex.value) + assert "can no longer be read or streamed, since the response has already been closed" in str(ex.value) + +@pytest.mark.asyncio +async def test_decompress_plain_no_header(client): + # thanks to Xiang Yan for this test! + account_name = "coretests" + url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) + request = HttpRequest("GET", url) + async with client: + response = await client.send_request(request, stream=True) + with pytest.raises(ResponseNotReadError): + response.content + await response.read() + assert response.content == b"test" + +@pytest.mark.asyncio +async def test_compress_plain_no_header(client): + # thanks to Xiang Yan for this test! + account_name = "coretests" + url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) + request = HttpRequest("GET", url) + async with client: + response = await client.send_request(request, stream=True) + iter = response.iter_raw() + data = b"" + async for d in iter: + data += d + assert data == b"test" + +@pytest.mark.asyncio +async def test_iter_read_back_and_forth(client): + # thanks to McCoy Patiรฑo for this test! + + # while this test may look like it's exposing buggy behavior, this is httpx's behavior + # the reason why the code flow is like this, is because the 'iter_x' functions don't + # actually read the contents into the response, the output them. Once they're yielded, + # the stream is closed, so you have to catch the output when you iterate through it + request = HttpRequest("GET", "/basic/lines") + + async with client.send_request(request, stream=True) as response: + async for line in response.iter_lines(): + assert line + with pytest.raises(ResponseNotReadError): + response.text + with pytest.raises(StreamConsumedError): + await response.read() + with pytest.raises(ResponseNotReadError): + response.text + +@pytest.mark.asyncio +async def test_stream_with_return_pipeline_response(client): + request = HttpRequest("GET", "/basic/lines") + pipeline_response = await client.send_request(request, stream=True, _return_pipeline_response=True) + assert hasattr(pipeline_response, "http_request") + assert hasattr(pipeline_response.http_request, "content") + assert hasattr(pipeline_response, "http_response") + assert hasattr(pipeline_response, "context") + parts = [] + async for line in pipeline_response.http_response.iter_lines(): + parts.append(line) + assert parts == ['Hello,\n', 'world!'] + await client.close() + +@pytest.mark.asyncio +async def test_error_reading(client): + request = HttpRequest("GET", "/errors/403") + async with client.send_request(request, stream=True) as response: + await response.read() + assert response.content == b"" + response.content + + response = await client.send_request(request, stream=True) + with pytest.raises(HttpResponseError): + response.raise_for_status() + await response.read() + assert response.content == b"" + await client.close() diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_trio_transport.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_trio_transport.py new file mode 100644 index 000000000000..7e563ca3d6c5 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_trio_transport.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.core.pipeline.transport import TrioRequestsTransport +from azure.core.rest import HttpRequest +from rest_client_async import AsyncTestRestClient + +import pytest + + +@pytest.mark.trio +async def test_async_gen_data(port): + class AsyncGen: + def __init__(self): + self._range = iter([b"azerty"]) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._range) + except StopIteration: + raise StopAsyncIteration + + async with TrioRequestsTransport() as transport: + client = AsyncTestRestClient(port, transport=transport) + request = HttpRequest('GET', 'http://httpbin.org/anything', content=AsyncGen()) + response = await client.send_request(request) + assert response.json()['data'] == "azerty" + +@pytest.mark.trio +async def test_send_data(port): + async with TrioRequestsTransport() as transport: + request = HttpRequest('PUT', 'http://httpbin.org/anything', content=b"azerty") + client = AsyncTestRestClient(port, transport=transport) + response = await client.send_request(request) + + assert response.json()['data'] == "azerty" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/test_testserver_async.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_testserver_async.py similarity index 100% rename from sdk/core/azure-core/tests/testserver_tests/test_testserver_async.py rename to sdk/core/azure-core/tests/testserver_tests/async_tests/test_testserver_async.py diff --git a/sdk/core/azure-core/tests/testserver_tests/conftest.py b/sdk/core/azure-core/tests/testserver_tests/conftest.py index 10a99fb3ce21..422904288fd1 100644 --- a/sdk/core/azure-core/tests/testserver_tests/conftest.py +++ b/sdk/core/azure-core/tests/testserver_tests/conftest.py @@ -28,9 +28,15 @@ import signal import os import subprocess -import sys import random from six.moves import urllib +from rest_client import TestRestClient +import sys + +# Ignore collection of async tests for Python 2 +collect_ignore = [] +if sys.version_info < (3, 5): + collect_ignore.append("async_tests") def is_port_available(port_num): req = urllib.request.Request("http://localhost:{}/health".format(port_num)) @@ -81,8 +87,6 @@ def testserver(): yield terminate_testserver(server) - -# Ignore collection of async tests for Python 2 -collect_ignore_glob = [] -if sys.version_info < (3, 5): - collect_ignore_glob.append("*_async.py") +@pytest.fixture +def client(port): + return TestRestClient(port) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py index 4b7d5ae92ad4..0f0735522d11 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py @@ -1,4 +1,4 @@ -# coding: utf-8 +# -*- coding: utf-8 -*- # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See LICENSE.txt in the project root for @@ -56,9 +56,9 @@ def complicated_json(): assert request.json['SpacesAfterUnicode'] == 'Text ' assert request.json['SpacesBeforeAndAfterByte'] == ' Text ' assert request.json['SpacesBeforeAndAfterUnicode'] == ' Text ' - assert request.json['ๅ•Š้ฝ„ไธ‚็‹›'] == '๊€•' + assert request.json[u'ๅ•Š้ฝ„ไธ‚็‹›'] == u'๊€•' assert request.json['RowKey'] == 'test2' - assert request.json['ๅ•Š้ฝ„ไธ‚็‹›็‹œ'] == 'hello' + assert request.json[u'ๅ•Š้ฝ„ไธ‚็‹›็‹œ'] == 'hello' assert request.json["singlequote"] == "a''''b" assert request.json["doublequote"] == 'a""""b' assert request.json["None"] == None diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py index 12224e568ee5..104ef7608bd0 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py @@ -1,4 +1,4 @@ -# coding: utf-8 +# -*- coding: utf-8 -*- # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See LICENSE.txt in the project root for @@ -15,7 +15,7 @@ @encoding_api.route('/latin-1', methods=['GET']) def latin_1(): r = Response( - "Latin 1: รฟ".encode("latin-1"), status=200 + u"Latin 1: รฟ".encode("latin-1"), status=200 ) r.headers["Content-Type"] = "text/plain; charset=latin-1" return r @@ -23,7 +23,7 @@ def latin_1(): @encoding_api.route('/latin-1-with-utf-8', methods=['GET']) def latin_1_charset_utf8(): r = Response( - "Latin 1: รฟ".encode("latin-1"), status=200 + u"Latin 1: รฟ".encode("latin-1"), status=200 ) r.headers["Content-Type"] = "text/plain; charset=utf-8" return r @@ -39,7 +39,7 @@ def latin_1_no_charset(): @encoding_api.route('/iso-8859-1', methods=['GET']) def iso_8859_1(): r = Response( - "Accented: ร–sterreich".encode("iso-8859-1"), status=200 + u"Accented: ร–sterreich".encode("iso-8859-1"), status=200 ) r.headers["Content-Type"] = "text/plain" return r @@ -47,14 +47,14 @@ def iso_8859_1(): @encoding_api.route('/emoji', methods=['GET']) def emoji(): r = Response( - "๐Ÿ‘ฉ", status=200 + u"๐Ÿ‘ฉ", status=200 ) return r @encoding_api.route('/emoji-family-skin-tone-modifier', methods=['GET']) def emoji_family_skin_tone_modifier(): r = Response( - "๐Ÿ‘ฉ๐Ÿปโ€๐Ÿ‘ฉ๐Ÿฝโ€๐Ÿ‘ง๐Ÿพโ€๐Ÿ‘ฆ๐Ÿฟ SSN: 859-98-0987", status=200 + u"๐Ÿ‘ฉ๐Ÿปโ€๐Ÿ‘ฉ๐Ÿฝโ€๐Ÿ‘ง๐Ÿพโ€๐Ÿ‘ฆ๐Ÿฟ SSN: 859-98-0987", status=200 ) return r diff --git a/sdk/core/azure-core/tests/testserver_tests/rest_client.py b/sdk/core/azure-core/tests/testserver_tests/rest_client.py new file mode 100644 index 000000000000..2d896ac3b6aa --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/rest_client.py @@ -0,0 +1,83 @@ + +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +from azure.core.pipeline import policies +from azure.core.configuration import Configuration +from azure.core import PipelineClient +from copy import deepcopy + + +class TestRestClientConfiguration(Configuration): + def __init__( + self, **kwargs + ): + # type: (...) -> None + super(TestRestClientConfiguration, self).__init__(**kwargs) + + kwargs.setdefault("sdk_moniker", "autorestswaggerbatfileservice/1.0.0b1") + self._configure(**kwargs) + + def _configure( + self, **kwargs + ): + # type: (...) -> None + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + +class TestRestClient(object): + + def __init__(self, port, **kwargs): + self._config = TestRestClientConfiguration(**kwargs) + self._client = PipelineClient( + base_url="http://localhost:{}/".format(port), + config=self._config, + **kwargs + ) + + def send_request(self, request, **kwargs): + """Runs the network request through the client's chained policies. + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "http://localhost:3000/helloWorld") + + >>> response = client.send_request(request) + + For more information on this code flow, see https://aka.ms/azsdk/python/protocol/quickstart + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.HttpResponse + """ + request_copy = deepcopy(request) + request_copy.url = self._client.format_url(request_copy.url) + return self._client.send_request(request_copy, **kwargs) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/test_rest_context_manager.py b/sdk/core/azure-core/tests/testserver_tests/test_rest_context_manager.py new file mode 100644 index 000000000000..34f802971537 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/test_rest_context_manager.py @@ -0,0 +1,78 @@ +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import pytest +from azure.core.rest import HttpRequest +from azure.core.exceptions import ResponseNotReadError + +def test_normal_call(client, port): + def _raise_and_get_text(response): + response.raise_for_status() + assert response.text == "Hello, world!" + assert response.is_closed + request = HttpRequest("GET", url="/basic/string") + response = client.send_request(request) + _raise_and_get_text(response) + assert response.is_closed + + with client.send_request(request) as response: + _raise_and_get_text(response) + + response = client.send_request(request) + with response as response: + _raise_and_get_text(response) + +def test_stream_call(client): + def _raise_and_get_text(response): + response.raise_for_status() + assert not response.is_closed + with pytest.raises(ResponseNotReadError): + response.text + response.read() + assert response.text == "Hello, world!" + assert response.is_closed + request = HttpRequest("GET", url="/streams/basic") + response = client.send_request(request, stream=True) + _raise_and_get_text(response) + assert response.is_closed + + with client.send_request(request, stream=True) as response: + _raise_and_get_text(response) + assert response.is_closed + + response = client.send_request(request, stream=True) + with response as response: + _raise_and_get_text(response) + +# TODO: commenting until https://github.com/Azure/azure-sdk-for-python/issues/18086 is fixed + +# def test_stream_with_error(client): +# request = HttpRequest("GET", url="/streams/error") +# with client.send_request(request, stream=True) as response: +# assert not response.is_closed +# with pytest.raises(HttpResponseError) as e: +# response.raise_for_status() +# error = e.value +# assert error.status_code == 400 +# assert error.reason == "BAD REQUEST" +# assert "Operation returned an invalid status 'BAD REQUEST'" in str(error) +# with pytest.raises(ResponseNotReadError): +# error.error +# with pytest.raises(ResponseNotReadError): +# error.model +# with pytest.raises(ResponseNotReadError): +# response.json() +# with pytest.raises(ResponseNotReadError): +# response.content + +# # NOW WE READ THE RESPONSE +# response.read() +# assert error.status_code == 400 +# assert error.reason == "BAD REQUEST" +# assert error.error.code == "BadRequest" +# assert error.error.message == "You made a bad request" +# assert error.model.code == "BadRequest" +# assert error.error.message == "You made a bad request" diff --git a/sdk/core/azure-core/tests/testserver_tests/test_rest_headers.py b/sdk/core/azure-core/tests/testserver_tests/test_rest_headers.py new file mode 100644 index 000000000000..30112c50c912 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/test_rest_headers.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import sys + +# NOTE: These tests are heavily inspired from the httpx test suite: https://github.com/encode/httpx/tree/master/tests +# Thank you httpx for your wonderful tests! +from azure.core.rest import HttpRequest + +def _get_headers(header_value): + request = HttpRequest(method="GET", url="http://example.org", headers=header_value) + return request.headers + +def test_headers(): + # headers still can't be list of tuples. Will uncomment once we add this support + # h = _get_headers([("a", "123"), ("a", "456"), ("b", "789")]) + # assert "a" in h + # assert "A" in h + # assert "b" in h + # assert "B" in h + # assert "c" not in h + # assert h["a"] == "123, 456" + # assert h.get("a") == "123, 456" + # assert h.get("nope", default=None) is None + # assert h.get_list("a") == ["123", "456"] + + # assert list(h.keys()) == ["a", "b"] + # assert list(h.values()) == ["123, 456", "789"] + # assert list(h.items()) == [("a", "123, 456"), ("b", "789")] + # assert list(h) == ["a", "b"] + # assert dict(h) == {"a": "123, 456", "b": "789"} + # assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])" + # assert h == [("a", "123"), ("b", "789"), ("a", "456")] + # assert h == [("a", "123"), ("A", "456"), ("b", "789")] + # assert h == {"a": "123", "A": "456", "b": "789"} + # assert h != "a: 123\nA: 456\nb: 789" + + h = _get_headers({"a": "123", "b": "789"}) + assert h["A"] == "123" + assert h["B"] == "789" + + +def test_header_mutations(): + h = _get_headers({}) + assert dict(h) == {} + h["a"] = "1" + assert dict(h) == {"a": "1"} + h["a"] = "2" + assert dict(h) == {"a": "2"} + h.setdefault("a", "3") + assert dict(h) == {"a": "2"} + h.setdefault("b", "4") + assert dict(h) == {"a": "2", "b": "4"} + del h["a"] + assert dict(h) == {"b": "4"} + + +def test_headers_insert_retains_ordering(): + h = _get_headers({"a": "a", "b": "b", "c": "c"}) + h["b"] = "123" + if sys.version_info >= (3, 6): + assert list(h.values()) == ["a", "123", "c"] + else: + assert set(list(h.values())) == set(["a", "123", "c"]) + + +def test_headers_insert_appends_if_new(): + h = _get_headers({"a": "a", "b": "b", "c": "c"}) + h["d"] = "123" + if sys.version_info >= (3, 6): + assert list(h.values()) == ["a", "b", "c", "123"] + else: + assert set(list(h.values())) == set(["a", "b", "c", "123"]) + + +def test_headers_insert_removes_all_existing(): + h = _get_headers([("a", "123"), ("a", "456")]) + h["a"] = "789" + assert dict(h) == {"a": "789"} + + +def test_headers_delete_removes_all_existing(): + h = _get_headers([("a", "123"), ("a", "456")]) + del h["a"] + assert dict(h) == {} + +def test_headers_not_override(): + request = HttpRequest("PUT", "http://example.org", json={"hello": "world"}, headers={"Content-Length": "5000", "Content-Type": "application/my-content-type"}) + assert request.headers["Content-Length"] == "5000" + assert request.headers["Content-Type"] == "application/my-content-type" + +# Can't support list of tuples. Will uncomment once we add that support + +# def test_multiple_headers(): +# """ +# `Headers.get_list` should support both split_commas=False and split_commas=True. +# """ +# h = _get_headers([("set-cookie", "a, b"), ("set-cookie", "c")]) +# assert h.get_list("Set-Cookie") == ["a, b", "c"] + +# h = _get_headers([("vary", "a, b"), ("vary", "c")]) +# assert h.get_list("Vary", split_commas=True) == ["a", "b", "c"] \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/test_rest_http_request.py b/sdk/core/azure-core/tests/testserver_tests/test_rest_http_request.py new file mode 100644 index 000000000000..a41a911c3a6d --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/test_rest_http_request.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +# NOTE: These tests are heavily inspired from the httpx test suite: https://github.com/encode/httpx/tree/master/tests +# Thank you httpx for your wonderful tests! +import io +import pytest +import sys +import collections +from typing import Generator +from azure.core.rest import HttpRequest + +@pytest.fixture +def assert_iterator_body(): + def _comparer(request, final_value): + content = b"".join([p for p in request.content]) + assert content == final_value + return _comparer + +def test_request_repr(): + request = HttpRequest("GET", "http://example.org") + assert repr(request) == "" + +def test_no_content(): + request = HttpRequest("GET", "http://example.org") + assert "Content-Length" not in request.headers + +def test_content_length_header(): + request = HttpRequest("POST", "http://example.org", content=b"test 123") + assert request.headers["Content-Length"] == "8" + + +def test_iterable_content(assert_iterator_body): + class Content: + def __iter__(self): + yield b"test 123" # pragma: nocover + + request = HttpRequest("POST", "http://example.org", content=Content()) + assert request.headers == {} + assert_iterator_body(request, b"test 123") + + +def test_generator_with_transfer_encoding_header(assert_iterator_body): + def content(): + yield b"test 123" # pragma: nocover + + request = HttpRequest("POST", "http://example.org", content=content()) + assert request.headers == {} + assert_iterator_body(request, b"test 123") + + +def test_generator_with_content_length_header(assert_iterator_body): + def content(): + yield b"test 123" # pragma: nocover + + headers = {"Content-Length": "8"} + request = HttpRequest( + "POST", "http://example.org", content=content(), headers=headers + ) + assert request.headers == {"Content-Length": "8"} + assert_iterator_body(request, b"test 123") + + +def test_url_encoded_data(): + request = HttpRequest("POST", "http://example.org", data={"test": "123"}) + + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.content == {'test': '123'} # httpx makes this just b'test=123'. set_formdata_body is still keeping it as a dict + + +def test_json_encoded_data(): + request = HttpRequest("POST", "http://example.org", json={"test": 123}) + + assert request.headers["Content-Type"] == "application/json" + assert request.content == '{"test": 123}' + + +def test_headers(): + request = HttpRequest("POST", "http://example.org", json={"test": 123}) + + assert request.headers == { + "Content-Type": "application/json", + "Content-Length": "13", + } + + +def test_ignore_transfer_encoding_header_if_content_length_exists(): + """ + `Transfer-Encoding` should be ignored if `Content-Length` has been set explicitly. + See https://github.com/encode/httpx/issues/1168 + """ + + def streaming_body(data): + yield data # pragma: nocover + + data = streaming_body(b"abcd") + + headers = {"Content-Length": "4"} + request = HttpRequest("POST", "http://example.org", data=data, headers=headers) + assert "Transfer-Encoding" not in request.headers + assert request.headers["Content-Length"] == "4" + +def test_override_accept_encoding_header(): + headers = {"Accept-Encoding": "identity"} + + request = HttpRequest("GET", "http://example.org", headers=headers) + assert request.headers["Accept-Encoding"] == "identity" + +"""Test request body""" +def test_empty_content(): + request = HttpRequest("GET", "http://example.org") + assert request.content is None + +def test_string_content(): + request = HttpRequest("PUT", "http://example.org", content="Hello, world!") + assert request.headers == {"Content-Length": "13", "Content-Type": "text/plain"} + assert request.content == "Hello, world!" + + # Support 'data' for compat with requests. + request = HttpRequest("PUT", "http://example.org", data="Hello, world!") + + assert request.headers == {"Content-Length": "13", "Content-Type": "text/plain"} + assert request.content == "Hello, world!" + + # content length should not be set for GET requests + + request = HttpRequest("GET", "http://example.org", data="Hello, world!") + + assert request.headers == {"Content-Length": "13", "Content-Type": "text/plain"} + assert request.content == "Hello, world!" + +@pytest.mark.skipif(sys.version_info < (3, 0), + reason="In 2.7, b'' is the same as a string, so will have text/plain content type") +def test_bytes_content(): + request = HttpRequest("PUT", "http://example.org", content=b"Hello, world!") + assert request.headers == {"Content-Length": "13"} + assert request.content == b"Hello, world!" + + # Support 'data' for compat with requests. + request = HttpRequest("PUT", "http://example.org", data=b"Hello, world!") + + assert request.headers == {"Content-Length": "13"} + assert request.content == b"Hello, world!" + + # should still be set regardless of method + + request = HttpRequest("GET", "http://example.org", data=b"Hello, world!") + + assert request.headers == {"Content-Length": "13"} + assert request.content == b"Hello, world!" + +def test_iterator_content(assert_iterator_body): + # NOTE: in httpx, content reads out the actual value. Don't do that (yet) in azure rest + def hello_world(): + yield b"Hello, " + yield b"world!" + + request = HttpRequest("POST", url="http://example.org", content=hello_world()) + assert isinstance(request.content, collections.Iterable) + + assert_iterator_body(request, b"Hello, world!") + assert request.headers == {} + + # Support 'data' for compat with requests. + request = HttpRequest("POST", url="http://example.org", data=hello_world()) + assert isinstance(request.content, collections.Iterable) + + assert_iterator_body(request, b"Hello, world!") + assert request.headers == {} + + # transfer encoding should still be set for GET requests + request = HttpRequest("GET", url="http://example.org", data=hello_world()) + assert isinstance(request.content, collections.Iterable) + + assert_iterator_body(request, b"Hello, world!") + assert request.headers == {} + + +def test_json_content(): + request = HttpRequest("POST", url="http://example.org", json={"Hello": "world!"}) + + assert request.headers == { + "Content-Length": "19", + "Content-Type": "application/json", + } + assert request.content == '{"Hello": "world!"}' + +def test_urlencoded_content(): + # NOTE: not adding content length setting and content testing bc we're not adding content length in the rest code + # that's dealt with later in the pipeline. + request = HttpRequest("POST", url="http://example.org", data={"Hello": "world!"}) + assert request.headers == { + "Content-Type": "application/x-www-form-urlencoded", + } + +@pytest.mark.parametrize(("key"), (1, 2.3, None)) +def test_multipart_invalid_key(key): + + data = {key: "abc"} + files = {"file": io.BytesIO(b"")} + with pytest.raises(TypeError) as e: + HttpRequest( + url="http://127.0.0.1:8000/", + method="POST", + data=data, + files=files, + ) + assert "Invalid type for data name" in str(e.value) + assert repr(key) in str(e.value) + + +@pytest.mark.skipif(sys.version_info < (3, 0), + reason="In 2.7, b'' is the same as a string, so check doesn't fail") +def test_multipart_invalid_key_binary_string(): + + data = {b"abc": "abc"} + files = {"file": io.BytesIO(b"")} + with pytest.raises(TypeError) as e: + HttpRequest( + url="http://127.0.0.1:8000/", + method="POST", + data=data, + files=files, + ) + assert "Invalid type for data name" in str(e.value) + assert repr(b"abc") in str(e.value) + +@pytest.mark.parametrize(("value"), (object(), {"key": "value"})) +def test_multipart_invalid_value(value): + + data = {"text": value} + files = {"file": io.BytesIO(b"")} + with pytest.raises(TypeError) as e: + HttpRequest("POST", "http://127.0.0.1:8000/", data=data, files=files) + assert "Invalid type for data value" in str(e.value) + +def test_empty_request(): + request = HttpRequest("POST", url="http://example.org", data={}, files={}) + + assert request.headers == {} + assert not request.content # in core, we don't convert urlencoded dict to bytes representation in content + +def test_read_content(assert_iterator_body): + def content(): + yield b"test 123" + + request = HttpRequest("POST", "http://example.org", content=content()) + assert_iterator_body(request, b"test 123") + # in this case, request._data is what we end up passing to the requests transport + assert isinstance(request._data, collections.Iterable) + +def test_complicated_json(client): + # thanks to Sean Kane for this test! + input = { + 'EmptyByte': '', + 'EmptyUnicode': '', + 'SpacesOnlyByte': ' ', + 'SpacesOnlyUnicode': ' ', + 'SpacesBeforeByte': ' Text', + 'SpacesBeforeUnicode': ' Text', + 'SpacesAfterByte': 'Text ', + 'SpacesAfterUnicode': 'Text ', + 'SpacesBeforeAndAfterByte': ' Text ', + 'SpacesBeforeAndAfterUnicode': ' Text ', + 'ๅ•Š้ฝ„ไธ‚็‹›': '๊€•', + 'RowKey': 'test2', + 'ๅ•Š้ฝ„ไธ‚็‹›็‹œ': 'hello', + "singlequote": "a''''b", + "doublequote": 'a""""b', + "None": None, + } + request = HttpRequest("POST", "/basic/complicated-json", json=input) + r = client.send_request(request) + r.raise_for_status() + +# NOTE: For files, we don't allow list of tuples yet, just dict. Will uncomment when we add this capability +# def test_multipart_multiple_files_single_input_content(): +# files = [ +# ("file", io.BytesIO(b"")), +# ("file", io.BytesIO(b"")), +# ] +# request = HttpRequest("POST", url="http://example.org", files=files) +# assert request.headers == { +# "Content-Length": "271", +# "Content-Type": "multipart/form-data; boundary=+++", +# } +# assert request.content == b"".join( +# [ +# b"--+++\r\n", +# b'Content-Disposition: form-data; name="file"; filename="upload"\r\n', +# b"Content-Type: application/octet-stream\r\n", +# b"\r\n", +# b"\r\n", +# b"--+++\r\n", +# b'Content-Disposition: form-data; name="file"; filename="upload"\r\n', +# b"Content-Type: application/octet-stream\r\n", +# b"\r\n", +# b"\r\n", +# b"--+++--\r\n", +# ] +# ) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/test_rest_http_response.py b/sdk/core/azure-core/tests/testserver_tests/test_rest_http_response.py new file mode 100644 index 000000000000..83255119f4ab --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/test_rest_http_response.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +# NOTE: These tests are heavily inspired from the httpx test suite: https://github.com/encode/httpx/tree/master/tests +# Thank you httpx for your wonderful tests! +import io +import sys +import pytest +from azure.core.rest import HttpRequest +from azure.core.exceptions import HttpResponseError +import xml.etree.ElementTree as ET + +@pytest.fixture +def send_request(client): + def _send_request(request): + response = client.send_request(request, stream=False) + response.raise_for_status() + return response + return _send_request + +def test_response(send_request, port): + response = send_request( + request=HttpRequest("GET", "/basic/string"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + assert response.text == "Hello, world!" + assert response.request.method == "GET" + assert response.request.url == "http://localhost:{}/basic/string".format(port) + + +def test_response_content(send_request): + response = send_request( + request=HttpRequest("GET", "/basic/bytes"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + assert response.text == "Hello, world!" + + +def test_response_text(send_request): + response = send_request( + request=HttpRequest("GET", "/basic/string"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + assert response.text == "Hello, world!" + assert response.headers["Content-Length"] == '13' + assert response.headers['Content-Type'] == "text/plain; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" + +def test_response_html(send_request): + response = send_request( + request=HttpRequest("GET", "/basic/html"), + ) + assert response.status_code == 200 + assert response.reason == "OK" + assert response.text == "Hello, world!" + +def test_raise_for_status(client): + response = client.send_request( + HttpRequest("GET", "/basic/string"), + ) + response.raise_for_status() + + response = client.send_request( + HttpRequest("GET", "/errors/403"), + ) + assert response.status_code == 403 + with pytest.raises(HttpResponseError): + response.raise_for_status() + + response = client.send_request( + HttpRequest("GET", "/errors/500"), + retry_total=0, # takes too long with retires on 500 + ) + assert response.status_code == 500 + with pytest.raises(HttpResponseError): + response.raise_for_status() + +def test_response_repr(send_request): + response = send_request( + request=HttpRequest("GET", "/basic/string") + ) + assert repr(response) == "" + +def test_response_content_type_encoding(send_request): + """ + Use the charset encoding in the Content-Type header if possible. + """ + response = send_request( + request=HttpRequest("GET", "/encoding/latin-1") + ) + assert response.content_type == "text/plain; charset=latin-1" + assert response.text == u"Latin 1: รฟ" + assert response.encoding == "latin-1" + + +def test_response_autodetect_encoding(send_request): + """ + Autodetect encoding if there is no Content-Type header. + """ + response = send_request( + request=HttpRequest("GET", "/encoding/latin-1") + ) + + assert response.text == u'Latin 1: รฟ' + assert response.encoding == "latin-1" + +@pytest.mark.skipif(sys.version_info < (3, 0), + reason="In 2.7, b'' is the same as a string, so will have text/plain content type") +def test_response_fallback_to_autodetect(send_request): + """ + Fallback to autodetection if we get an invalid charset in the Content-Type header. + """ + response = send_request( + request=HttpRequest("GET", "/encoding/invalid-codec-name") + ) + + assert response.headers["Content-Type"] == "text/plain; charset=invalid-codec-name" + assert response.text == u"ใŠใฏใ‚ˆใ†ใ”ใ–ใ„ใพใ™ใ€‚" + assert response.encoding is None + + +def test_response_no_charset_with_ascii_content(send_request): + """ + A response with ascii encoded content should decode correctly, + even with no charset specified. + """ + response = send_request( + request=HttpRequest("GET", "/encoding/no-charset"), + ) + + assert response.headers["Content-Type"] == "text/plain" + assert response.status_code == 200 + assert response.encoding == 'ascii' + assert response.text == "Hello, world!" + + +def test_response_no_charset_with_iso_8859_1_content(send_request): + """ + A response with ISO 8859-1 encoded content should decode correctly, + even with no charset specified. + """ + response = send_request( + request=HttpRequest("GET", "/encoding/iso-8859-1"), + ) + assert response.text == u"Accented: ร–sterreich" + assert response.encoding == 'ISO-8859-1' + +def test_response_set_explicit_encoding(send_request): + # Deliberately incorrect charset + response = send_request( + request=HttpRequest("GET", "/encoding/latin-1-with-utf-8"), + ) + assert response.headers["Content-Type"] == "text/plain; charset=utf-8" + response.encoding = "latin-1" + assert response.text == u"Latin 1: รฟ" + assert response.encoding == "latin-1" + +def test_json(send_request): + response = send_request( + request=HttpRequest("GET", "/basic/json"), + ) + assert response.json() == {"greeting": "hello", "recipient": "world"} + assert response.encoding == 'utf-8-sig' # for requests, we use utf-8-sig instead of utf-8 bc of requests behavior + +def test_json_with_specified_encoding(send_request): + response = send_request( + request=HttpRequest("GET", "/encoding/json"), + ) + assert response.json() == {"greeting": "hello", "recipient": "world"} + assert response.encoding == "utf-16" + +def test_emoji(send_request): + response = send_request( + request=HttpRequest("GET", "/encoding/emoji"), + ) + assert response.text == u"๐Ÿ‘ฉ" + +def test_emoji_family_with_skin_tone_modifier(send_request): + response = send_request( + request=HttpRequest("GET", "/encoding/emoji-family-skin-tone-modifier"), + ) + assert response.text == u"๐Ÿ‘ฉ๐Ÿปโ€๐Ÿ‘ฉ๐Ÿฝโ€๐Ÿ‘ง๐Ÿพโ€๐Ÿ‘ฆ๐Ÿฟ SSN: 859-98-0987" + +def test_korean_nfc(send_request): + response = send_request( + request=HttpRequest("GET", "/encoding/korean"), + ) + assert response.text == u"์•„๊ฐ€" + +def test_urlencoded_content(send_request): + send_request( + request=HttpRequest( + "POST", + "/urlencoded/pet/add/1", + data={ "pet_type": "dog", "pet_food": "meat", "name": "Fido", "pet_age": 42 } + ), + ) + +def test_multipart_files_content(send_request): + request = HttpRequest( + "POST", + "/multipart/basic", + files={"fileContent": io.BytesIO(b"")}, + ) + send_request(request) + +def test_multipart_data_and_files_content(send_request): + request = HttpRequest( + "POST", + "/multipart/data-and-files", + data={"message": "Hello, world!"}, + files={"fileContent": io.BytesIO(b"")}, + ) + send_request(request) + +@pytest.mark.skipif(sys.version_info < (3, 0), + reason="In 2.7, get requests error even if we use a pipelien transport") +def test_multipart_encode_non_seekable_filelike(send_request): + """ + Test that special readable but non-seekable filelike objects are supported, + at the cost of reading them into memory at most once. + """ + + class IteratorIO(io.IOBase): + def __init__(self, iterator): + self._iterator = iterator + + def read(self, *args): + return b"".join(self._iterator) + + def data(): + yield b"Hello" + yield b"World" + + fileobj = IteratorIO(data()) + files = {"file": fileobj} + request = HttpRequest( + "POST", + "/multipart/non-seekable-filelike", + files=files, + ) + send_request(request) + +def test_get_xml_basic(send_request): + request = HttpRequest( + "GET", + "/xml/basic", + ) + response = send_request(request) + parsed_xml = ET.fromstring(response.text) + assert parsed_xml.tag == 'slideshow' + attributes = parsed_xml.attrib + assert attributes['title'] == "Sample Slide Show" + assert attributes['date'] == "Date of publication" + assert attributes['author'] == "Yours Truly" + +def test_put_xml_basic(send_request): + + basic_body = """ + + + Wake up to WonderWidgets! + + + Overview + Why WonderWidgets are great + + Who buys WonderWidgets + +""" + + request = HttpRequest( + "PUT", + "/xml/basic", + content=ET.fromstring(basic_body), + ) + send_request(request) + +def test_send_request_return_pipeline_response(client): + # we use return_pipeline_response for some cases in autorest + request = HttpRequest("GET", "/basic/string") + response = client.send_request(request, _return_pipeline_response=True) + assert hasattr(response, "http_request") + assert hasattr(response, "http_response") + assert hasattr(response, "context") + assert response.http_response.text == "Hello, world!" + assert hasattr(response.http_request, "content") diff --git a/sdk/core/azure-core/tests/testserver_tests/test_rest_query.py b/sdk/core/azure-core/tests/testserver_tests/test_rest_query.py new file mode 100644 index 000000000000..7933e998f1e6 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/test_rest_query.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +# NOTE: These tests are heavily inspired from the httpx test suite: https://github.com/encode/httpx/tree/master/tests +# Thank you httpx for your wonderful tests! + +import pytest +from azure.core.rest import HttpRequest + +def _format_query_into_url(url, params): + request = HttpRequest(method="GET", url=url, params=params) + return request.url + +def test_request_url_with_params(): + url = _format_query_into_url(url="a/b/c?t=y", params={"g": "h"}) + assert url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + +def test_request_url_with_params_as_list(): + url = _format_query_into_url(url="a/b/c?t=y", params={"g": ["h","i"]}) + assert url in ["a/b/c?g=h&g=i&t=y", "a/b/c?t=y&g=h&g=i"] + +def test_request_url_with_params_with_none_in_list(): + with pytest.raises(ValueError): + _format_query_into_url(url="a/b/c?t=y", params={"g": ["h",None]}) + +def test_request_url_with_params_with_none(): + with pytest.raises(ValueError): + _format_query_into_url(url="a/b/c?t=y", params={"g": None}) \ No newline at end of file diff --git a/sdk/core/azure-core/tests/testserver_tests/test_rest_stream_responses.py b/sdk/core/azure-core/tests/testserver_tests/test_rest_stream_responses.py new file mode 100644 index 000000000000..61053ca7abb9 --- /dev/null +++ b/sdk/core/azure-core/tests/testserver_tests/test_rest_stream_responses.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import pytest +from azure.core.rest import HttpRequest +from azure.core.exceptions import StreamClosedError, StreamConsumedError, ResponseNotReadError +from azure.core.exceptions import HttpResponseError, ServiceRequestError + +def _assert_stream_state(response, open): + # if open is true, check the stream is open. + # if false, check if everything is closed + checks = [ + response._internal_response._content_consumed, + response.is_closed, + response.is_stream_consumed + ] + if open: + assert not any(checks) + else: + assert all(checks) + +def test_iter_raw(client): + request = HttpRequest("GET", "/streams/basic") + with client.send_request(request, stream=True) as response: + raw = b"" + for part in response.iter_raw(): + assert not response._internal_response._content_consumed + assert not response.is_closed + assert response.is_stream_consumed # we follow httpx behavior here + raw += part + assert raw == b"Hello, world!" + assert response._internal_response._content_consumed + assert response.is_closed + assert response.is_stream_consumed + +def test_iter_raw_on_iterable(client): + request = HttpRequest("GET", "/streams/iterable") + + with client.send_request(request, stream=True) as response: + raw = b"" + for part in response.iter_raw(): + raw += part + assert raw == b"Hello, world!" + +def test_iter_with_error(client): + request = HttpRequest("GET", "/errors/403") + + with client.send_request(request, stream=True) as response: + with pytest.raises(HttpResponseError): + response.raise_for_status() + assert response.is_closed + + with pytest.raises(HttpResponseError): + with client.send_request(request, stream=True) as response: + response.raise_for_status() + assert response.is_closed + + request = HttpRequest("GET", "http://doesNotExist") + with pytest.raises(ServiceRequestError): + with client.send_request(request, stream=True) as response: + raise ValueError("Should error before entering") + assert response.is_closed + +def test_iter_bytes(client): + request = HttpRequest("GET", "/streams/basic") + + with client.send_request(request, stream=True) as response: + raw = b"" + for chunk in response.iter_bytes(): + assert not response._internal_response._content_consumed + assert not response.is_closed + assert response.is_stream_consumed # we follow httpx behavior here + raw += chunk + assert response._internal_response._content_consumed + assert response.is_closed + assert response.is_stream_consumed + assert raw == b"Hello, world!" + +def test_iter_text(client): + request = HttpRequest("GET", "/basic/string") + + with client.send_request(request, stream=True) as response: + content = "" + for part in response.iter_text(): + content += part + assert content == "Hello, world!" + +def test_iter_lines(client): + request = HttpRequest("GET", "/basic/lines") + + with client.send_request(request, stream=True) as response: + content = [] + for line in response.iter_lines(): + content.append(line) + assert content == ["Hello,\n", "world!"] + +def test_sync_streaming_response(client): + request = HttpRequest("GET", "/streams/basic") + + with client.send_request(request, stream=True) as response: + assert response.status_code == 200 + assert not response.is_closed + + content = response.read() + + assert content == b"Hello, world!" + assert response.content == b"Hello, world!" + assert response.is_closed + +def test_cannot_read_after_stream_consumed(client, port): + request = HttpRequest("GET", "/streams/basic") + + with client.send_request(request, stream=True) as response: + content = b"" + for part in response.iter_bytes(): + content += part + + assert content == b"Hello, world!" + + with pytest.raises(StreamConsumedError) as ex: + response.read() + + assert "".format(port) in str(ex.value) + assert "You have likely already consumed this stream, so it can not be accessed anymore" in str(ex.value) + +def test_cannot_read_after_response_closed(port, client): + request = HttpRequest("GET", "/streams/basic") + + with client.send_request(request, stream=True) as response: + response.close() + with pytest.raises(StreamClosedError) as ex: + response.read() + # breaking up assert into multiple lines + assert "".format(port) in str(ex.value) + assert "can no longer be read or streamed, since the response has already been closed" in str(ex.value) + +def test_decompress_plain_no_header(client): + # thanks to Xiang Yan for this test! + account_name = "coretests" + url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) + request = HttpRequest("GET", url) + response = client.send_request(request, stream=True) + with pytest.raises(ResponseNotReadError): + response.content + response.read() + assert response.content == b"test" + +def test_compress_plain_no_header(client): + # thanks to Xiang Yan for this test! + account_name = "coretests" + url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) + request = HttpRequest("GET", url) + response = client.send_request(request, stream=True) + iter = response.iter_raw() + data = b"".join(list(iter)) + assert data == b"test" + +def test_decompress_compressed_no_header(client): + # thanks to Xiang Yan for this test! + account_name = "coretests" + url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) + request = HttpRequest("GET", url) + response = client.send_request(request, stream=True) + iter = response.iter_bytes() + data = b"".join(list(iter)) + assert data == b'\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\n+I-.\x01\x00\x0c~\x7f\xd8\x04\x00\x00\x00' + +def test_decompress_compressed_header(client): + # thanks to Xiang Yan for this test! + account_name = "coretests" + account_url = "https://{}.blob.core.windows.net".format(account_name) + url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) + request = HttpRequest("GET", url) + response = client.send_request(request, stream=True) + iter = response.iter_text() + data = "".join(list(iter)) + assert data == "test" + +def test_iter_read(client): + # thanks to McCoy Patiรฑo for this test! + request = HttpRequest("GET", "/basic/lines") + response = client.send_request(request, stream=True) + response.read() + iterator = response.iter_lines() + for line in iterator: + assert line + assert response.text + +def test_iter_read_back_and_forth(client): + # thanks to McCoy Patiรฑo for this test! + + # while this test may look like it's exposing buggy behavior, this is httpx's behavior + # the reason why the code flow is like this, is because the 'iter_x' functions don't + # actually read the contents into the response, the output them. Once they're yielded, + # the stream is closed, so you have to catch the output when you iterate through it + request = HttpRequest("GET", "/basic/lines") + response = client.send_request(request, stream=True) + iterator = response.iter_lines() + for line in iterator: + assert line + with pytest.raises(ResponseNotReadError): + response.text + with pytest.raises(StreamConsumedError): + response.read() + with pytest.raises(ResponseNotReadError): + response.text + +def test_stream_with_return_pipeline_response(client): + request = HttpRequest("GET", "/basic/lines") + pipeline_response = client.send_request(request, stream=True, _return_pipeline_response=True) + assert hasattr(pipeline_response, "http_request") + assert hasattr(pipeline_response, "http_response") + assert hasattr(pipeline_response, "context") + assert list(pipeline_response.http_response.iter_lines()) == ['Hello,\n', 'world!'] + +def test_error_reading(client): + request = HttpRequest("GET", "/errors/403") + with client.send_request(request, stream=True) as response: + response.read() + assert response.content == b"" + + response = client.send_request(request, stream=True) + with pytest.raises(HttpResponseError): + response.raise_for_status() + response.read() + assert response.content == b"" + # try giving a really slow response, see what happens