From 006dbbee619ba23a4a0791e22ef7e1a21c59f752 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Thu, 30 Apr 2020 00:35:31 +0300 Subject: [PATCH] provides extra encoding safety Signed-off-by: Oleh Dokuka --- .../io/rsocket/core/RSocketRequester.java | 20 ++++++++ .../io/rsocket/core/RSocketResponder.java | 37 +++++++++++---- .../io/rsocket/core/RSocketRequesterTest.java | 25 ++++++++++ .../io/rsocket/core/RSocketResponderTest.java | 46 +++++++++++++++++++ 4 files changed, 118 insertions(+), 10 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index fabea217b..ced250620 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -195,6 +195,10 @@ public Mono onClose() { } private Mono handleFireAndForget(Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + Throwable err = checkAvailable(); if (err != null) { payload.release(); @@ -227,6 +231,10 @@ private Mono handleFireAndForget(Payload payload) { } private Mono handleRequestResponse(final Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + Throwable err = checkAvailable(); if (err != null) { payload.release(); @@ -289,6 +297,10 @@ public void hookOnTerminal(SignalType signalType) { } private Flux handleRequestStream(final Payload payload) { + if (payload.refCnt() <= 0) { + return Flux.error(new IllegalReferenceCountException()); + } + Throwable err = checkAvailable(); if (err != null) { payload.release(); @@ -371,6 +383,10 @@ private Flux handleChannel(Flux request) { (s, flux) -> { Payload payload = s.get(); if (payload != null) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); final IllegalArgumentException t = @@ -509,6 +525,10 @@ public void cancel() { } private Mono handleMetadataPush(Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + Throwable err = this.terminationError; if (err != null) { payload.release(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index f5c4aecec..2f073ba8a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -450,8 +450,32 @@ protected void hookOnSubscribe(Subscription s) { @Override protected void hookOnNext(Payload payload) { - if (!PayloadValidationUtils.isValid(mtu, payload)) { - payload.release(); + try { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + // specifically for requestChannel case so when Payload is invalid we will not be + // sending CancelFrame and ErrorFrame + // Note: CancelFrame is redundant and due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream + // is + // terminated on both Requester and Responder. + // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both the Requester and Responder. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); + } + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; + } + + ByteBuf byteBuf = + PayloadFrameFlyweight.encodeNextReleasingPayload(allocator, streamId, payload); + sendProcessor.onNext(byteBuf); + } catch (Throwable e) { // specifically for requestChannel case so when Payload is invalid we will not be // sending CancelFrame and ErrorFrame // Note: CancelFrame is redundant and due to spec @@ -464,15 +488,8 @@ protected void hookOnNext(Payload payload) { channelProcessors.remove(streamId, requestChannel); } cancel(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - handleError(streamId, t); - return; + handleError(streamId, e); } - - ByteBuf byteBuf = - PayloadFrameFlyweight.encodeNextReleasingPayload(allocator, streamId, payload); - sendProcessor.onNext(byteBuf); } @Override diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index d7cd8c24b..2117e195d 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -37,6 +37,7 @@ import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; import io.rsocket.Payload; @@ -775,6 +776,30 @@ static Stream encodeDecodePayloadCases() { Arguments.of(REQUEST_CHANNEL, 5, 5)); } + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload( + BiFunction> sourceProducer) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + + Publisher source = sourceProducer.apply(invalidPayload, rule.socket); + + StepVerifier.create(source, 0) + .expectError(IllegalReferenceCountException.class) + .verify(Duration.ofMillis(100)); + } + + private static Stream>> refCntCases() { + return Stream.of( + (p, r) -> r.fireAndForget(p), + (p, r) -> r.requestResponse(p), + (p, r) -> r.requestStream(p), + (p, r) -> r.requestChannel(Mono.just(p)), + (p, r) -> + r.requestChannel(Flux.just(EmptyPayload.INSTANCE, p).doOnSubscribe(s -> s.request(1)))); + } + @Test public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { Payload payload1 = ByteBufPayload.create("abc1"); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 2dbf6715b..9ec2a2df1 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -18,6 +18,7 @@ import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.frame.FrameHeaderFlyweight.frameType; +import static io.rsocket.frame.FrameType.ERROR; import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; import static io.rsocket.frame.FrameType.REQUEST_FNF; import static io.rsocket.frame.FrameType.REQUEST_N; @@ -711,6 +712,51 @@ static Stream encodeDecodePayloadCases() { Arguments.of(REQUEST_CHANNEL, 5, 5)); } + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload(FrameType frameType) { + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Mono.just(invalidPayload); + } + + @Override + public Flux requestStream(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + }); + + rule.sendRequest(1, frameType); + + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches( + bb -> frameType(bb) == ERROR, + "Expect frame type to be {" + + ERROR + + "} but was {" + + frameType(rule.connection.getSent().iterator().next()) + + "}"); + } + + private static Stream refCntCases() { + return Stream.of(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + public static class ServerSocketRule extends AbstractSocketRule { private RSocket acceptingSocket;