From 54e92a462305ee3def6906868e91c6e80de7dcbc Mon Sep 17 00:00:00 2001 From: iscai-msft Date: Thu, 19 Aug 2021 11:44:54 -0400 Subject: [PATCH] fix aiohttp kyes and values to be mutable --- .../azure-core/azure/core/rest/_aiohttp.py | 40 +++++++++++++++++-- .../azure/core/rest/_requests_basic.py | 14 ++++--- sdk/core/azure-core/azure/core/rest/_rest.py | 4 +- .../azure-core/azure/core/rest/_rest_py3.py | 4 +- .../async_tests/test_rest_headers_async.py | 4 +- 5 files changed, 54 insertions(+), 12 deletions(-) diff --git a/sdk/core/azure-core/azure/core/rest/_aiohttp.py b/sdk/core/azure-core/azure/core/rest/_aiohttp.py index cf61c1ecc511..0ddb80cdd3d7 100644 --- a/sdk/core/azure-core/azure/core/rest/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/rest/_aiohttp.py @@ -38,7 +38,6 @@ class _ItemsView(collections.abc.ItemsView): def __init__(self, ref): super().__init__(ref) self._ref = ref - self._items = [] def __iter__(self): for key, groups in groupby(self._ref.__iter__(), lambda x: x[0]): @@ -55,6 +54,41 @@ def __contains__(self, item): def __repr__(self): return f"dict_items({list(self.__iter__())})" +class _KeysView(collections.abc.KeysView): + def __init__(self, items): + super().__init__(items) + self._items = items + + def __iter__(self): + for key, _ in self._items: + yield key + + def __contains__(self, key): + for k in self.__iter__(): + if key.lower() == k.lower(): + return True + return False + def __repr__(self): + return f"dict_keys({list(self.__iter__())})" + +class _ValuesView(collections.abc.ValuesView): + def __init__(self, items): + super().__init__(items) + self._items = items + + def __iter__(self): + for _, value in self._items: + yield value + + def __contains__(self, value): + for v in self.__iter__(): + if value == v: + return True + return False + + def __repr__(self): + return f"dict_values({list(self.__iter__())})" + class _CIMultiDict(CIMultiDict): """Dictionary with the support for duplicate case-insensitive keys.""" @@ -64,7 +98,7 @@ def __iter__(self): def keys(self): """Return a new view of the dictionary's keys.""" - return dict(self.items()).keys() + return _KeysView(self.items()) def items(self): """Return a new view of the dictionary's items.""" @@ -72,7 +106,7 @@ def items(self): def values(self): """Return a new view of the dictionary's values.""" - return dict(self.items()).values() + return _ValuesView(self.items()) def __getitem__(self, key: str) -> str: return ", ".join(self.getall(key, [])) diff --git a/sdk/core/azure-core/azure/core/rest/_requests_basic.py b/sdk/core/azure-core/azure/core/rest/_requests_basic.py index c9649805d8b8..dc3ee631e506 100644 --- a/sdk/core/azure-core/azure/core/rest/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/rest/_requests_basic.py @@ -23,27 +23,31 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +import collections from typing import TYPE_CHECKING, cast +from requests.structures import CaseInsensitiveDict from ..exceptions import ResponseNotReadError, StreamConsumedError, StreamClosedError from ._rest import _HttpResponseBase, HttpResponse from ..pipeline.transport._requests_basic import StreamDownloadGenerator -from collections import ItemsView -from requests.structures import CaseInsensitiveDict -class _ItemsView(ItemsView): +class _ItemsView(collections.ItemsView): def __contains__(self, item): if not (isinstance(item, (list, tuple)) and len(item) == 2): - return False + return False # requests raises here, we just return False for k, v in self.__iter__(): if item[0].lower() == k.lower() and item[1] == v: return True return False def __repr__(self): - return 'ItemsView({0._mapping!r})'.format(self) + return 'ItemsView({})'.format(dict(self.__iter__())) class _CaseInsensitiveDict(CaseInsensitiveDict): + """Overriding default requests dict so we can unify + to not raise if users pass in incorrect items to contains. + Instead, we return False + """ def items(self): """Return a new view of the dictionary's items.""" diff --git a/sdk/core/azure-core/azure/core/rest/_rest.py b/sdk/core/azure-core/azure/core/rest/_rest.py index 295051a83c50..323314c854b8 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest.py +++ b/sdk/core/azure-core/azure/core/rest/_rest.py @@ -305,7 +305,9 @@ class HttpResponse(_HttpResponseBase): # pylint: disable=too-many-instance-attr :keyword request: The request that resulted in this response. :paramtype request: ~azure.core.rest.HttpRequest :ivar int status_code: The status code of this response - :ivar mapping headers: The response headers + :ivar mapping headers: The case-insensitive response headers. + While looking up headers is case-insensitive, when looking up + keys in `header.keys()`, we recommend using lowercase. :ivar str reason: The reason phrase for this response :ivar bytes content: The response content in bytes. :ivar str url: The URL that resulted in this response diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py index af957f767a39..1c91eb1eeafb 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest_py3.py +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -319,7 +319,9 @@ class HttpResponse(_HttpResponseBase): :keyword request: The request that resulted in this response. :paramtype request: ~azure.core.rest.HttpRequest :ivar int status_code: The status code of this response - :ivar mapping headers: The response headers + :ivar mapping headers: The case-insensitive response headers. + While looking up headers is case-insensitive, when looking up + keys in `header.keys()`, we recommend using lowercase. :ivar str reason: The reason phrase for this response :ivar bytes content: The response content in bytes. :ivar str url: The URL that resulted in this response diff --git a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_headers_async.py b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_headers_async.py index 8d0debbac10c..637a6e4e27d2 100644 --- a/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_headers_async.py +++ b/sdk/core/azure-core/tests/testserver_tests/async_tests/test_rest_headers_async.py @@ -57,9 +57,9 @@ async def test_headers_response_keys(get_response_headers): assert list(h.keys()) == list(ref_dict.keys()) assert repr(h.keys()) == repr(ref_dict.keys()) assert "a" in h.keys() - assert "A" not in h.keys() + assert "A" in h.keys() assert "b" in h.keys() - assert "B" not in h.keys() + assert "B" in h.keys() assert set(h.keys()) == set(ref_dict.keys()) # test mutability