Skip to content

Commit

Permalink
fix aiohttp kyes and values to be mutable
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft committed Aug 19, 2021
1 parent c6ec359 commit 54e92a4
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 12 deletions.
40 changes: 37 additions & 3 deletions sdk/core/azure-core/azure/core/rest/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class _ItemsView(collections.abc.ItemsView):
def __init__(self, ref):
super().__init__(ref)
self._ref = ref
self._items = []

def __iter__(self):
for key, groups in groupby(self._ref.__iter__(), lambda x: x[0]):
Expand All @@ -55,6 +54,41 @@ def __contains__(self, item):
def __repr__(self):
return f"dict_items({list(self.__iter__())})"

class _KeysView(collections.abc.KeysView):
def __init__(self, items):
super().__init__(items)
self._items = items

def __iter__(self):
for key, _ in self._items:
yield key

def __contains__(self, key):
for k in self.__iter__():
if key.lower() == k.lower():
return True
return False
def __repr__(self):
return f"dict_keys({list(self.__iter__())})"

class _ValuesView(collections.abc.ValuesView):
def __init__(self, items):
super().__init__(items)
self._items = items

def __iter__(self):
for _, value in self._items:
yield value

def __contains__(self, value):
for v in self.__iter__():
if value == v:
return True
return False

def __repr__(self):
return f"dict_values({list(self.__iter__())})"


class _CIMultiDict(CIMultiDict):
"""Dictionary with the support for duplicate case-insensitive keys."""
Expand All @@ -64,15 +98,15 @@ def __iter__(self):

def keys(self):
"""Return a new view of the dictionary's keys."""
return dict(self.items()).keys()
return _KeysView(self.items())

def items(self):
"""Return a new view of the dictionary's items."""
return _ItemsView(super().items())

def values(self):
"""Return a new view of the dictionary's values."""
return dict(self.items()).values()
return _ValuesView(self.items())

def __getitem__(self, key: str) -> str:
return ", ".join(self.getall(key, []))
Expand Down
14 changes: 9 additions & 5 deletions sdk/core/azure-core/azure/core/rest/_requests_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,31 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import collections
from typing import TYPE_CHECKING, cast
from requests.structures import CaseInsensitiveDict

from ..exceptions import ResponseNotReadError, StreamConsumedError, StreamClosedError
from ._rest import _HttpResponseBase, HttpResponse
from ..pipeline.transport._requests_basic import StreamDownloadGenerator
from collections import ItemsView
from requests.structures import CaseInsensitiveDict

class _ItemsView(ItemsView):
class _ItemsView(collections.ItemsView):
def __contains__(self, item):
if not (isinstance(item, (list, tuple)) and len(item) == 2):
return False
return False # requests raises here, we just return False
for k, v in self.__iter__():
if item[0].lower() == k.lower() and item[1] == v:
return True
return False

def __repr__(self):
return 'ItemsView({0._mapping!r})'.format(self)
return 'ItemsView({})'.format(dict(self.__iter__()))

class _CaseInsensitiveDict(CaseInsensitiveDict):
"""Overriding default requests dict so we can unify
to not raise if users pass in incorrect items to contains.
Instead, we return False
"""

def items(self):
"""Return a new view of the dictionary's items."""
Expand Down
4 changes: 3 additions & 1 deletion sdk/core/azure-core/azure/core/rest/_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ class HttpResponse(_HttpResponseBase): # pylint: disable=too-many-instance-attr
:keyword request: The request that resulted in this response.
:paramtype request: ~azure.core.rest.HttpRequest
:ivar int status_code: The status code of this response
:ivar mapping headers: The response headers
:ivar mapping headers: The case-insensitive response headers.
While looking up headers is case-insensitive, when looking up
keys in `header.keys()`, we recommend using lowercase.
:ivar str reason: The reason phrase for this response
:ivar bytes content: The response content in bytes.
:ivar str url: The URL that resulted in this response
Expand Down
4 changes: 3 additions & 1 deletion sdk/core/azure-core/azure/core/rest/_rest_py3.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ class HttpResponse(_HttpResponseBase):
:keyword request: The request that resulted in this response.
:paramtype request: ~azure.core.rest.HttpRequest
:ivar int status_code: The status code of this response
:ivar mapping headers: The response headers
:ivar mapping headers: The case-insensitive response headers.
While looking up headers is case-insensitive, when looking up
keys in `header.keys()`, we recommend using lowercase.
:ivar str reason: The reason phrase for this response
:ivar bytes content: The response content in bytes.
:ivar str url: The URL that resulted in this response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ async def test_headers_response_keys(get_response_headers):
assert list(h.keys()) == list(ref_dict.keys())
assert repr(h.keys()) == repr(ref_dict.keys())
assert "a" in h.keys()
assert "A" not in h.keys()
assert "A" in h.keys()
assert "b" in h.keys()
assert "B" not in h.keys()
assert "B" in h.keys()
assert set(h.keys()) == set(ref_dict.keys())

# test mutability
Expand Down

0 comments on commit 54e92a4

Please sign in to comment.