Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to wrappers.py #1835

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions elasticsearch_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,25 @@

import collections.abc
from copy import copy
from typing import Any, ClassVar, Dict, List, Optional, Type, Union
from typing import Any, ClassVar, Dict, Generic, List, Optional, Type, TypeVar, Union

from typing_extensions import Self
from typing_extensions import Self, TypeAlias

from .exceptions import UnknownDslObject, ValidationException

JSONType = Union[int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]]
# Usefull types

JSONType: TypeAlias = Union[
int, bool, str, float, List["JSONType"], Dict[str, "JSONType"]
]


# Type variables for internals

_KeyT = TypeVar("_KeyT")
_ValT = TypeVar("_ValT")

# Constants

SKIP_VALUES = ("", None)
EXPAND__TO_DOT = True
Expand Down Expand Up @@ -110,18 +122,20 @@ def to_list(self):
return self._l_


class AttrDict:
class AttrDict(Generic[_KeyT, _ValT]):
"""
Helper class to provide attribute like access (read and write) to
dictionaries. Used to provide a convenient way to access both results and
nested dsl dicts.
"""

def __init__(self, d):
_d_: Dict[_KeyT, _ValT]

def __init__(self, d: Dict[_KeyT, _ValT]):
# assign the inner dict manually to prevent __setattr__ from firing
super().__setattr__("_d_", d)

def __contains__(self, key):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this type be _KeyT instead of object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept is as object to be consistent with the dict.__contains__() typing, conceptually I think is also not a type error to ask if an object of a wrong type is in the dictionary, is just gonna be false.

def __contains__(self, key: object) -> bool:
return key in self._d_

def __nonzero__(self):
Expand Down
84 changes: 73 additions & 11 deletions elasticsearch_dsl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,78 @@
# under the License.

import operator
from typing import (
Any,
Callable,
ClassVar,
Dict,
Literal,
Mapping,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
cast,
)

from typing_extensions import TypeAlias

from .utils import AttrDict

__all__ = ["Range"]

class SupportsDunderLT(Protocol):
def __lt__(self, other: Any, /) -> Any: ...


class SupportsDunderGT(Protocol):
def __gt__(self, other: Any, /) -> Any: ...


class SupportsDunderLE(Protocol):
def __le__(self, other: Any, /) -> Any: ...


class SupportsDunderGE(Protocol):
def __ge__(self, other: Any, /) -> Any: ...


class Range(AttrDict):
OPS = {
SupportsComparison: TypeAlias = Union[
SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT
]

ComparisonOperators: TypeAlias = Literal["lt", "lte", "gt", "gte"]
RangeValT = TypeVar("RangeValT", bound=SupportsComparison)

__all__ = ["Range", "SupportsComparison"]


class Range(AttrDict[ComparisonOperators, RangeValT]):
OPS: ClassVar[
Mapping[
ComparisonOperators,
Callable[[SupportsComparison, SupportsComparison], bool],
]
] = {
"lt": operator.lt,
"lte": operator.le,
"gt": operator.gt,
"gte": operator.ge,
}

def __init__(self, *args, **kwargs):
if args and (len(args) > 1 or kwargs or not isinstance(args[0], dict)):
def __init__(
self,
d: Optional[Dict[ComparisonOperators, RangeValT]] = None,
/,
**kwargs: RangeValT,
):
if d is not None and (kwargs or not isinstance(d, dict)):
raise ValueError(
"Range accepts a single dictionary or a set of keyword arguments."
)
data = args[0] if args else kwargs

# Cast here since mypy is inferring d as an `object` type for some reason
data = cast(Dict[str, RangeValT], d) if d is not None else kwargs

for k in data:
if k not in self.OPS:
Expand All @@ -47,30 +99,40 @@ def __init__(self, *args, **kwargs):
if "lt" in data and "lte" in data:
raise ValueError("You cannot specify both lt and lte for Range.")

super().__init__(args[0] if args else kwargs)
# Here we use cast() since we now the keys are in the allowed values, but mypy does
# not infer it.
super().__init__(cast(Dict[ComparisonOperators, RangeValT], data))

def __repr__(self):
def __repr__(self) -> str:
return "Range(%s)" % ", ".join("%s=%r" % op for op in self._d_.items())

def __contains__(self, item):
def __contains__(self, item: object) -> bool:
if isinstance(item, str):
return super().__contains__(item)

item_supports_comp = any(hasattr(item, f"__{op}__") for op in self.OPS)
if not item_supports_comp:
return False

# Cast to tell mypy whe have checked it and its ok to use the comparison methods
# on `item`
item = cast(SupportsComparison, item)

for op in self.OPS:
if op in self._d_ and not self.OPS[op](item, self._d_[op]):
return False
return True

@property
def upper(self):
def upper(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
if "lt" in self._d_:
return self._d_["lt"], False
if "lte" in self._d_:
return self._d_["lte"], True
return None, False

@property
def lower(self):
def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]:
if "gt" in self._d_:
return self._d_["gt"], False
if "gte" in self._d_:
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
TYPED_FILES = (
"elasticsearch_dsl/function.py",
"elasticsearch_dsl/query.py",
"elasticsearch_dsl/wrappers.py",
"tests/test_query.py",
"tests/test_wrappers.py",
)


Expand Down
26 changes: 21 additions & 5 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.

from datetime import datetime, timedelta
from typing import Any, Mapping, Optional, Sequence

import pytest

from elasticsearch_dsl import Range
from elasticsearch_dsl.wrappers import SupportsComparison


@pytest.mark.parametrize(
Expand All @@ -34,7 +36,9 @@
({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()),
],
)
def test_range_contains(kwargs, item):
def test_range_contains(
kwargs: Mapping[str, SupportsComparison], item: SupportsComparison
) -> None:
assert item in Range(**kwargs)


Expand All @@ -48,7 +52,9 @@ def test_range_contains(kwargs, item):
({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()),
],
)
def test_range_not_contains(kwargs, item):
def test_range_not_contains(
kwargs: Mapping[str, SupportsComparison], item: SupportsComparison
) -> None:
assert item not in Range(**kwargs)


Expand All @@ -62,7 +68,9 @@ def test_range_not_contains(kwargs, item):
((), {"gt": 1, "gte": 1}),
],
)
def test_range_raises_value_error_on_wrong_params(args, kwargs):
def test_range_raises_value_error_on_wrong_params(
args: Sequence[Any], kwargs: Mapping[str, SupportsComparison]
) -> None:
with pytest.raises(ValueError):
Range(*args, **kwargs)

Expand All @@ -76,7 +84,11 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs):
(Range(lt=42), None, False),
],
)
def test_range_lower(range, lower, inclusive):
def test_range_lower(
range: Range[SupportsComparison],
lower: Optional[SupportsComparison],
inclusive: bool,
) -> None:
assert (lower, inclusive) == range.lower


Expand All @@ -89,5 +101,9 @@ def test_range_lower(range, lower, inclusive):
(Range(gt=42), None, False),
],
)
def test_range_upper(range, upper, inclusive):
def test_range_upper(
range: Range[SupportsComparison],
upper: Optional[SupportsComparison],
inclusive: bool,
) -> None:
assert (upper, inclusive) == range.upper
Loading