Skip to content
This repository was archived by the owner on Sep 26, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
*/
package com.google.api.gax.grpc;

import com.google.api.core.ApiFunction;
import com.google.api.core.BetaApi;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.core.ExecutorProvider;
Expand Down Expand Up @@ -75,6 +76,9 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final Boolean keepAliveWithoutCalls;
@Nullable private final Integer poolSize;

@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;

private InstantiatingGrpcChannelProvider(Builder builder) {
this.processorCount = builder.processorCount;
this.executorProvider = builder.executorProvider;
Expand All @@ -87,6 +91,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.keepAliveTimeout = builder.keepAliveTimeout;
this.keepAliveWithoutCalls = builder.keepAliveWithoutCalls;
this.poolSize = builder.poolSize;
this.channelConfigurator = builder.channelConfigurator;
}

@Override
Expand Down Expand Up @@ -242,6 +247,9 @@ private ManagedChannel createSingleChannel() throws IOException {
if (interceptorProvider != null) {
builder.intercept(interceptorProvider.getInterceptors());
}
if (channelConfigurator != null) {
builder = channelConfigurator.apply(builder);
}

return builder.build();
}
Expand Down Expand Up @@ -297,6 +305,7 @@ public static final class Builder {
@Nullable private Duration keepAliveTimeout;
@Nullable private Boolean keepAliveWithoutCalls;
@Nullable private Integer poolSize;
@Nullable private ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;

private Builder() {
processorCount = Runtime.getRuntime().availableProcessors();
Expand All @@ -314,6 +323,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.keepAliveTimeout = provider.keepAliveTimeout;
this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls;
this.poolSize = provider.poolSize;
this.channelConfigurator = provider.channelConfigurator;
}

/** Sets the number of available CPUs, used internally for testing. */
Expand Down Expand Up @@ -469,6 +479,25 @@ public Builder setChannelsPerCpu(double multiplier, int maxChannels) {
public InstantiatingGrpcChannelProvider build() {
return new InstantiatingGrpcChannelProvider(this);
}

/**
* Add a callback that can intercept channel creation.
*
* <p>This can be used for advanced configuration like setting the netty event loop. The
* callback will be invoked with a fully configured channel builder, which the callback can
* augment or replace.
*/
@BetaApi("Surface for advanced channel configuration is not yet stable")
public Builder setChannelConfigurator(
@Nullable ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator) {
this.channelConfigurator = channelConfigurator;
return this;
}

@Nullable
public ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> getChannelConfigurator() {
return channelConfigurator;
}
}

private static void validateEndpoint(String endpoint) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,21 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

import com.google.api.core.ApiFunction;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.threeten.bp.Duration;

Expand Down Expand Up @@ -168,4 +172,39 @@ private void testWithInterceptors(int numChannels) throws Exception {
channelProvider.getTransportChannel().shutdownNow();
Mockito.verify(interceptorProvider, Mockito.times(numChannels)).getInterceptors();
}

@Test
public void testChannelConfigurator() throws IOException {
final int numChannels = 5;

// Create a mock configurator that will insert mock channels
@SuppressWarnings("unchecked")
ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator =
Mockito.mock(ApiFunction.class);

ArgumentCaptor<ManagedChannelBuilder> channelBuilderCaptor =
ArgumentCaptor.forClass(ManagedChannelBuilder.class);

ManagedChannelBuilder swappedBuilder = Mockito.mock(ManagedChannelBuilder.class);
ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
Mockito.when(swappedBuilder.build()).thenReturn(fakeChannel);

Mockito.when(channelConfigurator.apply(channelBuilderCaptor.capture()))
.thenReturn(swappedBuilder);

// Invoke the provider
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutorProvider(Mockito.mock(ExecutorProvider.class))
.setChannelConfigurator(channelConfigurator)
.setPoolSize(numChannels)
.build()
.getTransportChannel();

// Make sure that the provider passed in a configured channel
assertThat(channelBuilderCaptor.getValue()).isNotNull();
// And that it was replaced with the mock
Mockito.verify(swappedBuilder, Mockito.times(numChannels)).build();
}
}