Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions firestore/google/cloud/firestore_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.cloud.client import ClientWithProject

from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1 import query
from google.cloud.firestore_v1 import types
from google.cloud.firestore_v1.batch import WriteBatch
from google.cloud.firestore_v1.collection import CollectionReference
Expand Down Expand Up @@ -179,6 +180,31 @@ def collection(self, *collection_path):

return CollectionReference(*path, client=self)

def collection_group(self, collection_id):
"""
Creates and returns a new Query that includes all documents in the
database that are contained in a collection or subcollection with the
given collection_id.

.. code-block:: python

>>> query = firestore.collection_group('mygroup')

@param {string} collectionId Identifies the collections to query over.
Every collection or subcollection with this ID as the last segment of its
path will be included. Cannot contain a slash.
@returns {Query} The created Query.
"""
if "/" in collection_id:
raise ValueError(
"Invalid collection_id "
+ collection_id
+ ". Collection IDs must not contain '/'."
)

collection = self.collection(collection_id)
return query.Query(collection, all_descendants=True)

def document(self, *document_path):
"""Get a reference to a document in a collection.

Expand Down Expand Up @@ -215,6 +241,13 @@ def document(self, *document_path):
else:
path = document_path

# DocumentReference takes a relative path. Strip the database string if present.
base_path = self._database_string + "/documents/"
joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path)
if joined_path.startswith(base_path):
joined_path = joined_path[len(base_path) :]
path = joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER)

return DocumentReference(*path, client=self)

@staticmethod
Expand Down
55 changes: 51 additions & 4 deletions firestore/google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class Query(object):
any matching documents will be included in the result set.
When the query is formed, the document values
will be used in the order given by ``orders``.
all_descendants (Optional[bool]): When false, selects only collections
that are immediate children of the `parent` specified in the
containing `RunQueryRequest`. When true, selects all descendant
collections.
"""

ASCENDING = "ASCENDING"
Expand All @@ -128,6 +132,7 @@ def __init__(
offset=None,
start_at=None,
end_at=None,
all_descendants=False,
):
self._parent = parent
self._projection = projection
Expand All @@ -137,6 +142,7 @@ def __init__(
self._offset = offset
self._start_at = start_at
self._end_at = end_at
self._all_descendants = all_descendants

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand All @@ -150,6 +156,7 @@ def __eq__(self, other):
and self._offset == other._offset
and self._start_at == other._start_at
and self._end_at == other._end_at
and self._all_descendants == other._all_descendants
)

@property
Expand Down Expand Up @@ -203,6 +210,7 @@ def select(self, field_paths):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def where(self, field_path, op_string, value):
Expand Down Expand Up @@ -270,6 +278,7 @@ def where(self, field_path, op_string, value):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

@staticmethod
Expand Down Expand Up @@ -321,6 +330,7 @@ def order_by(self, field_path, direction=ASCENDING):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def limit(self, count):
Expand All @@ -346,6 +356,7 @@ def limit(self, count):
offset=self._offset,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def offset(self, num_to_skip):
Expand All @@ -372,6 +383,7 @@ def offset(self, num_to_skip):
offset=num_to_skip,
start_at=self._start_at,
end_at=self._end_at,
all_descendants=self._all_descendants,
)

def _cursor_helper(self, document_fields, before, start):
Expand Down Expand Up @@ -418,6 +430,7 @@ def _cursor_helper(self, document_fields, before, start):
"orders": self._orders,
"limit": self._limit,
"offset": self._offset,
"all_descendants": self._all_descendants,
}
if start:
query_kwargs["start_at"] = cursor_pair
Expand Down Expand Up @@ -679,7 +692,7 @@ def _to_protobuf(self):
"select": projection,
"from": [
query_pb2.StructuredQuery.CollectionSelector(
collection_id=self._parent.id
collection_id=self._parent.id, all_descendants=self._all_descendants
)
],
"where": self._filters_pb(),
Expand Down Expand Up @@ -739,9 +752,14 @@ def stream(self, transaction=None):
)

for response in response_iterator:
snapshot = _query_response_to_snapshot(
response, self._parent, expected_prefix
)
if self._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
response, self._parent
)
else:
snapshot = _query_response_to_snapshot(
response, self._parent, expected_prefix
)
if snapshot is not None:
yield snapshot

Expand Down Expand Up @@ -968,3 +986,32 @@ def _query_response_to_snapshot(response_pb, collection, expected_prefix):
update_time=response_pb.document.update_time,
)
return snapshot


def _collection_group_query_response_to_snapshot(response_pb, collection):
"""Parse a query response protobuf to a document snapshot.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little surprised that this is a different method from _query_response_to_snapshot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was related to expectations built into this method. @tseaver made this to address that. It is possible we can combine them by refactoring that code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@schmidt-sebastian The implementation of _query_response_to_snapshot mandates (via _helpers.get_doc_id) that the result document is a direct child of the query's parent collection. I'm unsure whether this restriction is "important", or just an artifact of prior development iterations. If it is an artifact, then we could just use the new implementation in _collection_group_query_response_to_snapshot for all queries.


Args:
response_pb (google.cloud.proto.firestore.v1.\
firestore_pb2.RunQueryResponse): A
collection (~.firestore_v1.collection.CollectionReference): A
reference to the collection that initiated the query.

Returns:
Optional[~.firestore.document.DocumentSnapshot]: A
snapshot of the data returned in the query. If ``response_pb.document``
is not set, the snapshot will be :data:`None`.
"""
if not response_pb.HasField("document"):
return None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Python not return empty DocumentSnapshots for non-existing documents? In either case, this is a query and this condition should never trigger (fingers crossed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@schmidt-sebastian DocumentReference.get does return an empty snapshot for a non-existent document. For the query case, the back-end can return a response without a document when reporting partial results: see #6924 and PR #7206 for the history.

Copy link

@schmidt-sebastian schmidt-sebastian Apr 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reference = collection._client.document(response_pb.document.name)
data = _helpers.decode_dict(response_pb.document.fields, collection._client)
snapshot = document.DocumentSnapshot(
reference,
data,
exists=True,
read_time=response_pb.read_time,
create_time=response_pb.document.create_time,
update_time=response_pb.document.update_time,
)
return snapshot
114 changes: 114 additions & 0 deletions firestore/tests/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,120 @@ def test_query_unary(client, cleanup):
assert math.isnan(data1[field_name])


def test_collection_group_queries(client, cleanup):
collection_group = "b" + unique_resource_id("-")

doc_paths = [
"abc/123/" + collection_group + "/cg-doc1",
"abc/123/" + collection_group + "/cg-doc2",
collection_group + "/cg-doc3",
collection_group + "/cg-doc4",
"def/456/" + collection_group + "/cg-doc5",
collection_group + "/virtual-doc/nested-coll/not-cg-doc",
"x" + collection_group + "/not-cg-doc",
collection_group + "x/not-cg-doc",
"abc/123/" + collection_group + "x/not-cg-doc",
"abc/123/x" + collection_group + "/not-cg-doc",
"abc/" + collection_group,
]

batch = client.batch()
for doc_path in doc_paths:
doc_ref = client.document(doc_path)
batch.set(doc_ref, {"x": 1})

batch.commit()

query = client.collection_group(collection_group)
snapshots = list(query.stream())
found = [snapshot.id for snapshot in snapshots]
expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"]
assert found == expected


def test_collection_group_queries_startat_endat(client, cleanup):
collection_group = "b" + unique_resource_id("-")

doc_paths = [
"a/a/" + collection_group + "/cg-doc1",
"a/b/a/b/" + collection_group + "/cg-doc2",
"a/b/" + collection_group + "/cg-doc3",
"a/b/c/d/" + collection_group + "/cg-doc4",
"a/c/" + collection_group + "/cg-doc5",
collection_group + "/cg-doc6",
"a/b/nope/nope",
]

batch = client.batch()
for doc_path in doc_paths:
doc_ref = client.document(doc_path)
batch.set(doc_ref, {"x": doc_path})

batch.commit()

query = (
client.collection_group(collection_group)
.order_by("__name__")
.start_at([client.document("a/b")])
.end_at([client.document("a/b0")])
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"])

query = (
client.collection_group(collection_group)
.order_by("__name__")
.start_after([client.document("a/b")])
.end_before([client.document("a/b/" + collection_group + "/cg-doc3")])
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2"])


def test_collection_group_queries_filters(client, cleanup):
collection_group = "b" + unique_resource_id("-")

doc_paths = [
"a/a/" + collection_group + "/cg-doc1",
"a/b/a/b/" + collection_group + "/cg-doc2",
"a/b/" + collection_group + "/cg-doc3",
"a/b/c/d/" + collection_group + "/cg-doc4",
"a/c/" + collection_group + "/cg-doc5",
collection_group + "/cg-doc6",
"a/b/nope/nope",
]

batch = client.batch()

for index, doc_path in enumerate(doc_paths):
doc_ref = client.document(doc_path)
batch.set(doc_ref, {"x": index})

batch.commit()

query = (
client.collection_group(collection_group)
.where("__name__", ">=", client.document("a/b"))
.where("__name__", "<=", client.document("a/b0"))
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"])

query = (
client.collection_group(collection_group)
.where("__name__", ">", client.document("a/b"))
.where(
"__name__", "<", client.document("a/b/{}/cg-doc3".format(collection_group))
)
)
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2"])


def test_get_all(client, cleanup):
collection_name = "get-all" + unique_resource_id("-")

Expand Down
30 changes: 29 additions & 1 deletion firestore/tests/unit/v1/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ def test_collection_factory_nested(self):
self.assertIs(collection2._client, client)
self.assertIsInstance(collection2, CollectionReference)

def test_collection_group(self):
client = self._make_default_one()
query = client.collection_group("collectionId").where("foo", "==", u"bar")

assert query._all_descendants
assert query._field_filters[0].field.field_path == "foo"
assert query._field_filters[0].value.string_value == u"bar"
assert query._field_filters[0].op == query._field_filters[0].EQUAL
assert query._parent.id == "collectionId"

def test_collection_group_no_slashes(self):
client = self._make_default_one()
with self.assertRaises(ValueError):
client.collection_group("foo/bar")

def test_document_factory(self):
from google.cloud.firestore_v1.document import DocumentReference

Expand All @@ -148,7 +163,20 @@ def test_document_factory(self):
self.assertIs(document2._client, client)
self.assertIsInstance(document2, DocumentReference)

def test_document_factory_nested(self):
def test_document_factory_w_absolute_path(self):
from google.cloud.firestore_v1.document import DocumentReference

parts = ("rooms", "roomA")
client = self._make_default_one()
doc_path = "/".join(parts)
to_match = client.document(doc_path)
document1 = client.document(to_match._document_path)

self.assertEqual(document1._path, parts)
self.assertIs(document1._client, client)
self.assertIsInstance(document1, DocumentReference)

def test_document_factory_w_nested_path(self):
from google.cloud.firestore_v1.document import DocumentReference

client = self._make_default_one()
Expand Down
Loading