Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,23 @@
final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand {
private final NettyServerStream.TransportState stream;
private final Status reason;
private final PeerNotify peerNotify;

CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) {
private CancelServerStreamCommand(
NettyServerStream.TransportState stream, Status reason, PeerNotify peerNotify) {
this.stream = Preconditions.checkNotNull(stream, "stream");
this.reason = Preconditions.checkNotNull(reason, "reason");
this.peerNotify = Preconditions.checkNotNull(peerNotify, "peerNotify");
}

static CancelServerStreamCommand withReset(
NettyServerStream.TransportState stream, Status reason) {
return new CancelServerStreamCommand(stream, reason, PeerNotify.RESET);
}

static CancelServerStreamCommand withReason(
NettyServerStream.TransportState stream, Status reason) {
return new CancelServerStreamCommand(stream, reason, PeerNotify.BEST_EFFORT_STATUS);
}

NettyServerStream.TransportState stream() {
Expand All @@ -41,6 +54,10 @@ Status reason() {
return reason;
}

boolean wantsHeaders() {
return peerNotify == PeerNotify.BEST_EFFORT_STATUS;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -68,4 +85,11 @@ public String toString() {
.add("reason", reason)
.toString();
}

private enum PeerNotify {
/** Notify the peer by sending a RST_STREAM with no other information. */
RESET,
/** Notify the peer about the {@link #reason} by sending structured headers, if possible. */
BEST_EFFORT_STATUS,
}
}
32 changes: 30 additions & 2 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -788,9 +788,37 @@ private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand c
PerfMark.linkIn(cmd.getLink());
// Notify the listener if we haven't already.
cmd.stream().transportReportStatus(cmd.reason());
// Terminate the stream.
encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise);

// Now we need to decide how we're going to notify the peer that this stream is closed.
// If possible, it's nice to inform the peer _why_ this stream was cancelled by sending
// a structured headers frame.
if (shouldCloseStreamWithHeaders(cmd, connection())) {
Metadata md = new Metadata();
md.put(InternalStatus.CODE_KEY, cmd.reason());
if (cmd.reason().getDescription() != null) {
md.put(InternalStatus.MESSAGE_KEY, cmd.reason().getDescription());
}
Http2Headers headers = Utils.convertServerHeaders(md);
encoder().writeHeaders(
ctx, cmd.stream().id(), headers, /* padding = */ 0, /* endStream = */ true, promise);
} else {
// Terminate the stream.
encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise);
}
}
}

// Determine whether a CancelServerStreamCommand should try to close the stream with a
// HEADERS or a RST_STREAM frame. The caller has some influence over this (they can
// configure cmd.wantsHeaders()). The state of the stream also has an influence: we
// only try to send HEADERS if the stream exists and hasn't already sent any headers.
private static boolean shouldCloseStreamWithHeaders(
CancelServerStreamCommand cmd, Http2Connection conn) {
if (!cmd.wantsHeaders()) {
return false;
}
Http2Stream stream = conn.stream(cmd.stream().id());
return stream != null && !stream.isHeadersSent();
}

private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg,
Expand Down
6 changes: 3 additions & 3 deletions netty/src/main/java/io/grpc/netty/NettyServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status)
@Override
public void cancel(Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) {
writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true);
writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true);
}
}
}
Expand Down Expand Up @@ -189,7 +189,7 @@ public void deframeFailed(Throwable cause) {
log.log(Level.WARNING, "Exception processing message", cause);
Status status = Status.fromThrowable(cause);
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReason(this, status), true);
}

private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) {
Expand Down Expand Up @@ -222,7 +222,7 @@ private void handleWriteFutureFailures(ChannelFuture future) {
*/
protected void http2ProcessingFailed(Status status) {
transportReportStatus(status);
handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true);
handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReset(this, status), true);
}

void inboundDataReceived(ByteBuf frame, boolean endOfStream) {
Expand Down
34 changes: 33 additions & 1 deletion netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@
import java.io.InputStream;
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
Expand Down Expand Up @@ -469,11 +471,41 @@ public void connectionWindowShouldBeOverridden() throws Exception {
public void cancelShouldSendRstStream() throws Exception {
manualSetUp();
createStream();
enqueue(new CancelServerStreamCommand(stream.transportState(), Status.DEADLINE_EXCEEDED));
enqueue(CancelServerStreamCommand.withReset(stream.transportState(), Status.DEADLINE_EXCEEDED));
verifyWrite().writeRstStream(eq(ctx()), eq(stream.transportState().id()),
eq(Http2Error.CANCEL.code()), any(ChannelPromise.class));
}

@Test
public void cancelWithNotify_shouldSendHeaders() throws Exception {
manualSetUp();
createStream();

enqueue(CancelServerStreamCommand.withReason(
stream.transportState(),
Status.RESOURCE_EXHAUSTED.withDescription("my custom description")
));

ArgumentCaptor<Http2Headers> captor = ArgumentCaptor.forClass(Http2Headers.class);
verifyWrite()
.writeHeaders(
eq(ctx()),
eq(STREAM_ID),
captor.capture(),
eq(0),
eq(true),
any(ChannelPromise.class));

// For arcane reasons, the specific implementation of Http2Headers here doesn't actually support
// methods like `get(...)`, so we have to manually convert it into a map.
Map<String, String> actualHeaders = new HashMap<>();
for (Map.Entry<CharSequence, CharSequence> entry : captor.getValue()) {
actualHeaders.put(entry.getKey().toString(), entry.getValue().toString());
}
assertEquals("8", actualHeaders.get(InternalStatus.CODE_KEY.name()));
assertEquals("my custom description", actualHeaders.get(InternalStatus.MESSAGE_KEY.name()));
}

@Test
public void headersWithInvalidContentTypeShouldFail() throws Exception {
manualSetUp();
Expand Down
29 changes: 26 additions & 3 deletions netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static io.grpc.netty.NettyTestUtil.messageFrame;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
Expand All @@ -37,6 +36,7 @@
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import io.grpc.Attributes;
Expand Down Expand Up @@ -73,6 +73,8 @@
/** Unit tests for {@link NettyServerStream}. */
@RunWith(JUnit4.class)
public class NettyServerStreamTest extends NettyStreamTestBase<NettyServerStream> {
private static final int TEST_MAX_MESSAGE_SIZE = 128;

@Mock
protected ServerStreamListener serverListener;

Expand Down Expand Up @@ -380,18 +382,39 @@ public void emptyFramerShouldSendNoPayload() {
public void cancelStreamShouldSucceed() {
stream().cancel(Status.DEADLINE_EXCEEDED);
verify(writeQueue).enqueue(
new CancelServerStreamCommand(stream().transportState(), Status.DEADLINE_EXCEEDED),
CancelServerStreamCommand.withReset(stream().transportState(), Status.DEADLINE_EXCEEDED),
true);
}

@Test
public void oversizedMessagesResultInResourceExhaustedTrailers() throws Exception {
@SuppressWarnings("InlineMeInliner") // Requires Java 11
String oversizedMsg = Strings.repeat("a", TEST_MAX_MESSAGE_SIZE + 1);
stream.request(1);
stream.transportState().inboundDataReceived(messageFrame(oversizedMsg), false);
assertNull("message should have caused a deframer error", listenerMessageQueue().poll());

ArgumentCaptor<CancelServerStreamCommand> cancelCmdCap =
ArgumentCaptor.forClass(CancelServerStreamCommand.class);
verify(writeQueue).enqueue(cancelCmdCap.capture(), eq(true));

Status status = Status.RESOURCE_EXHAUSTED
.withDescription("gRPC message exceeds maximum size 128: 129");

CancelServerStreamCommand actualCmd = cancelCmdCap.getValue();
assertThat(actualCmd.reason().getCode()).isEqualTo(status.getCode());
assertThat(actualCmd.reason().getDescription()).isEqualTo(status.getDescription());
assertThat(actualCmd.wantsHeaders()).isTrue();
}

@Override
@SuppressWarnings("DirectInvocationOnMock")
protected NettyServerStream createStream() {
when(handler.getWriteQueue()).thenReturn(writeQueue);
StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
TransportTracer transportTracer = new TransportTracer();
NettyServerStream.TransportState state = new NettyServerStream.TransportState(
handler, channel.eventLoop(), http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx,
handler, channel.eventLoop(), http2Stream, TEST_MAX_MESSAGE_SIZE, statsTraceCtx,
transportTracer, "method");
NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY,
"test-authority", statsTraceCtx);
Expand Down