Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,11 @@ HttpResponse<byte[]> sendInternal(Marshaler marshaler) throws IOException {
do {
if (attempt > 0) {
// Compute and sleep for backoff
long upperBoundNanos = Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos());
long backoffNanos = ThreadLocalRandom.current().nextLong(upperBoundNanos);
nextBackoffNanos = (long) (nextBackoffNanos * retryPolicy.getBackoffMultiplier());
long currentBackoffNanos =
Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos());
long backoffNanos =
(long) (ThreadLocalRandom.current().nextDouble(0.8d, 1.2d) * currentBackoffNanos);
nextBackoffNanos = (long) (currentBackoffNanos * retryPolicy.getBackoffMultiplier());
try {
TimeUnit.NANOSECONDS.sleep(backoffNanos);
} catch (InterruptedException e) {
Expand All @@ -227,16 +229,11 @@ HttpResponse<byte[]> sendInternal(Marshaler marshaler) throws IOException {
break;
}
}

attempt++;
httpResponse = null;
exception = null;
requestBuilder.timeout(Duration.ofNanos(timeoutNanos - (System.nanoTime() - startTimeNanos)));
try {
httpResponse = sendRequest(requestBuilder, byteBufferPool);
} catch (IOException e) {
exception = e;
}

if (httpResponse != null) {
boolean retryable = retryableStatusCodes.contains(httpResponse.statusCode());
if (logger.isLoggable(Level.FINER)) {
logger.log(
Expand All @@ -251,8 +248,8 @@ HttpResponse<byte[]> sendInternal(Marshaler marshaler) throws IOException {
if (!retryable) {
return httpResponse;
}
}
if (exception != null) {
} catch (IOException e) {
exception = e;
boolean retryable = retryExceptionPredicate.test(exception);
if (logger.isLoggable(Level.FINER)) {
logger.log(
Expand All @@ -268,7 +265,7 @@ HttpResponse<byte[]> sendInternal(Marshaler marshaler) throws IOException {
throw exception;
}
}
} while (attempt < retryPolicy.getMaxAttempts());
} while (++attempt < retryPolicy.getMaxAttempts());

if (httpResponse != null) {
return httpResponse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;
import okhttp3.Interceptor;
Expand All @@ -37,7 +38,7 @@ public final class RetryInterceptor implements Interceptor {
private final Function<Response, Boolean> isRetryable;
private final Predicate<IOException> retryExceptionPredicate;
private final Sleeper sleeper;
private final BoundedLongGenerator randomLong;
private final Supplier<Double> randomJitter;

/** Constructs a new retrier. */
public RetryInterceptor(RetryPolicy retryPolicy, Function<Response, Boolean> isRetryable) {
Expand All @@ -48,7 +49,7 @@ public RetryInterceptor(RetryPolicy retryPolicy, Function<Response, Boolean> isR
? RetryInterceptor::isRetryableException
: retryPolicy.getRetryExceptionPredicate(),
TimeUnit.NANOSECONDS::sleep,
bound -> ThreadLocalRandom.current().nextLong(bound));
() -> ThreadLocalRandom.current().nextDouble(0.8d, 1.2d));
}

// Visible for testing
Expand All @@ -57,12 +58,12 @@ public RetryInterceptor(RetryPolicy retryPolicy, Function<Response, Boolean> isR
Function<Response, Boolean> isRetryable,
Predicate<IOException> retryExceptionPredicate,
Sleeper sleeper,
BoundedLongGenerator randomLong) {
Supplier<Double> randomJitter) {
this.retryPolicy = retryPolicy;
this.isRetryable = isRetryable;
this.retryExceptionPredicate = retryExceptionPredicate;
this.sleeper = sleeper;
this.randomLong = randomLong;
this.randomJitter = randomJitter;
}

@Override
Expand All @@ -75,9 +76,10 @@ public Response intercept(Chain chain) throws IOException {
if (attempt > 0) {
// Compute and sleep for backoff
// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#exponential-backoff
long upperBoundNanos = Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos());
long backoffNanos = randomLong.get(upperBoundNanos);
nextBackoffNanos = (long) (nextBackoffNanos * retryPolicy.getBackoffMultiplier());
long currentBackoffNanos =
Math.min(nextBackoffNanos, retryPolicy.getMaxBackoff().toNanos());
long backoffNanos = (long) (randomJitter.get() * currentBackoffNanos);
nextBackoffNanos = (long) (currentBackoffNanos * retryPolicy.getBackoffMultiplier());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also update the implementation in JdkHttpSender.

I know its not ideal that there are two implementations.. Maybe worth adding a utility function to RetryPolicy that computes the backoff for a given attempt N. Signature might look like:

public long computeBackoffNanosForAttempt(int attempt, Random randomSource) {...}

It wouldn't be as efficient as the current implementation, but...

  • its such a tiny amount of compute that who cares
  • the compute is trivial compared to the overall cost of preparing and executing and HTTP request

Copy link
Contributor Author

@YuriyHolinko YuriyHolinko Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Added code for JdkHttpSender

I did not consider adding that method to calculate a backoff delay time
looking at the code I would say we can build more abstractions for sending requests and checking responses and exceptions, but not sure it's really helpful so let's probably move on with duplicated approach as we had before

try {
sleeper.sleep(backoffNanos);
} catch (InterruptedException e) {
Expand All @@ -88,31 +90,31 @@ public Response intercept(Chain chain) throws IOException {
if (response != null) {
response.close();
}
exception = null;
}

attempt++;
try {
response = chain.proceed(chain.request());
if (response != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for changing this part of the logic?

Copy link
Contributor Author

@YuriyHolinko YuriyHolinko Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if response is null and no exception happened, the code fails in throw exception; line because exception is null. when response is null it's not something transient

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is possible but I haven't seen response null in practice. If it does occur, we can simply add a null check immediately after response = chain.proceed(chain.request()).

Copy link
Contributor Author

@YuriyHolinko YuriyHolinko Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's exactly what I did in this PR 😃
also null check was before this change so I intentionally keep it (but I suspect I can just drop it)

also previous code has the issue - if previous(before last) attempt returned rertryable response but the last attempt gets retryable exception method still returns the previous response which is not good as the last state (exception) should be returned

boolean retryable = Boolean.TRUE.equals(isRetryable.apply(response));
if (logger.isLoggable(Level.FINER)) {
logger.log(
Level.FINER,
"Attempt "
+ attempt
+ " returned "
+ (retryable ? "retryable" : "non-retryable")
+ " response: "
+ responseStringRepresentation(response));
}
if (!retryable) {
return response;
}
} else {
throw new NullPointerException("response cannot be null.");
}
} catch (IOException e) {
exception = e;
}
if (response != null) {
boolean retryable = Boolean.TRUE.equals(isRetryable.apply(response));
if (logger.isLoggable(Level.FINER)) {
logger.log(
Level.FINER,
"Attempt "
+ attempt
+ " returned "
+ (retryable ? "retryable" : "non-retryable")
+ " response: "
+ responseStringRepresentation(response));
}
if (!retryable) {
return response;
}
}
if (exception != null) {
response = null;
boolean retryable = retryExceptionPredicate.test(exception);
if (logger.isLoggable(Level.FINER)) {
logger.log(
Expand All @@ -128,8 +130,7 @@ public Response intercept(Chain chain) throws IOException {
throw exception;
}
}

} while (attempt < retryPolicy.getMaxAttempts());
} while (++attempt < retryPolicy.getMaxAttempts());

if (response != null) {
return response;
Expand Down Expand Up @@ -172,11 +173,6 @@ static boolean isRetryableException(IOException e) {
return false;
}

// Visible for testing
interface BoundedLongGenerator {
long get(long bound);
}

// Visible for testing
interface Sleeper {
void sleep(long delayNanos) throws InterruptedException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
Expand All @@ -32,9 +33,11 @@
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Stream;
import okhttp3.Interceptor;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
Expand All @@ -47,15 +50,17 @@
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.stubbing.Answer;

@ExtendWith(MockitoExtension.class)
class RetryInterceptorTest {

@RegisterExtension static final MockWebServerExtension server = new MockWebServerExtension();

@Mock private RetryInterceptor.Sleeper sleeper;
@Mock private RetryInterceptor.BoundedLongGenerator random;
@Mock private Supplier<Double> random;
private Predicate<IOException> retryExceptionPredicate;

private RetryInterceptor retrier;
Expand Down Expand Up @@ -91,6 +96,24 @@ public boolean test(IOException e) {
client = new OkHttpClient.Builder().addInterceptor(retrier).build();
}

@Test
void noRetryOnNullResponse() throws IOException {
Interceptor.Chain chain = mock(Interceptor.Chain.class);
when(chain.proceed(any())).thenReturn(null);
when(chain.request())
.thenReturn(new Request.Builder().url(server.httpUri().toString()).build());
assertThatThrownBy(
() -> {
retrier.intercept(chain);
})
.isInstanceOf(NullPointerException.class)
.hasMessage("response cannot be null.");

verifyNoInteractions(retryExceptionPredicate);
verifyNoInteractions(random);
verifyNoInteractions(sleeper);
}

@Test
void noRetry() throws Exception {
server.enqueue(HttpResponse.of(HttpStatus.OK));
Expand All @@ -109,17 +132,8 @@ void noRetry() throws Exception {
@ValueSource(ints = {5, 6})
void backsOff(int attempts) throws Exception {
succeedOnAttempt(attempts);

// Will backoff 4 times
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 0)))).thenReturn(100L);
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 1)))).thenReturn(50L);
// Capped
when(random.get(TimeUnit.SECONDS.toNanos(2))).thenReturn(500L).thenReturn(510L);

doNothing().when(sleeper).sleep(100);
doNothing().when(sleeper).sleep(50);
doNothing().when(sleeper).sleep(500);
doNothing().when(sleeper).sleep(510);
when(random.get()).thenReturn(1.0d);
doNothing().when(sleeper).sleep(anyLong());

try (Response response = sendRequest()) {
if (attempts <= 5) {
Expand All @@ -139,16 +153,26 @@ void interrupted() throws Exception {
succeedOnAttempt(5);

// Backs off twice, second is interrupted
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 0)))).thenReturn(100L);
when(random.get((long) (TimeUnit.SECONDS.toNanos(1) * Math.pow(1.6, 1)))).thenReturn(50L);
when(random.get()).thenReturn(1.0d).thenReturn(1.0d);
doAnswer(
new Answer<Void>() {
int counter = 0;

doNothing().when(sleeper).sleep(100);
doThrow(new InterruptedException()).when(sleeper).sleep(50);
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
if (counter++ == 1) {
throw new InterruptedException();
}
return null;
}
})
.when(sleeper)
.sleep(anyLong());

try (Response response = sendRequest()) {
assertThat(response.isSuccessful()).isFalse();
}

verify(sleeper, times(2)).sleep(anyLong());
for (int i = 0; i < 2; i++) {
server.takeRequest(0, TimeUnit.NANOSECONDS);
}
Expand All @@ -157,7 +181,7 @@ void interrupted() throws Exception {
@Test
void connectTimeout() throws Exception {
client = connectTimeoutClient();
when(random.get(anyLong())).thenReturn(1L);
when(random.get()).thenReturn(1.0d);
doNothing().when(sleeper).sleep(anyLong());

// Connecting to a non-routable IP address to trigger connection error
Expand All @@ -174,7 +198,7 @@ void connectTimeout() throws Exception {
@Test
void connectException() throws Exception {
client = connectTimeoutClient();
when(random.get(anyLong())).thenReturn(1L);
when(random.get()).thenReturn(1.0d);
doNothing().when(sleeper).sleep(anyLong());

// Connecting to localhost on an unused port address to trigger java.net.ConnectException
Expand Down
Loading