From bd374b7a1d2029266c7e82f82f4531ab460712e2 Mon Sep 17 00:00:00 2001 From: jiangtaoli2016 Date: Thu, 3 Jan 2019 16:51:37 -0800 Subject: [PATCH 1/5] ALTS: release handshaker channel if no longer needed --- .../java/io/grpc/alts/AltsChannelBuilder.java | 16 ++++----- .../java/io/grpc/alts/AltsServerBuilder.java | 15 ++++---- .../alts/GoogleDefaultChannelBuilder.java | 17 +++++---- .../grpc/alts/HandshakerServiceChannel.java | 11 +++--- .../alts/internal/AltsProtocolNegotiator.java | 36 ++++++++++++++----- .../GoogleDefaultProtocolNegotiator.java | 12 +++++-- .../grpc/alts/internal/TsiFrameHandler.java | 12 +++---- .../alts/internal/TsiHandshakeHandler.java | 7 ++-- .../alts/internal/TsiHandshakerFactory.java | 3 +- .../alts/HandshakerServiceChannelTest.java | 25 ++++++------- .../internal/AltsProtocolNegotiatorTest.java | 21 +++++++---- .../grpc/alts/internal/FakeTsiHandshaker.java | 9 ++--- .../alts/internal/TsiFrameHandlerTest.java | 11 +++--- 13 files changed, 114 insertions(+), 81 deletions(-) diff --git a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java index 6fb41d12a8c..c82443e7cb2 100644 --- a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java @@ -104,8 +104,9 @@ public AltsChannelBuilder enableUntrustedAltsForTesting() { public AltsChannelBuilder setHandshakerAddressForTesting(String handshakerAddress) { // Instead of using the default shared channel to the handshaker service, create a separate // resource to the test address. - handshakerChannelPool = SharedResourcePool.forResource( - HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress)); + handshakerChannelPool = + SharedResourcePool.forResource( + HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress)); return this; } @@ -147,10 +148,7 @@ public AltsProtocolNegotiator buildProtocolNegotiator() { TsiHandshakerFactory altsHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(String authority) { - // Used the shared grpc channel to connecting to the ALTS handshaker service. - // TODO: Release the channel if it is not used. - // https://github.com/grpc/grpc-java/issues/4755. + public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { AltsClientOptions handshakerOptions = new AltsClientOptions.Builder() .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) @@ -158,12 +156,12 @@ public TsiHandshaker newHandshaker(String authority) { .setTargetName(authority) .build(); return AltsTsiHandshaker.newClient( - HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()), - handshakerOptions); + HandshakerServiceGrpc.newStub(handshakerChannel), handshakerOptions); } }; return negotiatorForTest = - AltsProtocolNegotiator.createClientNegotiator(altsHandshakerFactory); + AltsProtocolNegotiator.createClientNegotiator( + altsHandshakerFactory, handshakerChannelPool); } } diff --git a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java index 5ca1f51b4aa..2980eb05b52 100644 --- a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java @@ -92,8 +92,9 @@ public AltsServerBuilder enableUntrustedAltsForTesting() { public AltsServerBuilder setHandshakerAddressForTesting(String handshakerAddress) { // Instead of using the default shared channel to the handshaker service, create a separate // resource to the test address. - handshakerChannelPool = SharedResourcePool.forResource( - HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress)); + handshakerChannelPool = + SharedResourcePool.forResource( + HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress)); return this; } @@ -200,15 +201,13 @@ public Server build() { AltsProtocolNegotiator.createServerNegotiator( new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(String authority) { - // Used the shared grpc channel to connecting to the ALTS handshaker service. - // TODO: Release the channel if it is not used. - // https://github.com/grpc/grpc-java/issues/4755. + public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { return AltsTsiHandshaker.newServer( - HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()), + HandshakerServiceGrpc.newStub(handshakerChannel), new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions())); } - })); + }, + handshakerChannelPool)); return delegate.build(); } diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java index ccd5104ce8f..089bda16465 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java @@ -36,7 +36,7 @@ import io.grpc.alts.internal.TsiHandshakerFactory; import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.SharedResourceHolder; +import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder; @@ -94,24 +94,20 @@ GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() { private final class ProtocolNegotiatorFactory implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { + @Override public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() { TsiHandshakerFactory altsHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(String authority) { - // Used the shared grpc channel to connecting to the ALTS handshaker service. - // TODO: Release the channel if it is not used. - // https://github.com/grpc/grpc-java/issues/4755. - Channel channel = - SharedResourceHolder.get(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL); + public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { AltsClientOptions handshakerOptions = new AltsClientOptions.Builder() .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) .setTargetName(authority) .build(); return AltsTsiHandshaker.newClient( - HandshakerServiceGrpc.newStub(channel), handshakerOptions); + HandshakerServiceGrpc.newStub(handshakerChannel), handshakerOptions); } }; SslContext sslContext; @@ -121,7 +117,10 @@ public TsiHandshaker newHandshaker(String authority) { throw new RuntimeException(ex); } return negotiatorForTest = - new GoogleDefaultProtocolNegotiator(altsHandshakerFactory, sslContext); + new GoogleDefaultProtocolNegotiator( + altsHandshakerFactory, + SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), + sslContext); } } diff --git a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java index 078b865026b..24ea6b8f7e1 100644 --- a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java +++ b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java @@ -55,11 +55,12 @@ public Channel create() { /* Use its own event loop thread pool to avoid blocking. */ EventLoopGroup eventGroup = new NioEventLoopGroup(1, new DefaultThreadFactory("handshaker pool", true)); - ManagedChannel channel = NettyChannelBuilder.forTarget(target) - .directExecutor() - .eventLoopGroup(eventGroup) - .usePlaintext() - .build(); + ManagedChannel channel = + NettyChannelBuilder.forTarget(target) + .directExecutor() + .eventLoopGroup(eventGroup) + .usePlaintext() + .build(); return new EventLoopHoldingChannel(channel, eventGroup); } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index bebf968eac2..eb8f50a7362 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -21,6 +21,7 @@ import com.google.protobuf.Any; import io.grpc.Attributes; import io.grpc.CallCredentials; +import io.grpc.Channel; import io.grpc.Grpc; import io.grpc.InternalChannelz.OtherSecurity; import io.grpc.InternalChannelz.Security; @@ -28,6 +29,7 @@ import io.grpc.Status; import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult; import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent; +import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler; @@ -47,18 +49,28 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator { @Grpc.TransportAttr public static final Attributes.Key TSI_PEER_KEY = Attributes.Key.create("TSI_PEER"); + @Grpc.TransportAttr public static final Attributes.Key ALTS_CONTEXT_KEY = Attributes.Key.create("ALTS_CONTEXT_KEY"); + private static final AsciiString scheme = AsciiString.of("https"); /** Creates a negotiator used for ALTS client. */ public static AltsProtocolNegotiator createClientNegotiator( - final TsiHandshakerFactory handshakerFactory) { + final TsiHandshakerFactory handshakerFactory, + final ObjectPool handshakerChannelPool) { final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator { + + private Channel handshakerChannel = null; + @Override public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority()); + if (handshakerChannel == null) { + handshakerChannel = handshakerChannelPool.getObject(); + } + TsiHandshaker handshaker = + handshakerFactory.newHandshaker(handshakerChannel, grpcHandler.getAuthority()); return new BufferUntilAltsNegotiatedHandler( grpcHandler, new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)), @@ -68,20 +80,28 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { @Override public void close() { logger.finest("ALTS Client ProtocolNegotiator Closed"); - // TODO(jiangtaoli2016): release resources + handshakerChannelPool.returnObject(handshakerChannel); } } - + return new ClientAltsProtocolNegotiator(); } /** Creates a negotiator used for ALTS server. */ public static AltsProtocolNegotiator createServerNegotiator( - final TsiHandshakerFactory handshakerFactory) { + final TsiHandshakerFactory handshakerFactory, + final ObjectPool handshakerChannelPool) { final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator { + + private Channel handshakerChannel = null; + @Override public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null); + if (handshakerChannel == null) { + handshakerChannel = handshakerChannelPool.getObject(); + } + TsiHandshaker handshaker = + handshakerFactory.newHandshaker(handshakerChannel, /*authority=*/ null); return new BufferUntilAltsNegotiatedHandler( grpcHandler, new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)), @@ -91,7 +111,7 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { @Override public void close() { logger.finest("ALTS Server ProtocolNegotiator Closed"); - // TODO(jiangtaoli2016): release resources + handshakerChannelPool.returnObject(handshakerChannel); } } @@ -129,7 +149,7 @@ public AsciiString scheme() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (logger.isLoggable(Level.FINEST)) { - logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[]{evt}); + logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[] {evt}); } if (evt instanceof TsiHandshakeCompletionEvent) { TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt; diff --git a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java index b08453452cd..a84bc5a1aed 100644 --- a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java @@ -17,7 +17,9 @@ package io.grpc.alts.internal; import com.google.common.annotations.VisibleForTesting; +import io.grpc.Channel; import io.grpc.internal.GrpcAttributes; +import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiators; @@ -25,11 +27,17 @@ /** A client-side GPRC {@link ProtocolNegotiator} for Google Default Channel. */ public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator { + private final ProtocolNegotiator altsProtocolNegotiator; private final ProtocolNegotiator tlsProtocolNegotiator; - public GoogleDefaultProtocolNegotiator(TsiHandshakerFactory altsFactory, SslContext sslContext) { - altsProtocolNegotiator = AltsProtocolNegotiator.createClientNegotiator(altsFactory); + /** Constructor for protocol negotiator of Google Default Channel. */ + public GoogleDefaultProtocolNegotiator( + TsiHandshakerFactory altsFactory, + ObjectPool handshakerChannelPool, + SslContext sslContext) { + altsProtocolNegotiator = + AltsProtocolNegotiator.createClientNegotiator(altsFactory, handshakerChannelPool); tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext); } diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java index 264541223b4..e5aa5e0096b 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java @@ -71,7 +71,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception { if (logger.isLoggable(Level.FINEST)) { - logger.log(Level.FINEST, "TsiFrameHandler user event triggered", new Object[]{event}); + logger.log(Level.FINEST, "TsiFrameHandler user event triggered", new Object[] {event}); } if (event instanceof TsiHandshakeCompletionEvent) { TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event; @@ -96,9 +96,7 @@ void setProtector(TsiFrameProtector protector) { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - checkState( - state == State.PROTECTED, - "Cannot read frames while the TSI handshake is %s", state); + checkState(state == State.PROTECTED, "Cannot read frames while the TSI handshake is %s", state); protector.unprotect(in, out, ctx.alloc()); } @@ -106,8 +104,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) throws Exception { checkState( - state == State.PROTECTED, - "Cannot write frames while the TSI handshake state is %s", state); + state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state); ByteBuf msg = (ByteBuf) message; if (!msg.isReadable()) { // Nothing to encode. @@ -193,7 +190,8 @@ public void read(ChannelHandlerContext ctx) { public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException { if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) { logger.fine( - String.format("FrameHandler is inactive(%s), channel id: %s", + String.format( + "FrameHandler is inactive(%s), channel id: %s", state, ctx.channel().id().asShortText())); return; } diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java index 98dd1f90908..8fea539e616 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java @@ -172,10 +172,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t try { ctx.pipeline().remove(this); protector = handshaker.createFrameProtector(ctx.alloc()); - TsiHandshakeCompletionEvent evt = new TsiHandshakeCompletionEvent( - protector, - handshaker.extractPeer(), - handshaker.extractPeerObject()); + TsiHandshakeCompletionEvent evt = + new TsiHandshakeCompletionEvent( + protector, handshaker.extractPeer(), handshaker.extractPeerObject()); protector = null; ctx.fireUserEventTriggered(evt); // No need to do anything with the in buffer, it will be re added to the pipeline when this diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java index 996bd003654..ae893626a80 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java @@ -16,11 +16,12 @@ package io.grpc.alts.internal; +import io.grpc.Channel; import javax.annotation.Nullable; /** Factory that manufactures instances of {@link TsiHandshaker}. */ public interface TsiHandshakerFactory { /** Creates a new handshaker. */ - TsiHandshaker newHandshaker(@Nullable String authority); + TsiHandshaker newHandshaker(@Nullable Channel handshakerChannel, @Nullable String authority); } diff --git a/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java b/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java index ac6ca6f8cb3..39ff67fe442 100644 --- a/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java +++ b/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java @@ -35,18 +35,19 @@ @RunWith(JUnit4.class) public final class HandshakerServiceChannelTest { - @Rule - public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private Server server = grpcCleanup.register( - ServerBuilder.forPort(0) - .addService(new SimpleServiceGrpc.SimpleServiceImplBase() { - @Override - public void unaryRpc(SimpleRequest request, StreamObserver so) { - so.onNext(SimpleResponse.getDefaultInstance()); - so.onCompleted(); - } - }) - .build()); + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private Server server = + grpcCleanup.register( + ServerBuilder.forPort(0) + .addService( + new SimpleServiceGrpc.SimpleServiceImplBase() { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver so) { + so.onNext(SimpleResponse.getDefaultInstance()); + so.onCompleted(); + } + }) + .build()); private Resource resource; @Before diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index bc5193665c6..9324e52ef20 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -25,12 +25,17 @@ import io.grpc.Attributes; import io.grpc.CallCredentials; +import io.grpc.Channel; import io.grpc.Grpc; import io.grpc.InternalChannelz; +import io.grpc.ManagedChannel; import io.grpc.SecurityLevel; import io.grpc.alts.internal.TsiFrameProtector.Consumer; import io.grpc.alts.internal.TsiPeer.Property; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.NettyChannelBuilder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; @@ -133,8 +138,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E TsiHandshakerFactory handshakerFactory = new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) { @Override - public TsiHandshaker newHandshaker(String authority) { - return new DelegatingTsiHandshaker(super.newHandshaker(authority)) { + public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { + return new DelegatingTsiHandshaker(super.newHandshaker(handshakerChannel, authority)) { @Override public TsiPeer extractPeer() throws GeneralSecurityException { return mockedTsiPeer; @@ -147,8 +152,11 @@ public Object extractPeerObject() throws GeneralSecurityException { }; } }; + ManagedChannel fakeChannel = NettyChannelBuilder.forTarget("localhost:8080").build(); + ObjectPool fakeChannelPool = new FixedObjectPool(fakeChannel); handler = - AltsProtocolNegotiator.createServerNegotiator(handshakerFactory).newHandler(grpcHandler); + AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, fakeChannelPool) + .newHandler(grpcHandler); channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler); } @@ -340,8 +348,7 @@ public void doNotFlushEmptyBuffer() throws Exception { public void peerPropagated() throws Exception { doHandshake(); - assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)) - .isEqualTo(mockedTsiPeer); + assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer); assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY)) .isEqualTo(mockedAltsContext); assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString()) @@ -425,8 +432,8 @@ private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFact } @Override - public TsiHandshaker newHandshaker(String authority) { - return delegate.newHandshaker(authority); + public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { + return delegate.newHandshaker(handshakerChannel, authority); } } diff --git a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java index d742607618a..01958e0aff7 100644 --- a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java +++ b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java @@ -19,6 +19,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Preconditions; +import io.grpc.Channel; import io.grpc.alts.internal.TsiPeer.Property; import io.netty.buffer.ByteBufAllocator; import java.nio.ByteBuffer; @@ -37,7 +38,7 @@ public class FakeTsiHandshaker implements TsiHandshaker { private static final TsiHandshakerFactory clientHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(String authority) { + public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { return new FakeTsiHandshaker(true); } }; @@ -45,7 +46,7 @@ public TsiHandshaker newHandshaker(String authority) { private static final TsiHandshakerFactory serverHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(String authority) { + public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { return new FakeTsiHandshaker(false); } }; @@ -83,11 +84,11 @@ public static TsiHandshakerFactory serverHandshakerFactory() { } public static TsiHandshaker newFakeHandshakerClient() { - return clientHandshakerFactory.newHandshaker(null); + return clientHandshakerFactory.newHandshaker(null, null); } public static TsiHandshaker newFakeHandshakerServer() { - return serverHandshakerFactory.newHandshaker(null); + return serverHandshakerFactory.newHandshaker(null, null); } protected FakeTsiHandshaker(boolean isClient) { diff --git a/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java index c24ad498c73..fe4126c9f07 100644 --- a/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java @@ -43,8 +43,7 @@ @RunWith(JUnit4.class) public class TsiFrameHandlerTest { - @Rule - public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5)); + @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5)); private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler(); private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler); @@ -115,7 +114,8 @@ public void close_shouldFlushRemainingMessage() throws InterruptedException { channel.close().sync(); assertWithMessage("pending write should be flushed on close") - .that((Object) channel.readOutbound()).isEqualTo(msg); + .that((Object) channel.readOutbound()) + .isEqualTo(msg); channel.checkException(); } @@ -128,8 +128,9 @@ private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() { private static final class IdentityFrameProtector implements TsiFrameProtector { @Override - public void protectFlush(List unprotectedBufs, Consumer ctxWrite, - ByteBufAllocator alloc) throws GeneralSecurityException { + public void protectFlush( + List unprotectedBufs, Consumer ctxWrite, ByteBufAllocator alloc) + throws GeneralSecurityException { for (ByteBuf unprotectedBuf : unprotectedBufs) { ctxWrite.accept(unprotectedBuf); } From efce2dfe51cca0cf55e2d9df18d0fe939e60fd4a Mon Sep 17 00:00:00 2001 From: jiangtaoli2016 Date: Fri, 4 Jan 2019 13:04:27 -0800 Subject: [PATCH 2/5] Add comments to TsiHandshakerFactory --- .../java/io/grpc/alts/internal/TsiHandshakerFactory.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java index ae893626a80..9fb1779f2be 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java @@ -22,6 +22,11 @@ /** Factory that manufactures instances of {@link TsiHandshaker}. */ public interface TsiHandshakerFactory { - /** Creates a new handshaker. */ + /** + * Creates a new handshaker. + * + * @param handshakerChannel the shared channel to the handshaker service. + * @param authority the destination that the channel connects to. + */ TsiHandshaker newHandshaker(@Nullable Channel handshakerChannel, @Nullable String authority); } From e618040192faa69284c7f0cdd1da89ea568ba3ab Mon Sep 17 00:00:00 2001 From: jiangtaoli2016 Date: Fri, 4 Jan 2019 15:48:08 -0800 Subject: [PATCH 3/5] Revise based on review comments --- .../java/io/grpc/alts/internal/AltsProtocolNegotiator.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index eb8f50a7362..f75646618ba 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -62,7 +62,7 @@ public static AltsProtocolNegotiator createClientNegotiator( final ObjectPool handshakerChannelPool) { final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator { - private Channel handshakerChannel = null; + private Channel handshakerChannel; @Override public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { @@ -93,7 +93,7 @@ public static AltsProtocolNegotiator createServerNegotiator( final ObjectPool handshakerChannelPool) { final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator { - private Channel handshakerChannel = null; + private Channel handshakerChannel; @Override public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { From 9b9aa0137370ae59df2b6c92b98b39e6abf3c1de Mon Sep 17 00:00:00 2001 From: jiangtaoli2016 Date: Fri, 11 Jan 2019 10:27:24 -0800 Subject: [PATCH 4/5] Use LazyChannel insted of modifying TsiHandshakerFactory interface --- .../java/io/grpc/alts/AltsChannelBuilder.java | 8 ++-- .../java/io/grpc/alts/AltsServerBuilder.java | 9 ++-- .../alts/GoogleDefaultChannelBuilder.java | 13 +++-- .../grpc/alts/HandshakerServiceChannel.java | 11 ++--- .../alts/internal/AltsProtocolNegotiator.java | 47 ++++++++++++------- .../GoogleDefaultProtocolNegotiator.java | 5 +- .../grpc/alts/internal/TsiFrameHandler.java | 12 +++-- .../alts/internal/TsiHandshakeHandler.java | 7 +-- .../alts/internal/TsiHandshakerFactory.java | 10 +--- .../internal/AltsProtocolNegotiatorTest.java | 13 +++-- .../grpc/alts/internal/FakeTsiHandshaker.java | 9 ++-- .../alts/internal/TsiFrameHandlerTest.java | 11 ++--- 12 files changed, 87 insertions(+), 68 deletions(-) diff --git a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java index c82443e7cb2..760100b6bae 100644 --- a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java @@ -30,6 +30,7 @@ import io.grpc.Status; import io.grpc.alts.internal.AltsClientOptions; import io.grpc.alts.internal.AltsProtocolNegotiator; +import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; import io.grpc.alts.internal.AltsTsiHandshaker; import io.grpc.alts.internal.HandshakerServiceGrpc; import io.grpc.alts.internal.RpcProtocolVersionsUtil; @@ -145,10 +146,11 @@ private final class ProtocolNegotiatorFactory @Override public AltsProtocolNegotiator buildProtocolNegotiator() { final ImmutableList targetServiceAccounts = targetServiceAccountsBuilder.build(); + final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); TsiHandshakerFactory altsHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { + public TsiHandshaker newHandshaker(String authority) { AltsClientOptions handshakerOptions = new AltsClientOptions.Builder() .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) @@ -156,12 +158,12 @@ public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) .setTargetName(authority) .build(); return AltsTsiHandshaker.newClient( - HandshakerServiceGrpc.newStub(handshakerChannel), handshakerOptions); + HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions); } }; return negotiatorForTest = AltsProtocolNegotiator.createClientNegotiator( - altsHandshakerFactory, handshakerChannelPool); + altsHandshakerFactory, handshakerChannelPool, lazyHandshakerChannel); } } diff --git a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java index 2980eb05b52..d484d4f15c4 100644 --- a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java @@ -35,6 +35,7 @@ import io.grpc.Status; import io.grpc.alts.internal.AltsHandshakerOptions; import io.grpc.alts.internal.AltsProtocolNegotiator; +import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; import io.grpc.alts.internal.AltsTsiHandshaker; import io.grpc.alts.internal.HandshakerServiceGrpc; import io.grpc.alts.internal.RpcProtocolVersionsUtil; @@ -197,17 +198,19 @@ public Server build() { } } + final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); delegate.protocolNegotiator( AltsProtocolNegotiator.createServerNegotiator( new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { + public TsiHandshaker newHandshaker(String authority) { return AltsTsiHandshaker.newServer( - HandshakerServiceGrpc.newStub(handshakerChannel), + HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions())); } }, - handshakerChannelPool)); + handshakerChannelPool, + lazyHandshakerChannel)); return delegate.build(); } diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java index 089bda16465..96fb379df6e 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java @@ -28,6 +28,7 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import io.grpc.alts.internal.AltsClientOptions; +import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; import io.grpc.alts.internal.AltsTsiHandshaker; import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator; import io.grpc.alts.internal.HandshakerServiceGrpc; @@ -36,6 +37,7 @@ import io.grpc.alts.internal.TsiHandshakerFactory; import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.InternalNettyChannelBuilder; @@ -97,17 +99,20 @@ private final class ProtocolNegotiatorFactory @Override public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() { + final ObjectPool handshakerChannelPool = + SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL); + final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); TsiHandshakerFactory altsHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { + public TsiHandshaker newHandshaker(String authority) { AltsClientOptions handshakerOptions = new AltsClientOptions.Builder() .setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) .setTargetName(authority) .build(); return AltsTsiHandshaker.newClient( - HandshakerServiceGrpc.newStub(handshakerChannel), handshakerOptions); + HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions); } }; SslContext sslContext; @@ -118,9 +123,7 @@ public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) } return negotiatorForTest = new GoogleDefaultProtocolNegotiator( - altsHandshakerFactory, - SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL), - sslContext); + altsHandshakerFactory, handshakerChannelPool, lazyHandshakerChannel, sslContext); } } diff --git a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java index 24ea6b8f7e1..078b865026b 100644 --- a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java +++ b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java @@ -55,12 +55,11 @@ public Channel create() { /* Use its own event loop thread pool to avoid blocking. */ EventLoopGroup eventGroup = new NioEventLoopGroup(1, new DefaultThreadFactory("handshaker pool", true)); - ManagedChannel channel = - NettyChannelBuilder.forTarget(target) - .directExecutor() - .eventLoopGroup(eventGroup) - .usePlaintext() - .build(); + ManagedChannel channel = NettyChannelBuilder.forTarget(target) + .directExecutor() + .eventLoopGroup(eventGroup) + .usePlaintext() + .build(); return new EventLoopHoldingChannel(channel, eventGroup); } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index f75646618ba..410ad6a584f 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -59,18 +59,13 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator { /** Creates a negotiator used for ALTS client. */ public static AltsProtocolNegotiator createClientNegotiator( final TsiHandshakerFactory handshakerFactory, - final ObjectPool handshakerChannelPool) { + final ObjectPool handshakerChannelPool, + final LazyChannel lazyHandshakerChannel) { final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator { - private Channel handshakerChannel; - @Override public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - if (handshakerChannel == null) { - handshakerChannel = handshakerChannelPool.getObject(); - } - TsiHandshaker handshaker = - handshakerFactory.newHandshaker(handshakerChannel, grpcHandler.getAuthority()); + TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority()); return new BufferUntilAltsNegotiatedHandler( grpcHandler, new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)), @@ -80,7 +75,7 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { @Override public void close() { logger.finest("ALTS Client ProtocolNegotiator Closed"); - handshakerChannelPool.returnObject(handshakerChannel); + handshakerChannelPool.returnObject(lazyHandshakerChannel.get()); } } @@ -90,18 +85,13 @@ public void close() { /** Creates a negotiator used for ALTS server. */ public static AltsProtocolNegotiator createServerNegotiator( final TsiHandshakerFactory handshakerFactory, - final ObjectPool handshakerChannelPool) { + final ObjectPool handshakerChannelPool, + final LazyChannel lazyHandshakerChannel) { final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator { - private Channel handshakerChannel; - @Override public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - if (handshakerChannel == null) { - handshakerChannel = handshakerChannelPool.getObject(); - } - TsiHandshaker handshaker = - handshakerFactory.newHandshaker(handshakerChannel, /*authority=*/ null); + TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null); return new BufferUntilAltsNegotiatedHandler( grpcHandler, new TsiHandshakeHandler(new NettyTsiHandshaker(handshaker)), @@ -111,13 +101,34 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { @Override public void close() { logger.finest("ALTS Server ProtocolNegotiator Closed"); - handshakerChannelPool.returnObject(handshakerChannel); + handshakerChannelPool.returnObject(lazyHandshakerChannel.get()); } } return new ServerAltsProtocolNegotiator(); } + /** Channel created from a channel pool lazily. */ + public static class LazyChannel { + private ObjectPool channelPool; + private Channel channel; + + public LazyChannel(ObjectPool channelPool) { + this.channelPool = channelPool; + } + + /** + * On the first call, it gets a channel from the channel pool. On the remaining calls, it + * returns the cached channel. + */ + public synchronized Channel get() { + if (channel == null) { + channel = channelPool.getObject(); + } + return channel; + } + } + /** Buffers all writes until the ALTS handshake is complete. */ @VisibleForTesting static final class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler diff --git a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java index a84bc5a1aed..a6f71283c39 100644 --- a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Channel; +import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -35,9 +36,11 @@ public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator public GoogleDefaultProtocolNegotiator( TsiHandshakerFactory altsFactory, ObjectPool handshakerChannelPool, + LazyChannel lazyHandshakerChannel, SslContext sslContext) { altsProtocolNegotiator = - AltsProtocolNegotiator.createClientNegotiator(altsFactory, handshakerChannelPool); + AltsProtocolNegotiator.createClientNegotiator( + altsFactory, handshakerChannelPool, lazyHandshakerChannel); tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext); } diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java index e5aa5e0096b..264541223b4 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiFrameHandler.java @@ -71,7 +71,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception { if (logger.isLoggable(Level.FINEST)) { - logger.log(Level.FINEST, "TsiFrameHandler user event triggered", new Object[] {event}); + logger.log(Level.FINEST, "TsiFrameHandler user event triggered", new Object[]{event}); } if (event instanceof TsiHandshakeCompletionEvent) { TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event; @@ -96,7 +96,9 @@ void setProtector(TsiFrameProtector protector) { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - checkState(state == State.PROTECTED, "Cannot read frames while the TSI handshake is %s", state); + checkState( + state == State.PROTECTED, + "Cannot read frames while the TSI handshake is %s", state); protector.unprotect(in, out, ctx.alloc()); } @@ -104,7 +106,8 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) throws Exception { checkState( - state == State.PROTECTED, "Cannot write frames while the TSI handshake state is %s", state); + state == State.PROTECTED, + "Cannot write frames while the TSI handshake state is %s", state); ByteBuf msg = (ByteBuf) message; if (!msg.isReadable()) { // Nothing to encode. @@ -190,8 +193,7 @@ public void read(ChannelHandlerContext ctx) { public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException { if (state == State.CLOSED || state == State.HANDSHAKE_FAILED) { logger.fine( - String.format( - "FrameHandler is inactive(%s), channel id: %s", + String.format("FrameHandler is inactive(%s), channel id: %s", state, ctx.channel().id().asShortText())); return; } diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java index 8fea539e616..98dd1f90908 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakeHandler.java @@ -172,9 +172,10 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t try { ctx.pipeline().remove(this); protector = handshaker.createFrameProtector(ctx.alloc()); - TsiHandshakeCompletionEvent evt = - new TsiHandshakeCompletionEvent( - protector, handshaker.extractPeer(), handshaker.extractPeerObject()); + TsiHandshakeCompletionEvent evt = new TsiHandshakeCompletionEvent( + protector, + handshaker.extractPeer(), + handshaker.extractPeerObject()); protector = null; ctx.fireUserEventTriggered(evt); // No need to do anything with the in buffer, it will be re added to the pipeline when this diff --git a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java index 9fb1779f2be..996bd003654 100644 --- a/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java +++ b/alts/src/main/java/io/grpc/alts/internal/TsiHandshakerFactory.java @@ -16,17 +16,11 @@ package io.grpc.alts.internal; -import io.grpc.Channel; import javax.annotation.Nullable; /** Factory that manufactures instances of {@link TsiHandshaker}. */ public interface TsiHandshakerFactory { - /** - * Creates a new handshaker. - * - * @param handshakerChannel the shared channel to the handshaker service. - * @param authority the destination that the channel connects to. - */ - TsiHandshaker newHandshaker(@Nullable Channel handshakerChannel, @Nullable String authority); + /** Creates a new handshaker. */ + TsiHandshaker newHandshaker(@Nullable String authority); } diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index 9324e52ef20..21303e428f8 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -30,6 +30,7 @@ import io.grpc.InternalChannelz; import io.grpc.ManagedChannel; import io.grpc.SecurityLevel; +import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; import io.grpc.alts.internal.TsiFrameProtector.Consumer; import io.grpc.alts.internal.TsiPeer.Property; import io.grpc.internal.FixedObjectPool; @@ -138,8 +139,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E TsiHandshakerFactory handshakerFactory = new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) { @Override - public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { - return new DelegatingTsiHandshaker(super.newHandshaker(handshakerChannel, authority)) { + public TsiHandshaker newHandshaker(String authority) { + return new DelegatingTsiHandshaker(super.newHandshaker(authority)) { @Override public TsiPeer extractPeer() throws GeneralSecurityException { return mockedTsiPeer; @@ -154,8 +155,10 @@ public Object extractPeerObject() throws GeneralSecurityException { }; ManagedChannel fakeChannel = NettyChannelBuilder.forTarget("localhost:8080").build(); ObjectPool fakeChannelPool = new FixedObjectPool(fakeChannel); + LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool); handler = - AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, fakeChannelPool) + AltsProtocolNegotiator.createServerNegotiator( + handshakerFactory, fakeChannelPool, lazyFakeChannel) .newHandler(grpcHandler); channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler); } @@ -432,8 +435,8 @@ private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFact } @Override - public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { - return delegate.newHandshaker(handshakerChannel, authority); + public TsiHandshaker newHandshaker(String authority) { + return delegate.newHandshaker(authority); } } diff --git a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java index 01958e0aff7..d742607618a 100644 --- a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java +++ b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java @@ -19,7 +19,6 @@ import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Preconditions; -import io.grpc.Channel; import io.grpc.alts.internal.TsiPeer.Property; import io.netty.buffer.ByteBufAllocator; import java.nio.ByteBuffer; @@ -38,7 +37,7 @@ public class FakeTsiHandshaker implements TsiHandshaker { private static final TsiHandshakerFactory clientHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { + public TsiHandshaker newHandshaker(String authority) { return new FakeTsiHandshaker(true); } }; @@ -46,7 +45,7 @@ public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) private static final TsiHandshakerFactory serverHandshakerFactory = new TsiHandshakerFactory() { @Override - public TsiHandshaker newHandshaker(Channel handshakerChannel, String authority) { + public TsiHandshaker newHandshaker(String authority) { return new FakeTsiHandshaker(false); } }; @@ -84,11 +83,11 @@ public static TsiHandshakerFactory serverHandshakerFactory() { } public static TsiHandshaker newFakeHandshakerClient() { - return clientHandshakerFactory.newHandshaker(null, null); + return clientHandshakerFactory.newHandshaker(null); } public static TsiHandshaker newFakeHandshakerServer() { - return serverHandshakerFactory.newHandshaker(null, null); + return serverHandshakerFactory.newHandshaker(null); } protected FakeTsiHandshaker(boolean isClient) { diff --git a/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java index 991524cae1c..efc1f57ba37 100644 --- a/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/TsiFrameHandlerTest.java @@ -43,7 +43,8 @@ @RunWith(JUnit4.class) public class TsiFrameHandlerTest { - @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5)); + @Rule + public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(5)); private final TsiFrameHandler tsiFrameHandler = new TsiFrameHandler(); private final EmbeddedChannel channel = new EmbeddedChannel(tsiFrameHandler); @@ -115,8 +116,7 @@ public void close_shouldFlushRemainingMessage() throws InterruptedException { channel.close().sync(); assertWithMessage("pending write should be flushed on close") - .that((Object) channel.readOutbound()) - .isEqualTo(msg); + .that((Object) channel.readOutbound()).isEqualTo(msg); channel.checkException(); } @@ -129,9 +129,8 @@ private TsiHandshakeCompletionEvent getHandshakeSuccessEvent() { private static final class IdentityFrameProtector implements TsiFrameProtector { @Override - public void protectFlush( - List unprotectedBufs, Consumer ctxWrite, ByteBufAllocator alloc) - throws GeneralSecurityException { + public void protectFlush(List unprotectedBufs, Consumer ctxWrite, + ByteBufAllocator alloc) throws GeneralSecurityException { for (ByteBuf unprotectedBuf : unprotectedBufs) { ctxWrite.accept(unprotectedBuf); } From 5fbd5ed3c874d82134b48770fab1d6be28f6acd2 Mon Sep 17 00:00:00 2001 From: jiangtaoli2016 Date: Fri, 11 Jan 2019 12:58:34 -0800 Subject: [PATCH 5/5] Add a close to LazyChannel --- .../java/io/grpc/alts/AltsChannelBuilder.java | 2 +- .../java/io/grpc/alts/AltsServerBuilder.java | 1 - .../alts/GoogleDefaultChannelBuilder.java | 9 +++---- .../alts/internal/AltsProtocolNegotiator.java | 25 +++++++++++-------- .../GoogleDefaultProtocolNegotiator.java | 10 ++------ .../internal/AltsProtocolNegotiatorTest.java | 3 +-- 6 files changed, 22 insertions(+), 28 deletions(-) diff --git a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java index 760100b6bae..8252884b249 100644 --- a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java @@ -163,7 +163,7 @@ public TsiHandshaker newHandshaker(String authority) { }; return negotiatorForTest = AltsProtocolNegotiator.createClientNegotiator( - altsHandshakerFactory, handshakerChannelPool, lazyHandshakerChannel); + altsHandshakerFactory, lazyHandshakerChannel); } } diff --git a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java index d484d4f15c4..7cffef0283d 100644 --- a/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsServerBuilder.java @@ -209,7 +209,6 @@ public TsiHandshaker newHandshaker(String authority) { new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions())); } }, - handshakerChannelPool, lazyHandshakerChannel)); return delegate.build(); } diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java index 96fb379df6e..3752462c9e4 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java @@ -37,7 +37,6 @@ import io.grpc.alts.internal.TsiHandshakerFactory; import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.InternalNettyChannelBuilder; @@ -99,9 +98,9 @@ private final class ProtocolNegotiatorFactory @Override public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() { - final ObjectPool handshakerChannelPool = - SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL); - final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool); + final LazyChannel lazyHandshakerChannel = + new LazyChannel( + SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL)); TsiHandshakerFactory altsHandshakerFactory = new TsiHandshakerFactory() { @Override @@ -123,7 +122,7 @@ public TsiHandshaker newHandshaker(String authority) { } return negotiatorForTest = new GoogleDefaultProtocolNegotiator( - altsHandshakerFactory, handshakerChannelPool, lazyHandshakerChannel, sslContext); + altsHandshakerFactory, lazyHandshakerChannel, sslContext); } } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index 9a5f365284b..042215f8d66 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -58,9 +58,7 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator { /** Creates a negotiator used for ALTS client. */ public static AltsProtocolNegotiator createClientNegotiator( - final TsiHandshakerFactory handshakerFactory, - final ObjectPool handshakerChannelPool, - final LazyChannel lazyHandshakerChannel) { + final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) { final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator { @Override @@ -75,7 +73,7 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { @Override public void close() { logger.finest("ALTS Client ProtocolNegotiator Closed"); - handshakerChannelPool.returnObject(lazyHandshakerChannel.get()); + lazyHandshakerChannel.close(); } } @@ -84,9 +82,7 @@ public void close() { /** Creates a negotiator used for ALTS server. */ public static AltsProtocolNegotiator createServerNegotiator( - final TsiHandshakerFactory handshakerFactory, - final ObjectPool handshakerChannelPool, - final LazyChannel lazyHandshakerChannel) { + final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) { final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator { @Override @@ -101,7 +97,7 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { @Override public void close() { logger.finest("ALTS Server ProtocolNegotiator Closed"); - handshakerChannelPool.returnObject(lazyHandshakerChannel.get()); + lazyHandshakerChannel.close(); } } @@ -110,7 +106,7 @@ public void close() { /** Channel created from a channel pool lazily. */ public static class LazyChannel { - private ObjectPool channelPool; + private final ObjectPool channelPool; private Channel channel; public LazyChannel(ObjectPool channelPool) { @@ -118,8 +114,8 @@ public LazyChannel(ObjectPool channelPool) { } /** - * On the first call, it gets a channel from the channel pool. On the remaining calls, it - * returns the cached channel. + * If channel is null, gets a channel from the channel pool, otherwise, returns the cached + * channel. */ public synchronized Channel get() { if (channel == null) { @@ -127,6 +123,13 @@ public synchronized Channel get() { } return channel; } + + /** Returns the cached channel to the channel pool. */ + public synchronized void close() { + if (channel != null) { + channelPool.returnObject(channel); + } + } } /** Buffers all writes until the ALTS handshake is complete. */ diff --git a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java index a6f71283c39..a4d611d0788 100644 --- a/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiator.java @@ -17,10 +17,8 @@ package io.grpc.alts.internal; import com.google.common.annotations.VisibleForTesting; -import io.grpc.Channel; import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; import io.grpc.internal.GrpcAttributes; -import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiators; @@ -34,13 +32,9 @@ public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator /** Constructor for protocol negotiator of Google Default Channel. */ public GoogleDefaultProtocolNegotiator( - TsiHandshakerFactory altsFactory, - ObjectPool handshakerChannelPool, - LazyChannel lazyHandshakerChannel, - SslContext sslContext) { + TsiHandshakerFactory altsFactory, LazyChannel lazyHandshakerChannel, SslContext sslContext) { altsProtocolNegotiator = - AltsProtocolNegotiator.createClientNegotiator( - altsFactory, handshakerChannelPool, lazyHandshakerChannel); + AltsProtocolNegotiator.createClientNegotiator(altsFactory, lazyHandshakerChannel); tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext); } diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index 8dacaecdefc..d474382bb9f 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -157,8 +157,7 @@ public Object extractPeerObject() throws GeneralSecurityException { ObjectPool fakeChannelPool = new FixedObjectPool(fakeChannel); LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool); handler = - AltsProtocolNegotiator.createServerNegotiator( - handshakerFactory, fakeChannelPool, lazyFakeChannel) + AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, lazyFakeChannel) .newHandler(grpcHandler); channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler); }