diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java index e9d0709343f5..446d097a8ed8 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.firestore; import static java.util.Objects.requireNonNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import com.google.firestore.v1.BatchGetDocumentsRequest; import com.google.firestore.v1.BatchGetDocumentsResponse; @@ -67,6 +68,7 @@ import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -502,6 +504,9 @@ public PartitionQuery.Builder partitionQuery() { */ @Immutable public static final class Write { + private @Nullable String projectId; + private @Nullable String databaseId; + private static final Write INSTANCE = new Write(); private Write() {} @@ -537,8 +542,20 @@ private Write() {} * @see google.firestore.v1.BatchWriteResponse */ + public Write withProjectId(String projectId) { + checkArgument(projectId != null, "projectId can not be null"); + this.projectId = projectId; + return this; + } + + public Write withDatabaseId(String databaseId) { + checkArgument(databaseId != null, "databaseId can not be null"); + this.databaseId = databaseId; + return this; + } + public BatchWriteWithSummary.Builder batchWrite() { - return new BatchWriteWithSummary.Builder(); + return new BatchWriteWithSummary.Builder().setProjectId(projectId).setDatabaseId(databaseId); } } @@ -1348,11 +1365,18 @@ public static final class BatchWriteWithSummary BatchWriteWithSummary, BatchWriteWithSummary.Builder> { - private BatchWriteWithSummary( + private final @Nullable String projectId; + private final @Nullable String databaseId; + + public BatchWriteWithSummary( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, - RpcQosOptions rpcQosOptions) { + RpcQosOptions rpcQosOptions, + @Nullable String projectId, + @Nullable String databaseId) { super(clock, firestoreStatefulComponentFactory, rpcQosOptions); + this.projectId = projectId; + this.databaseId = databaseId; } @Override @@ -1365,7 +1389,9 @@ public PCollection expand( clock, firestoreStatefulComponentFactory, rpcQosOptions, - CounterFactory.DEFAULT))); + CounterFactory.DEFAULT, + projectId, + databaseId))); } @Override @@ -1403,6 +1429,9 @@ public static final class Builder BatchWriteWithSummary, BatchWriteWithSummary.Builder> { + private @Nullable String projectId; + private @Nullable String databaseId; + private Builder() { super(); } @@ -1414,9 +1443,35 @@ private Builder( super(clock, firestoreStatefulComponentFactory, rpcQosOptions); } + /** Set the GCP project ID to be used by the Firestore client. */ + private Builder setProjectId(@Nullable String projectId) { + this.projectId = projectId; + return this; + } + + /** Set the Firestore database ID (e.g., "(default)"). */ + private Builder setDatabaseId(@Nullable String databaseId) { + this.databaseId = databaseId; + return this; + } + + @VisibleForTesting + @Nullable + String getProjectId() { + return this.projectId; + } + + @VisibleForTesting + @Nullable + String getDatabaseId() { + return this.databaseId; + } + public BatchWriteWithDeadLetterQueue.Builder withDeadLetterQueue() { return new BatchWriteWithDeadLetterQueue.Builder( - clock, firestoreStatefulComponentFactory, rpcQosOptions); + clock, firestoreStatefulComponentFactory, rpcQosOptions) + .setProjectId(projectId) + .setDatabaseId(databaseId); } @Override @@ -1429,7 +1484,8 @@ BatchWriteWithSummary buildSafe( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions) { - return new BatchWriteWithSummary(clock, firestoreStatefulComponentFactory, rpcQosOptions); + return new BatchWriteWithSummary( + clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } } } @@ -1474,11 +1530,18 @@ public static final class BatchWriteWithDeadLetterQueue BatchWriteWithDeadLetterQueue, BatchWriteWithDeadLetterQueue.Builder> { + private final @Nullable String projectId; + private final @Nullable String databaseId; + private BatchWriteWithDeadLetterQueue( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, - RpcQosOptions rpcQosOptions) { + RpcQosOptions rpcQosOptions, + @Nullable String projectId, + @Nullable String databaseId) { super(clock, firestoreStatefulComponentFactory, rpcQosOptions); + this.projectId = projectId; + this.databaseId = databaseId; } @Override @@ -1490,7 +1553,9 @@ public PCollection expand(PCollection { + private @Nullable String projectId; + private @Nullable String databaseId; + private Builder() { super(); } + private Builder setProjectId(@Nullable String projectId) { + this.projectId = projectId; + return this; + } + + private Builder setDatabaseId(@Nullable String databaseId) { + this.databaseId = databaseId; + return this; + } + + @VisibleForTesting + @Nullable + String getProjectId() { + return this.projectId; + } + + @VisibleForTesting + @Nullable + String getDatabaseId() { + return this.databaseId; + } + private Builder( JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, @@ -1550,7 +1640,7 @@ BatchWriteWithDeadLetterQueue buildSafe( FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions) { return new BatchWriteWithDeadLetterQueue( - clock, firestoreStatefulComponentFactory, rpcQosOptions); + clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId); } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java index 09378d4f80c5..70c2b91ffbfd 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java @@ -72,8 +72,16 @@ static final class BatchWriteFnWithSummary extends BaseBatchWriteFn extends ExplicitlyWindowedFirestore // bundle scoped state private transient FirestoreStub firestoreStub; private transient DatabaseRootName databaseRootName; + private final @Nullable String configuredProjectId; + private final @Nullable String configuredDatabaseId; @VisibleForTesting transient Queue<@NonNull WriteElement> writes = new PriorityQueue<>(WriteElement.COMPARATOR); @@ -171,12 +189,16 @@ abstract static class BaseBatchWriteFn extends ExplicitlyWindowedFirestore JodaClock clock, FirestoreStatefulComponentFactory firestoreStatefulComponentFactory, RpcQosOptions rpcQosOptions, - CounterFactory counterFactory) { + CounterFactory counterFactory, + @Nullable String configuredProjectId, + @Nullable String configuredDatabaseId) { this.clock = clock; this.firestoreStatefulComponentFactory = firestoreStatefulComponentFactory; this.rpcQosOptions = rpcQosOptions; this.counterFactory = counterFactory; this.rpcAttemptContext = V1FnRpcAttemptContext.BatchWrite; + this.configuredProjectId = configuredProjectId; + this.configuredDatabaseId = configuredDatabaseId; } @Override @@ -202,11 +224,19 @@ public void setup() { @Override public final void startBundle(StartBundleContext c) { - String project = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject(); + String project = + configuredProjectId != null + ? configuredProjectId + : c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject(); + if (project == null) { project = c.getPipelineOptions().as(GcpOptions.class).getProject(); } - String databaseId = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb(); + + String databaseId = + configuredDatabaseId != null + ? configuredDatabaseId + : c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb(); databaseRootName = DatabaseRootName.of( requireNonNull( diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java index 2948be7658a9..35d0ea9482d3 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java @@ -223,6 +223,7 @@ protected BatchWriteFnWithDeadLetterQueue getFn( RpcQosOptions rpcQosOptions, CounterFactory counterFactory, DistributionFactory distributionFactory) { - return new BatchWriteFnWithDeadLetterQueue(clock, ff, rpcQosOptions, counterFactory); + return new BatchWriteFnWithDeadLetterQueue( + clock, ff, rpcQosOptions, counterFactory, null, null); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java index 70c4ce5046a5..3e37e3975bf5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java @@ -201,7 +201,8 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception { when(callable.call(requestCaptor1.capture())).thenReturn(response1); BaseBatchWriteFn fn = - new BatchWriteFnWithSummary(clock, ff, options, CounterFactory.DEFAULT); + new BatchWriteFnWithSummary( + clock, ff, options, CounterFactory.DEFAULT, "testing-project", "(default)"); fn.setup(); fn.startBundle(startBundleContext); fn.processElement(processContext, window); // write0 @@ -238,6 +239,15 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception { verifyNoMoreInteractions(callable); } + @Test + public void testWithProjectId_thenWithDatabaseId() { + FirestoreV1.Write beamWrite = + FirestoreIO.v1().write().withProjectId("my-project").withDatabaseId("(default)"); + + assertEquals("my-project", beamWrite.batchWrite().getProjectId()); + assertEquals("(default)", beamWrite.batchWrite().getDatabaseId()); + } + @Override protected BatchWriteFnWithSummary getFn( JodaClock clock, @@ -245,6 +255,6 @@ protected BatchWriteFnWithSummary getFn( RpcQosOptions rpcQosOptions, CounterFactory counterFactory, DistributionFactory distributionFactory) { - return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory); + return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory, null, null); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java index 509797892e04..8695080cb885 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java @@ -92,10 +92,12 @@ abstract class BaseFirestoreIT { .build(); protected static String project; + protected static String databaseId; @Before public void setup() { project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + databaseId = "firestoredb"; } private static Instant toWriteTime(WriteResult result) { @@ -441,7 +443,14 @@ protected final void runWriteTest( testPipeline .apply(Create.of(Collections.singletonList(documentIds))) .apply(createWrite) - .apply(FirestoreIO.v1().write().batchWrite().withRpcQosOptions(RPC_QOS_OPTIONS).build()); + .apply( + FirestoreIO.v1() + .write() + .withProjectId(project) + .withDatabaseId(databaseId) + .batchWrite() + .withRpcQosOptions(RPC_QOS_OPTIONS) + .build()); testPipeline.run(TestPipeline.testingPipelineOptions()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java index 204aa67619bd..dce3ca4b8753 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java @@ -116,6 +116,8 @@ public void batchWrite_partialFailureOutputsToDeadLetterQueue() .apply( FirestoreIO.v1() .write() + .withProjectId(project) + .withDatabaseId(databaseId) .batchWrite() .withDeadLetterQueue() .withRpcQosOptions(options)