diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java index 65d201d71..1a2d9a27f 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java @@ -29,6 +29,7 @@ */ package com.google.api.gax.grpc; +import com.google.api.core.ApiFunction; import com.google.api.core.BetaApi; import com.google.api.core.InternalExtensionOnly; import com.google.api.gax.core.ExecutorProvider; @@ -75,6 +76,9 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP @Nullable private final Boolean keepAliveWithoutCalls; @Nullable private final Integer poolSize; + @Nullable + private final ApiFunction channelConfigurator; + private InstantiatingGrpcChannelProvider(Builder builder) { this.processorCount = builder.processorCount; this.executorProvider = builder.executorProvider; @@ -87,6 +91,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) { this.keepAliveTimeout = builder.keepAliveTimeout; this.keepAliveWithoutCalls = builder.keepAliveWithoutCalls; this.poolSize = builder.poolSize; + this.channelConfigurator = builder.channelConfigurator; } @Override @@ -242,6 +247,9 @@ private ManagedChannel createSingleChannel() throws IOException { if (interceptorProvider != null) { builder.intercept(interceptorProvider.getInterceptors()); } + if (channelConfigurator != null) { + builder = channelConfigurator.apply(builder); + } return builder.build(); } @@ -297,6 +305,7 @@ public static final class Builder { @Nullable private Duration keepAliveTimeout; @Nullable private Boolean keepAliveWithoutCalls; @Nullable private Integer poolSize; + @Nullable private ApiFunction channelConfigurator; private Builder() { processorCount = Runtime.getRuntime().availableProcessors(); @@ -314,6 +323,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) { this.keepAliveTimeout = provider.keepAliveTimeout; this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls; this.poolSize = provider.poolSize; + this.channelConfigurator = provider.channelConfigurator; } /** Sets the number of available CPUs, used internally for testing. */ @@ -469,6 +479,25 @@ public Builder setChannelsPerCpu(double multiplier, int maxChannels) { public InstantiatingGrpcChannelProvider build() { return new InstantiatingGrpcChannelProvider(this); } + + /** + * Add a callback that can intercept channel creation. + * + *

This can be used for advanced configuration like setting the netty event loop. The + * callback will be invoked with a fully configured channel builder, which the callback can + * augment or replace. + */ + @BetaApi("Surface for advanced channel configuration is not yet stable") + public Builder setChannelConfigurator( + @Nullable ApiFunction channelConfigurator) { + this.channelConfigurator = channelConfigurator; + return this; + } + + @Nullable + public ApiFunction getChannelConfigurator() { + return channelConfigurator; + } } private static void validateEndpoint(String endpoint) { diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java index 34cde1b60..d62be234e 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java @@ -33,10 +33,13 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; +import com.google.api.core.ApiFunction; import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; import com.google.api.gax.rpc.HeaderProvider; import com.google.api.gax.rpc.TransportChannelProvider; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; import java.io.IOException; import java.util.Collections; import java.util.concurrent.ScheduledExecutorService; @@ -44,6 +47,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.threeten.bp.Duration; @@ -168,4 +172,39 @@ private void testWithInterceptors(int numChannels) throws Exception { channelProvider.getTransportChannel().shutdownNow(); Mockito.verify(interceptorProvider, Mockito.times(numChannels)).getInterceptors(); } + + @Test + public void testChannelConfigurator() throws IOException { + final int numChannels = 5; + + // Create a mock configurator that will insert mock channels + @SuppressWarnings("unchecked") + ApiFunction channelConfigurator = + Mockito.mock(ApiFunction.class); + + ArgumentCaptor channelBuilderCaptor = + ArgumentCaptor.forClass(ManagedChannelBuilder.class); + + ManagedChannelBuilder swappedBuilder = Mockito.mock(ManagedChannelBuilder.class); + ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class); + Mockito.when(swappedBuilder.build()).thenReturn(fakeChannel); + + Mockito.when(channelConfigurator.apply(channelBuilderCaptor.capture())) + .thenReturn(swappedBuilder); + + // Invoke the provider + InstantiatingGrpcChannelProvider.newBuilder() + .setEndpoint("localhost:8080") + .setHeaderProvider(Mockito.mock(HeaderProvider.class)) + .setExecutorProvider(Mockito.mock(ExecutorProvider.class)) + .setChannelConfigurator(channelConfigurator) + .setPoolSize(numChannels) + .build() + .getTransportChannel(); + + // Make sure that the provider passed in a configured channel + assertThat(channelBuilderCaptor.getValue()).isNotNull(); + // And that it was replaced with the mock + Mockito.verify(swappedBuilder, Mockito.times(numChannels)).build(); + } }