Skip to content

Commit

Permalink
feat: add support for recursive queries (#407)
Browse files Browse the repository at this point in the history
* refactor: added BaseQuery.copy method

* 🦉 Updates from OwlBot

See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md

* responded to code review

* feat: added recursive query

* tidied up

* 🦉 Updates from OwlBot

See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md

* more tidying up

* fixed error with path compilation

* fixed async handling in system tests

* 🦉 Updates from OwlBot

See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md

* Update google/cloud/firestore_v1/base_collection.py

Co-authored-by: Christopher Wilcox <crwilcox@google.com>

* reverted error message changes

* 🦉 Updates from OwlBot

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* comment updates

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: Christopher Wilcox <crwilcox@google.com>
  • Loading branch information
3 people authored Aug 11, 2021
1 parent 0176cc7 commit eb45a36
Show file tree
Hide file tree
Showing 10 changed files with 367 additions and 4 deletions.
1 change: 1 addition & 0 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def verify_path(path, is_collection) -> None:
if is_collection:
if num_elements % 2 == 0:
raise ValueError("A collection must have an odd number of path elements")

else:
if num_elements % 2 == 1:
raise ValueError("A document must have an even number of path elements")
Expand Down
18 changes: 17 additions & 1 deletion google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore

from google.cloud import firestore_v1
from google.cloud.firestore_v1.base_query import (
BaseCollectionGroup,
BaseQuery,
Expand All @@ -32,7 +33,7 @@
)

from google.cloud.firestore_v1 import async_document
from typing import AsyncGenerator
from typing import AsyncGenerator, Type

# Types needed only for Type Hints
from google.cloud.firestore_v1.transaction import Transaction
Expand Down Expand Up @@ -92,6 +93,9 @@ class AsyncQuery(BaseQuery):
When false, selects only collections that are immediate children
of the `parent` specified in the containing `RunQueryRequest`.
When true, selects all descendant collections.
recursive (Optional[bool]):
When true, returns all documents and all documents in any subcollections
below them. Defaults to false.
"""

def __init__(
Expand All @@ -106,6 +110,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=False,
recursive=False,
) -> None:
super(AsyncQuery, self).__init__(
parent=parent,
Expand All @@ -118,6 +123,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

async def get(
Expand Down Expand Up @@ -224,6 +230,14 @@ async def stream(
if snapshot is not None:
yield snapshot

@staticmethod
def _get_collection_reference_class() -> Type[
"firestore_v1.async_collection.AsyncCollectionReference"
]:
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference

return AsyncCollectionReference


class AsyncCollectionGroup(AsyncQuery, BaseCollectionGroup):
"""Represents a Collection Group in the Firestore API.
Expand All @@ -249,6 +263,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=True,
recursive=False,
) -> None:
super(AsyncCollectionGroup, self).__init__(
parent=parent,
Expand All @@ -261,6 +276,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

@staticmethod
Expand Down
8 changes: 7 additions & 1 deletion google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def document(self, document_id: str = None) -> DocumentReference:
if document_id is None:
document_id = _auto_id()

child_path = self._path + (document_id,)
# Append `self._path` and the passed document's ID as long as the first
# element in the path is not an empty string, which comes from setting the
# parent to "" for recursive queries.
child_path = self._path + (document_id,) if self._path[0] else (document_id,)
return self._client.document(*child_path)

def _parent_info(self) -> Tuple[Any, str]:
Expand Down Expand Up @@ -200,6 +203,9 @@ def list_documents(
]:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
return self._query().recursive()

def select(self, field_paths: Iterable[str]) -> BaseQuery:
"""Create a "select" query with this collection as parent.
Expand Down
65 changes: 64 additions & 1 deletion google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,17 @@
from google.cloud.firestore_v1.types import Cursor
from google.cloud.firestore_v1.types import RunQueryResponse
from google.cloud.firestore_v1.order import Order
from typing import Any, Dict, Generator, Iterable, NoReturn, Optional, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
NoReturn,
Optional,
Tuple,
Type,
Union,
)

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
Expand Down Expand Up @@ -144,6 +154,9 @@ class BaseQuery(object):
When false, selects only collections that are immediate children
of the `parent` specified in the containing `RunQueryRequest`.
When true, selects all descendant collections.
recursive (Optional[bool]):
When true, returns all documents and all documents in any subcollections
below them. Defaults to false.
"""

ASCENDING = "ASCENDING"
Expand All @@ -163,6 +176,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=False,
recursive=False,
) -> None:
self._parent = parent
self._projection = projection
Expand All @@ -174,6 +188,7 @@ def __init__(
self._start_at = start_at
self._end_at = end_at
self._all_descendants = all_descendants
self._recursive = recursive

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand Down Expand Up @@ -247,6 +262,7 @@ def _copy(
start_at: Optional[Tuple[dict, bool]] = _not_passed,
end_at: Optional[Tuple[dict, bool]] = _not_passed,
all_descendants: Optional[bool] = _not_passed,
recursive: Optional[bool] = _not_passed,
) -> "BaseQuery":
return self.__class__(
self._parent,
Expand All @@ -261,6 +277,7 @@ def _copy(
all_descendants=self._evaluate_param(
all_descendants, self._all_descendants
),
recursive=self._evaluate_param(recursive, self._recursive),
)

def _evaluate_param(self, value, fallback_value):
Expand Down Expand Up @@ -813,6 +830,46 @@ def stream(
def on_snapshot(self, callback) -> NoReturn:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
"""Returns a copy of this query whose iterator will yield all matching
documents as well as each of their descendent subcollections and documents.
This differs from the `all_descendents` flag, which only returns descendents
whose subcollection names match the parent collection's name. To return
all descendents, regardless of their subcollection name, use this.
"""
copied = self._copy(recursive=True, all_descendants=True)
if copied._parent and copied._parent.id:
original_collection_id = "/".join(copied._parent._path)

# Reset the parent to nothing so we can recurse through the entire
# database. This is required to have
# `CollectionSelector.collection_id` not override
# `CollectionSelector.all_descendants`, which happens if both are
# set.
copied._parent = copied._get_collection_reference_class()("")
copied._parent._client = self._parent._client

# But wait! We don't want to load the entire database; only the
# collection the user originally specified. To accomplish that, we
# add the following arcane filters.

REFERENCE_NAME_MIN_ID = "__id-9223372036854775808__"
start_at = f"{original_collection_id}/{REFERENCE_NAME_MIN_ID}"

# The backend interprets this null character is flipping the filter
# to mean the end of the range instead of the beginning.
nullChar = "\0"
end_at = f"{original_collection_id}{nullChar}/{REFERENCE_NAME_MIN_ID}"

copied = (
copied.order_by(field_path_module.FieldPath.document_id())
.start_at({field_path_module.FieldPath.document_id(): start_at})
.end_at({field_path_module.FieldPath.document_id(): end_at})
)

return copied

def _comparator(self, doc1, doc2) -> int:
_orders = self._orders

Expand Down Expand Up @@ -1073,6 +1130,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=True,
recursive=False,
) -> None:
if not all_descendants:
raise ValueError("all_descendants must be True for collection group query.")
Expand All @@ -1088,6 +1146,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

def _validate_partition_query(self):
Expand Down Expand Up @@ -1133,6 +1192,10 @@ def get_partitions(
) -> NoReturn:
raise NotImplementedError

@staticmethod
def _get_collection_reference_class() -> Type["BaseCollectionGroup"]:
raise NotImplementedError


class QueryPartition:
"""Represents a bounded partition of a collection group query.
Expand Down
15 changes: 14 additions & 1 deletion google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
a more common way to create a query than direct usage of the constructor.
"""

from google.cloud import firestore_v1
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
Expand All @@ -34,7 +35,7 @@

from google.cloud.firestore_v1 import document
from google.cloud.firestore_v1.watch import Watch
from typing import Any, Callable, Generator, List
from typing import Any, Callable, Generator, List, Type


class Query(BaseQuery):
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=False,
recursive=False,
) -> None:
super(Query, self).__init__(
parent=parent,
Expand All @@ -117,6 +119,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

def get(
Expand Down Expand Up @@ -254,6 +257,14 @@ def on_snapshot(docs, changes, read_time):
self, callback, document.DocumentSnapshot, document.DocumentReference
)

@staticmethod
def _get_collection_reference_class() -> Type[
"firestore_v1.collection.CollectionReference"
]:
from google.cloud.firestore_v1.collection import CollectionReference

return CollectionReference


class CollectionGroup(Query, BaseCollectionGroup):
"""Represents a Collection Group in the Firestore API.
Expand All @@ -279,6 +290,7 @@ def __init__(
start_at=None,
end_at=None,
all_descendants=True,
recursive=False,
) -> None:
super(CollectionGroup, self).__init__(
parent=parent,
Expand All @@ -291,6 +303,7 @@ def __init__(
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)

@staticmethod
Expand Down
Loading

0 comments on commit eb45a36

Please sign in to comment.