Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ private abstract class AbstractWindmillStream<RequestT, ResponseT> implements Wi
// The following should be protected by synchronizing on this, except for
// the atomics which may be read atomically for status pages.
private StreamObserver<RequestT> requestObserver;
// Indicates if the current stream in requestObserver is closed by calling close() method
private final AtomicBoolean streamClosed = new AtomicBoolean();
private final AtomicLong startTimeMs = new AtomicLong();
private final AtomicLong lastSendTimeMs = new AtomicLong();
private final AtomicLong lastResponseTimeMs = new AtomicLong();
Expand Down Expand Up @@ -663,7 +665,7 @@ protected AbstractWindmillStream(
protected final void send(RequestT request) {
lastSendTimeMs.set(Instant.now().getMillis());
synchronized (this) {
if (clientClosed.get()) {
if (streamClosed.get()) {
throw new IllegalStateException("Send called on a client closed stream.");
}
requestObserver.onNext(request);
Expand All @@ -681,6 +683,7 @@ protected final void startStream() {
startTimeMs.set(Instant.now().getMillis());
lastResponseTimeMs.set(0);
requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver());
streamClosed.set(false);
onNewStream();
if (clientClosed.get()) {
close();
Expand Down Expand Up @@ -742,10 +745,11 @@ public final void appendSummaryHtml(PrintWriter writer) {
writer.format(", %dms backoff remaining", sleepLeft);
}
writer.format(
", current stream is %dms old, last send %dms, last response %dms",
", current stream is %dms old, last send %dms, last response %dms, closed: %s",
debugDuration(nowMs, startTimeMs.get()),
debugDuration(nowMs, lastSendTimeMs.get()),
debugDuration(nowMs, lastResponseTimeMs.get()));
debugDuration(nowMs, lastResponseTimeMs.get()),
streamClosed.get());
}

// Don't require synchronization on stream, see the appendSummaryHtml comment.
Expand Down Expand Up @@ -838,6 +842,7 @@ public final synchronized void close() {
// Synchronization of close and onCompleted necessary for correct retry logic in onNewStream.
clientClosed.set(true);
requestObserver.onCompleted();
streamClosed.set(true);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.io.InputStream;
import java.io.SequenceInputStream;
Expand All @@ -28,11 +29,14 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
Expand Down Expand Up @@ -491,11 +495,96 @@ private WorkItemCommitRequest makeCommitRequest(int i, int size) {
.build();
}

// This server receives WorkItemCommitRequests, and verifies they are equal to the provided
// commitRequest.
private StreamObserver<StreamingCommitWorkRequest> getTestCommitStreamObserver(
StreamObserver<StreamingCommitResponse> responseObserver,
Map<Long, WorkItemCommitRequest> commitRequests) {
return new StreamObserver<StreamingCommitWorkRequest>() {
boolean sawHeader = false;
InputStream buffer = null;
long remainingBytes = 0;
ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);

@Override
public void onNext(StreamingCommitWorkRequest request) {
maybeInjectError(responseObserver);

if (!sawHeader) {
errorCollector.checkThat(
request.getHeader(),
Matchers.equalTo(
JobHeader.newBuilder()
.setJobId("job")
.setProjectId("project")
.setWorkerId("worker")
.build()));
sawHeader = true;
LOG.info("Received header");
} else {
boolean first = true;
LOG.info("Received request with {} chunks", request.getCommitChunkCount());
for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) {
assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE);
if (first || chunk.hasComputationId()) {
errorCollector.checkThat(chunk.getComputationId(), Matchers.equalTo("computation"));
}

if (remainingBytes != 0) {
errorCollector.checkThat(buffer, Matchers.notNullValue());
errorCollector.checkThat(
remainingBytes,
Matchers.is(
chunk.getSerializedWorkItemCommit().size()
+ chunk.getRemainingBytesForWorkItem()));
buffer =
new SequenceInputStream(buffer, chunk.getSerializedWorkItemCommit().newInput());
} else {
errorCollector.checkThat(buffer, Matchers.nullValue());
buffer = chunk.getSerializedWorkItemCommit().newInput();
}
remainingBytes = chunk.getRemainingBytesForWorkItem();
if (remainingBytes == 0) {
try {
WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer);
errorCollector.checkThat(
received, Matchers.equalTo(commitRequests.get(received.getWorkToken())));
try {
responseObserver.onNext(
StreamingCommitResponse.newBuilder()
.addRequestId(chunk.getRequestId())
.build());
} catch (IllegalStateException e) {
// Stream is closed.
}
} catch (Exception e) {
errorCollector.addError(e);
}
buffer = null;
} else {
errorCollector.checkThat(first, Matchers.is(true));
}
first = false;
}
}
}

@Override
public void onError(Throwable throwable) {}

@Override
public void onCompleted() {
injector.cancel();
responseObserver.onCompleted();
}
};
}

@Test
public void testStreamingCommit() throws Exception {
List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
List<CountDownLatch> latches = new ArrayList<>();
Map<Long, WorkItemCommitRequest> commitRequests = new HashMap<>();
Map<Long, WorkItemCommitRequest> commitRequests = new ConcurrentHashMap<>();
for (int i = 0; i < 500; ++i) {
// Build some requests of varying size with a few big ones.
WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128));
Expand All @@ -505,92 +594,94 @@ public void testStreamingCommit() throws Exception {
}
Collections.shuffle(commitRequestList);

// This server receives WorkItemCommitRequests, and verifies they are equal to the above
// commitRequest.
serviceRegistry.addService(
new CloudWindmillServiceV1Alpha1ImplBase() {
@Override
public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(
StreamObserver<StreamingCommitResponse> responseObserver) {
return new StreamObserver<StreamingCommitWorkRequest>() {
boolean sawHeader = false;
InputStream buffer = null;
long remainingBytes = 0;
ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);
return getTestCommitStreamObserver(responseObserver, commitRequests);
}
});

@Override
public void onNext(StreamingCommitWorkRequest request) {
maybeInjectError(responseObserver);
// Make the commit requests, waiting for each of them to be verified and acknowledged.
CommitWorkStream stream = client.commitWorkStream();
for (int i = 0; i < commitRequestList.size(); ) {
final CountDownLatch latch = latches.get(i);
if (stream.commitWorkItem(
"computation",
commitRequestList.get(i),
(CommitStatus status) -> {
assertEquals(status, CommitStatus.OK);
latch.countDown();
})) {
i++;
} else {
stream.flush();
}
}
stream.flush();
stream.close();
for (CountDownLatch latch : latches) {
assertTrue(latch.await(1, TimeUnit.MINUTES));
}
assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
}

if (!sawHeader) {
errorCollector.checkThat(
request.getHeader(),
Matchers.equalTo(
JobHeader.newBuilder()
.setJobId("job")
.setProjectId("project")
.setWorkerId("worker")
.build()));
sawHeader = true;
LOG.info("Received header");
} else {
boolean first = true;
LOG.info("Received request with {} chunks", request.getCommitChunkCount());
for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) {
assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE);
if (first || chunk.hasComputationId()) {
errorCollector.checkThat(
chunk.getComputationId(), Matchers.equalTo("computation"));
}
@Test
// Tests stream retries on server errors before and after `close()`
public void testStreamingCommitClosedStream() throws Exception {
List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
List<CountDownLatch> latches = new ArrayList<>();
Map<Long, WorkItemCommitRequest> commitRequests = new ConcurrentHashMap<>();
AtomicBoolean shouldServerReturnError = new AtomicBoolean(true);
AtomicBoolean isClientClosed = new AtomicBoolean(false);
AtomicInteger errorsBeforeClose = new AtomicInteger();
AtomicInteger errorsAfterClose = new AtomicInteger();
for (int i = 0; i < 500; ++i) {
// Build some requests of varying size with a few big ones.
WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128));
commitRequestList.add(request);
commitRequests.put((long) i, request);
latches.add(new CountDownLatch(1));
}
Collections.shuffle(commitRequestList);

if (remainingBytes != 0) {
errorCollector.checkThat(buffer, Matchers.notNullValue());
errorCollector.checkThat(
remainingBytes,
Matchers.is(
chunk.getSerializedWorkItemCommit().size()
+ chunk.getRemainingBytesForWorkItem()));
buffer =
new SequenceInputStream(
buffer, chunk.getSerializedWorkItemCommit().newInput());
} else {
errorCollector.checkThat(buffer, Matchers.nullValue());
buffer = chunk.getSerializedWorkItemCommit().newInput();
}
remainingBytes = chunk.getRemainingBytesForWorkItem();
if (remainingBytes == 0) {
try {
WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer);
errorCollector.checkThat(
received,
Matchers.equalTo(commitRequests.get(received.getWorkToken())));
try {
responseObserver.onNext(
StreamingCommitResponse.newBuilder()
.addRequestId(chunk.getRequestId())
.build());
} catch (IllegalStateException e) {
// Stream is closed.
}
} catch (Exception e) {
errorCollector.addError(e);
}
buffer = null;
// This server returns errors if shouldServerReturnError is true, else returns valid responses.
serviceRegistry.addService(
new CloudWindmillServiceV1Alpha1ImplBase() {
@Override
public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(
StreamObserver<StreamingCommitResponse> responseObserver) {
StreamObserver<StreamingCommitWorkRequest> testCommitStreamObserver =
getTestCommitStreamObserver(responseObserver, commitRequests);
return new StreamObserver<StreamingCommitWorkRequest>() {
@Override
public void onNext(StreamingCommitWorkRequest request) {
if (shouldServerReturnError.get()) {
try {
responseObserver.onError(
new RuntimeException("shouldServerReturnError = true"));
if (isClientClosed.get()) {
errorsAfterClose.incrementAndGet();
} else {
errorCollector.checkThat(first, Matchers.is(true));
errorsBeforeClose.incrementAndGet();
}
first = false;
} catch (IllegalStateException e) {
// The stream is already closed.
}
} else {
testCommitStreamObserver.onNext(request);
}
}

@Override
public void onError(Throwable throwable) {}
public void onError(Throwable throwable) {
testCommitStreamObserver.onError(throwable);
}

@Override
public void onCompleted() {
injector.cancel();
responseObserver.onCompleted();
testCommitStreamObserver.onCompleted();
}
};
}
Expand All @@ -613,11 +704,50 @@ public void onCompleted() {
}
}
stream.flush();
for (CountDownLatch latch : latches) {
assertTrue(latch.await(1, TimeUnit.MINUTES));

long deadline = System.currentTimeMillis() + 60_000; // 1 min
while (true) {
Thread.sleep(100);
int tmpErrorsBeforeClose = errorsBeforeClose.get();
// wait for at least 1 errors before close
if (tmpErrorsBeforeClose > 0) {
break;
}
if (System.currentTimeMillis() > deadline) {
// Control should not reach here if the test is working as expected
fail(
String.format(
"Expected errors not sent by server errorsBeforeClose: %s"
+ " \n Should not reach here if the test is working as expected.",
tmpErrorsBeforeClose));
}
}

stream.close();
isClientClosed.set(true);

deadline = System.currentTimeMillis() + 60_000; // 1 min
while (true) {
Thread.sleep(100);
int tmpErrorsAfterClose = errorsAfterClose.get();
// wait for at least 1 errors after close
if (tmpErrorsAfterClose > 0) {
break;
}
if (System.currentTimeMillis() > deadline) {
// Control should not reach here if the test is working as expected
fail(
String.format(
"Expected errors not sent by server errorsAfterClose: %s"
+ " \n Should not reach here if the test is working as expected.",
tmpErrorsAfterClose));
}
}

shouldServerReturnError.set(false);
for (CountDownLatch latch : latches) {
assertTrue(latch.await(1, TimeUnit.MINUTES));
}
assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
}

Expand Down