diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java index 781ee596708a..48526b80ee2f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java @@ -20,7 +20,6 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; import java.io.InputStream; -import java.io.OutputStream; import java.io.PrintWriter; import java.io.SequenceInputStream; import java.net.URI; @@ -120,8 +119,8 @@ public class GrpcWindmillServer extends WindmillServerStub { private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class); - // If a connection cannot be established, gRPC will fail fast so this deadline can be - // relatively high. + // If a connection cannot be established, gRPC will fail fast so this deadline can be relatively + // high. private static final long DEFAULT_UNARY_RPC_DEADLINE_SECONDS = 300; private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; @@ -1463,82 +1462,36 @@ private void issueBatchedRequest(Map requests) { } } - // An OutputStream which splits the output into chunks of no more than COMMIT_STREAM_CHUNK_SIZE - // before calling the chunkWriter on each. - // - // This avoids materializing the whole serialized request in the case it is large. - private class ChunkingByteStream extends OutputStream { - private final ByteString.Output output = ByteString.newOutput(COMMIT_STREAM_CHUNK_SIZE); - private final Consumer chunkWriter; - - ChunkingByteStream(Consumer chunkWriter) { - this.chunkWriter = chunkWriter; - } - - @Override - public void close() { - flushBytes(); - } - - @Override - public void write(int b) throws IOException { - output.write(b); - if (output.size() == COMMIT_STREAM_CHUNK_SIZE) { - flushBytes(); - } - } - - @Override - public void write(byte b[], int currentOffset, int len) throws IOException { - final int endOffset = currentOffset + len; - while ((endOffset - currentOffset) + output.size() >= COMMIT_STREAM_CHUNK_SIZE) { - int writeSize = COMMIT_STREAM_CHUNK_SIZE - output.size(); - output.write(b, currentOffset, writeSize); - currentOffset += writeSize; - flushBytes(); - } - if (currentOffset != endOffset) { - output.write(b, currentOffset, endOffset - currentOffset); - } - } - - private void flushBytes() { - if (output.size() == 0) { - return; - } - chunkWriter.accept(output.toByteString()); - output.reset(); - } - } - private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { Preconditions.checkNotNull(pendingRequest.computation); - Consumer chunkWriter = - new Consumer() { - private long remaining = pendingRequest.request.getSerializedSize(); + final ByteString serializedCommit = pendingRequest.request.toByteString(); - @Override - public void accept(ByteString chunk) { - StreamingCommitRequestChunk.Builder chunkBuilder = - StreamingCommitRequestChunk.newBuilder() - .setRequestId(id) - .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computation) - .setShardingKey(pendingRequest.request.getShardingKey()); - Preconditions.checkState(remaining >= chunk.size()); - remaining -= chunk.size(); - if (remaining > 0) { - chunkBuilder.setRemainingBytesForWorkItem(remaining); - } - StreamingCommitWorkRequest requestChunk = - StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - send(requestChunk); - } - }; - try (ChunkingByteStream s = new ChunkingByteStream(chunkWriter)) { - pendingRequest.request.writeTo(s); - } catch (IllegalStateException | IOException e) { - LOG.info("Stream was broken, request will be retried when stream is reopened.", e); + synchronized (this) { + pending.put(id, pendingRequest); + for (int i = 0; i < serializedCommit.size(); i += COMMIT_STREAM_CHUNK_SIZE) { + int end = i + COMMIT_STREAM_CHUNK_SIZE; + ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); + + StreamingCommitRequestChunk.Builder chunkBuilder = + StreamingCommitRequestChunk.newBuilder() + .setRequestId(id) + .setSerializedWorkItemCommit(chunk) + .setComputationId(pendingRequest.computation) + .setShardingKey(pendingRequest.request.getShardingKey()); + int remaining = serializedCommit.size() - end; + if (remaining > 0) { + chunkBuilder.setRemainingBytesForWorkItem(remaining); + } + + StreamingCommitWorkRequest requestChunk = + StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); + try { + send(requestChunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + break; + } + } } } }