diff --git a/CHANGES.md b/CHANGES.md index a2d47753afe5..1e445ccafb58 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -81,7 +81,9 @@ ## Bugfixes +* Fixed a condition where retrying queries would yield an incorrect cursor in the Java SDK Firestore Connector ([#22089](https://github.com/apache/beam/issues/22089)). * Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). + ## Known Issues * ([#X](https://github.com/apache/beam/issues/X)). diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java index 0745da1efedb..405ab65f941d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java @@ -33,6 +33,7 @@ import com.google.cloud.firestore.v1.stub.FirestoreStub; import com.google.firestore.v1.BatchGetDocumentsRequest; import com.google.firestore.v1.BatchGetDocumentsResponse; +import com.google.firestore.v1.BatchGetDocumentsResponse.ResultCase; import com.google.firestore.v1.Cursor; import com.google.firestore.v1.ListCollectionIdsRequest; import com.google.firestore.v1.ListCollectionIdsResponse; @@ -43,15 +44,11 @@ import com.google.firestore.v1.RunQueryRequest; import com.google.firestore.v1.RunQueryResponse; import com.google.firestore.v1.StructuredQuery; -import com.google.firestore.v1.StructuredQuery.Direction; -import com.google.firestore.v1.StructuredQuery.FieldReference; import com.google.firestore.v1.StructuredQuery.Order; import com.google.firestore.v1.Value; import com.google.protobuf.Message; import com.google.protobuf.ProtocolStringList; import java.io.Serializable; -import java.util.List; -import java.util.Map; import java.util.Objects; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.gcp.firestore.FirestoreDoFn.ImplicitlyWindowedFirestoreDoFn; @@ -109,46 +106,31 @@ protected ServerStreamingCallable getCallable protected RunQueryRequest setStartFrom( RunQueryRequest element, RunQueryResponse runQueryResponse) { StructuredQuery query = element.getStructuredQuery(); - StructuredQuery.Builder builder; - List orderByList = query.getOrderByList(); - // if the orderByList is empty that means the default sort of "__name__ ASC" will be used - // Before we can set the cursor to the last document name read, we need to explicitly add - // the order of "__name__ ASC" because a cursor value must map to an order by - if (orderByList.isEmpty()) { - builder = - query - .toBuilder() - .addOrderBy( - Order.newBuilder() - .setField(FieldReference.newBuilder().setFieldPath("__name__").build()) - .setDirection(Direction.ASCENDING) - .build()) - .setStartAt( - Cursor.newBuilder() - .setBefore(false) - .addValues( - Value.newBuilder() - .setReferenceValue(runQueryResponse.getDocument().getName()) - .build())); - } else { - Cursor.Builder cursor = Cursor.newBuilder().setBefore(false); - Map fieldsMap = runQueryResponse.getDocument().getFieldsMap(); - for (Order order : orderByList) { - String fieldPath = order.getField().getFieldPath(); - Value value = fieldsMap.get(fieldPath); - if (value != null) { - cursor.addValues(value); - } else if ("__name__".equals(fieldPath)) { - cursor.addValues( - Value.newBuilder() - .setReferenceValue(runQueryResponse.getDocument().getName()) - .build()); - } + StructuredQuery.Builder builder = query.toBuilder(); + builder.addAllOrderBy(QueryUtils.getImplicitOrderBy(query)); + Cursor.Builder cursor = Cursor.newBuilder().setBefore(false); + for (Order order : builder.getOrderByList()) { + Value value = + QueryUtils.lookupDocumentValue( + runQueryResponse.getDocument(), order.getField().getFieldPath()); + if (value == null) { + throw new IllegalStateException( + String.format( + "Failed to build query resumption token, field '%s' not found in doc with __name__ '%s'", + order.getField().getFieldPath(), runQueryResponse.getDocument().getName())); } - builder = query.toBuilder().setStartAt(cursor.build()); + cursor.addValues(value); } + builder.setStartAt(cursor.build()); return element.toBuilder().setStructuredQuery(builder.build()).build(); } + + @Override + protected @Nullable RunQueryResponse resumptionValue( + @Nullable RunQueryResponse previousValue, RunQueryResponse nextValue) { + // We need a document to resume, may be null if reporting partial progress. + return nextValue.hasDocument() ? nextValue : previousValue; + } } /** @@ -380,6 +362,13 @@ protected BatchGetDocumentsRequest setStartFrom( "Unable to determine BatchGet resumption point. Most recently received doc __name__ '%s'", foundName != null ? foundName : missing)); } + + @Override + protected @Nullable BatchGetDocumentsResponse resumptionValue( + @Nullable BatchGetDocumentsResponse previousValue, BatchGetDocumentsResponse newValue) { + // No sense in resuming from an empty result. + return newValue.getResultCase() == ResultCase.RESULT_NOT_SET ? previousValue : newValue; + } } /** @@ -407,6 +396,8 @@ protected StreamingFirestoreV1ReadFn( protected abstract InT setStartFrom(InT element, OutT out); + protected abstract @Nullable OutT resumptionValue(@Nullable OutT previousValue, OutT newValue); + @Override public final void processElement(ProcessContext c) throws Exception { @SuppressWarnings( @@ -421,14 +412,14 @@ public final void processElement(ProcessContext c) throws Exception { } Instant start = clock.instant(); + InT request = + lastReceivedValue == null ? element : setStartFrom(element, lastReceivedValue); try { - InT request = - lastReceivedValue == null ? element : setStartFrom(element, lastReceivedValue); attempt.recordRequestStart(start); ServerStream serverStream = getCallable(firestoreStub).call(request); attempt.recordRequestSuccessful(clock.instant()); for (OutT out : serverStream) { - lastReceivedValue = out; + lastReceivedValue = resumptionValue(lastReceivedValue, out); attempt.recordStreamValue(clock.instant()); c.output(out); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/QueryUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/QueryUtils.java new file mode 100644 index 000000000000..a1a4af07510c --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/QueryUtils.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.firestore; + +import com.google.firestore.v1.Document; +import com.google.firestore.v1.StructuredQuery; +import com.google.firestore.v1.StructuredQuery.Direction; +import com.google.firestore.v1.StructuredQuery.FieldFilter; +import com.google.firestore.v1.StructuredQuery.FieldFilter.Operator; +import com.google.firestore.v1.StructuredQuery.FieldReference; +import com.google.firestore.v1.StructuredQuery.Filter; +import com.google.firestore.v1.StructuredQuery.Order; +import com.google.firestore.v1.StructuredQuery.UnaryFilter; +import com.google.firestore.v1.Value; +import com.google.firestore.v1.Value.ValueTypeCase; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Ascii; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.UnsignedBytes; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Contains several internal utility functions for Firestore query handling, such as filling + * implicit ordering or escaping field references. + */ +class QueryUtils { + + private static final ImmutableSet INEQUALITY_FIELD_FILTER_OPS = + ImmutableSet.of( + FieldFilter.Operator.LESS_THAN, + FieldFilter.Operator.LESS_THAN_OR_EQUAL, + FieldFilter.Operator.GREATER_THAN, + FieldFilter.Operator.GREATER_THAN_OR_EQUAL, + FieldFilter.Operator.NOT_EQUAL, + FieldFilter.Operator.NOT_IN); + private static final ImmutableSet INEQUALITY_UNARY_FILTER_OPS = + ImmutableSet.of(UnaryFilter.Operator.IS_NOT_NAN, UnaryFilter.Operator.IS_NOT_NULL); + + /** + * Populates implicit orderBy of a query in accordance with our documentation. * Required + * inequality fields are appended in field name order. * __name__ is appended if not specified. + * See here + * for more details. + * + * @param query The StructuredQuery of the original request. + * @return A list of additional orderBy fields, excluding the explicit ones. + */ + static List getImplicitOrderBy(StructuredQuery query) { + List expectedImplicitOrders = new ArrayList<>(); + if (query.hasWhere()) { + fillInequalityFields(query.getWhere(), expectedImplicitOrders); + } + Collections.sort(expectedImplicitOrders); + if (expectedImplicitOrders.stream().noneMatch(OrderByFieldPath::isDocumentName)) { + expectedImplicitOrders.add(OrderByFieldPath.fromString("__name__")); + } + for (Order order : query.getOrderByList()) { + OrderByFieldPath orderField = OrderByFieldPath.fromString(order.getField().getFieldPath()); + expectedImplicitOrders.remove(orderField); + } + + List additionalOrders = new ArrayList<>(); + if (!expectedImplicitOrders.isEmpty()) { + Direction lastDirection = + query.getOrderByCount() == 0 + ? Direction.ASCENDING + : query.getOrderByList().get(query.getOrderByCount() - 1).getDirection(); + + for (OrderByFieldPath field : expectedImplicitOrders) { + additionalOrders.add( + Order.newBuilder() + .setDirection(lastDirection) + .setField( + FieldReference.newBuilder().setFieldPath(field.getOriginalString()).build()) + .build()); + } + } + + return additionalOrders; + } + + private static void fillInequalityFields(Filter filter, List result) { + switch (filter.getFilterTypeCase()) { + case FIELD_FILTER: + if (INEQUALITY_FIELD_FILTER_OPS.contains(filter.getFieldFilter().getOp())) { + OrderByFieldPath fieldPath = + OrderByFieldPath.fromString(filter.getFieldFilter().getField().getFieldPath()); + if (!result.contains(fieldPath)) { + result.add(fieldPath); + } + } + break; + case COMPOSITE_FILTER: + filter.getCompositeFilter().getFiltersList().forEach(f -> fillInequalityFields(f, result)); + break; + case UNARY_FILTER: + if (INEQUALITY_UNARY_FILTER_OPS.contains(filter.getUnaryFilter().getOp())) { + OrderByFieldPath fieldPath = + OrderByFieldPath.fromString(filter.getUnaryFilter().getField().getFieldPath()); + if (!result.contains(fieldPath)) { + result.add(fieldPath); + } + } + break; + default: + break; + } + } + + static @Nullable Value lookupDocumentValue(Document document, String fieldPath) { + OrderByFieldPath resolvedPath = OrderByFieldPath.fromString(fieldPath); + // __name__ is a special field and doesn't exist in (top-level) valueMap (see + // https://firebase.google.com/docs/firestore/reference/rest/v1/projects.databases.documents#Document). + if (resolvedPath.isDocumentName()) { + return Value.newBuilder().setReferenceValue(document.getName()).build(); + } + return findMapValue(new ArrayList<>(resolvedPath.getSegments()), document.getFieldsMap()); + } + + private static @Nullable Value findMapValue(List segments, Map valueMap) { + if (segments.isEmpty()) { + return null; + } + String field = segments.remove(0); + Value value = valueMap.get(field); + if (segments.isEmpty()) { + return value; + } + // Field path traversal is not done, recurse into map values. + if (value == null || !value.getValueTypeCase().equals(ValueTypeCase.MAP_VALUE)) { + return null; + } + return findMapValue(segments, value.getMapValue().getFieldsMap()); + } + + private static class OrderByFieldPath implements Comparable { + + private static final String UNQUOTED_NAME_REGEX_STRING = "([a-zA-Z_][a-zA-Z_0-9]*)"; + private static final String QUOTED_NAME_REGEX_STRING = "(`(?:[^`\\\\]|(?:\\\\.))+`)"; + // After each segment follows a dot and more characters, or the end of the string. + private static final Pattern FIELD_PATH_SEGMENT_REGEX = + Pattern.compile( + String.format( + "(?:%s|%s)(\\..+|$)", UNQUOTED_NAME_REGEX_STRING, QUOTED_NAME_REGEX_STRING), + Pattern.DOTALL); + + public static OrderByFieldPath fromString(String fieldPath) { + if (fieldPath.isEmpty()) { + throw new IllegalArgumentException("Could not resolve empty field path"); + } + String originalString = fieldPath; + List segments = new ArrayList<>(); + while (!fieldPath.isEmpty()) { + Matcher segmentMatcher = FIELD_PATH_SEGMENT_REGEX.matcher(fieldPath); + boolean foundMatch = segmentMatcher.lookingAt(); + if (!foundMatch) { + throw new IllegalArgumentException("OrderBy field path was malformed"); + } + String fieldName; + if ((fieldName = segmentMatcher.group(1)) != null) { + segments.add(fieldName); + } else if ((fieldName = segmentMatcher.group(2)) != null) { + String unescaped = unescapeFieldName(fieldName.substring(1, fieldName.length() - 1)); + segments.add(unescaped); + } else { + throw new IllegalArgumentException("OrderBy field path was malformed"); + } + fieldPath = fieldPath.substring(fieldName.length()); + // Due to the regex, any non-empty fieldPath will have a dot before the next nested field. + if (fieldPath.startsWith(".")) { + fieldPath = fieldPath.substring(1); + } + } + return new OrderByFieldPath(originalString, ImmutableList.copyOf(segments)); + } + + private final String originalString; + private final ImmutableList segments; + + private OrderByFieldPath(String originalString, ImmutableList segments) { + this.originalString = originalString; + this.segments = segments; + } + + public String getOriginalString() { + return originalString; + } + + public boolean isDocumentName() { + return segments.size() == 1 && "__name__".equals(segments.get(0)); + } + + public ImmutableList getSegments() { + return segments; + } + + @Override + public boolean equals(@Nullable Object other) { + if (other instanceof OrderByFieldPath) { + return this.segments.equals(((OrderByFieldPath) other).getSegments()); + } + return super.equals(other); + } + + @Override + public int hashCode() { + return Objects.hash(segments); + } + + @Override + public int compareTo(OrderByFieldPath other) { + // Inspired by com.google.cloud.firestore.FieldPath. + int length = Math.min(this.getSegments().size(), other.getSegments().size()); + for (int i = 0; i < length; i++) { + byte[] thisField = this.getSegments().get(i).getBytes(StandardCharsets.UTF_8); + byte[] otherField = other.getSegments().get(i).getBytes(StandardCharsets.UTF_8); + int cmp = UnsignedBytes.lexicographicalComparator().compare(thisField, otherField); + if (cmp != 0) { + return cmp; + } + } + return Integer.compare(this.getSegments().size(), other.getSegments().size()); + } + + private static String unescapeFieldName(String fieldName) { + if (fieldName.isEmpty()) { + throw new IllegalArgumentException("quoted identifier cannot be empty"); + } + StringBuilder buf = new StringBuilder(); + for (int i = 0; i < fieldName.length(); i++) { + char c = fieldName.charAt(i); + // Roughly speaking, there are 4 cases we care about: + // - carriage returns: \r and \r\n + // - unescaped quotes: ` + // - non-escape sequences + // - escape sequences + if (c == '`') { + throw new IllegalArgumentException("quoted identifier cannot contain unescaped quote"); + } else if (c == '\r') { + buf.append('\n'); + // Convert '\r\n' into '\n' + if (i + 1 < fieldName.length() && fieldName.charAt(i + 1) == '\n') { + i++; + } + } else if (c != '\\') { + buf.append(c); + } else if (i + 1 >= fieldName.length()) { + throw new IllegalArgumentException("illegal trailing backslash"); + } else { + i++; + switch (fieldName.charAt(i)) { + case 'a': + buf.appendCodePoint(Ascii.BEL); // "Alert" control character + break; + case 'b': + buf.append('\b'); + break; + case 'f': + buf.append('\f'); + break; + case 'n': + buf.append('\n'); + break; + case 'r': + buf.append('\r'); + break; + case 't': + buf.append('\t'); + break; + case 'v': + buf.appendCodePoint(Ascii.VT); // vertical tab + break; + case '?': + buf.append('?'); // artifact of ancient C grammar + break; + case '\\': + buf.append('\\'); + break; + case '\'': + buf.append('\''); + break; + case '"': + buf.append('\"'); + break; + case '`': + buf.append('`'); + break; + case '0': + case '1': + case '2': + case '3': + if (i + 3 > fieldName.length()) { + throw new IllegalArgumentException("illegal octal escape sequence"); + } + buf.appendCodePoint(unescapeOctal(fieldName.substring(i, i + 3))); + i += 3; + break; + case 'x': + case 'X': + i++; + if (i + 2 > fieldName.length()) { + throw new IllegalArgumentException("illegal hex escape sequence"); + } + buf.appendCodePoint(unescapeHex(fieldName.substring(i, i + 2))); + i += 2; + break; + case 'u': + i++; + if (i + 4 > fieldName.length()) { + throw new IllegalArgumentException("illegal unicode escape sequence"); + } + buf.appendCodePoint(unescapeHex(fieldName.substring(i, i + 4))); + i += 4; + break; + case 'U': + i++; + if (i + 8 > fieldName.length()) { + throw new IllegalArgumentException("illegal unicode escape sequence"); + } + buf.appendCodePoint(unescapeHex(fieldName.substring(i, i + 8))); + i += 8; + break; + default: + throw new IllegalArgumentException("illegal escape"); + } + } + } + return buf.toString(); + } + + private static int unescapeOctal(String str) { + int ch = 0; + for (int i = 0; i < str.length(); i++) { + ch = 8 * ch + octalValue(str.charAt(i)); + } + if (!Character.isValidCodePoint(ch)) { + throw new IllegalArgumentException("illegal codepoint"); + } + return ch; + } + + private static int unescapeHex(String str) { + int ch = 0; + for (int i = 0; i < str.length(); i++) { + ch = 16 * ch + hexValue(str.charAt(i)); + } + if (!Character.isValidCodePoint(ch)) { + throw new IllegalArgumentException("illegal codepoint"); + } + return ch; + } + + private static int octalValue(char d) { + if (d >= '0' && d <= '7') { + return d - '0'; + } else { + throw new IllegalArgumentException("illegal octal digit"); + } + } + + private static int hexValue(char d) { + if (d >= '0' && d <= '9') { + return d - '0'; + } else if (d >= 'a' && d <= 'f') { + return 10 + d - 'a'; + } else if (d >= 'A' && d <= 'F') { + return 10 + d - 'A'; + } else { + throw new IllegalArgumentException("illegal hex digit"); + } + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java index 23b18d0a6253..6bb5999344bc 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnRunQueryTest.java @@ -32,6 +32,7 @@ import com.google.cloud.firestore.v1.stub.FirestoreStub; import com.google.firestore.v1.Cursor; import com.google.firestore.v1.Document; +import com.google.firestore.v1.MapValue; import com.google.firestore.v1.RunQueryRequest; import com.google.firestore.v1.RunQueryResponse; import com.google.firestore.v1.StructuredQuery; @@ -60,18 +61,19 @@ public final class FirestoreV1FnRunQueryTest extends BaseFirestoreV1ReadFnTest { @Mock private ServerStreamingCallable callable; - @Mock private ServerStream responseStream1; - @Mock private ServerStream responseStream2; + @Mock private ServerStream responseStream; + @Mock private ServerStream retryResponseStream; @Test public void endToEnd() throws Exception { - TestData testData = TestData.fieldEqualsBar().setProjectId(projectId).build(); + TestData testData = + new TestData.Builder().setFilter(TestData.FIELD_EQUALS_BAR).setProjectId(projectId).build(); List responses = ImmutableList.of(testData.response1, testData.response2, testData.response3); - when(responseStream1.iterator()).thenReturn(responses.iterator()); + when(responseStream.iterator()).thenReturn(responses.iterator()); - when(callable.call(testData.request)).thenReturn(responseStream1); + when(callable.call(testData.request)).thenReturn(responseStream); when(stub.runQueryCallable()).thenReturn(callable); @@ -97,16 +99,32 @@ public void endToEnd() throws Exception { @Override public void resumeFromLastReadValue() throws Exception { + buildAndRunQueryRetryTest("foo", "bar"); + } + + @Test + public void resumeFromLastReadValue_nestedOrderBy() throws Exception { + buildAndRunQueryRetryTest("baz.qux", "val"); + } + + @Test + public void resumeFromLastReadValue_nestedOrderBySimpleEscaping() throws Exception { + buildAndRunQueryRetryTest("`quux.quuz`", "123"); + } + + @Test + public void resumeFromLastReadValue_nestedOrderByComplexEscaping() throws Exception { + buildAndRunQueryRetryTest("`fo\\`o.m\\`ap`.`bar.key`", "bar.val"); + } + + @Test + public void resumeFromLastReadValue_withNoOrderBy() throws Exception { TestData testData = - TestData.fieldEqualsBar() + new TestData.Builder() + .setFilter(TestData.FIELD_NOT_EQUALS_FOO) .setProjectId(projectId) - .setOrderFunction( - f -> - Collections.singletonList( - Order.newBuilder().setDirection(Direction.ASCENDING).setField(f).build())) .build(); - - RunQueryRequest request2 = + RunQueryRequest expectedRetryRequest = RunQueryRequest.newBuilder() .setParent(String.format("projects/%s/databases/(default)/document", projectId)) .setStructuredQuery( @@ -117,65 +135,38 @@ public void resumeFromLastReadValue() throws Exception { .setStartAt( Cursor.newBuilder() .setBefore(false) - .addValues(Value.newBuilder().setStringValue("bar")))) + .addValues(Value.newBuilder().setStringValue("bar")) + .addValues( + Value.newBuilder() + .setReferenceValue(testData.response2.getDocument().getName()))) + .addOrderBy( + // Implicit orderBy adds order for inequality filters + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("foo")) + .setDirection(Direction.ASCENDING) + .build()) + .addOrderBy( + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("__name__")) + .setDirection(Direction.ASCENDING))) .build(); - List responses = - ImmutableList.of(testData.response1, testData.response2, testData.response3); - when(responseStream1.iterator()) - .thenReturn( - new AbstractIterator() { - private int invocationCount = 1; - - @Override - protected RunQueryResponse computeNext() { - int count = invocationCount++; - if (count == 1) { - return responses.get(0); - } else if (count == 2) { - return responses.get(1); - } else { - throw RETRYABLE_ERROR; - } - } - }); - - when(callable.call(testData.request)).thenReturn(responseStream1); - doNothing().when(attempt).checkCanRetry(any(), eq(RETRYABLE_ERROR)); - when(responseStream2.iterator()).thenReturn(ImmutableList.of(responses.get(2)).iterator()); - when(callable.call(request2)).thenReturn(responseStream2); - - when(stub.runQueryCallable()).thenReturn(callable); - - when(ff.getFirestoreStub(any())).thenReturn(stub); - when(ff.getRpcQos(any())).thenReturn(rpcQos); - when(rpcQos.newReadAttempt(any())).thenReturn(attempt); - when(attempt.awaitSafeToProceed(any())).thenReturn(true); - - ArgumentCaptor responsesCaptor = - ArgumentCaptor.forClass(RunQueryResponse.class); - - doNothing().when(processContext).output(responsesCaptor.capture()); - - when(processContext.element()).thenReturn(testData.request); - - RunQueryFn fn = new RunQueryFn(clock, ff, rpcQosOptions); - - runFunction(fn); - - List allValues = responsesCaptor.getAllValues(); - assertEquals(responses, allValues); - - verify(callable, times(1)).call(testData.request); - verify(callable, times(1)).call(request2); - verify(attempt, times(3)).recordStreamValue(any()); + runQueryRetryTest(testData, expectedRetryRequest); } - @Test - public void resumeFromLastReadValue_withNoOrderBy() throws Exception { - TestData testData = TestData.fieldEqualsBar().setProjectId(projectId).build(); - - RunQueryRequest request2 = + private void buildAndRunQueryRetryTest(String fieldName, String fieldValue) throws Exception { + TestData testData = + new TestData.Builder() + .setFilter(TestData.FIELD_EQUALS_BAR) + .setProjectId(projectId) + .setOrderFunction( + f -> { + FieldReference f2 = FieldReference.newBuilder().setFieldPath(fieldName).build(); + return Collections.singletonList( + Order.newBuilder().setDirection(Direction.ASCENDING).setField(f2).build()); + }) + .build(); + RunQueryRequest expectedRetryRequest = RunQueryRequest.newBuilder() .setParent(String.format("projects/%s/databases/(default)/document", projectId)) .setStructuredQuery( @@ -186,6 +177,7 @@ public void resumeFromLastReadValue_withNoOrderBy() throws Exception { .setStartAt( Cursor.newBuilder() .setBefore(false) + .addValues(Value.newBuilder().setStringValue(fieldValue)) .addValues( Value.newBuilder() .setReferenceValue(testData.response2.getDocument().getName()))) @@ -195,9 +187,12 @@ public void resumeFromLastReadValue_withNoOrderBy() throws Exception { .setDirection(Direction.ASCENDING))) .build(); - List responses = - ImmutableList.of(testData.response1, testData.response2, testData.response3); - when(responseStream1.iterator()) + runQueryRetryTest(testData, expectedRetryRequest); + } + + private void runQueryRetryTest(TestData testData, RunQueryRequest expectedRetryRequest) + throws Exception { + when(responseStream.iterator()) .thenReturn( new AbstractIterator() { private int invocationCount = 1; @@ -206,19 +201,20 @@ public void resumeFromLastReadValue_withNoOrderBy() throws Exception { protected RunQueryResponse computeNext() { int count = invocationCount++; if (count == 1) { - return responses.get(0); + return testData.response1; } else if (count == 2) { - return responses.get(1); + return testData.response2; } else { throw RETRYABLE_ERROR; } } }); - when(callable.call(testData.request)).thenReturn(responseStream1); + when(callable.call(testData.request)).thenReturn(responseStream); doNothing().when(attempt).checkCanRetry(any(), eq(RETRYABLE_ERROR)); - when(responseStream2.iterator()).thenReturn(ImmutableList.of(testData.response3).iterator()); - when(callable.call(request2)).thenReturn(responseStream2); + when(retryResponseStream.iterator()) + .thenReturn(ImmutableList.of(testData.response3).iterator()); + when(callable.call(expectedRetryRequest)).thenReturn(retryResponseStream); when(stub.runQueryCallable()).thenReturn(callable); @@ -238,12 +234,13 @@ protected RunQueryResponse computeNext() { runFunction(fn); - List allValues = responsesCaptor.getAllValues(); - assertEquals(responses, allValues); - verify(callable, times(1)).call(testData.request); - verify(callable, times(1)).call(request2); + verify(callable, times(1)).call(expectedRetryRequest); verify(attempt, times(3)).recordStreamValue(any()); + + List allValues = responsesCaptor.getAllValues(); + assertEquals( + ImmutableList.of(testData.response1, testData.response2, testData.response3), allValues); } @Override @@ -283,62 +280,96 @@ protected RunQueryFn getFn( private static final class TestData { + static final FieldReference FILTER_FIELD_PATH = + FieldReference.newBuilder().setFieldPath("foo").build(); + static final Filter FIELD_EQUALS_BAR = + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField(FILTER_FIELD_PATH) + .setOp(Operator.EQUAL) + .setValue(Value.newBuilder().setStringValue("bar")) + .build()) + .build(); + static final Filter FIELD_NOT_EQUALS_FOO = + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField(FILTER_FIELD_PATH) + .setOp(Operator.NOT_EQUAL) + .setValue(Value.newBuilder().setStringValue("foo")) + .build()) + .build(); + private final RunQueryRequest request; private final RunQueryResponse response1; private final RunQueryResponse response2; private final RunQueryResponse response3; - public TestData(String projectId, Function> orderFunction) { - String fieldPath = "foo"; - FieldReference foo = FieldReference.newBuilder().setFieldPath(fieldPath).build(); + public TestData( + String projectId, Function> orderFunction, Filter filter) { StructuredQuery.Builder builder = StructuredQuery.newBuilder() .addFrom( CollectionSelector.newBuilder() .setAllDescendants(false) .setCollectionId("collection")) - .setWhere( - Filter.newBuilder() - .setFieldFilter( - FieldFilter.newBuilder() - .setField(foo) - .setOp(Operator.EQUAL) - .setValue(Value.newBuilder().setStringValue("bar")) - .build())); - - orderFunction.apply(foo).forEach(builder::addOrderBy); + .setWhere(filter); + + orderFunction.apply(FILTER_FIELD_PATH).forEach(builder::addOrderBy); request = RunQueryRequest.newBuilder() .setParent(String.format("projects/%s/databases/(default)/document", projectId)) .setStructuredQuery(builder) .build(); - response1 = newResponse(fieldPath, 1); - response2 = newResponse(fieldPath, 2); - response3 = newResponse(fieldPath, 3); + response1 = newResponse(1); + response2 = newResponse(2); + response3 = newResponse(3); } - private static RunQueryResponse newResponse(String field, int docNumber) { + /** + * Returns single-document response like this: { "__name__": "doc-{docNumber}", "foo": "bar", + * "fo`o.m`ap": { "bar.key": "bar.val" }, "baz" : { "qux" : "val" }, "quux.quuz" : "123" }. + */ + private static RunQueryResponse newResponse(int docNumber) { String docId = String.format("doc-%d", docNumber); return RunQueryResponse.newBuilder() .setDocument( Document.newBuilder() .setName(docId) .putAllFields( - ImmutableMap.of(field, Value.newBuilder().setStringValue("bar").build())) - .build()) + ImmutableMap.of( + "foo", + Value.newBuilder().setStringValue("bar").build(), + "fo`o.m`ap", + Value.newBuilder() + .setMapValue( + MapValue.newBuilder() + .putFields( + "bar.key", + Value.newBuilder().setStringValue("bar.val").build()) + .build()) + .build(), + "baz", + Value.newBuilder() + .setMapValue( + MapValue.newBuilder() + .putFields( + "qux", Value.newBuilder().setStringValue("val").build()) + .build()) + .build(), + "quux.quuz", + Value.newBuilder().setStringValue("123").build()))) .build(); } - private static Builder fieldEqualsBar() { - return new Builder(); - } - @SuppressWarnings("initialization.fields.uninitialized") // fields set via builder methods private static final class Builder { private String projectId; private Function> orderFunction; + private Filter filter; public Builder() { orderFunction = f -> Collections.emptyList(); @@ -354,10 +385,16 @@ public Builder setOrderFunction(Function> orderFunct return this; } + public Builder setFilter(Filter filter) { + this.filter = filter; + return this; + } + private TestData build() { return new TestData( requireNonNull(projectId, "projectId must be non null"), - requireNonNull(orderFunction, "orderFunction must be non null")); + requireNonNull(orderFunction, "orderFunction must be non null"), + requireNonNull(filter, "filter must be non-null")); } } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/QueryUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/QueryUtilsTest.java new file mode 100644 index 000000000000..ce5ac6dd849d --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/QueryUtilsTest.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.firestore; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +import com.google.firestore.v1.Document; +import com.google.firestore.v1.MapValue; +import com.google.firestore.v1.StructuredQuery; +import com.google.firestore.v1.StructuredQuery.CollectionSelector; +import com.google.firestore.v1.StructuredQuery.CompositeFilter; +import com.google.firestore.v1.StructuredQuery.Direction; +import com.google.firestore.v1.StructuredQuery.FieldFilter; +import com.google.firestore.v1.StructuredQuery.FieldReference; +import com.google.firestore.v1.StructuredQuery.Filter; +import com.google.firestore.v1.StructuredQuery.Order; +import com.google.firestore.v1.StructuredQuery.UnaryFilter; +import com.google.firestore.v1.Value; +import java.util.List; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.junit.Before; +import org.junit.Test; + +public class QueryUtilsTest { + + private Document testDocument; + private StructuredQuery testQuery; + + @Before + public void setUp() { + // { "__name__": "doc-123", "fo`o.m`ap": { "bar.key": "bar.val" } } + testDocument = + Document.newBuilder() + .setName("doc-123") + .putAllFields( + ImmutableMap.of( + "fo`o.m`ap", + Value.newBuilder() + .setMapValue( + MapValue.newBuilder() + .putFields( + "bar.key", Value.newBuilder().setStringValue("bar.val").build()) + .build()) + .build())) + .build(); + + // WHERE (`z€a`.a.a != "" AND `b` > "") AND c == "" AND `z$` > "456" AND `z` > "123" AND z IS + // NOT NAN + Filter.Builder filter = + Filter.newBuilder() + .setCompositeFilter( + CompositeFilter.newBuilder() + .addFilters( + Filter.newBuilder() + .setCompositeFilter( + CompositeFilter.newBuilder() + .addFilters( + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField( + FieldReference.newBuilder() + .setFieldPath("`z€a`.a.a")) + .setOp(FieldFilter.Operator.NOT_EQUAL) + .setValue( + Value.newBuilder().setStringValue("")))) + .addFilters( + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField( + FieldReference.newBuilder() + .setFieldPath("`b`")) + .setOp(FieldFilter.Operator.GREATER_THAN) + .setValue( + Value.newBuilder().setStringValue("")))) + .setOp(CompositeFilter.Operator.AND))) + .addFilters( + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("c")) + .setOp(FieldFilter.Operator.EQUAL) + .setValue(Value.newBuilder().setStringValue("")))) + .addFilters( + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z$`")) + .setOp(FieldFilter.Operator.GREATER_THAN) + .setValue(Value.newBuilder().setStringValue("456")))) + .addFilters( + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z`")) + .setOp(FieldFilter.Operator.GREATER_THAN) + .setValue(Value.newBuilder().setStringValue("123")))) + .addFilters( + Filter.newBuilder() + .setUnaryFilter( + UnaryFilter.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("z")) + .setOp(UnaryFilter.Operator.IS_NOT_NAN))) + .setOp(CompositeFilter.Operator.AND) + .build()); + testQuery = + StructuredQuery.newBuilder() + .addFrom( + CollectionSelector.newBuilder() + .setAllDescendants(false) + .setCollectionId("collection")) + .setWhere(filter) + .addOrderBy( + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("b")) + .setDirection(Direction.DESCENDING)) + .build(); + } + + @Test + public void getImplicitOrderBy_success() { + // WHERE (`z€a`.a.a != "" AND `b` > "") AND c == "" AND `z$` > "456" AND `z` > "123" AND z IS + // NOT NAN ORDER BY b DESC + // -> (ORDER BY b DESC) + `z` DESC, `z$` DESC, `z€a`.a.a DESC, __name__ DESC + List expected = + ImmutableList.of( + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z`")) + .setDirection(Direction.DESCENDING) + .build(), + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z$`")) + .setDirection(Direction.DESCENDING) + .build(), + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z€a`.a.a")) + .setDirection(Direction.DESCENDING) + .build(), + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("__name__")) + .setDirection(Direction.DESCENDING) + .build()); + List actual = QueryUtils.getImplicitOrderBy(testQuery); + assertEquals(expected, actual); + } + + @Test + public void getImplicitOrderBy_nameInWhere() { + StructuredQuery.Builder builder = testQuery.toBuilder(); + builder + .getWhereBuilder() + .getCompositeFilterBuilder() + .addFilters( + Filter.newBuilder() + .setFieldFilter( + FieldFilter.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("__name__")) + .setOp(FieldFilter.Operator.NOT_EQUAL) + .setValue(Value.newBuilder().setStringValue("")))); + testQuery = builder.build(); + // WHERE (`z€a`.a.a != "" AND `b` > "") AND c == "" AND `z$` > "456" AND `z` > "123" AND z IS + // NOT NAN AND __name__ != "" ORDER BY b DESC + // -> (ORDER BY b DESC) + __name__ DESC, `z` DESC, `z$` DESC, `z€a`.a.a DESC + List expected = + ImmutableList.of( + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("__name__")) + .setDirection(Direction.DESCENDING) + .build(), + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z`")) + .setDirection(Direction.DESCENDING) + .build(), + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z$`")) + .setDirection(Direction.DESCENDING) + .build(), + Order.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("`z€a`.a.a")) + .setDirection(Direction.DESCENDING) + .build()); + List actual = QueryUtils.getImplicitOrderBy(testQuery); + assertEquals(expected, actual); + } + + @Test + public void getImplicitOrderBy_malformedWhereThrows() { + testQuery = + testQuery + .toBuilder() + .setWhere( + Filter.newBuilder() + .setUnaryFilter( + UnaryFilter.newBuilder() + .setField(FieldReference.newBuilder().setFieldPath("")) + .setOp(UnaryFilter.Operator.IS_NOT_NAN))) + .build(); + assertThrows(IllegalArgumentException.class, () -> QueryUtils.getImplicitOrderBy(testQuery)); + } + + @Test + public void lookupDocumentValue_findsName() { + assertEquals( + QueryUtils.lookupDocumentValue(testDocument, "__name__"), + Value.newBuilder().setReferenceValue("doc-123").build()); + } + + @Test + public void lookupDocumentValue_nestedField() { + assertEquals( + QueryUtils.lookupDocumentValue(testDocument, "`fo\\`o.m\\`ap`.`bar.key`"), + Value.newBuilder().setStringValue("bar.val").build()); + } + + @Test + public void lookupDocumentValue_returnsNullIfNotFound() { + assertNull(QueryUtils.lookupDocumentValue(testDocument, "foobar")); + } + + @Test + public void lookupDocumentValue_invalidThrows() { + assertThrows( + IllegalArgumentException.class, () -> QueryUtils.lookupDocumentValue(testDocument, "")); + } +}