diff --git a/pom.xml b/pom.xml index c3a6b73dfb1a..da2da26b9b30 100644 --- a/pom.xml +++ b/pom.xml @@ -125,7 +125,7 @@ 0.5.160304 20.0 1.2.0 - 0.1.0 + 0.1.9 1.3 2.8.8 3.0.1 @@ -175,8 +175,8 @@ - release @@ -493,7 +493,7 @@ beam-sdks-java-io-hadoop-input-format ${project.version} - + org.apache.beam beam-runners-core-construction-java @@ -737,13 +737,13 @@ google-auth-library-credentials ${google-auth.version} - + com.google.auth google-auth-library-oauth2-http ${google-auth.version} - com.google.guava @@ -808,12 +808,24 @@ + + com.google.api.grpc + proto-google-cloud-spanner-admin-database-v1 + ${grpc-google-common-protos.version} + + + + com.google.api.grpc + proto-google-common-protos + ${grpc-google-common-protos.version} + + com.google.apis google-api-services-storage ${storage.version} - com.google.guava @@ -900,7 +912,7 @@ google-api-services-dataflow ${dataflow.version} - com.google.guava @@ -914,7 +926,7 @@ google-api-services-clouddebugger ${clouddebugger.version} - com.google.guava @@ -1015,7 +1027,7 @@ byte-buddy 1.6.8 - + org.springframework spring-expression @@ -1122,7 +1134,7 @@ maven-antrun-plugin 1.8 - + org.apache.maven.plugins maven-checkstyle-plugin @@ -1393,7 +1405,7 @@ - org.eclipse.m2e @@ -1730,7 +1742,7 @@ ${basedir}/sdks/python - + ${basedir} @@ -1739,8 +1751,8 @@ README.md - - + + diff --git a/sdks/java/io/google-cloud-platform/pom.xml b/sdks/java/io/google-cloud-platform/pom.xml index 9143ccf553cf..8b5382092ec6 100644 --- a/sdks/java/io/google-cloud-platform/pom.xml +++ b/sdks/java/io/google-cloud-platform/pom.xml @@ -86,11 +86,6 @@ grpc-core - - com.google.api.grpc - grpc-google-common-protos - - com.google.apis google-api-services-bigquery @@ -248,6 +243,21 @@ true + + com.google.api.grpc + proto-google-cloud-spanner-admin-database-v1 + + + + com.google.api.grpc + proto-google-common-protos + + + + org.apache.commons + commons-lang3 + + org.apache.beam diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimator.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimator.java new file mode 100644 index 000000000000..61652e736e90 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimator.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import com.google.cloud.ByteArray; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.Value; + +/** Estimates the logical size of {@link com.google.cloud.spanner.Mutation}. */ +class MutationSizeEstimator { + + // Prevent construction. + private MutationSizeEstimator() {} + + /** Estimates a size of mutation in bytes. */ + static long sizeOf(Mutation m) { + long result = 0; + for (Value v : m.getValues()) { + switch (v.getType().getCode()) { + case ARRAY: + result += estimateArrayValue(v); + break; + case STRUCT: + throw new IllegalArgumentException("Structs are not supported in mutation."); + default: + result += estimatePrimitiveValue(v); + } + } + return result; + } + + private static long estimatePrimitiveValue(Value v) { + switch (v.getType().getCode()) { + case BOOL: + return 1; + case INT64: + case FLOAT64: + return 8; + case DATE: + case TIMESTAMP: + return 12; + case STRING: + return v.isNull() ? 0 : v.getString().length(); + case BYTES: + return v.isNull() ? 0 : v.getBytes().length(); + } + throw new IllegalArgumentException("Unsupported type " + v.getType()); + } + + private static long estimateArrayValue(Value v) { + switch (v.getType().getArrayElementType().getCode()) { + case BOOL: + return v.getBoolArray().size(); + case INT64: + return 8 * v.getInt64Array().size(); + case FLOAT64: + return 8 * v.getFloat64Array().size(); + case STRING: + long totalLength = 0; + for (String s : v.getStringArray()) { + if (s == null) { + continue; + } + totalLength += s.length(); + } + return totalLength; + case BYTES: + totalLength = 0; + for (ByteArray bytes : v.getBytesArray()) { + if (bytes == null) { + continue; + } + totalLength += bytes.length(); + } + return totalLength; + case DATE: + return 12 * v.getDateArray().size(); + case TIMESTAMP: + return 12 * v.getTimestampArray().size(); + } + throw new IllegalArgumentException("Unsupported type " + v.getType()); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index ec119311c106..c5325bb5a85f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -20,6 +20,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.auto.value.AutoValue; +import com.google.cloud.ServiceFactory; +import com.google.cloud.ServiceOptions; import com.google.cloud.spanner.AbortedException; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; @@ -32,6 +34,7 @@ import java.util.List; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -48,8 +51,8 @@ import org.slf4j.LoggerFactory; /** - * Experimental {@link PTransform Transforms} for reading from and writing to - * Google Cloud Spanner. + * Experimental {@link PTransform Transforms} for reading from and writing to Google Cloud Spanner. * *

Reading from Cloud Spanner

* @@ -72,21 +75,35 @@ * mutations.apply( * "Write", SpannerIO.write().withInstanceId("instance").withDatabaseId("database")); * } + * + *

The default size of the batch is set to 1MB, to override this use {@link + * Write#withBatchSizeBytes(long)}. Setting batch size to a small value or zero practically disables + * batching. + * + *

The transform does not provide same transactional guarantees as Cloud Spanner. In particular, + * + *

    + *
  • Mutations are not submitted atomically; + *
  • A mutation is applied at least once; + *
  • If the pipeline was unexpectedly stopped, mutations that were already applied will not get + * rolled back. + *
*/ @Experimental(Experimental.Kind.SOURCE_SINK) public class SpannerIO { - @VisibleForTesting - static final int SPANNER_MUTATIONS_PER_COMMIT_LIMIT = 20000; + private static final long DEFAULT_BATCH_SIZE_BYTES = 1024 * 1024; // 1 MB /** - * Creates an unitialized instance of {@link Write}. Before use, the {@link Write} must be + * Creates an uninitialized instance of {@link Write}. Before use, the {@link Write} must be * configured with a {@link Write#withInstanceId} and {@link Write#withDatabaseId} that identify * the Cloud Spanner database being written. */ @Experimental public static Write write() { - return new AutoValue_SpannerIO_Write.Builder().build(); + return new AutoValue_SpannerIO_Write.Builder() + .setBatchSizeBytes(DEFAULT_BATCH_SIZE_BYTES) + .build(); } /** @@ -98,24 +115,57 @@ public static Write write() { @AutoValue public abstract static class Write extends PTransform, PDone> { + @Nullable + abstract String getProjectId(); + @Nullable abstract String getInstanceId(); @Nullable abstract String getDatabaseId(); + abstract long getBatchSizeBytes(); + + @Nullable + @VisibleForTesting + abstract ServiceFactory getServiceFactory(); + abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { + abstract Builder setProjectId(String projectId); + abstract Builder setInstanceId(String instanceId); abstract Builder setDatabaseId(String databaseId); + abstract Builder setBatchSizeBytes(long batchSizeBytes); + + @VisibleForTesting + abstract Builder setServiceFactory(ServiceFactory serviceFactory); + abstract Write build(); } + SpannerOptions getSpannerOptions() { + SpannerOptions.Builder builder = SpannerOptions.newBuilder(); + if (getServiceFactory() != null) { + builder.setServiceFactory(getServiceFactory()); + } + return builder.build(); + } + + /** + * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner project. + * + *

Does not modify this object. + */ + public Write withProjectId(String projectId) { + return toBuilder().setProjectId(projectId).build(); + } + /** * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner * instance. @@ -126,6 +176,15 @@ public Write withInstanceId(String instanceId) { return toBuilder().setInstanceId(instanceId).build(); } + /** + * Returns a new {@link SpannerIO.Write} with a new batch size limit. + * + *

Does not modify this object. + */ + public Write withBatchSizeBytes(long batchSizeBytes) { + return toBuilder().setBatchSizeBytes(batchSizeBytes).build(); + } + /** * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner * database. @@ -136,12 +195,24 @@ public Write withDatabaseId(String databaseId) { return toBuilder().setDatabaseId(databaseId).build(); } + @VisibleForTesting + Write withServiceFactory(ServiceFactory serviceFactory) { + return toBuilder().setServiceFactory(serviceFactory).build(); + } + @Override - public PDone expand(PCollection input) { - input.apply("Write mutations to Spanner", - ParDo.of(new SpannerWriterFn( - getInstanceId(), getDatabaseId(), SPANNER_MUTATIONS_PER_COMMIT_LIMIT))); + public void validate(PipelineOptions options) { + checkNotNull( + getInstanceId(), + "SpannerIO.write() requires instance id to be set with withInstanceId method"); + checkNotNull( + getDatabaseId(), + "SpannerIO.write() requires database id to be set with withDatabaseId method"); + } + @Override + public PDone expand(PCollection input) { + input.apply("Write mutations to Cloud Spanner", ParDo.of(new SpannerWriteFn(this))); return PDone.in(input.getPipeline()); } @@ -149,64 +220,69 @@ public PDone expand(PCollection input) { public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder - .addIfNotNull(DisplayData.item("instanceId", getInstanceId()) - .withLabel("Output Instance")) - .addIfNotNull(DisplayData.item("databaseId", getDatabaseId()) - .withLabel("Output Database")); + .addIfNotNull(DisplayData.item("projectId", getProjectId()).withLabel("Output Project")) + .addIfNotNull( + DisplayData.item("instanceId", getInstanceId()).withLabel("Output Instance")) + .addIfNotNull( + DisplayData.item("databaseId", getDatabaseId()).withLabel("Output Database")) + .add(DisplayData.item("batchSizeBytes", getBatchSizeBytes()) + .withLabel("Batch Size in Bytes")); + if (getServiceFactory() != null) { + builder.addIfNotNull( + DisplayData.item("serviceFactory", getServiceFactory().getClass().getName()) + .withLabel("Service Factory")); + } } } - /** - * {@link DoFn} that writes {@link Mutation}s to Google Cloud Spanner. Mutations are written in - * batches, where the maximum batch size is {@link SpannerIO#SPANNER_MUTATIONS_PER_COMMIT_LIMIT}. - * - *

Commits are non-transactional. If a commit fails, it will be retried (up to - * {@link SpannerWriterFn#MAX_RETRIES} times). This means that the mutation operation should be - * idempotent. - * - *

See Google Cloud Spanner documentation. - */ + /** Batches together and writes mutations to Google Cloud Spanner. */ @VisibleForTesting - static class SpannerWriterFn extends DoFn { - private static final Logger LOG = LoggerFactory.getLogger(SpannerWriterFn.class); + static class SpannerWriteFn extends DoFn { + private static final Logger LOG = LoggerFactory.getLogger(SpannerWriteFn.class); + private final Write spec; private transient Spanner spanner; - private final String instanceId; - private final String databaseId; - private final int batchSize; private transient DatabaseClient dbClient; // Current batch of mutations to be written. - private final List mutations = new ArrayList<>(); + private List mutations; + private long batchSizeBytes = 0; private static final int MAX_RETRIES = 5; private static final FluentBackoff BUNDLE_WRITE_BACKOFF = FluentBackoff.DEFAULT - .withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5)); + .withMaxRetries(MAX_RETRIES) + .withInitialBackoff(Duration.standardSeconds(5)); @VisibleForTesting - SpannerWriterFn(String instanceId, String databaseId, int batchSize) { - this.instanceId = checkNotNull(instanceId, "instanceId"); - this.databaseId = checkNotNull(databaseId, "databaseId"); - this.batchSize = batchSize; + SpannerWriteFn(Write spec) { + this.spec = spec; } @Setup public void setup() throws Exception { - SpannerOptions options = SpannerOptions.newBuilder().build(); - spanner = options.getService(); - dbClient = spanner.getDatabaseClient( - DatabaseId.of(options.getProjectId(), instanceId, databaseId)); + spanner = spec.getSpannerOptions().getService(); + dbClient = + spanner.getDatabaseClient( + DatabaseId.of(projectId(), spec.getInstanceId(), spec.getDatabaseId())); + mutations = new ArrayList<>(); + batchSizeBytes = 0; } @ProcessElement public void processElement(ProcessContext c) throws Exception { Mutation m = c.element(); mutations.add(m); - int columnCount = m.asMap().size(); - if ((mutations.size() + 1) * columnCount >= batchSize) { + batchSizeBytes += MutationSizeEstimator.sizeOf(m); + if (batchSizeBytes >= spec.getBatchSizeBytes()) { flushBatch(); } } + private String projectId() { + return spec.getProjectId() == null + ? ServiceOptions.getDefaultProjectId() + : spec.getProjectId(); + } + @FinishBundle public void finishBundle() throws Exception { if (!mutations.isEmpty()) { @@ -217,20 +293,20 @@ public void finishBundle() throws Exception { @Teardown public void teardown() throws Exception { if (spanner == null) { - return; + return; } spanner.closeAsync().get(); + spanner = null; } /** * Writes a batch of mutations to Cloud Spanner. * - *

If a commit fails, it will be retried up to {@link #MAX_RETRIES} times. - * If the retry limit is exceeded, the last exception from Cloud Spanner will be - * thrown. + *

If a commit fails, it will be retried up to {@link #MAX_RETRIES} times. If the retry limit + * is exceeded, the last exception from Cloud Spanner will be thrown. * * @throws AbortedException if the commit fails or IOException or InterruptedException if - * backing off between retries fails. + * backing off between retries fails. */ private void flushBatch() throws AbortedException, IOException, InterruptedException { LOG.debug("Writing batch of {} mutations", mutations.size()); @@ -247,8 +323,8 @@ private void flushBatch() throws AbortedException, IOException, InterruptedExcep } catch (AbortedException exception) { // Only log the code and message for potentially-transient errors. The entire exception // will be propagated upon the last retry. - LOG.error("Error writing to Spanner ({}): {}", exception.getCode(), - exception.getMessage()); + LOG.error( + "Error writing to Spanner ({}): {}", exception.getCode(), exception.getMessage()); if (!BackOffUtils.next(sleeper, backoff)) { LOG.error("Aborting after {} retries.", MAX_RETRIES); throw exception; @@ -256,20 +332,16 @@ private void flushBatch() throws AbortedException, IOException, InterruptedExcep } } LOG.debug("Successfully wrote {} mutations", mutations.size()); - mutations.clear(); + mutations = new ArrayList<>(); + batchSizeBytes = 0; } @Override public void populateDisplayData(Builder builder) { super.populateDisplayData(builder); - builder - .addIfNotNull(DisplayData.item("instanceId", instanceId) - .withLabel("Instance")) - .addIfNotNull(DisplayData.item("databaseId", databaseId) - .withLabel("Database")); + spec.populateDisplayData(builder); } } private SpannerIO() {} // Prevent construction. - } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java new file mode 100644 index 000000000000..03eb28ed943d --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import com.google.cloud.ByteArray; +import com.google.cloud.Date; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.Mutation; +import java.util.Arrays; +import org.junit.Test; + +/** A set of unit tests for {@link MutationSizeEstimator}. */ +public class MutationSizeEstimatorTest { + + @Test + public void primitives() throws Exception { + Mutation int64 = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); + Mutation float64 = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2.9).build(); + Mutation bool = Mutation.newInsertOrUpdateBuilder("test").set("one").to(false).build(); + + assertThat(MutationSizeEstimator.sizeOf(int64), is(8L)); + assertThat(MutationSizeEstimator.sizeOf(float64), is(8L)); + assertThat(MutationSizeEstimator.sizeOf(bool), is(1L)); + } + + @Test + public void primitiveArrays() throws Exception { + Mutation int64 = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .toInt64Array(new long[] {1L, 2L, 3L}) + .build(); + Mutation float64 = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .toFloat64Array(new double[] {1., 2.}) + .build(); + Mutation bool = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .toBoolArray(new boolean[] {true, true, false, true}) + .build(); + + assertThat(MutationSizeEstimator.sizeOf(int64), is(24L)); + assertThat(MutationSizeEstimator.sizeOf(float64), is(16L)); + assertThat(MutationSizeEstimator.sizeOf(bool), is(4L)); + } + + @Test + public void strings() throws Exception { + Mutation emptyString = Mutation.newInsertOrUpdateBuilder("test").set("one").to("").build(); + Mutation nullString = + Mutation.newInsertOrUpdateBuilder("test").set("one").to((String) null).build(); + Mutation sampleString = Mutation.newInsertOrUpdateBuilder("test").set("one").to("abc").build(); + Mutation sampleArray = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .toStringArray(Arrays.asList("one", "two", null)) + .build(); + + assertThat(MutationSizeEstimator.sizeOf(emptyString), is(0L)); + assertThat(MutationSizeEstimator.sizeOf(nullString), is(0L)); + assertThat(MutationSizeEstimator.sizeOf(sampleString), is(3L)); + assertThat(MutationSizeEstimator.sizeOf(sampleArray), is(6L)); + } + + @Test + public void bytes() throws Exception { + Mutation empty = + Mutation.newInsertOrUpdateBuilder("test").set("one").to(ByteArray.fromBase64("")).build(); + Mutation nullValue = + Mutation.newInsertOrUpdateBuilder("test").set("one").to((ByteArray) null).build(); + Mutation sample = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .to(ByteArray.fromBase64("abcdabcd")) + .build(); + + assertThat(MutationSizeEstimator.sizeOf(empty), is(0L)); + assertThat(MutationSizeEstimator.sizeOf(nullValue), is(0L)); + assertThat(MutationSizeEstimator.sizeOf(sample), is(6L)); + } + + @Test + public void dates() throws Exception { + Mutation timestamp = + Mutation.newInsertOrUpdateBuilder("test").set("one").to(Timestamp.now()).build(); + Mutation nullTimestamp = + Mutation.newInsertOrUpdateBuilder("test").set("one").to((Timestamp) null).build(); + Mutation date = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .to(Date.fromYearMonthDay(2017, 10, 10)) + .build(); + Mutation nullDate = + Mutation.newInsertOrUpdateBuilder("test").set("one").to((Date) null).build(); + Mutation timestampArray = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .toTimestampArray(Arrays.asList(Timestamp.now(), null)) + .build(); + Mutation dateArray = + Mutation.newInsertOrUpdateBuilder("test") + .set("one") + .toDateArray( + Arrays.asList( + null, + Date.fromYearMonthDay(2017, 1, 1), + null, + Date.fromYearMonthDay(2017, 1, 2))) + .build(); + + assertThat(MutationSizeEstimator.sizeOf(timestamp), is(12L)); + assertThat(MutationSizeEstimator.sizeOf(date), is(12L)); + assertThat(MutationSizeEstimator.sizeOf(nullTimestamp), is(12L)); + assertThat(MutationSizeEstimator.sizeOf(nullDate), is(12L)); + assertThat(MutationSizeEstimator.sizeOf(timestampArray), is(24L)); + assertThat(MutationSizeEstimator.sizeOf(dateArray), is(48L)); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java new file mode 100644 index 000000000000..5bdfea5522b2 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import static org.mockito.Mockito.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.api.core.ApiFuture; +import com.google.cloud.ServiceFactory; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import com.google.common.collect.Iterables; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import javax.annotation.concurrent.GuardedBy; + +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFnTester; +import org.apache.beam.sdk.values.PCollection; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatcher; +import org.mockito.Matchers; + + +/** + * Unit tests for {@link SpannerIO}. + */ +@RunWith(JUnit4.class) +public class SpannerIOTest implements Serializable { + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + private FakeServiceFactory serviceFactory; + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + serviceFactory = new FakeServiceFactory(); + } + + @Test + public void emptyTransform() throws Exception { + SpannerIO.Write write = SpannerIO.write(); + thrown.expect(NullPointerException.class); + thrown.expectMessage("requires instance id to be set with"); + write.validate(null); + } + + @Test + public void emptyInstanceId() throws Exception { + SpannerIO.Write write = SpannerIO.write().withDatabaseId("123"); + thrown.expect(NullPointerException.class); + thrown.expectMessage("requires instance id to be set with"); + write.validate(null); + } + + @Test + public void emptyDatabaseId() throws Exception { + SpannerIO.Write write = SpannerIO.write().withInstanceId("123"); + thrown.expect(NullPointerException.class); + thrown.expectMessage("requires database id to be set with"); + write.validate(null); + } + + @Test + @Category(NeedsRunner.class) + public void singleMutationPipeline() throws Exception { + Mutation mutation = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2).build(); + PCollection mutations = pipeline.apply(Create.of(mutation)); + + mutations.apply( + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withServiceFactory(serviceFactory)); + pipeline.run(); + verify(serviceFactory.mockSpanner()) + .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database")); + verify(serviceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(argThat(new IterableOfSize(1))); + } + + @Test + public void batching() throws Exception { + Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); + Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build(); + SpannerIO.Write write = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withBatchSizeBytes(1000000000) + .withServiceFactory(serviceFactory); + SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write); + DoFnTester fnTester = DoFnTester.of(writerFn); + fnTester.processBundle(Arrays.asList(one, two)); + + verify(serviceFactory.mockSpanner()) + .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database")); + verify(serviceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(argThat(new IterableOfSize(2))); + } + + @Test + public void batchingGroups() throws Exception { + Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); + Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build(); + Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build(); + + // Have a room to accumulate one more item. + long batchSize = MutationSizeEstimator.sizeOf(one) + 1; + + SpannerIO.Write write = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withBatchSizeBytes(batchSize) + .withServiceFactory(serviceFactory); + SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write); + DoFnTester fnTester = DoFnTester.of(writerFn); + fnTester.processBundle(Arrays.asList(one, two, three)); + + verify(serviceFactory.mockSpanner()) + .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database")); + verify(serviceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(argThat(new IterableOfSize(2))); + verify(serviceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(argThat(new IterableOfSize(1))); + } + + @Test + public void noBatching() throws Exception { + Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build(); + Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build(); + SpannerIO.Write write = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withBatchSizeBytes(0) // turn off batching. + .withServiceFactory(serviceFactory); + SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write); + DoFnTester fnTester = DoFnTester.of(writerFn); + fnTester.processBundle(Arrays.asList(one, two)); + + verify(serviceFactory.mockSpanner()) + .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database")); + verify(serviceFactory.mockDatabaseClient(), times(2)) + .writeAtLeastOnce(argThat(new IterableOfSize(1))); + } + + private static class FakeServiceFactory + implements ServiceFactory, Serializable { + // Marked as static so they could be returned by serviceFactory, which is serializable. + private static final Object lock = new Object(); + + @GuardedBy("lock") + private static final List mockSpanners = new ArrayList<>(); + + @GuardedBy("lock") + private static final List mockDatabaseClients = new ArrayList<>(); + + @GuardedBy("lock") + private static int count = 0; + + private final int index; + + public FakeServiceFactory() { + synchronized (lock) { + index = count++; + mockSpanners.add(mock(Spanner.class, withSettings().serializable())); + mockDatabaseClients.add(mock(DatabaseClient.class, withSettings().serializable())); + } + ApiFuture voidFuture = mock(ApiFuture.class, withSettings().serializable()); + when(mockSpanner().getDatabaseClient(Matchers.any(DatabaseId.class))) + .thenReturn(mockDatabaseClient()); + when(mockSpanner().closeAsync()).thenReturn(voidFuture); + } + + DatabaseClient mockDatabaseClient() { + synchronized (lock) { + return mockDatabaseClients.get(index); + } + } + + Spanner mockSpanner() { + synchronized (lock) { + return mockSpanners.get(index); + } + } + + @Override + public Spanner create(SpannerOptions serviceOptions) { + return mockSpanner(); + } + } + + private static class IterableOfSize extends ArgumentMatcher> { + private final int size; + + private IterableOfSize(int size) { + this.size = size; + } + + @Override + public boolean matches(Object argument) { + return argument instanceof Iterable && Iterables.size((Iterable) argument) == size; + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java new file mode 100644 index 000000000000..064c65eedcef --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.DatabaseAdminClient; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.Operation; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.Spanner; +import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.Statement; +import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; +import java.util.Collections; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** End-to-end test of Cloud Spanner Sink. */ +@RunWith(JUnit4.class) +public class SpannerWriteIT { + @Rule public final transient TestPipeline p = TestPipeline.create(); + + /** Pipeline options for this test. */ + public interface SpannerTestPipelineOptions extends TestPipelineOptions { + @Description("Project ID for Spanner") + @Default.String("apache-beam-testing") + String getProjectId(); + void setProjectId(String value); + + @Description("Instance ID to write to in Spanner") + @Default.String("beam-test") + String getInstanceId(); + void setInstanceId(String value); + + @Description("Database ID to write to in Spanner") + @Default.String("beam-testdb") + String getDatabaseId(); + void setDatabaseId(String value); + + @Description("Table name") + @Default.String("users") + String getTable(); + void setTable(String value); + } + + private Spanner spanner; + private DatabaseAdminClient databaseAdminClient; + private SpannerTestPipelineOptions options; + + @Before + public void setUp() throws Exception { + PipelineOptionsFactory.register(SpannerTestPipelineOptions.class); + options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class); + + spanner = SpannerOptions.newBuilder().setProjectId(options.getProjectId()).build().getService(); + + databaseAdminClient = spanner.getDatabaseAdminClient(); + + // Delete database if exists. + databaseAdminClient.dropDatabase(options.getInstanceId(), options.getDatabaseId()); + + Operation op = + databaseAdminClient.createDatabase( + options.getInstanceId(), + options.getDatabaseId(), + Collections.singleton( + "CREATE TABLE " + + options.getTable() + + " (" + + " Key INT64," + + " Value STRING(MAX)," + + ") PRIMARY KEY (Key)")); + op.waitFor(); + } + + @Test + public void testWrite() throws Exception { + p.apply(GenerateSequence.from(0).to(100)) + .apply(ParDo.of(new GenerateMutations(options.getTable()))) + .apply( + SpannerIO.write() + .withProjectId(options.getProjectId()) + .withInstanceId(options.getInstanceId()) + .withDatabaseId(options.getDatabaseId())); + + p.run(); + DatabaseClient databaseClient = + spanner.getDatabaseClient( + DatabaseId.of( + options.getProjectId(), options.getInstanceId(), options.getDatabaseId())); + + ResultSet resultSet = + databaseClient + .singleUse() + .executeQuery(Statement.of("SELECT COUNT(*) FROM " + options.getTable())); + assertThat(resultSet.next(), is(true)); + assertThat(resultSet.getLong(0), equalTo(100L)); + assertThat(resultSet.next(), is(false)); + } + + @After + public void tearDown() throws Exception { + databaseAdminClient.dropDatabase(options.getInstanceId(), options.getDatabaseId()); + spanner.closeAsync().get(); + } + + private static class GenerateMutations extends DoFn { + private final String table; + private final int valueSize = 100; + + public GenerateMutations(String table) { + this.table = table; + } + + @ProcessElement + public void processElement(ProcessContext c) { + Mutation.WriteBuilder builder = Mutation.newInsertOrUpdateBuilder(table); + Long key = c.element(); + builder.set("Key").to(key); + builder.set("Value").to(RandomStringUtils.random(valueSize, true, true)); + Mutation mutation = builder.build(); + c.output(mutation); + } + } +}