Skip to content

Commit

Permalink
feat: allow 'Collection.where(__name__, in, [hello, world])' (#501)
Browse files Browse the repository at this point in the history
Closes #421.

Supersedes #496.
  • Loading branch information
tseaver authored Dec 23, 2021
1 parent c4878c3 commit 7d71244
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
19 changes: 16 additions & 3 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,29 @@ def where(self, field_path: str, op_string: str, value) -> BaseQuery:
field_path (str): A field path (``.``-delimited list of
field names) for the field to filter on.
op_string (str): A comparison operation in the form of a string.
Acceptable values are ``<``, ``<=``, ``==``, ``>=``
and ``>``.
Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``,
and ``in``.
value (Any): The value to compare the field against in the filter.
If ``value`` is :data:`None` or a NaN, then ``==`` is the only
allowed operation.
allowed operation. If ``op_string`` is ``in``, ``value``
must be a sequence of values.
Returns:
:class:`~google.cloud.firestore_v1.query.Query`:
A filtered query.
"""
if field_path == "__name__" and op_string == "in":
wrapped_names = []

for name in value:

if isinstance(name, str):
name = self.document(name)

wrapped_names.append(name)

value = wrapped_names

query = self._query()
return query.where(field_path, op_string, value)

Expand Down
39 changes: 39 additions & 0 deletions tests/unit/v1/test_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,45 @@ def test_basecollectionreference_where(mock_query):
assert query == mock_query.where.return_value


@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True)
def test_basecollectionreference_where_w___name___w_value_as_list_of_str(mock_query):
from google.cloud.firestore_v1.base_collection import BaseCollectionReference

with mock.patch.object(BaseCollectionReference, "_query") as _query:
_query.return_value = mock_query

client = _make_client()
collection = _make_base_collection_reference("collection", client=client)
field_path = "__name__"
op_string = "in"
names = ["hello", "world"]

query = collection.where(field_path, op_string, names)

expected_refs = [collection.document(name) for name in names]
mock_query.where.assert_called_once_with(field_path, op_string, expected_refs)
assert query == mock_query.where.return_value


@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True)
def test_basecollectionreference_where_w___name___w_value_as_list_of_docref(mock_query):
from google.cloud.firestore_v1.base_collection import BaseCollectionReference

with mock.patch.object(BaseCollectionReference, "_query") as _query:
_query.return_value = mock_query

client = _make_client()
collection = _make_base_collection_reference("collection", client=client)
field_path = "__name__"
op_string = "in"
refs = [collection.document("hello"), collection.document("world")]

query = collection.where(field_path, op_string, refs)

mock_query.where.assert_called_once_with(field_path, op_string, refs)
assert query == mock_query.where.return_value


@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True)
def test_basecollectionreference_order_by(mock_query):
from google.cloud.firestore_v1.base_query import BaseQuery
Expand Down

0 comments on commit 7d71244

Please sign in to comment.