diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java index e84324a81455..2a95791f6c80 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java @@ -185,6 +185,10 @@ private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) { } String userAgentString = USER_AGENT_PREFIX + "/" + ReleaseInfo.getReleaseInfo().getVersion(); builder.setHeaderProvider(FixedHeaderProvider.create("user-agent", userAgentString)); + String databaseRole = spannerConfig.getDatabaseRole(); + if (databaseRole != null && !databaseRole.isEmpty()) { + builder.setDatabaseRole(databaseRole); + } SpannerOptions options = builder.build(); Spanner spanner = options.getService(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java index 05c8c8926d9c..c10af8429ece 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java @@ -73,6 +73,8 @@ public abstract class SpannerConfig implements Serializable { public abstract @Nullable ValueProvider getRpcPriority(); + public abstract @Nullable String getDatabaseRole(); + @VisibleForTesting abstract @Nullable ServiceFactory getServiceFactory(); @@ -145,6 +147,8 @@ abstract Builder setExecuteStreamingSqlRetrySettings( abstract Builder setRpcPriority(ValueProvider rpcPriority); + abstract Builder setDatabaseRole(String databaseRole); + public abstract SpannerConfig build(); } @@ -256,4 +260,9 @@ public SpannerConfig withRpcPriority(ValueProvider rpcPriority) { checkNotNull(rpcPriority, "withRpcPriority(rpcPriority) called with null input."); return toBuilder().setRpcPriority(rpcPriority).build(); } + + /** Specifies the Cloud Spanner database role. */ + public SpannerConfig withDatabaseRole(String databaseRole) { + return toBuilder().setDatabaseRole(databaseRole).build(); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java index 8ce5d681b84d..ef9f59a2d6e2 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java @@ -99,4 +99,46 @@ public void testRefCountedSpannerAccessorDifferentDbsOnlyOnce() { .getDatabaseClient(eq(DatabaseId.of("project", "test2", "test2"))); verify(serviceFactory.mockSpanner(), times(2)).close(); } + + @Test + public void testCreateWithValidDatabaseRole() { + SpannerConfig config1 = + SpannerConfig.create() + .toBuilder() + .setServiceFactory(serviceFactory) + .setProjectId(StaticValueProvider.of("project")) + .setInstanceId(StaticValueProvider.of("test1")) + .setDatabaseId(StaticValueProvider.of("test1")) + .setDatabaseRole("test-role") + .build(); + + SpannerAccessor acc1 = SpannerAccessor.getOrCreate(config1); + acc1.close(); + + // getDatabaseClient and close() only called once. + verify(serviceFactory.mockSpanner(), times(1)) + .getDatabaseClient(DatabaseId.of("project", "test1", "test1")); + verify(serviceFactory.mockSpanner(), times(1)).close(); + } + + @Test + public void testCreateWithEmptyDatabaseRole() { + SpannerConfig config1 = + SpannerConfig.create() + .toBuilder() + .setServiceFactory(serviceFactory) + .setProjectId(StaticValueProvider.of("project")) + .setInstanceId(StaticValueProvider.of("test1")) + .setDatabaseId(StaticValueProvider.of("test1")) + .setDatabaseRole("") + .build(); + + SpannerAccessor acc1 = SpannerAccessor.getOrCreate(config1); + acc1.close(); + + // getDatabaseClient and close() only called once. + verify(serviceFactory.mockSpanner(), times(1)) + .getDatabaseClient(DatabaseId.of("project", "test1", "test1")); + verify(serviceFactory.mockSpanner(), times(1)).close(); + } }