diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
index e9d0709343f5..446d097a8ed8 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1.java
@@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.gcp.firestore;
import static java.util.Objects.requireNonNull;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import com.google.firestore.v1.BatchGetDocumentsRequest;
import com.google.firestore.v1.BatchGetDocumentsResponse;
@@ -67,6 +68,7 @@
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;
@@ -502,6 +504,9 @@ public PartitionQuery.Builder partitionQuery() {
*/
@Immutable
public static final class Write {
+ private @Nullable String projectId;
+ private @Nullable String databaseId;
+
private static final Write INSTANCE = new Write();
private Write() {}
@@ -537,8 +542,20 @@ private Write() {}
* @see google.firestore.v1.BatchWriteResponse
*/
+ public Write withProjectId(String projectId) {
+ checkArgument(projectId != null, "projectId can not be null");
+ this.projectId = projectId;
+ return this;
+ }
+
+ public Write withDatabaseId(String databaseId) {
+ checkArgument(databaseId != null, "databaseId can not be null");
+ this.databaseId = databaseId;
+ return this;
+ }
+
public BatchWriteWithSummary.Builder batchWrite() {
- return new BatchWriteWithSummary.Builder();
+ return new BatchWriteWithSummary.Builder().setProjectId(projectId).setDatabaseId(databaseId);
}
}
@@ -1348,11 +1365,18 @@ public static final class BatchWriteWithSummary
BatchWriteWithSummary,
BatchWriteWithSummary.Builder> {
- private BatchWriteWithSummary(
+ private final @Nullable String projectId;
+ private final @Nullable String databaseId;
+
+ public BatchWriteWithSummary(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
- RpcQosOptions rpcQosOptions) {
+ RpcQosOptions rpcQosOptions,
+ @Nullable String projectId,
+ @Nullable String databaseId) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
+ this.projectId = projectId;
+ this.databaseId = databaseId;
}
@Override
@@ -1365,7 +1389,9 @@ public PCollection expand(
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
- CounterFactory.DEFAULT)));
+ CounterFactory.DEFAULT,
+ projectId,
+ databaseId)));
}
@Override
@@ -1403,6 +1429,9 @@ public static final class Builder
BatchWriteWithSummary,
BatchWriteWithSummary.Builder> {
+ private @Nullable String projectId;
+ private @Nullable String databaseId;
+
private Builder() {
super();
}
@@ -1414,9 +1443,35 @@ private Builder(
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
}
+ /** Set the GCP project ID to be used by the Firestore client. */
+ private Builder setProjectId(@Nullable String projectId) {
+ this.projectId = projectId;
+ return this;
+ }
+
+ /** Set the Firestore database ID (e.g., "(default)"). */
+ private Builder setDatabaseId(@Nullable String databaseId) {
+ this.databaseId = databaseId;
+ return this;
+ }
+
+ @VisibleForTesting
+ @Nullable
+ String getProjectId() {
+ return this.projectId;
+ }
+
+ @VisibleForTesting
+ @Nullable
+ String getDatabaseId() {
+ return this.databaseId;
+ }
+
public BatchWriteWithDeadLetterQueue.Builder withDeadLetterQueue() {
return new BatchWriteWithDeadLetterQueue.Builder(
- clock, firestoreStatefulComponentFactory, rpcQosOptions);
+ clock, firestoreStatefulComponentFactory, rpcQosOptions)
+ .setProjectId(projectId)
+ .setDatabaseId(databaseId);
}
@Override
@@ -1429,7 +1484,8 @@ BatchWriteWithSummary buildSafe(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
- return new BatchWriteWithSummary(clock, firestoreStatefulComponentFactory, rpcQosOptions);
+ return new BatchWriteWithSummary(
+ clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId);
}
}
}
@@ -1474,11 +1530,18 @@ public static final class BatchWriteWithDeadLetterQueue
BatchWriteWithDeadLetterQueue,
BatchWriteWithDeadLetterQueue.Builder> {
+ private final @Nullable String projectId;
+ private final @Nullable String databaseId;
+
private BatchWriteWithDeadLetterQueue(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
- RpcQosOptions rpcQosOptions) {
+ RpcQosOptions rpcQosOptions,
+ @Nullable String projectId,
+ @Nullable String databaseId) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
+ this.projectId = projectId;
+ this.databaseId = databaseId;
}
@Override
@@ -1490,7 +1553,9 @@ public PCollection expand(PCollection {
+ private @Nullable String projectId;
+ private @Nullable String databaseId;
+
private Builder() {
super();
}
+ private Builder setProjectId(@Nullable String projectId) {
+ this.projectId = projectId;
+ return this;
+ }
+
+ private Builder setDatabaseId(@Nullable String databaseId) {
+ this.databaseId = databaseId;
+ return this;
+ }
+
+ @VisibleForTesting
+ @Nullable
+ String getProjectId() {
+ return this.projectId;
+ }
+
+ @VisibleForTesting
+ @Nullable
+ String getDatabaseId() {
+ return this.databaseId;
+ }
+
private Builder(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
@@ -1550,7 +1640,7 @@ BatchWriteWithDeadLetterQueue buildSafe(
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
return new BatchWriteWithDeadLetterQueue(
- clock, firestoreStatefulComponentFactory, rpcQosOptions);
+ clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId);
}
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java
index 09378d4f80c5..70c2b91ffbfd 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java
@@ -72,8 +72,16 @@ static final class BatchWriteFnWithSummary extends BaseBatchWriteFn extends ExplicitlyWindowedFirestore
// bundle scoped state
private transient FirestoreStub firestoreStub;
private transient DatabaseRootName databaseRootName;
+ private final @Nullable String configuredProjectId;
+ private final @Nullable String configuredDatabaseId;
@VisibleForTesting
transient Queue<@NonNull WriteElement> writes = new PriorityQueue<>(WriteElement.COMPARATOR);
@@ -171,12 +189,16 @@ abstract static class BaseBatchWriteFn extends ExplicitlyWindowedFirestore
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions,
- CounterFactory counterFactory) {
+ CounterFactory counterFactory,
+ @Nullable String configuredProjectId,
+ @Nullable String configuredDatabaseId) {
this.clock = clock;
this.firestoreStatefulComponentFactory = firestoreStatefulComponentFactory;
this.rpcQosOptions = rpcQosOptions;
this.counterFactory = counterFactory;
this.rpcAttemptContext = V1FnRpcAttemptContext.BatchWrite;
+ this.configuredProjectId = configuredProjectId;
+ this.configuredDatabaseId = configuredDatabaseId;
}
@Override
@@ -202,11 +224,19 @@ public void setup() {
@Override
public final void startBundle(StartBundleContext c) {
- String project = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject();
+ String project =
+ configuredProjectId != null
+ ? configuredProjectId
+ : c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject();
+
if (project == null) {
project = c.getPipelineOptions().as(GcpOptions.class).getProject();
}
- String databaseId = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb();
+
+ String databaseId =
+ configuredDatabaseId != null
+ ? configuredDatabaseId
+ : c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb();
databaseRootName =
DatabaseRootName.of(
requireNonNull(
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java
index 2948be7658a9..35d0ea9482d3 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java
@@ -223,6 +223,7 @@ protected BatchWriteFnWithDeadLetterQueue getFn(
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory,
DistributionFactory distributionFactory) {
- return new BatchWriteFnWithDeadLetterQueue(clock, ff, rpcQosOptions, counterFactory);
+ return new BatchWriteFnWithDeadLetterQueue(
+ clock, ff, rpcQosOptions, counterFactory, null, null);
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java
index 70c4ce5046a5..3e37e3975bf5 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java
@@ -201,7 +201,8 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception {
when(callable.call(requestCaptor1.capture())).thenReturn(response1);
BaseBatchWriteFn fn =
- new BatchWriteFnWithSummary(clock, ff, options, CounterFactory.DEFAULT);
+ new BatchWriteFnWithSummary(
+ clock, ff, options, CounterFactory.DEFAULT, "testing-project", "(default)");
fn.setup();
fn.startBundle(startBundleContext);
fn.processElement(processContext, window); // write0
@@ -238,6 +239,15 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception {
verifyNoMoreInteractions(callable);
}
+ @Test
+ public void testWithProjectId_thenWithDatabaseId() {
+ FirestoreV1.Write beamWrite =
+ FirestoreIO.v1().write().withProjectId("my-project").withDatabaseId("(default)");
+
+ assertEquals("my-project", beamWrite.batchWrite().getProjectId());
+ assertEquals("(default)", beamWrite.batchWrite().getDatabaseId());
+ }
+
@Override
protected BatchWriteFnWithSummary getFn(
JodaClock clock,
@@ -245,6 +255,6 @@ protected BatchWriteFnWithSummary getFn(
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory,
DistributionFactory distributionFactory) {
- return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory);
+ return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory, null, null);
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java
index 509797892e04..8695080cb885 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/BaseFirestoreIT.java
@@ -92,10 +92,12 @@ abstract class BaseFirestoreIT {
.build();
protected static String project;
+ protected static String databaseId;
@Before
public void setup() {
project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
+ databaseId = "firestoredb";
}
private static Instant toWriteTime(WriteResult result) {
@@ -441,7 +443,14 @@ protected final void runWriteTest(
testPipeline
.apply(Create.of(Collections.singletonList(documentIds)))
.apply(createWrite)
- .apply(FirestoreIO.v1().write().batchWrite().withRpcQosOptions(RPC_QOS_OPTIONS).build());
+ .apply(
+ FirestoreIO.v1()
+ .write()
+ .withProjectId(project)
+ .withDatabaseId(databaseId)
+ .batchWrite()
+ .withRpcQosOptions(RPC_QOS_OPTIONS)
+ .build());
testPipeline.run(TestPipeline.testingPipelineOptions());
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java
index 204aa67619bd..dce3ca4b8753 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreV1IT.java
@@ -116,6 +116,8 @@ public void batchWrite_partialFailureOutputsToDeadLetterQueue()
.apply(
FirestoreIO.v1()
.write()
+ .withProjectId(project)
+ .withDatabaseId(databaseId)
.batchWrite()
.withDeadLetterQueue()
.withRpcQosOptions(options)