diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java index ffa11d13a3c9..5ce4d2e671b0 100644 --- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java +++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/Count.java @@ -16,10 +16,23 @@ package com.google.cloud.dataflow.sdk.transforms; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.CustomCoder; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.util.VarInt; import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.dataflow.sdk.values.PCollection; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UTFDataFormatException; +import java.util.Iterator; + + /** * {@code PTransorm}s to count the elements in a {@link PCollection}. * @@ -106,30 +119,74 @@ public void processElement(ProcessContext c) { /** * A {@link CombineFn} that counts elements. */ - private static class CountFn extends CombineFn { + private static class CountFn extends CombineFn { + // Note that the long[] accumulator always has size 1, used as + // a box for a mutable long. @Override - public Long createAccumulator() { - return 0L; + public long[] createAccumulator() { + return new long[] {0}; } @Override - public Long addInput(Long accumulator, T input) { - return accumulator + 1; + public long[] addInput(long[] accumulator, T input) { + accumulator[0] += 1; + return accumulator; } @Override - public Long mergeAccumulators(Iterable accumulators) { - long result = 0L; - for (Long accum : accumulators) { - result += accum; + public long[] mergeAccumulators(Iterable accumulators) { + Iterator iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); } - return result; + long[] running = iter.next(); + while (iter.hasNext()) { + running[0] += iter.next()[0]; + } + return running; } @Override - public Long extractOutput(Long accumulator) { - return accumulator; + public Long extractOutput(long[] accumulator) { + return accumulator[0]; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, + Coder inputCoder) { + return new CustomCoder() { + @Override + public void encode(long[] value, OutputStream outStream, Context context) + throws IOException { + VarInt.encode(value[0], outStream); + } + + @Override + public long[] decode(InputStream inStream, Context context) + throws IOException, CoderException { + try { + return new long[] {VarInt.decodeLong(inStream)}; + } catch (EOFException | UTFDataFormatException exn) { + throw new CoderException(exn); + } + } + + @Override + public boolean isRegisterByteSizeObserverCheap(long[] value, Context context) { + return true; + } + + @Override + protected long getEncodedElementByteSize(long[] value, Context context) { + return VarInt.getLength(value[0]); + } + + @Override + public String getEncodingId() { + return "VarLongSingletonArray"; + } + }; } } }