From d4aaca9b0d55eb970855229c0d2c66c50401a0df Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 14:54:31 -0400 Subject: [PATCH 1/8] tests(firestore): use one testcase per valid operator --- firestore/tests/unit/v1/test_query.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/firestore/tests/unit/v1/test_query.py b/firestore/tests/unit/v1/test_query.py index a4911fecb44f..bd30af36b1b8 100644 --- a/firestore/tests/unit/v1/test_query.py +++ b/firestore/tests/unit/v1/test_query.py @@ -1464,18 +1464,37 @@ def _call_fut(op_string): return _enum_from_op_string(op_string) - def test_success(self): + @staticmethod + def _get_op_class(): from google.cloud.firestore_v1.gapic import enums - op_class = enums.StructuredQuery.FieldFilter.Operator + return enums.StructuredQuery.FieldFilter.Operator + + def test_lt(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("<"), op_class.LESS_THAN) + + def test_le(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL) + + def test_eq(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("=="), op_class.EQUAL) + + def test_ge(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL) + + def test_gt(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN) + + def test_array_contains(self): + op_class = self._get_op_class() self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) - def test_failure(self): + def test_invalid(self): with self.assertRaises(ValueError): self._call_fut("?") From b2a6d28c61805b1bee0988f37fd6fc48785c471a Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 14:59:49 -0400 Subject: [PATCH 2/8] feat(firestore): add support for 'IN' operator --- firestore/google/cloud/firestore_v1/query.py | 1 + firestore/tests/unit/v1/test_query.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/firestore/google/cloud/firestore_v1/query.py b/firestore/google/cloud/firestore_v1/query.py index 6f4c498c0725..1b935c892270 100644 --- a/firestore/google/cloud/firestore_v1/query.py +++ b/firestore/google/cloud/firestore_v1/query.py @@ -43,6 +43,7 @@ ">=": _operator_enum.GREATER_THAN_OR_EQUAL, ">": _operator_enum.GREATER_THAN, "array_contains": _operator_enum.ARRAY_CONTAINS, + "in": _operator_enum.IN, } _BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." _BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' diff --git a/firestore/tests/unit/v1/test_query.py b/firestore/tests/unit/v1/test_query.py index bd30af36b1b8..45281030a7d8 100644 --- a/firestore/tests/unit/v1/test_query.py +++ b/firestore/tests/unit/v1/test_query.py @@ -1494,6 +1494,10 @@ def test_array_contains(self): op_class = self._get_op_class() self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) + def test_in(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut("in"), op_class.IN) + def test_invalid(self): with self.assertRaises(ValueError): self._call_fut("?") From 5982527260da65b5ab90fa7185bde63e9e52fb28 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 15:01:06 -0400 Subject: [PATCH 3/8] feat(firestore): add support for 'ARRAY_CONTAINS_ANY' operator --- firestore/google/cloud/firestore_v1/query.py | 1 + firestore/tests/unit/v1/test_query.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/firestore/google/cloud/firestore_v1/query.py b/firestore/google/cloud/firestore_v1/query.py index 1b935c892270..d4e1f7f07324 100644 --- a/firestore/google/cloud/firestore_v1/query.py +++ b/firestore/google/cloud/firestore_v1/query.py @@ -44,6 +44,7 @@ ">": _operator_enum.GREATER_THAN, "array_contains": _operator_enum.ARRAY_CONTAINS, "in": _operator_enum.IN, + "array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY, } _BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}." _BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values' diff --git a/firestore/tests/unit/v1/test_query.py b/firestore/tests/unit/v1/test_query.py index 45281030a7d8..8927dd87945e 100644 --- a/firestore/tests/unit/v1/test_query.py +++ b/firestore/tests/unit/v1/test_query.py @@ -1498,6 +1498,12 @@ def test_in(self): op_class = self._get_op_class() self.assertEqual(self._call_fut("in"), op_class.IN) + def test_array_contains_any(self): + op_class = self._get_op_class() + self.assertEqual(self._call_fut( + "array_contains_any"), op_class.ARRAY_CONTAINS_ANY + ) + def test_invalid(self): with self.assertRaises(ValueError): self._call_fut("?") From 28e7640c2e3c9d35203458c53ff91c03ca9c38ac Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 15:35:33 -0400 Subject: [PATCH 4/8] tests(firestore): split up monolithic 'test_query_stream' systest --- firestore/tests/system/test_system.py | 122 +++++++++++++++----------- 1 file changed, 72 insertions(+), 50 deletions(-) diff --git a/firestore/tests/system/test_system.py b/firestore/tests/system/test_system.py index f2d30c94a171..41637d8830e5 100644 --- a/firestore/tests/system/test_system.py +++ b/firestore/tests/system/test_system.py @@ -492,11 +492,13 @@ def test_collection_add(client, cleanup): assert set(collection3.list_documents()) == {document_ref5} -def test_query_stream(client, cleanup): +@pytest.fixture +def query_docs(client): collection_id = "qs" + UNIQUE_RESOURCE_ID sub_collection = "child" + UNIQUE_RESOURCE_ID collection = client.collection(collection_id, "doc", sub_collection) + cleanup = [] stored = {} num_vals = 5 allowed_vals = six.moves.xrange(num_vals) @@ -509,34 +511,45 @@ def test_query_stream(client, cleanup): } _, doc_ref = collection.add(document_data) # Add to clean-up. - cleanup(doc_ref.delete) + cleanup.append(doc_ref.delete) stored[doc_ref.id] = document_data - # 0. Limit to snapshots where ``a==1``. - query0 = collection.where("a", "==", 1) - values0 = {snapshot.id: snapshot.to_dict() for snapshot in query0.stream()} - assert len(values0) == num_vals - for key, value in six.iteritems(values0): + yield collection, stored, allowed_vals + + for operation in cleanup: + operation() + + +def test_query_stream_w_simple_field(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("a", "==", 1) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): assert stored[key] == value assert value["a"] == 1 - # 1. Order by ``b``. - query1 = collection.order_by("b", direction=query0.DESCENDING) - values1 = [(snapshot.id, snapshot.to_dict()) for snapshot in query1.stream()] - assert len(values1) == len(stored) - b_vals1 = [] - for key, value in values1: + +def test_query_stream_w_order_by(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.order_by("b", direction=firestore.Query.DESCENDING) + values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] + assert len(values) == len(stored) + b_vals = [] + for key, value in values: assert stored[key] == value - b_vals1.append(value["b"]) + b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. - assert sorted(b_vals1, reverse=True) == b_vals1 + assert sorted(b_vals, reverse=True) == b_vals + - # 2. Limit to snapshots where ``stats.sum > 1`` (a field path). - query2 = collection.where("stats.sum", ">", 4) - values2 = {snapshot.id: snapshot.to_dict() for snapshot in query2.stream()} - assert len(values2) == 10 +def test_query_stream_w_field_path(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("stats.sum", ">", 4) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == 10 ab_pairs2 = set() - for key, value in six.iteritems(values2): + for key, value in six.iteritems(values): assert stored[key] == value ab_pairs2.add((value["a"], value["b"])) @@ -550,63 +563,72 @@ def test_query_stream(client, cleanup): ) assert expected_ab_pairs == ab_pairs2 - # 3. Use a start and end cursor. - query3 = ( + +def test_query_stream_w_start_end_cursor(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = ( collection.order_by("a") .start_at({"a": num_vals - 2}) .end_before({"a": num_vals - 1}) ) - values3 = [(snapshot.id, snapshot.to_dict()) for snapshot in query3.stream()] - assert len(values3) == num_vals - for key, value in values3: + values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] + assert len(values) == num_vals + for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 - b_vals1.append(value["b"]) - - # 4. Send a query with no results. - query4 = collection.where("b", "==", num_vals + 100) - values4 = list(query4.stream()) - assert len(values4) == 0 - - # 5. Select a subset of fields. - query5 = collection.where("b", "<=", 1) - query5 = query5.select(["a", "stats.product"]) - values5 = {snapshot.id: snapshot.to_dict() for snapshot in query5.stream()} - assert len(values5) == num_vals * 2 # a ANY, b in (0, 1) - for key, value in six.iteritems(values5): + + +def test_query_stream_wo_results(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("b", "==", num_vals + 100) + values = list(query.stream()) + assert len(values) == 0 + + +def test_query_stream_w_projection(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("b", "<=", 1).select(["a", "stats.product"]) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == num_vals * 2 # a ANY, b in (0, 1) + for key, value in six.iteritems(values): expected = { "a": stored[key]["a"], "stats": {"product": stored[key]["stats"]["product"]}, } assert expected == value - # 6. Add multiple filters via ``where()``. - query6 = collection.where("stats.product", ">", 5) - query6 = query6.where("stats.product", "<", 10) - values6 = {snapshot.id: snapshot.to_dict() for snapshot in query6.stream()} +def test_query_stream_w_multiple_filters(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("stats.product", ">", 5).where("stats.product", "<", 10) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} matching_pairs = [ (a_val, b_val) for a_val in allowed_vals for b_val in allowed_vals if 5 < a_val * b_val < 10 ] - assert len(values6) == len(matching_pairs) - for key, value in six.iteritems(values6): + assert len(values) == len(matching_pairs) + for key, value in six.iteritems(values): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs - # 7. Skip the first three results, when ``b==2`` - query7 = collection.where("b", "==", 2) + +def test_query_stream_w_offset(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) offset = 3 - query7 = query7.offset(offset) - values7 = {snapshot.id: snapshot.to_dict() for snapshot in query7.stream()} + query = collection.where("b", "==", 2).offset(offset) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} # NOTE: We don't check the ``a``-values, since that would require # an ``order_by('a')``, which combined with the ``b == 2`` # filter would necessitate an index. - assert len(values7) == num_vals - offset - for key, value in six.iteritems(values7): + assert len(values) == num_vals - offset + for key, value in six.iteritems(values): assert stored[key] == value assert value["b"] == 2 From d206f67bcfe9dc87b5fd96a99ac5291a7ee76e79 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 15:43:31 -0400 Subject: [PATCH 5/8] tests(firestore): add systest for 'array_contains' operator --- firestore/tests/system/test_system.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/firestore/tests/system/test_system.py b/firestore/tests/system/test_system.py index 41637d8830e5..308bffb93697 100644 --- a/firestore/tests/system/test_system.py +++ b/firestore/tests/system/test_system.py @@ -507,6 +507,7 @@ def query_docs(client): document_data = { "a": a_val, "b": b_val, + "c": [a_val, num_vals * 100], "stats": {"sum": a_val + b_val, "product": a_val * b_val}, } _, doc_ref = collection.add(document_data) @@ -520,7 +521,7 @@ def query_docs(client): operation() -def test_query_stream_w_simple_field(query_docs): +def test_query_stream_w_simple_field_eq_op(query_docs): collection, stored, allowed_vals = query_docs query = collection.where("a", "==", 1) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -530,6 +531,16 @@ def test_query_stream_w_simple_field(query_docs): assert value["a"] == 1 +def test_query_stream_w_simple_field_array_contains_op(query_docs): + collection, stored, allowed_vals = query_docs + query = collection.where("c", "array_contains", 1) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): + assert stored[key] == value + assert value["a"] == 1 + + def test_query_stream_w_order_by(query_docs): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) From cfa3cc6b373850f0486f07e51e969a08537cbc4a Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 15:45:59 -0400 Subject: [PATCH 6/8] tests(firestore): add systest for 'in' operator --- firestore/tests/system/test_system.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/firestore/tests/system/test_system.py b/firestore/tests/system/test_system.py index 308bffb93697..6bf56102670e 100644 --- a/firestore/tests/system/test_system.py +++ b/firestore/tests/system/test_system.py @@ -541,6 +541,17 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs): assert value["a"] == 1 +def test_query_stream_w_simple_field_in_op(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("a", "in", [1, num_vals + 100]) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): + assert stored[key] == value + assert value["a"] == 1 + + def test_query_stream_w_order_by(query_docs): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) From 74bb3d438cca41810eccfe8ca8f233628f5fc21c Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 15:49:33 -0400 Subject: [PATCH 7/8] tests(firestore): add systest for 'array_contains_any' operator --- firestore/tests/system/test_system.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/firestore/tests/system/test_system.py b/firestore/tests/system/test_system.py index 6bf56102670e..71ac07fcee74 100644 --- a/firestore/tests/system/test_system.py +++ b/firestore/tests/system/test_system.py @@ -552,6 +552,17 @@ def test_query_stream_w_simple_field_in_op(query_docs): assert value["a"] == 1 +def test_query_stream_w_simple_field_array_contains_any_op(query_docs): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + query = collection.where("c", "array_contains_any", [1, num_vals * 200]) + values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(values) == len(allowed_vals) + for key, value in six.iteritems(values): + assert stored[key] == value + assert value["a"] == 1 + + def test_query_stream_w_order_by(query_docs): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) From 1fb97885a1955f5921fef862e680991ac9e04f9e Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 25 Oct 2019 16:03:37 -0400 Subject: [PATCH 8/8] chore: lint --- firestore/tests/unit/v1/test_query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firestore/tests/unit/v1/test_query.py b/firestore/tests/unit/v1/test_query.py index 8927dd87945e..bdb0e922d00b 100644 --- a/firestore/tests/unit/v1/test_query.py +++ b/firestore/tests/unit/v1/test_query.py @@ -1500,8 +1500,8 @@ def test_in(self): def test_array_contains_any(self): op_class = self._get_op_class() - self.assertEqual(self._call_fut( - "array_contains_any"), op_class.ARRAY_CONTAINS_ANY + self.assertEqual( + self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY ) def test_invalid(self):