From 6a37d3df1f61294bbe0f4e5361e19750691e8429 Mon Sep 17 00:00:00 2001 From: nielm Date: Tue, 18 Jan 2022 17:11:42 +0100 Subject: [PATCH] [BEAM-13665] Make SpannerIO projectID optional again Fixes regression introduced by PR #15493 which inadvertently caused an NPE when the projectID was not specified for a SpannerIO read or write. Adds unit test for reading/writing both with and without projectID --- .../sdk/io/gcp/spanner/BatchSpannerRead.java | 10 +++++++--- .../beam/sdk/io/gcp/spanner/SpannerIO.java | 12 +++++++++--- .../sdk/io/gcp/spanner/SpannerIOReadTest.java | 17 ++++++++++++++--- .../sdk/io/gcp/spanner/SpannerIOWriteTest.java | 15 +++++++++++++++ 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java index 1f4a01d5f135..c61c57e0f716 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java @@ -23,6 +23,7 @@ import com.google.cloud.spanner.Partition; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.SpannerOptions; import com.google.cloud.spanner.Struct; import com.google.cloud.spanner.TimestampBound; import java.util.HashMap; @@ -151,6 +152,7 @@ private static class ReadFromPartitionFn extends DoFn { private final PCollectionView txView; private transient SpannerAccessor spannerAccessor; + private transient String projectId; public ReadFromPartitionFn( SpannerConfig config, PCollectionView txView) { @@ -161,6 +163,10 @@ public ReadFromPartitionFn( @Setup public void setup() throws Exception { spannerAccessor = SpannerAccessor.getOrCreate(config); + projectId = + this.config.getProjectId() == null + ? SpannerOptions.getDefaultProjectId() + : this.config.getProjectId().get(); } @Teardown @@ -172,9 +178,7 @@ public void teardown() throws Exception { public void processElement(ProcessContext c) throws Exception { ServiceCallMetric serviceCallMetric = createServiceCallMetric( - this.config.getProjectId().toString(), - this.config.getDatabaseId().toString(), - this.config.getInstanceId().toString()); + projectId, this.config.getDatabaseId().get(), this.config.getInstanceId().get()); Transaction tx = c.sideInput(txView); BatchReadOnlyTransaction batchTx = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 466fdb56aa26..9d5d97face57 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -1610,6 +1610,7 @@ static class WriteToSpannerFn extends DoFn, Void> { // Fluent Backoff is not serializable so create at runtime in setup(). private transient FluentBackoff bundleWriteBackoff; + private transient String projectId; WriteToSpannerFn( SpannerConfig spannerConfig, FailureMode failureMode, TupleTag failedTag) { @@ -1625,6 +1626,11 @@ public void setup() { FluentBackoff.DEFAULT .withMaxCumulativeBackoff(spannerConfig.getMaxCumulativeBackoff().get()) .withInitialBackoff(spannerConfig.getMaxCumulativeBackoff().get().dividedBy(60)); + + projectId = + this.spannerConfig.getProjectId() == null + ? SpannerOptions.getDefaultProjectId() + : this.spannerConfig.getProjectId().get(); } @Teardown @@ -1680,9 +1686,9 @@ private void spannerWriteWithRetryIfSchemaChange(Iterable batch) for (int retry = 1; ; retry++) { ServiceCallMetric serviceCallMetric = createServiceCallMetric( - this.spannerConfig.getProjectId().toString(), - this.spannerConfig.getDatabaseId().toString(), - this.spannerConfig.getInstanceId().toString(), + projectId, + this.spannerConfig.getDatabaseId().get(), + this.spannerConfig.getInstanceId().get(), "Write"); try { spannerAccessor diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java index 8e1c8330a7b6..67cffe4613ec 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java @@ -192,12 +192,23 @@ private SpannerConfig getSpannerConfig() { } @Test - public void runRead() throws Exception { + public void runReadTestWithProjectId() throws Exception { + runReadTest(getSpannerConfig()); + } + + @Test + public void runReadTestWithDefaultProject() throws Exception { + runReadTest( + SpannerConfig.create() + .withInstanceId("123") + .withDatabaseId("aaa") + .withServiceFactory(serviceFactory)); + } + + private void runReadTest(SpannerConfig spannerConfig) throws Exception { Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345); TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp); - SpannerConfig spannerConfig = getSpannerConfig(); - PCollection one = pipeline.apply( "read q", diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java index ea7ca685166b..506abdd30e8d 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java @@ -256,6 +256,21 @@ public void singleMutationPipeline() throws Exception { verifyBatches(batch(m(2L))); } + @Test + public void singleMutationPipelineNoProjectId() throws Exception { + Mutation mutation = m(2L); + PCollection mutations = pipeline.apply(Create.of(mutation)); + + mutations.apply( + SpannerIO.write() + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withServiceFactory(serviceFactory)); + pipeline.run(); + + verifyBatches(batch(m(2L))); + } + @Test public void singleMutationGroupPipeline() throws Exception { PCollection mutations =