From f0c8c08ffa338a8cd502d78fee384ff59b51172c Mon Sep 17 00:00:00 2001 From: Andrew Pilloud Date: Mon, 18 Apr 2022 13:00:41 -0700 Subject: [PATCH] [BEAM-14321] SQL passes Null for Null aggregates --- .../transform/BeamBuiltinAggregations.java | 52 +++++++++++++-- .../BeamSqlDslAggregationNullableTest.java | 64 +++++++++++++++++++ .../sdk/extensions/sql/utils/RowAsserts.java | 11 ++++ 3 files changed, 121 insertions(+), 6 deletions(-) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java index 47deb01aef94..14a2c29a7fba 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CountIf; import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CovarianceFn; import org.apache.beam.sdk.extensions.sql.impl.transform.agg.VarianceFn; @@ -45,6 +46,8 @@ import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.calcite.v1_28_0.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.checkerframework.checker.nullness.qual.Nullable; /** Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG/VAR_POP/VAR_SAMP. */ @@ -59,12 +62,14 @@ public class BeamBuiltinAggregations { ImmutableMap.>>builder() .put("ANY_VALUE", typeName -> Sample.anyValueCombineFn()) // Drop null elements for these aggregations. - .put("COUNT", typeName -> new DropNullFn(Count.combineFn())) + .put("COUNT", typeName -> new DropNullFnWithDefault(Count.combineFn())) .put("MAX", typeName -> new DropNullFn(BeamBuiltinAggregations.createMax(typeName))) .put("MIN", typeName -> new DropNullFn(BeamBuiltinAggregations.createMin(typeName))) .put("SUM", typeName -> new DropNullFn(BeamBuiltinAggregations.createSum(typeName))) .put( - "$SUM0", typeName -> new DropNullFn(BeamBuiltinAggregations.createSum0(typeName))) + "$SUM0", + typeName -> + new DropNullFnWithDefault(BeamBuiltinAggregations.createSum0(typeName))) .put("AVG", typeName -> new DropNullFn(BeamBuiltinAggregations.createAvg(typeName))) .put( "BIT_OR", @@ -360,7 +365,7 @@ static class BigDecimalSum0 extends BigDecimalSum { private static class DropNullFn extends CombineFn { - private final CombineFn combineFn; + protected final CombineFn combineFn; DropNullFn(CombineFn combineFn) { this.combineFn = combineFn; @@ -368,28 +373,63 @@ private static class DropNullFn @Override public AccumT createAccumulator() { - return combineFn.createAccumulator(); + return null; } @Override public AccumT addInput(AccumT accumulator, InputT input) { - return (input == null) ? accumulator : combineFn.addInput(accumulator, input); + if (input == null) { + return accumulator; + } + + if (accumulator == null) { + accumulator = combineFn.createAccumulator(); + } + return combineFn.addInput(accumulator, input); } @Override public AccumT mergeAccumulators(Iterable accumulators) { + // filter out nulls + accumulators = Iterables.filter(accumulators, Predicates.notNull()); + + // handle only nulls + if (!accumulators.iterator().hasNext()) { + return null; + } + return combineFn.mergeAccumulators(accumulators); } @Override public OutputT extractOutput(AccumT accumulator) { + if (accumulator == null) { + return null; + } return combineFn.extractOutput(accumulator); } @Override public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) throws CannotProvideCoderException { - return combineFn.getAccumulatorCoder(registry, inputCoder); + Coder coder = combineFn.getAccumulatorCoder(registry, inputCoder); + if (coder instanceof NullableCoder) { + return coder; + } + return NullableCoder.of(coder); + } + } + + private static class DropNullFnWithDefault + extends DropNullFn { + + DropNullFnWithDefault(CombineFn combineFn) { + super(combineFn); + } + + @Override + public AccumT createAccumulator() { + return combineFn.createAccumulator(); } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java index 91925f08d80e..429fbd608aab 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationNullableTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.extensions.sql; +import static org.apache.beam.sdk.extensions.sql.utils.RowAsserts.matchesNull; import static org.apache.beam.sdk.extensions.sql.utils.RowAsserts.matchesScalar; import static org.junit.Assert.assertEquals; @@ -71,6 +72,15 @@ public void testCount() { pipeline.run(); } + @Test + public void testCountNull() { + String sql = "SELECT COUNT(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3"; + + PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(0L)); + + pipeline.run(); + } + @Test public void testCountStar() { String sql = "SELECT COUNT(*) FROM PCOLLECTION GROUP BY f_int3"; @@ -111,6 +121,51 @@ public void testSum() { pipeline.run(); } + @Test + public void testSumNull() { + String sql = "SELECT SUM(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3"; + + PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull()); + + pipeline.run(); + } + + @Test + public void testMin() { + String sql = "SELECT MIN(f_int1) FROM PCOLLECTION GROUP BY f_int3"; + + PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(1)); + + pipeline.run(); + } + + @Test + public void testMinNull() { + String sql = "SELECT MIN(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3"; + + PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull()); + + pipeline.run(); + } + + @Test + public void testMax() { + String sql = "SELECT MAX(f_int1) FROM PCOLLECTION GROUP BY f_int3"; + + PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(3)); + + pipeline.run(); + } + + @Test + public void testMaxNull() { + String sql = "SELECT MAX(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3"; + + PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull()); + + pipeline.run(); + } + @Test public void testAvg() { String sql = "SELECT AVG(f_int1) FROM PCOLLECTION GROUP BY f_int3"; @@ -120,6 +175,15 @@ public void testAvg() { pipeline.run(); } + @Test + public void testAvgNull() { + String sql = "SELECT AVG(f_int1) FROM PCOLLECTION WHERE f_int2 IS NULL GROUP BY f_int3"; + + PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesNull()); + + pipeline.run(); + } + @Test public void testAvgGroupByNullable() { String sql = "SELECT AVG(f_int1), f_int2 FROM PCOLLECTION GROUP BY f_int2"; diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/RowAsserts.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/RowAsserts.java index a223495a281b..78dff2615120 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/RowAsserts.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/utils/RowAsserts.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; @@ -69,4 +70,14 @@ public static SerializableFunction, Void> matchesScalar( return null; }; } + + /** Asserts result contains single row with a single null field. */ + public static SerializableFunction, Void> matchesNull() { + return records -> { + Row row = Iterables.getOnlyElement(records); + assertNotNull(row); + assertNull(row.getValue(0)); + return null; + }; + } }