Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
Expand Down Expand Up @@ -104,8 +105,9 @@ public AltsChannelBuilder enableUntrustedAltsForTesting() {
public AltsChannelBuilder setHandshakerAddressForTesting(String handshakerAddress) {
// Instead of using the default shared channel to the handshaker service, create a separate
// resource to the test address.
handshakerChannelPool = SharedResourcePool.forResource(
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress));
handshakerChannelPool =
SharedResourcePool.forResource(
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress));
return this;
}

Expand Down Expand Up @@ -144,26 +146,24 @@ private final class ProtocolNegotiatorFactory
@Override
public AltsProtocolNegotiator buildProtocolNegotiator() {
final ImmutableList<String> targetServiceAccounts = targetServiceAccountsBuilder.build();
final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetServiceAccounts(targetServiceAccounts)
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()),
handshakerOptions);
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
}
};
return negotiatorForTest =
AltsProtocolNegotiator.createClientNegotiator(altsHandshakerFactory);
AltsProtocolNegotiator.createClientNegotiator(
altsHandshakerFactory, lazyHandshakerChannel);
}
}

Expand Down
15 changes: 8 additions & 7 deletions alts/src/main/java/io/grpc/alts/AltsServerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.grpc.Status;
import io.grpc.alts.internal.AltsHandshakerOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.HandshakerServiceGrpc;
import io.grpc.alts.internal.RpcProtocolVersionsUtil;
Expand Down Expand Up @@ -92,8 +93,9 @@ public AltsServerBuilder enableUntrustedAltsForTesting() {
public AltsServerBuilder setHandshakerAddressForTesting(String handshakerAddress) {
// Instead of using the default shared channel to the handshaker service, create a separate
// resource to the test address.
handshakerChannelPool = SharedResourcePool.forResource(
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress));
handshakerChannelPool =
SharedResourcePool.forResource(
HandshakerServiceChannel.getHandshakerChannelForTesting(handshakerAddress));
return this;
}

Expand Down Expand Up @@ -196,19 +198,18 @@ public Server build() {
}
}

final LazyChannel lazyHandshakerChannel = new LazyChannel(handshakerChannelPool);
delegate.protocolNegotiator(
AltsProtocolNegotiator.createServerNegotiator(
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
return AltsTsiHandshaker.newServer(
HandshakerServiceGrpc.newStub(handshakerChannelPool.getObject()),
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
}
}));
},
lazyHandshakerChannel));
return delegate.build();
}

Expand Down
17 changes: 9 additions & 8 deletions alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.alts.internal.AltsClientOptions;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.AltsTsiHandshaker;
import io.grpc.alts.internal.GoogleDefaultProtocolNegotiator;
import io.grpc.alts.internal.HandshakerServiceGrpc;
Expand All @@ -36,7 +37,7 @@
import io.grpc.alts.internal.TsiHandshakerFactory;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
import io.grpc.internal.SharedResourcePool;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.NettyChannelBuilder;
Expand Down Expand Up @@ -94,24 +95,23 @@ GoogleDefaultProtocolNegotiator getProtocolNegotiatorForTest() {

private final class ProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {

@Override
public GoogleDefaultProtocolNegotiator buildProtocolNegotiator() {
final LazyChannel lazyHandshakerChannel =
new LazyChannel(
SharedResourcePool.forResource(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL));
TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
// TODO: Release the channel if it is not used.
// https://github.com/grpc/grpc-java/issues/4755.
Channel channel =
SharedResourceHolder.get(HandshakerServiceChannel.SHARED_HANDSHAKER_CHANNEL);
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(channel), handshakerOptions);
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
}
};
SslContext sslContext;
Expand All @@ -121,7 +121,8 @@ public TsiHandshaker newHandshaker(String authority) {
throw new RuntimeException(ex);
}
return negotiatorForTest =
new GoogleDefaultProtocolNegotiator(altsHandshakerFactory, sslContext);
new GoogleDefaultProtocolNegotiator(
altsHandshakerFactory, lazyHandshakerChannel, sslContext);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Preconditions;
import com.google.protobuf.Any;
import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.Grpc;
import io.grpc.InternalChannelz.OtherSecurity;
import io.grpc.InternalChannelz.Security;
Expand All @@ -28,6 +29,7 @@
import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.ProtocolNegotiator;
import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler;
Expand All @@ -47,15 +49,18 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {

@Grpc.TransportAttr
public static final Attributes.Key<TsiPeer> TSI_PEER_KEY = Attributes.Key.create("TSI_PEER");

@Grpc.TransportAttr
public static final Attributes.Key<AltsAuthContext> ALTS_CONTEXT_KEY =
Attributes.Key.create("ALTS_CONTEXT_KEY");

private static final AsciiString scheme = AsciiString.of("https");

/** Creates a negotiator used for ALTS client. */
public static AltsProtocolNegotiator createClientNegotiator(
final TsiHandshakerFactory handshakerFactory) {
final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
final class ClientAltsProtocolNegotiator extends AltsProtocolNegotiator {

@Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
Expand All @@ -68,17 +73,18 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
@Override
public void close() {
logger.finest("ALTS Client ProtocolNegotiator Closed");
// TODO(jiangtaoli2016): release resources
lazyHandshakerChannel.close();
}
}

return new ClientAltsProtocolNegotiator();
}

/** Creates a negotiator used for ALTS server. */
public static AltsProtocolNegotiator createServerNegotiator(
final TsiHandshakerFactory handshakerFactory) {
final TsiHandshakerFactory handshakerFactory, final LazyChannel lazyHandshakerChannel) {
final class ServerAltsProtocolNegotiator extends AltsProtocolNegotiator {

@Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(/*authority=*/ null);
Expand All @@ -91,13 +97,41 @@ public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
@Override
public void close() {
logger.finest("ALTS Server ProtocolNegotiator Closed");
// TODO(jiangtaoli2016): release resources
lazyHandshakerChannel.close();
}
}

return new ServerAltsProtocolNegotiator();
}

/** Channel created from a channel pool lazily. */
public static class LazyChannel {
private final ObjectPool<Channel> channelPool;
private Channel channel;

public LazyChannel(ObjectPool<Channel> channelPool) {
this.channelPool = channelPool;
}

/**
* If channel is null, gets a channel from the channel pool, otherwise, returns the cached
* channel.
*/
public synchronized Channel get() {
if (channel == null) {
channel = channelPool.getObject();
}
return channel;
}

/** Returns the cached channel to the channel pool. */
public synchronized void close() {
if (channel != null) {
channelPool.returnObject(channel);
}
}
}

/** Buffers all writes until the ALTS handshake is complete. */
@VisibleForTesting
static final class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler
Expand Down Expand Up @@ -129,7 +163,7 @@ public AsciiString scheme() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (logger.isLoggable(Level.FINEST)) {
logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[]{evt});
logger.log(Level.FINEST, "User Event triggered while negotiating ALTS", new Object[] {evt});
}
if (evt instanceof TsiHandshakeCompletionEvent) {
TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.grpc.alts.internal;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.internal.GrpcAttributes;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.ProtocolNegotiator;
Expand All @@ -25,11 +26,15 @@

/** A client-side GPRC {@link ProtocolNegotiator} for Google Default Channel. */
public final class GoogleDefaultProtocolNegotiator implements ProtocolNegotiator {

private final ProtocolNegotiator altsProtocolNegotiator;
private final ProtocolNegotiator tlsProtocolNegotiator;

public GoogleDefaultProtocolNegotiator(TsiHandshakerFactory altsFactory, SslContext sslContext) {
altsProtocolNegotiator = AltsProtocolNegotiator.createClientNegotiator(altsFactory);
/** Constructor for protocol negotiator of Google Default Channel. */
public GoogleDefaultProtocolNegotiator(
TsiHandshakerFactory altsFactory, LazyChannel lazyHandshakerChannel, SslContext sslContext) {
altsProtocolNegotiator =
AltsProtocolNegotiator.createClientNegotiator(altsFactory, lazyHandshakerChannel);
tlsProtocolNegotiator = ProtocolNegotiators.tls(sslContext);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@
import static org.junit.Assert.assertTrue;

import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.ManagedChannel;
import io.grpc.SecurityLevel;
import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel;
import io.grpc.alts.internal.TsiFrameProtector.Consumer;
import io.grpc.alts.internal.TsiPeer.Property;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
Expand Down Expand Up @@ -147,8 +153,12 @@ public Object extractPeerObject() throws GeneralSecurityException {
};
}
};
ManagedChannel fakeChannel = NettyChannelBuilder.forTarget("localhost:8080").build();
ObjectPool<Channel> fakeChannelPool = new FixedObjectPool<Channel>(fakeChannel);
LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool);
handler =
AltsProtocolNegotiator.createServerNegotiator(handshakerFactory).newHandler(grpcHandler);
AltsProtocolNegotiator.createServerNegotiator(handshakerFactory, lazyFakeChannel)
.newHandler(grpcHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
}

Expand Down Expand Up @@ -340,8 +350,7 @@ public void doNotFlushEmptyBuffer() throws Exception {
public void peerPropagated() throws Exception {
doHandshake();

assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY))
.isEqualTo(mockedTsiPeer);
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer);
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY))
.isEqualTo(mockedAltsContext);
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
Expand Down