diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java index 48526b80ee2f..781ee596708a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/GrpcWindmillServer.java @@ -20,6 +20,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.io.PrintWriter; import java.io.SequenceInputStream; import java.net.URI; @@ -119,8 +120,8 @@ public class GrpcWindmillServer extends WindmillServerStub { private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class); - // If a connection cannot be established, gRPC will fail fast so this deadline can be relatively - // high. + // If a connection cannot be established, gRPC will fail fast so this deadline can be + // relatively high. private static final long DEFAULT_UNARY_RPC_DEADLINE_SECONDS = 300; private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; @@ -1462,37 +1463,83 @@ private void issueBatchedRequest(Map requests) { } } - private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { - Preconditions.checkNotNull(pendingRequest.computation); - final ByteString serializedCommit = pendingRequest.request.toByteString(); + // An OutputStream which splits the output into chunks of no more than COMMIT_STREAM_CHUNK_SIZE + // before calling the chunkWriter on each. + // + // This avoids materializing the whole serialized request in the case it is large. + private class ChunkingByteStream extends OutputStream { + private final ByteString.Output output = ByteString.newOutput(COMMIT_STREAM_CHUNK_SIZE); + private final Consumer chunkWriter; - synchronized (this) { - pending.put(id, pendingRequest); - for (int i = 0; i < serializedCommit.size(); i += COMMIT_STREAM_CHUNK_SIZE) { - int end = i + COMMIT_STREAM_CHUNK_SIZE; - ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - - StreamingCommitRequestChunk.Builder chunkBuilder = - StreamingCommitRequestChunk.newBuilder() - .setRequestId(id) - .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computation) - .setShardingKey(pendingRequest.request.getShardingKey()); - int remaining = serializedCommit.size() - end; - if (remaining > 0) { - chunkBuilder.setRemainingBytesForWorkItem(remaining); - } + ChunkingByteStream(Consumer chunkWriter) { + this.chunkWriter = chunkWriter; + } - StreamingCommitWorkRequest requestChunk = - StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - break; - } + @Override + public void close() { + flushBytes(); + } + + @Override + public void write(int b) throws IOException { + output.write(b); + if (output.size() == COMMIT_STREAM_CHUNK_SIZE) { + flushBytes(); } } + + @Override + public void write(byte b[], int currentOffset, int len) throws IOException { + final int endOffset = currentOffset + len; + while ((endOffset - currentOffset) + output.size() >= COMMIT_STREAM_CHUNK_SIZE) { + int writeSize = COMMIT_STREAM_CHUNK_SIZE - output.size(); + output.write(b, currentOffset, writeSize); + currentOffset += writeSize; + flushBytes(); + } + if (currentOffset != endOffset) { + output.write(b, currentOffset, endOffset - currentOffset); + } + } + + private void flushBytes() { + if (output.size() == 0) { + return; + } + chunkWriter.accept(output.toByteString()); + output.reset(); + } + } + + private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { + Preconditions.checkNotNull(pendingRequest.computation); + Consumer chunkWriter = + new Consumer() { + private long remaining = pendingRequest.request.getSerializedSize(); + + @Override + public void accept(ByteString chunk) { + StreamingCommitRequestChunk.Builder chunkBuilder = + StreamingCommitRequestChunk.newBuilder() + .setRequestId(id) + .setSerializedWorkItemCommit(chunk) + .setComputationId(pendingRequest.computation) + .setShardingKey(pendingRequest.request.getShardingKey()); + Preconditions.checkState(remaining >= chunk.size()); + remaining -= chunk.size(); + if (remaining > 0) { + chunkBuilder.setRemainingBytesForWorkItem(remaining); + } + StreamingCommitWorkRequest requestChunk = + StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); + send(requestChunk); + } + }; + try (ChunkingByteStream s = new ChunkingByteStream(chunkWriter)) { + pendingRequest.request.writeTo(s); + } catch (IllegalStateException | IOException e) { + LOG.info("Stream was broken, request will be retried when stream is reopened.", e); + } } }