Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,23 @@

package com.google.cloud.dataflow.sdk.transforms;

import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.CoderException;
import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
import com.google.cloud.dataflow.sdk.coders.CustomCoder;
import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn;
import com.google.cloud.dataflow.sdk.util.VarInt;
import com.google.cloud.dataflow.sdk.values.KV;
import com.google.cloud.dataflow.sdk.values.PCollection;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UTFDataFormatException;
import java.util.Iterator;


/**
* {@code PTransorm}s to count the elements in a {@link PCollection}.
*
Expand Down Expand Up @@ -106,30 +119,74 @@ public void processElement(ProcessContext c) {
/**
* A {@link CombineFn} that counts elements.
*/
private static class CountFn<T> extends CombineFn<T, Long, Long> {
private static class CountFn<T> extends CombineFn<T, long[], Long> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a comment explaining that the accumulator is always size 1, and is used as a mutable box for the long.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

// Note that the long[] accumulator always has size 1, used as
// a box for a mutable long.

@Override
public Long createAccumulator() {
return 0L;
public long[] createAccumulator() {
return new long[] {0};
}

@Override
public Long addInput(Long accumulator, T input) {
return accumulator + 1;
public long[] addInput(long[] accumulator, T input) {
accumulator[0] += 1;
return accumulator;
}

@Override
public Long mergeAccumulators(Iterable<Long> accumulators) {
long result = 0L;
for (Long accum : accumulators) {
result += accum;
public long[] mergeAccumulators(Iterable<long[]> accumulators) {
Iterator<long[]> iter = accumulators.iterator();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to import Iterator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Done.

if (!iter.hasNext()) {
return createAccumulator();
}
return result;
long[] running = iter.next();
while (iter.hasNext()) {
running[0] += iter.next()[0];
}
return running;
}

@Override
public Long extractOutput(Long accumulator) {
return accumulator;
public Long extractOutput(long[] accumulator) {
return accumulator[0];
}

@Override
public Coder<long[]> getAccumulatorCoder(CoderRegistry registry,
Coder<T> inputCoder) {
return new CustomCoder<long[]>() {
@Override
public void encode(long[] value, OutputStream outStream, Context context)
throws IOException {
VarInt.encode(value[0], outStream);
}

@Override
public long[] decode(InputStream inStream, Context context)
throws IOException, CoderException {
try {
return new long[] {VarInt.decodeLong(inStream)};
} catch (EOFException | UTFDataFormatException exn) {
throw new CoderException(exn);
}
}

@Override
public boolean isRegisterByteSizeObserverCheap(long[] value, Context context) {
return true;
}

@Override
protected long getEncodedElementByteSize(long[] value, Context context) {
return VarInt.getLength(value[0]);
}

@Override
public String getEncodingId() {
return "VarLongSingletonArray";
}
};
}
}
}