diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java index 6631ffa13e8a..e914ef160deb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java @@ -632,6 +632,8 @@ private abstract class AbstractWindmillStream implements Wi // The following should be protected by synchronizing on this, except for // the atomics which may be read atomically for status pages. private StreamObserver requestObserver; + // Indicates if the current stream in requestObserver is closed by calling close() method + private final AtomicBoolean streamClosed = new AtomicBoolean(); private final AtomicLong startTimeMs = new AtomicLong(); private final AtomicLong lastSendTimeMs = new AtomicLong(); private final AtomicLong lastResponseTimeMs = new AtomicLong(); @@ -663,7 +665,7 @@ protected AbstractWindmillStream( protected final void send(RequestT request) { lastSendTimeMs.set(Instant.now().getMillis()); synchronized (this) { - if (clientClosed.get()) { + if (streamClosed.get()) { throw new IllegalStateException("Send called on a client closed stream."); } requestObserver.onNext(request); @@ -681,6 +683,7 @@ protected final void startStream() { startTimeMs.set(Instant.now().getMillis()); lastResponseTimeMs.set(0); requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver()); + streamClosed.set(false); onNewStream(); if (clientClosed.get()) { close(); @@ -742,10 +745,11 @@ public final void appendSummaryHtml(PrintWriter writer) { writer.format(", %dms backoff remaining", sleepLeft); } writer.format( - ", current stream is %dms old, last send %dms, last response %dms", + ", current stream is %dms old, last send %dms, last response %dms, closed: %s", debugDuration(nowMs, startTimeMs.get()), debugDuration(nowMs, lastSendTimeMs.get()), - debugDuration(nowMs, lastResponseTimeMs.get())); + debugDuration(nowMs, lastResponseTimeMs.get()), + streamClosed.get()); } // Don't require synchronization on stream, see the appendSummaryHtml comment. @@ -838,6 +842,7 @@ public final synchronized void close() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. clientClosed.set(true); requestObserver.onCompleted(); + streamClosed.set(true); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java index c5d7b0c0f323..64a31f368315 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServerTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.io.InputStream; import java.io.SequenceInputStream; @@ -28,11 +29,14 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; @@ -491,11 +495,96 @@ private WorkItemCommitRequest makeCommitRequest(int i, int size) { .build(); } + // This server receives WorkItemCommitRequests, and verifies they are equal to the provided + // commitRequest. + private StreamObserver getTestCommitStreamObserver( + StreamObserver responseObserver, + Map commitRequests) { + return new StreamObserver() { + boolean sawHeader = false; + InputStream buffer = null; + long remainingBytes = 0; + ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); + + @Override + public void onNext(StreamingCommitWorkRequest request) { + maybeInjectError(responseObserver); + + if (!sawHeader) { + errorCollector.checkThat( + request.getHeader(), + Matchers.equalTo( + JobHeader.newBuilder() + .setJobId("job") + .setProjectId("project") + .setWorkerId("worker") + .build())); + sawHeader = true; + LOG.info("Received header"); + } else { + boolean first = true; + LOG.info("Received request with {} chunks", request.getCommitChunkCount()); + for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) { + assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE); + if (first || chunk.hasComputationId()) { + errorCollector.checkThat(chunk.getComputationId(), Matchers.equalTo("computation")); + } + + if (remainingBytes != 0) { + errorCollector.checkThat(buffer, Matchers.notNullValue()); + errorCollector.checkThat( + remainingBytes, + Matchers.is( + chunk.getSerializedWorkItemCommit().size() + + chunk.getRemainingBytesForWorkItem())); + buffer = + new SequenceInputStream(buffer, chunk.getSerializedWorkItemCommit().newInput()); + } else { + errorCollector.checkThat(buffer, Matchers.nullValue()); + buffer = chunk.getSerializedWorkItemCommit().newInput(); + } + remainingBytes = chunk.getRemainingBytesForWorkItem(); + if (remainingBytes == 0) { + try { + WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer); + errorCollector.checkThat( + received, Matchers.equalTo(commitRequests.get(received.getWorkToken()))); + try { + responseObserver.onNext( + StreamingCommitResponse.newBuilder() + .addRequestId(chunk.getRequestId()) + .build()); + } catch (IllegalStateException e) { + // Stream is closed. + } + } catch (Exception e) { + errorCollector.addError(e); + } + buffer = null; + } else { + errorCollector.checkThat(first, Matchers.is(true)); + } + first = false; + } + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + injector.cancel(); + responseObserver.onCompleted(); + } + }; + } + @Test public void testStreamingCommit() throws Exception { List commitRequestList = new ArrayList<>(); List latches = new ArrayList<>(); - Map commitRequests = new HashMap<>(); + Map commitRequests = new ConcurrentHashMap<>(); for (int i = 0; i < 500; ++i) { // Build some requests of varying size with a few big ones. WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128)); @@ -505,92 +594,94 @@ public void testStreamingCommit() throws Exception { } Collections.shuffle(commitRequestList); - // This server receives WorkItemCommitRequests, and verifies they are equal to the above - // commitRequest. serviceRegistry.addService( new CloudWindmillServiceV1Alpha1ImplBase() { @Override public StreamObserver commitWorkStream( StreamObserver responseObserver) { - return new StreamObserver() { - boolean sawHeader = false; - InputStream buffer = null; - long remainingBytes = 0; - ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver); + return getTestCommitStreamObserver(responseObserver, commitRequests); + } + }); - @Override - public void onNext(StreamingCommitWorkRequest request) { - maybeInjectError(responseObserver); + // Make the commit requests, waiting for each of them to be verified and acknowledged. + CommitWorkStream stream = client.commitWorkStream(); + for (int i = 0; i < commitRequestList.size(); ) { + final CountDownLatch latch = latches.get(i); + if (stream.commitWorkItem( + "computation", + commitRequestList.get(i), + (CommitStatus status) -> { + assertEquals(status, CommitStatus.OK); + latch.countDown(); + })) { + i++; + } else { + stream.flush(); + } + } + stream.flush(); + stream.close(); + for (CountDownLatch latch : latches) { + assertTrue(latch.await(1, TimeUnit.MINUTES)); + } + assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS)); + } - if (!sawHeader) { - errorCollector.checkThat( - request.getHeader(), - Matchers.equalTo( - JobHeader.newBuilder() - .setJobId("job") - .setProjectId("project") - .setWorkerId("worker") - .build())); - sawHeader = true; - LOG.info("Received header"); - } else { - boolean first = true; - LOG.info("Received request with {} chunks", request.getCommitChunkCount()); - for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) { - assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE); - if (first || chunk.hasComputationId()) { - errorCollector.checkThat( - chunk.getComputationId(), Matchers.equalTo("computation")); - } + @Test + // Tests stream retries on server errors before and after `close()` + public void testStreamingCommitClosedStream() throws Exception { + List commitRequestList = new ArrayList<>(); + List latches = new ArrayList<>(); + Map commitRequests = new ConcurrentHashMap<>(); + AtomicBoolean shouldServerReturnError = new AtomicBoolean(true); + AtomicBoolean isClientClosed = new AtomicBoolean(false); + AtomicInteger errorsBeforeClose = new AtomicInteger(); + AtomicInteger errorsAfterClose = new AtomicInteger(); + for (int i = 0; i < 500; ++i) { + // Build some requests of varying size with a few big ones. + WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128)); + commitRequestList.add(request); + commitRequests.put((long) i, request); + latches.add(new CountDownLatch(1)); + } + Collections.shuffle(commitRequestList); - if (remainingBytes != 0) { - errorCollector.checkThat(buffer, Matchers.notNullValue()); - errorCollector.checkThat( - remainingBytes, - Matchers.is( - chunk.getSerializedWorkItemCommit().size() - + chunk.getRemainingBytesForWorkItem())); - buffer = - new SequenceInputStream( - buffer, chunk.getSerializedWorkItemCommit().newInput()); - } else { - errorCollector.checkThat(buffer, Matchers.nullValue()); - buffer = chunk.getSerializedWorkItemCommit().newInput(); - } - remainingBytes = chunk.getRemainingBytesForWorkItem(); - if (remainingBytes == 0) { - try { - WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer); - errorCollector.checkThat( - received, - Matchers.equalTo(commitRequests.get(received.getWorkToken()))); - try { - responseObserver.onNext( - StreamingCommitResponse.newBuilder() - .addRequestId(chunk.getRequestId()) - .build()); - } catch (IllegalStateException e) { - // Stream is closed. - } - } catch (Exception e) { - errorCollector.addError(e); - } - buffer = null; + // This server returns errors if shouldServerReturnError is true, else returns valid responses. + serviceRegistry.addService( + new CloudWindmillServiceV1Alpha1ImplBase() { + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + StreamObserver testCommitStreamObserver = + getTestCommitStreamObserver(responseObserver, commitRequests); + return new StreamObserver() { + @Override + public void onNext(StreamingCommitWorkRequest request) { + if (shouldServerReturnError.get()) { + try { + responseObserver.onError( + new RuntimeException("shouldServerReturnError = true")); + if (isClientClosed.get()) { + errorsAfterClose.incrementAndGet(); } else { - errorCollector.checkThat(first, Matchers.is(true)); + errorsBeforeClose.incrementAndGet(); } - first = false; + } catch (IllegalStateException e) { + // The stream is already closed. } + } else { + testCommitStreamObserver.onNext(request); } } @Override - public void onError(Throwable throwable) {} + public void onError(Throwable throwable) { + testCommitStreamObserver.onError(throwable); + } @Override public void onCompleted() { - injector.cancel(); - responseObserver.onCompleted(); + testCommitStreamObserver.onCompleted(); } }; } @@ -613,11 +704,50 @@ public void onCompleted() { } } stream.flush(); - for (CountDownLatch latch : latches) { - assertTrue(latch.await(1, TimeUnit.MINUTES)); + + long deadline = System.currentTimeMillis() + 60_000; // 1 min + while (true) { + Thread.sleep(100); + int tmpErrorsBeforeClose = errorsBeforeClose.get(); + // wait for at least 1 errors before close + if (tmpErrorsBeforeClose > 0) { + break; + } + if (System.currentTimeMillis() > deadline) { + // Control should not reach here if the test is working as expected + fail( + String.format( + "Expected errors not sent by server errorsBeforeClose: %s" + + " \n Should not reach here if the test is working as expected.", + tmpErrorsBeforeClose)); + } } stream.close(); + isClientClosed.set(true); + + deadline = System.currentTimeMillis() + 60_000; // 1 min + while (true) { + Thread.sleep(100); + int tmpErrorsAfterClose = errorsAfterClose.get(); + // wait for at least 1 errors after close + if (tmpErrorsAfterClose > 0) { + break; + } + if (System.currentTimeMillis() > deadline) { + // Control should not reach here if the test is working as expected + fail( + String.format( + "Expected errors not sent by server errorsAfterClose: %s" + + " \n Should not reach here if the test is working as expected.", + tmpErrorsAfterClose)); + } + } + + shouldServerReturnError.set(false); + for (CountDownLatch latch : latches) { + assertTrue(latch.await(1, TimeUnit.MINUTES)); + } assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS)); }