Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CountIf;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceFn;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.VarianceFn;
Expand All @@ -45,6 +46,8 @@
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.calcite.v1_28_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.checkerframework.checker.nullness.qual.Nullable;

/** Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG/VAR_POP/VAR_SAMP. */
Expand All @@ -59,12 +62,14 @@ public class BeamBuiltinAggregations {
ImmutableMap.<String, Function<Schema.FieldType, CombineFn<?, ?, ?>>>builder()
.put("ANY_VALUE", typeName -> Sample.anyValueCombineFn())
// Drop null elements for these aggregations.
.put("COUNT", typeName -> new DropNullFn(Count.combineFn()))
.put("COUNT", typeName -> new DropNullFnWithDefault(Count.combineFn()))
.put("MAX", typeName -> new DropNullFn(BeamBuiltinAggregations.createMax(typeName)))
.put("MIN", typeName -> new DropNullFn(BeamBuiltinAggregations.createMin(typeName)))
.put("SUM", typeName -> new DropNullFn(BeamBuiltinAggregations.createSum(typeName)))
.put(
"$SUM0", typeName -> new DropNullFn(BeamBuiltinAggregations.createSum0(typeName)))
"$SUM0",
typeName ->
new DropNullFnWithDefault(BeamBuiltinAggregations.createSum0(typeName)))
.put("AVG", typeName -> new DropNullFn(BeamBuiltinAggregations.createAvg(typeName)))
.put(
"BIT_OR",
Expand Down Expand Up @@ -360,36 +365,71 @@ static class BigDecimalSum0 extends BigDecimalSum {

private static class DropNullFn<InputT, AccumT, OutputT>
extends CombineFn<InputT, AccumT, OutputT> {
private final CombineFn<InputT, AccumT, OutputT> combineFn;
protected final CombineFn<InputT, AccumT, OutputT> combineFn;

DropNullFn(CombineFn<InputT, AccumT, OutputT> combineFn) {
this.combineFn = combineFn;
}

@Override
public AccumT createAccumulator() {
return combineFn.createAccumulator();
return null;
}

@Override
public AccumT addInput(AccumT accumulator, InputT input) {
return (input == null) ? accumulator : combineFn.addInput(accumulator, input);
if (input == null) {
return accumulator;
}

if (accumulator == null) {
accumulator = combineFn.createAccumulator();
}
return combineFn.addInput(accumulator, input);
}

@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
// filter out nulls
accumulators = Iterables.filter(accumulators, Predicates.notNull());

// handle only nulls
if (!accumulators.iterator().hasNext()) {
return null;
}

return combineFn.mergeAccumulators(accumulators);
}

@Override
public OutputT extractOutput(AccumT accumulator) {
if (accumulator == null) {
return null;
}
return combineFn.extractOutput(accumulator);
}

@Override
public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder)
throws CannotProvideCoderException {
return combineFn.getAccumulatorCoder(registry, inputCoder);
Coder<AccumT> coder = combineFn.getAccumulatorCoder(registry, inputCoder);
if (coder instanceof NullableCoder) {
return coder;
}
return NullableCoder.of(coder);
}
}

private static class DropNullFnWithDefault<InputT, AccumT, OutputT>
extends DropNullFn<InputT, AccumT, OutputT> {

DropNullFnWithDefault(CombineFn<InputT, AccumT, OutputT> combineFn) {
super(combineFn);
}

@Override
public AccumT createAccumulator() {
return combineFn.createAccumulator();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.extensions.sql;

import static org.apache.beam.sdk.extensions.sql.utils.RowAsserts.matchesNull;
import static org.apache.beam.sdk.extensions.sql.utils.RowAsserts.matchesScalar;
import static org.junit.Assert.assertEquals;

Expand Down Expand Up @@ -71,6 +72,15 @@ public void testCount() {
pipeline.run();
}

@Test
public void testCountNull() {
String sql = "SELECT COUNT(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3";

PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(0L));

pipeline.run();
}

@Test
public void testCountStar() {
String sql = "SELECT COUNT(*) FROM PCOLLECTION GROUP BY f_int3";
Expand Down Expand Up @@ -111,6 +121,51 @@ public void testSum() {
pipeline.run();
}

@Test
public void testSumNull() {
String sql = "SELECT SUM(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3";

PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull());

pipeline.run();
}

@Test
public void testMin() {
String sql = "SELECT MIN(f_int1) FROM PCOLLECTION GROUP BY f_int3";

PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(1));

pipeline.run();
}

@Test
public void testMinNull() {
String sql = "SELECT MIN(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3";

PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull());

pipeline.run();
}

@Test
public void testMax() {
String sql = "SELECT MAX(f_int1) FROM PCOLLECTION GROUP BY f_int3";

PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(3));

pipeline.run();
}

@Test
public void testMaxNull() {
String sql = "SELECT MAX(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3";

PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull());

pipeline.run();
}

@Test
public void testAvg() {
String sql = "SELECT AVG(f_int1) FROM PCOLLECTION GROUP BY f_int3";
Expand All @@ -120,6 +175,15 @@ public void testAvg() {
pipeline.run();
}

@Test
public void testAvgNull() {
String sql = "SELECT AVG(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3";

PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull());

pipeline.run();
}

@Test
public void testAvgGroupByNullable() {
String sql = "SELECT AVG(f_int1), f_int2 FROM PCOLLECTION GROUP BY f_int2";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;

import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.Row;
Expand Down Expand Up @@ -69,4 +70,14 @@ public static SerializableFunction<Iterable<Row>, Void> matchesScalar(
return null;
};
}

/** Asserts result contains single row with a single null field. */
public static SerializableFunction<Iterable<Row>, Void> matchesNull() {
return records -> {
Row row = Iterables.getOnlyElement(records);
assertNotNull(row);
assertNull(row.getValue(0));
return null;
};
}
}