diff --git a/reactivesocket-core/src/main/java/io/reactivesocket/Frame.java b/reactivesocket-core/src/main/java/io/reactivesocket/Frame.java index 7f1aaabbd..312a00e55 100644 --- a/reactivesocket-core/src/main/java/io/reactivesocket/Frame.java +++ b/reactivesocket-core/src/main/java/io/reactivesocket/Frame.java @@ -368,7 +368,7 @@ public static Frame from(int streamId, int requestN) { return frame; } - public static long requestN(final Frame frame) { + public static int requestN(final Frame frame) { ensureFrameType(FrameType.REQUEST_N, frame); return RequestNFrameFlyweight.requestN(frame.directBuffer, frame.offset); } diff --git a/reactivesocket-core/src/main/java/io/reactivesocket/internal/Responder.java b/reactivesocket-core/src/main/java/io/reactivesocket/internal/Responder.java index fa6678868..b13ab2bb8 100644 --- a/reactivesocket-core/src/main/java/io/reactivesocket/internal/Responder.java +++ b/reactivesocket-core/src/main/java/io/reactivesocket/internal/Responder.java @@ -735,11 +735,11 @@ public void request(long n) { if (rn.intValue() > 0) { // initial requestN back to the requester (subtract 1 // for the initial frame which was already sent) - child.onNext(Frame.RequestN.from(streamId, rn.intValue() - 1)); + child.onNext(Frame.RequestN.from(streamId, Math.min(Integer.MAX_VALUE, rn.intValue() - 1))); } }, r -> { // requested - child.onNext(Frame.RequestN.from(streamId, r.intValue())); + child.onNext(Frame.RequestN.from(streamId, Math.min(Integer.MAX_VALUE, r.intValue()))); }); synchronized(Responder.this) { if(channels.get(streamId) != null) { diff --git a/reactivesocket-core/src/test/java/io/reactivesocket/internal/RequesterTest.java b/reactivesocket-core/src/test/java/io/reactivesocket/internal/RequesterTest.java index 1b90bf0a5..91d45a66d 100644 --- a/reactivesocket-core/src/test/java/io/reactivesocket/internal/RequesterTest.java +++ b/reactivesocket-core/src/test/java/io/reactivesocket/internal/RequesterTest.java @@ -15,30 +15,33 @@ */ package io.reactivesocket.internal; -import static io.reactivesocket.TestUtil.*; -import static org.junit.Assert.*; -import static io.reactivesocket.ConnectionSetupPayload.NO_FLAGS; -import static io.reactivex.Observable.*; - -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; - -import org.junit.Test; - import io.reactivesocket.ConnectionSetupPayload; import io.reactivesocket.Frame; import io.reactivesocket.FrameType; import io.reactivesocket.LatchedCompletable; import io.reactivesocket.Payload; import io.reactivesocket.TestConnection; -import io.reactivex.subscribers.TestSubscriber; +import io.reactivesocket.util.PayloadImpl; import io.reactivex.Observable; import io.reactivex.subjects.ReplaySubject; +import io.reactivex.subscribers.TestSubscriber; +import org.hamcrest.MatcherAssert; +import org.junit.Test; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import static io.reactivesocket.ConnectionSetupPayload.*; +import static io.reactivesocket.TestUtil.*; +import static io.reactivex.Observable.*; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + public class RequesterTest { final static Consumer ERROR_HANDLER = Throwable::printStackTrace; @@ -79,6 +82,30 @@ public void testReqMetaPushCancelBeforeRequestN() throws InterruptedException { testCancelBeforeRequestN(p.metadataPush(utf8EncodedPayload("hello", null))); } + @Test() + public void testReqStreamRequestLongMax() throws InterruptedException { + TestConnection testConnection = establishConnection(); + Requester p = createClientRequester(testConnection); + + testRequestLongMaxValue(p.requestStream(new PayloadImpl("")), testConnection); + } + + @Test() + public void testReqSubscriptionRequestLongMax() throws InterruptedException { + TestConnection testConnection = establishConnection(); + Requester p = createClientRequester(testConnection); + + testRequestLongMaxValue(p.requestSubscription(new PayloadImpl("")), testConnection); + } + + @Test() + public void testReqChannelRequestLongMax() throws InterruptedException { + TestConnection testConnection = establishConnection(); + Requester p = createClientRequester(testConnection); + + testRequestLongMaxValue(p.requestChannel(Publishers.just(new PayloadImpl(""))), testConnection); + } + @Test(timeout=2000) public void testRequestResponseSuccess() throws InterruptedException { TestConnection conn = establishConnection(); @@ -306,14 +333,35 @@ private static void testCancelBeforeRequestN(Publisher source) { testSubscriber.assertNotComplete(); } - private static Requester createClientRequester() throws InterruptedException { - TestConnection conn = establishConnection(); + private static void testRequestLongMaxValue(Publisher source, TestConnection testConnection) { + List requestNs = new ArrayList<>(); + testConnection.write.add(frame -> { + if (frame.getType() == FrameType.REQUEST_N) { + requestNs.add(Frame.RequestN.requestN(frame)); + } + }); + + TestSubscriber testSubscriber = new TestSubscriber(1L); + source.subscribe(testSubscriber); + + testSubscriber.request(Long.MAX_VALUE); + testSubscriber.assertNoErrors(); + testSubscriber.assertNotComplete(); + + MatcherAssert.assertThat("Negative requestNs received.", requestNs, not(contains(-1))); + } + + private static Requester createClientRequester(TestConnection connection) throws InterruptedException { LatchedCompletable rc = new LatchedCompletable(1); - Requester p = Requester.createClientRequester(conn, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); + Requester p = Requester.createClientRequester(connection, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); rc.await(); return p; } + private static Requester createClientRequester() throws InterruptedException { + return createClientRequester(establishConnection()); + } + private static TestConnection establishConnection() { return new TestConnection(); }