serviceFactory);
+
abstract Write build();
}
+ SpannerOptions getSpannerOptions() {
+ SpannerOptions.Builder builder = SpannerOptions.newBuilder();
+ if (getServiceFactory() != null) {
+ builder.setServiceFactory(getServiceFactory());
+ }
+ return builder.build();
+ }
+
+ /**
+ * Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner project.
+ *
+ * Does not modify this object.
+ */
+ public Write withProjectId(String projectId) {
+ return toBuilder().setProjectId(projectId).build();
+ }
+
/**
* Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner
* instance.
@@ -126,6 +176,15 @@ public Write withInstanceId(String instanceId) {
return toBuilder().setInstanceId(instanceId).build();
}
+ /**
+ * Returns a new {@link SpannerIO.Write} with a new batch size limit.
+ *
+ *
Does not modify this object.
+ */
+ public Write withBatchSizeBytes(long batchSizeBytes) {
+ return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
+ }
+
/**
* Returns a new {@link SpannerIO.Write} that will write to the specified Cloud Spanner
* database.
@@ -136,12 +195,24 @@ public Write withDatabaseId(String databaseId) {
return toBuilder().setDatabaseId(databaseId).build();
}
+ @VisibleForTesting
+ Write withServiceFactory(ServiceFactory serviceFactory) {
+ return toBuilder().setServiceFactory(serviceFactory).build();
+ }
+
@Override
- public PDone expand(PCollection input) {
- input.apply("Write mutations to Spanner",
- ParDo.of(new SpannerWriterFn(
- getInstanceId(), getDatabaseId(), SPANNER_MUTATIONS_PER_COMMIT_LIMIT)));
+ public void validate(PipelineOptions options) {
+ checkNotNull(
+ getInstanceId(),
+ "SpannerIO.write() requires instance id to be set with withInstanceId method");
+ checkNotNull(
+ getDatabaseId(),
+ "SpannerIO.write() requires database id to be set with withDatabaseId method");
+ }
+ @Override
+ public PDone expand(PCollection input) {
+ input.apply("Write mutations to Cloud Spanner", ParDo.of(new SpannerWriteFn(this)));
return PDone.in(input.getPipeline());
}
@@ -149,64 +220,69 @@ public PDone expand(PCollection input) {
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder
- .addIfNotNull(DisplayData.item("instanceId", getInstanceId())
- .withLabel("Output Instance"))
- .addIfNotNull(DisplayData.item("databaseId", getDatabaseId())
- .withLabel("Output Database"));
+ .addIfNotNull(DisplayData.item("projectId", getProjectId()).withLabel("Output Project"))
+ .addIfNotNull(
+ DisplayData.item("instanceId", getInstanceId()).withLabel("Output Instance"))
+ .addIfNotNull(
+ DisplayData.item("databaseId", getDatabaseId()).withLabel("Output Database"))
+ .add(DisplayData.item("batchSizeBytes", getBatchSizeBytes())
+ .withLabel("Batch Size in Bytes"));
+ if (getServiceFactory() != null) {
+ builder.addIfNotNull(
+ DisplayData.item("serviceFactory", getServiceFactory().getClass().getName())
+ .withLabel("Service Factory"));
+ }
}
}
- /**
- * {@link DoFn} that writes {@link Mutation}s to Google Cloud Spanner. Mutations are written in
- * batches, where the maximum batch size is {@link SpannerIO#SPANNER_MUTATIONS_PER_COMMIT_LIMIT}.
- *
- * Commits are non-transactional. If a commit fails, it will be retried (up to
- * {@link SpannerWriterFn#MAX_RETRIES} times). This means that the mutation operation should be
- * idempotent.
- *
- *
See Google Cloud Spanner documentation.
- */
+ /** Batches together and writes mutations to Google Cloud Spanner. */
@VisibleForTesting
- static class SpannerWriterFn extends DoFn {
- private static final Logger LOG = LoggerFactory.getLogger(SpannerWriterFn.class);
+ static class SpannerWriteFn extends DoFn {
+ private static final Logger LOG = LoggerFactory.getLogger(SpannerWriteFn.class);
+ private final Write spec;
private transient Spanner spanner;
- private final String instanceId;
- private final String databaseId;
- private final int batchSize;
private transient DatabaseClient dbClient;
// Current batch of mutations to be written.
- private final List mutations = new ArrayList<>();
+ private List mutations;
+ private long batchSizeBytes = 0;
private static final int MAX_RETRIES = 5;
private static final FluentBackoff BUNDLE_WRITE_BACKOFF =
FluentBackoff.DEFAULT
- .withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5));
+ .withMaxRetries(MAX_RETRIES)
+ .withInitialBackoff(Duration.standardSeconds(5));
@VisibleForTesting
- SpannerWriterFn(String instanceId, String databaseId, int batchSize) {
- this.instanceId = checkNotNull(instanceId, "instanceId");
- this.databaseId = checkNotNull(databaseId, "databaseId");
- this.batchSize = batchSize;
+ SpannerWriteFn(Write spec) {
+ this.spec = spec;
}
@Setup
public void setup() throws Exception {
- SpannerOptions options = SpannerOptions.newBuilder().build();
- spanner = options.getService();
- dbClient = spanner.getDatabaseClient(
- DatabaseId.of(options.getProjectId(), instanceId, databaseId));
+ spanner = spec.getSpannerOptions().getService();
+ dbClient =
+ spanner.getDatabaseClient(
+ DatabaseId.of(projectId(), spec.getInstanceId(), spec.getDatabaseId()));
+ mutations = new ArrayList<>();
+ batchSizeBytes = 0;
}
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
Mutation m = c.element();
mutations.add(m);
- int columnCount = m.asMap().size();
- if ((mutations.size() + 1) * columnCount >= batchSize) {
+ batchSizeBytes += MutationSizeEstimator.sizeOf(m);
+ if (batchSizeBytes >= spec.getBatchSizeBytes()) {
flushBatch();
}
}
+ private String projectId() {
+ return spec.getProjectId() == null
+ ? ServiceOptions.getDefaultProjectId()
+ : spec.getProjectId();
+ }
+
@FinishBundle
public void finishBundle() throws Exception {
if (!mutations.isEmpty()) {
@@ -217,20 +293,20 @@ public void finishBundle() throws Exception {
@Teardown
public void teardown() throws Exception {
if (spanner == null) {
- return;
+ return;
}
spanner.closeAsync().get();
+ spanner = null;
}
/**
* Writes a batch of mutations to Cloud Spanner.
*
- * If a commit fails, it will be retried up to {@link #MAX_RETRIES} times.
- * If the retry limit is exceeded, the last exception from Cloud Spanner will be
- * thrown.
+ *
If a commit fails, it will be retried up to {@link #MAX_RETRIES} times. If the retry limit
+ * is exceeded, the last exception from Cloud Spanner will be thrown.
*
* @throws AbortedException if the commit fails or IOException or InterruptedException if
- * backing off between retries fails.
+ * backing off between retries fails.
*/
private void flushBatch() throws AbortedException, IOException, InterruptedException {
LOG.debug("Writing batch of {} mutations", mutations.size());
@@ -247,8 +323,8 @@ private void flushBatch() throws AbortedException, IOException, InterruptedExcep
} catch (AbortedException exception) {
// Only log the code and message for potentially-transient errors. The entire exception
// will be propagated upon the last retry.
- LOG.error("Error writing to Spanner ({}): {}", exception.getCode(),
- exception.getMessage());
+ LOG.error(
+ "Error writing to Spanner ({}): {}", exception.getCode(), exception.getMessage());
if (!BackOffUtils.next(sleeper, backoff)) {
LOG.error("Aborting after {} retries.", MAX_RETRIES);
throw exception;
@@ -256,20 +332,16 @@ private void flushBatch() throws AbortedException, IOException, InterruptedExcep
}
}
LOG.debug("Successfully wrote {} mutations", mutations.size());
- mutations.clear();
+ mutations = new ArrayList<>();
+ batchSizeBytes = 0;
}
@Override
public void populateDisplayData(Builder builder) {
super.populateDisplayData(builder);
- builder
- .addIfNotNull(DisplayData.item("instanceId", instanceId)
- .withLabel("Instance"))
- .addIfNotNull(DisplayData.item("databaseId", databaseId)
- .withLabel("Database"));
+ spec.populateDisplayData(builder);
}
}
private SpannerIO() {} // Prevent construction.
-
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java
new file mode 100644
index 000000000000..03eb28ed943d
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationSizeEstimatorTest.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.ByteArray;
+import com.google.cloud.Date;
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.Mutation;
+import java.util.Arrays;
+import org.junit.Test;
+
+/** A set of unit tests for {@link MutationSizeEstimator}. */
+public class MutationSizeEstimatorTest {
+
+ @Test
+ public void primitives() throws Exception {
+ Mutation int64 = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
+ Mutation float64 = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2.9).build();
+ Mutation bool = Mutation.newInsertOrUpdateBuilder("test").set("one").to(false).build();
+
+ assertThat(MutationSizeEstimator.sizeOf(int64), is(8L));
+ assertThat(MutationSizeEstimator.sizeOf(float64), is(8L));
+ assertThat(MutationSizeEstimator.sizeOf(bool), is(1L));
+ }
+
+ @Test
+ public void primitiveArrays() throws Exception {
+ Mutation int64 =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .toInt64Array(new long[] {1L, 2L, 3L})
+ .build();
+ Mutation float64 =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .toFloat64Array(new double[] {1., 2.})
+ .build();
+ Mutation bool =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .toBoolArray(new boolean[] {true, true, false, true})
+ .build();
+
+ assertThat(MutationSizeEstimator.sizeOf(int64), is(24L));
+ assertThat(MutationSizeEstimator.sizeOf(float64), is(16L));
+ assertThat(MutationSizeEstimator.sizeOf(bool), is(4L));
+ }
+
+ @Test
+ public void strings() throws Exception {
+ Mutation emptyString = Mutation.newInsertOrUpdateBuilder("test").set("one").to("").build();
+ Mutation nullString =
+ Mutation.newInsertOrUpdateBuilder("test").set("one").to((String) null).build();
+ Mutation sampleString = Mutation.newInsertOrUpdateBuilder("test").set("one").to("abc").build();
+ Mutation sampleArray =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .toStringArray(Arrays.asList("one", "two", null))
+ .build();
+
+ assertThat(MutationSizeEstimator.sizeOf(emptyString), is(0L));
+ assertThat(MutationSizeEstimator.sizeOf(nullString), is(0L));
+ assertThat(MutationSizeEstimator.sizeOf(sampleString), is(3L));
+ assertThat(MutationSizeEstimator.sizeOf(sampleArray), is(6L));
+ }
+
+ @Test
+ public void bytes() throws Exception {
+ Mutation empty =
+ Mutation.newInsertOrUpdateBuilder("test").set("one").to(ByteArray.fromBase64("")).build();
+ Mutation nullValue =
+ Mutation.newInsertOrUpdateBuilder("test").set("one").to((ByteArray) null).build();
+ Mutation sample =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .to(ByteArray.fromBase64("abcdabcd"))
+ .build();
+
+ assertThat(MutationSizeEstimator.sizeOf(empty), is(0L));
+ assertThat(MutationSizeEstimator.sizeOf(nullValue), is(0L));
+ assertThat(MutationSizeEstimator.sizeOf(sample), is(6L));
+ }
+
+ @Test
+ public void dates() throws Exception {
+ Mutation timestamp =
+ Mutation.newInsertOrUpdateBuilder("test").set("one").to(Timestamp.now()).build();
+ Mutation nullTimestamp =
+ Mutation.newInsertOrUpdateBuilder("test").set("one").to((Timestamp) null).build();
+ Mutation date =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .to(Date.fromYearMonthDay(2017, 10, 10))
+ .build();
+ Mutation nullDate =
+ Mutation.newInsertOrUpdateBuilder("test").set("one").to((Date) null).build();
+ Mutation timestampArray =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .toTimestampArray(Arrays.asList(Timestamp.now(), null))
+ .build();
+ Mutation dateArray =
+ Mutation.newInsertOrUpdateBuilder("test")
+ .set("one")
+ .toDateArray(
+ Arrays.asList(
+ null,
+ Date.fromYearMonthDay(2017, 1, 1),
+ null,
+ Date.fromYearMonthDay(2017, 1, 2)))
+ .build();
+
+ assertThat(MutationSizeEstimator.sizeOf(timestamp), is(12L));
+ assertThat(MutationSizeEstimator.sizeOf(date), is(12L));
+ assertThat(MutationSizeEstimator.sizeOf(nullTimestamp), is(12L));
+ assertThat(MutationSizeEstimator.sizeOf(nullDate), is(12L));
+ assertThat(MutationSizeEstimator.sizeOf(timestampArray), is(24L));
+ assertThat(MutationSizeEstimator.sizeOf(dateArray), is(48L));
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java
new file mode 100644
index 000000000000..5bdfea5522b2
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOTest.java
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static org.mockito.Mockito.argThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.withSettings;
+
+import com.google.api.core.ApiFuture;
+import com.google.cloud.ServiceFactory;
+import com.google.cloud.spanner.DatabaseClient;
+import com.google.cloud.spanner.DatabaseId;
+import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.Spanner;
+import com.google.cloud.spanner.SpannerOptions;
+import com.google.common.collect.Iterables;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import javax.annotation.concurrent.GuardedBy;
+
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFnTester;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentMatcher;
+import org.mockito.Matchers;
+
+
+/**
+ * Unit tests for {@link SpannerIO}.
+ */
+@RunWith(JUnit4.class)
+public class SpannerIOTest implements Serializable {
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+ @Rule public transient ExpectedException thrown = ExpectedException.none();
+
+ private FakeServiceFactory serviceFactory;
+
+ @Before
+ @SuppressWarnings("unchecked")
+ public void setUp() throws Exception {
+ serviceFactory = new FakeServiceFactory();
+ }
+
+ @Test
+ public void emptyTransform() throws Exception {
+ SpannerIO.Write write = SpannerIO.write();
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires instance id to be set with");
+ write.validate(null);
+ }
+
+ @Test
+ public void emptyInstanceId() throws Exception {
+ SpannerIO.Write write = SpannerIO.write().withDatabaseId("123");
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires instance id to be set with");
+ write.validate(null);
+ }
+
+ @Test
+ public void emptyDatabaseId() throws Exception {
+ SpannerIO.Write write = SpannerIO.write().withInstanceId("123");
+ thrown.expect(NullPointerException.class);
+ thrown.expectMessage("requires database id to be set with");
+ write.validate(null);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void singleMutationPipeline() throws Exception {
+ Mutation mutation = Mutation.newInsertOrUpdateBuilder("test").set("one").to(2).build();
+ PCollection mutations = pipeline.apply(Create.of(mutation));
+
+ mutations.apply(
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withServiceFactory(serviceFactory));
+ pipeline.run();
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+ }
+
+ @Test
+ public void batching() throws Exception {
+ Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
+ Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(1000000000)
+ .withServiceFactory(serviceFactory);
+ SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write);
+ DoFnTester fnTester = DoFnTester.of(writerFn);
+ fnTester.processBundle(Arrays.asList(one, two));
+
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(2)));
+ }
+
+ @Test
+ public void batchingGroups() throws Exception {
+ Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
+ Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
+ Mutation three = Mutation.newInsertOrUpdateBuilder("test").set("three").to(3).build();
+
+ // Have a room to accumulate one more item.
+ long batchSize = MutationSizeEstimator.sizeOf(one) + 1;
+
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(batchSize)
+ .withServiceFactory(serviceFactory);
+ SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write);
+ DoFnTester fnTester = DoFnTester.of(writerFn);
+ fnTester.processBundle(Arrays.asList(one, two, three));
+
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(2)));
+ verify(serviceFactory.mockDatabaseClient(), times(1))
+ .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+ }
+
+ @Test
+ public void noBatching() throws Exception {
+ Mutation one = Mutation.newInsertOrUpdateBuilder("test").set("one").to(1).build();
+ Mutation two = Mutation.newInsertOrUpdateBuilder("test").set("two").to(2).build();
+ SpannerIO.Write write =
+ SpannerIO.write()
+ .withProjectId("test-project")
+ .withInstanceId("test-instance")
+ .withDatabaseId("test-database")
+ .withBatchSizeBytes(0) // turn off batching.
+ .withServiceFactory(serviceFactory);
+ SpannerIO.SpannerWriteFn writerFn = new SpannerIO.SpannerWriteFn(write);
+ DoFnTester fnTester = DoFnTester.of(writerFn);
+ fnTester.processBundle(Arrays.asList(one, two));
+
+ verify(serviceFactory.mockSpanner())
+ .getDatabaseClient(DatabaseId.of("test-project", "test-instance", "test-database"));
+ verify(serviceFactory.mockDatabaseClient(), times(2))
+ .writeAtLeastOnce(argThat(new IterableOfSize(1)));
+ }
+
+ private static class FakeServiceFactory
+ implements ServiceFactory, Serializable {
+ // Marked as static so they could be returned by serviceFactory, which is serializable.
+ private static final Object lock = new Object();
+
+ @GuardedBy("lock")
+ private static final List mockSpanners = new ArrayList<>();
+
+ @GuardedBy("lock")
+ private static final List mockDatabaseClients = new ArrayList<>();
+
+ @GuardedBy("lock")
+ private static int count = 0;
+
+ private final int index;
+
+ public FakeServiceFactory() {
+ synchronized (lock) {
+ index = count++;
+ mockSpanners.add(mock(Spanner.class, withSettings().serializable()));
+ mockDatabaseClients.add(mock(DatabaseClient.class, withSettings().serializable()));
+ }
+ ApiFuture voidFuture = mock(ApiFuture.class, withSettings().serializable());
+ when(mockSpanner().getDatabaseClient(Matchers.any(DatabaseId.class)))
+ .thenReturn(mockDatabaseClient());
+ when(mockSpanner().closeAsync()).thenReturn(voidFuture);
+ }
+
+ DatabaseClient mockDatabaseClient() {
+ synchronized (lock) {
+ return mockDatabaseClients.get(index);
+ }
+ }
+
+ Spanner mockSpanner() {
+ synchronized (lock) {
+ return mockSpanners.get(index);
+ }
+ }
+
+ @Override
+ public Spanner create(SpannerOptions serviceOptions) {
+ return mockSpanner();
+ }
+ }
+
+ private static class IterableOfSize extends ArgumentMatcher> {
+ private final int size;
+
+ private IterableOfSize(int size) {
+ this.size = size;
+ }
+
+ @Override
+ public boolean matches(Object argument) {
+ return argument instanceof Iterable && Iterables.size((Iterable>) argument) == size;
+ }
+ }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
new file mode 100644
index 000000000000..064c65eedcef
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.spanner.Database;
+import com.google.cloud.spanner.DatabaseAdminClient;
+import com.google.cloud.spanner.DatabaseClient;
+import com.google.cloud.spanner.DatabaseId;
+import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.Operation;
+import com.google.cloud.spanner.ResultSet;
+import com.google.cloud.spanner.Spanner;
+import com.google.cloud.spanner.SpannerOptions;
+import com.google.cloud.spanner.Statement;
+import com.google.spanner.admin.database.v1.CreateDatabaseMetadata;
+import java.util.Collections;
+import org.apache.beam.sdk.io.GenerateSequence;
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** End-to-end test of Cloud Spanner Sink. */
+@RunWith(JUnit4.class)
+public class SpannerWriteIT {
+ @Rule public final transient TestPipeline p = TestPipeline.create();
+
+ /** Pipeline options for this test. */
+ public interface SpannerTestPipelineOptions extends TestPipelineOptions {
+ @Description("Project ID for Spanner")
+ @Default.String("apache-beam-testing")
+ String getProjectId();
+ void setProjectId(String value);
+
+ @Description("Instance ID to write to in Spanner")
+ @Default.String("beam-test")
+ String getInstanceId();
+ void setInstanceId(String value);
+
+ @Description("Database ID to write to in Spanner")
+ @Default.String("beam-testdb")
+ String getDatabaseId();
+ void setDatabaseId(String value);
+
+ @Description("Table name")
+ @Default.String("users")
+ String getTable();
+ void setTable(String value);
+ }
+
+ private Spanner spanner;
+ private DatabaseAdminClient databaseAdminClient;
+ private SpannerTestPipelineOptions options;
+
+ @Before
+ public void setUp() throws Exception {
+ PipelineOptionsFactory.register(SpannerTestPipelineOptions.class);
+ options = TestPipeline.testingPipelineOptions().as(SpannerTestPipelineOptions.class);
+
+ spanner = SpannerOptions.newBuilder().setProjectId(options.getProjectId()).build().getService();
+
+ databaseAdminClient = spanner.getDatabaseAdminClient();
+
+ // Delete database if exists.
+ databaseAdminClient.dropDatabase(options.getInstanceId(), options.getDatabaseId());
+
+ Operation op =
+ databaseAdminClient.createDatabase(
+ options.getInstanceId(),
+ options.getDatabaseId(),
+ Collections.singleton(
+ "CREATE TABLE "
+ + options.getTable()
+ + " ("
+ + " Key INT64,"
+ + " Value STRING(MAX),"
+ + ") PRIMARY KEY (Key)"));
+ op.waitFor();
+ }
+
+ @Test
+ public void testWrite() throws Exception {
+ p.apply(GenerateSequence.from(0).to(100))
+ .apply(ParDo.of(new GenerateMutations(options.getTable())))
+ .apply(
+ SpannerIO.write()
+ .withProjectId(options.getProjectId())
+ .withInstanceId(options.getInstanceId())
+ .withDatabaseId(options.getDatabaseId()));
+
+ p.run();
+ DatabaseClient databaseClient =
+ spanner.getDatabaseClient(
+ DatabaseId.of(
+ options.getProjectId(), options.getInstanceId(), options.getDatabaseId()));
+
+ ResultSet resultSet =
+ databaseClient
+ .singleUse()
+ .executeQuery(Statement.of("SELECT COUNT(*) FROM " + options.getTable()));
+ assertThat(resultSet.next(), is(true));
+ assertThat(resultSet.getLong(0), equalTo(100L));
+ assertThat(resultSet.next(), is(false));
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ databaseAdminClient.dropDatabase(options.getInstanceId(), options.getDatabaseId());
+ spanner.closeAsync().get();
+ }
+
+ private static class GenerateMutations extends DoFn {
+ private final String table;
+ private final int valueSize = 100;
+
+ public GenerateMutations(String table) {
+ this.table = table;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ Mutation.WriteBuilder builder = Mutation.newInsertOrUpdateBuilder(table);
+ Long key = c.element();
+ builder.set("Key").to(key);
+ builder.set("Value").to(RandomStringUtils.random(valueSize, true, true));
+ Mutation mutation = builder.build();
+ c.output(mutation);
+ }
+ }
+}