From a8d7213a41f22575a32d7d234260252b41bf6483 Mon Sep 17 00:00:00 2001 From: mingmxu Date: Mon, 26 Jun 2017 16:03:51 -0700 Subject: [PATCH 1/5] support of UDAF + rebase 1. support DECIMAL in built-in aggregators; 2. add JavaDoc for BeamSqlUdaf; --- .../org/apache/beam/dsls/sql/BeamSqlEnv.java | 10 + .../beam/dsls/sql/rel/BeamAggregationRel.java | 2 +- .../beam/dsls/sql/schema/BeamSqlUdaf.java | 58 ++ .../transform/BeamAggregationTransforms.java | 582 ++++-------------- .../transform/BeamBuiltinAggregations.java | 307 +++++++++ .../dsls/sql/BeamSqlDslAggregationTest.java | 66 ++ .../BeamAggregationTransformTest.java | 2 +- 7 files changed, 562 insertions(+), 465 deletions(-) create mode 100644 dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java create mode 100644 dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java index baa2617d9fee..078d9d34644d 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java @@ -22,6 +22,7 @@ import org.apache.beam.dsls.sql.planner.BeamQueryPlanner; import org.apache.beam.dsls.sql.schema.BaseBeamTable; import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.dsls.sql.utils.CalciteUtils; import org.apache.calcite.DataContext; import org.apache.calcite.linq4j.Enumerable; @@ -32,6 +33,7 @@ import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.schema.Statistic; import org.apache.calcite.schema.Statistics; +import org.apache.calcite.schema.impl.AggregateFunctionImpl; import org.apache.calcite.schema.impl.ScalarFunctionImpl; import org.apache.calcite.tools.Frameworks; @@ -57,6 +59,14 @@ public void registerUdf(String functionName, Class clazz, String methodName) schema.add(functionName, ScalarFunctionImpl.create(clazz, methodName)); } + /** + * Register a UDAF function which can be used in GROUP-BY expression. + * See {@link BeamSqlUdaf} on how to implement a UDAF. + */ + public void registerUdaf(String functionName, Class clazz) { + schema.add(functionName, AggregateFunctionImpl.create(clazz)); + } + /** * Registers a {@link BaseBeamTable} which can be used for all subsequent queries. * diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java index 701f6206add6..e0c9a3c2f8c1 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java @@ -103,7 +103,7 @@ public PCollection buildBeamPipeline(PCollectionTuple inputPCollecti PCollection> aggregatedStream = exCombineByStream.apply( stageName + "_combineBy", Combine.perKey( - new BeamAggregationTransforms.AggregationCombineFn(getAggCallList(), + new BeamAggregationTransforms.AggregationAdaptor(getAggCallList(), CalciteUtils.toBeamRecordType(input.getRowType())))) .setCoder(KvCoder.of(keyCoder, aggCoder)); diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java new file mode 100644 index 000000000000..20e5c1089c35 --- /dev/null +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql.schema; + +import java.io.Serializable; +import org.apache.beam.sdk.transforms.Combine.CombineFn; + +/** + * abstract class of aggregation functions in Beam SQL. + * + *

There're several constrains for a UDAF:
+ * 1. A constructor with an empty argument list is required;
+ * 2. The type of {@code InputT} and {@code OutputT} can only be Interger/Long/Short/Byte/Double + * /Float/Date/BigDecimal, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT + * /TIMESTAMP/DECIMAL;
+ * 3. wrap intermediate data in a {@link BeamSqlRow}, and do not rely on elements in class;
+ * 4. The intermediate value of UDAF function is stored in a {@code BeamSqlRow} object.
+ */ +public abstract class BeamSqlUdaf implements Serializable { + public BeamSqlUdaf(){} + + /** + * create an initial aggregation object, equals to {@link CombineFn#createAccumulator()}. + */ + public abstract BeamSqlRow init(); + + /** + * add an input value, equals to {@link CombineFn#addInput(Object, Object)}. + */ + public abstract BeamSqlRow add(BeamSqlRow accumulator, InputT input); + + /** + * merge aggregation objects from parallel tasks, equals to + * {@link CombineFn#mergeAccumulators(Iterable)}. + */ + public abstract BeamSqlRow merge(Iterable accumulators); + + /** + * extract output value from aggregation object, equals to + * {@link CombineFn#extractOutput(Object)}. + */ + public abstract OutputT result(BeamSqlRow accumulator); +} diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java index 83d473a44233..1e1ac359f650 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java @@ -18,15 +18,15 @@ package org.apache.beam.dsls.sql.transform; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; -import java.util.Arrays; import java.util.Date; import java.util.List; import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlExpression; import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlInputRefExpression; -import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlPrimitive; import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; import org.apache.beam.dsls.sql.schema.BeamSqlRow; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.dsls.sql.utils.CalciteUtils; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.DoFn; @@ -34,8 +34,8 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.KV; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.schema.impl.AggregateFunctionImpl; +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; import org.apache.calcite.util.ImmutableBitSet; import org.joda.time.Instant; @@ -134,545 +134,201 @@ public Instant apply(BeamSqlRow input) { } /** - * Aggregation function which supports COUNT, MAX, MIN, SUM, AVG. - * - *

Multiple aggregation functions are combined together. - * For each aggregation function, it may accept part of all data types:
- * 1). COUNT works for any data type;
- * 2). MAX/MIN works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT, TIMESTAMP;
- * 3). SUM/AVG works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT;
- * + * An adaptor class to invoke Calcite UDAF instances in Beam {@code CombineFn}. */ - public static class AggregationCombineFn extends CombineFn { - private BeamSqlRecordType aggDataType; + public static class AggregationAdaptor + extends CombineFn, BeamSqlRow> { + private List aggregators; + private List sourceFieldExps; + private BeamSqlRecordType finalRecordType; - private int countIndex = -1; - - List aggFunctions; - List aggElementExpressions; - - public AggregationCombineFn(List aggregationCalls, + public AggregationAdaptor(List aggregationCalls, BeamSqlRecordType sourceRowRecordType) { - this.aggFunctions = new ArrayList<>(); - this.aggElementExpressions = new ArrayList<>(); - - boolean hasAvg = false; - boolean hasCount = false; - int countIndex = -1; - List fieldNames = new ArrayList<>(); - List fieldTypes = new ArrayList<>(); - for (int idx = 0; idx < aggregationCalls.size(); ++idx) { - AggregateCall ac = aggregationCalls.get(idx); - //verify it's supported. - verifySupportedAggregation(ac); - - fieldNames.add(ac.name); - fieldTypes.add(CalciteUtils.toJavaType(ac.type.getSqlTypeName())); - - SqlAggFunction aggFn = ac.getAggregation(); - switch (aggFn.getName()) { + aggregators = new ArrayList<>(); + sourceFieldExps = new ArrayList<>(); + List outFieldsName = new ArrayList<>(); + List outFieldsType = new ArrayList<>(); + for (AggregateCall call : aggregationCalls) { + int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0; + BeamSqlExpression sourceExp = new BeamSqlInputRefExpression( + CalciteUtils.getFieldType(sourceRowRecordType, refIndex), refIndex); + sourceFieldExps.add(sourceExp); + + outFieldsName.add(call.name); + int outFieldType = CalciteUtils.toJavaType(call.type.getSqlTypeName()); + outFieldsType.add(outFieldType); + + switch (call.getAggregation().getName()) { case "COUNT": - aggElementExpressions.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - hasCount = true; - countIndex = idx; + aggregators.add(new BeamBuiltinAggregations.Count()); break; - case "SUM": case "MAX": - case "MIN": - case "AVG": - int refIndex = ac.getArgList().get(0); - aggElementExpressions.add(new BeamSqlInputRefExpression( - CalciteUtils.getFieldType(sourceRowRecordType, refIndex), refIndex)); - if ("AVG".equals(aggFn.getName())) { - hasAvg = true; - } - break; - - default: - break; - } - aggFunctions.add(aggFn.getName()); - } - - - // add a COUNT holder if only have AVG - if (hasAvg && !hasCount) { - fieldNames.add("__COUNT"); - fieldTypes.add(CalciteUtils.toJavaType(SqlTypeName.BIGINT)); - - aggFunctions.add("COUNT"); - aggElementExpressions.add(BeamSqlPrimitive.of(SqlTypeName.BIGINT, 1L)); - - hasCount = true; - countIndex = aggDataType.size() - 1; - } - - this.aggDataType = BeamSqlRecordType.create(fieldNames, fieldTypes); - this.countIndex = countIndex; - } - - private void verifySupportedAggregation(AggregateCall ac) { - //donot support DISTINCT - if (ac.isDistinct()) { - throw new UnsupportedOperationException("DISTINCT is not supported yet."); - } - String aggFnName = ac.getAggregation().getName(); - switch (aggFnName) { - case "COUNT": - //COUNT works for any data type; - break; - case "SUM": - // SUM only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, - // TINYINT now - if (!Arrays - .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE, - SqlTypeName.SMALLINT, SqlTypeName.TINYINT) - .contains(ac.type.getSqlTypeName())) { - throw new UnsupportedOperationException( - "SUM only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT"); - } - break; - case "MAX": - case "MIN": - // MAX/MIN only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, - // TINYINT, TIMESTAMP now - if (!Arrays.asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, - SqlTypeName.DOUBLE, SqlTypeName.SMALLINT, SqlTypeName.TINYINT, - SqlTypeName.TIMESTAMP).contains(ac.type.getSqlTypeName())) { - throw new UnsupportedOperationException("MAX/MIN only support for INT, LONG, FLOAT," - + " DOUBLE, SMALLINT, TINYINT, TIMESTAMP"); - } - break; - case "AVG": - // AVG only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, - // TINYINT now - if (!Arrays - .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE, - SqlTypeName.SMALLINT, SqlTypeName.TINYINT) - .contains(ac.type.getSqlTypeName())) { - throw new UnsupportedOperationException( - "AVG only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT"); - } - break; - default: - throw new UnsupportedOperationException( - String.format("[%s] is not supported.", aggFnName)); - } - } - - @Override - public BeamSqlRow createAccumulator() { - BeamSqlRow initialRecord = new BeamSqlRow(aggDataType); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - initialRecord.addField(idx, 0L); - break; - case "AVG": - case "SUM": - //for both AVG/SUM, a summary value is hold at first. - switch (ex.getOutputType()) { + switch (call.type.getSqlTypeName()) { case INTEGER: - initialRecord.addField(idx, 0); - break; - case BIGINT: - initialRecord.addField(idx, 0L); + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; case SMALLINT: - initialRecord.addField(idx, (short) 0); + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; case TINYINT: - initialRecord.addField(idx, (byte) 0); - break; - case FLOAT: - initialRecord.addField(idx, 0.0f); - break; - case DOUBLE: - initialRecord.addField(idx, 0.0); - break; - default: - break; - } - break; - case "MAX": - switch (ex.getOutputType()) { - case INTEGER: - initialRecord.addField(idx, Integer.MIN_VALUE); + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; case BIGINT: - initialRecord.addField(idx, Long.MIN_VALUE); - break; - case SMALLINT: - initialRecord.addField(idx, Short.MIN_VALUE); - break; - case TINYINT: - initialRecord.addField(idx, Byte.MIN_VALUE); + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; case FLOAT: - initialRecord.addField(idx, Float.MIN_VALUE); + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; case DOUBLE: - initialRecord.addField(idx, Double.MIN_VALUE); + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; case TIMESTAMP: - initialRecord.addField(idx, new Date(0)); + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; - default: + case DECIMAL: + aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); break; + default: + throw new UnsupportedOperationException(); } break; case "MIN": - switch (ex.getOutputType()) { + switch (call.type.getSqlTypeName()) { case INTEGER: - initialRecord.addField(idx, Integer.MAX_VALUE); - break; - case BIGINT: - initialRecord.addField(idx, Long.MAX_VALUE); + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); break; case SMALLINT: - initialRecord.addField(idx, Short.MAX_VALUE); + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); break; case TINYINT: - initialRecord.addField(idx, Byte.MAX_VALUE); + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); + break; + case BIGINT: + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); break; case FLOAT: - initialRecord.addField(idx, Float.MAX_VALUE); + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); break; case DOUBLE: - initialRecord.addField(idx, Double.MAX_VALUE); + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); break; case TIMESTAMP: - initialRecord.addField(idx, new Date(Long.MAX_VALUE)); + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); break; - default: + case DECIMAL: + aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); break; + default: + throw new UnsupportedOperationException(); } break; - default: - break; - } - } - return initialRecord; - } - - @Override - public BeamSqlRow addInput(BeamSqlRow accumulator, BeamSqlRow input) { - BeamSqlRow deltaRecord = new BeamSqlRow(aggDataType); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - deltaRecord.addField(idx, 1 + accumulator.getLong(idx)); - break; - case "AVG": case "SUM": - // for both AVG/SUM, a summary value is hold at first. - switch (ex.getOutputType()) { + switch (call.type.getSqlTypeName()) { case INTEGER: - deltaRecord.addField(idx, - ex.evaluate(input).getInteger() + accumulator.getInteger(idx)); - break; - case BIGINT: - deltaRecord.addField(idx, ex.evaluate(input).getLong() + accumulator.getLong(idx)); + aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); break; case SMALLINT: - deltaRecord.addField(idx, - (short) (ex.evaluate(input).getShort() + accumulator.getShort(idx))); + aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); break; case TINYINT: - deltaRecord.addField(idx, - (byte) (ex.evaluate(input).getByte() + accumulator.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, - (float) (ex.evaluate(input).getFloat() + accumulator.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, ex.evaluate(input).getDouble() + accumulator.getDouble(idx)); - break; - default: - break; - } - break; - case "MAX": - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getInteger(), accumulator.getInteger(idx))); + aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); break; case BIGINT: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getLong(), accumulator.getLong(idx))); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) Math.max(ex.evaluate(input).getShort(), accumulator.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, - (byte) Math.max(ex.evaluate(input).getByte(), accumulator.getByte(idx))); + aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); break; case FLOAT: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getFloat(), accumulator.getFloat(idx))); + aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); break; case DOUBLE: - deltaRecord.addField(idx, - Math.max(ex.evaluate(input).getDouble(), accumulator.getDouble(idx))); + aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); break; - case TIMESTAMP: - Date preDate = accumulator.getDate(idx); - Date nowDate = ex.evaluate(input).getDate(); - deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate); + case DECIMAL: + aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); break; default: - break; + throw new UnsupportedOperationException(); } break; - case "MIN": - switch (ex.getOutputType()) { + case "AVG": + switch (call.type.getSqlTypeName()) { case INTEGER: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getInteger(), accumulator.getInteger(idx))); - break; - case BIGINT: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getLong(), accumulator.getLong(idx))); + aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); break; case SMALLINT: - deltaRecord.addField(idx, - (short) Math.min(ex.evaluate(input).getShort(), accumulator.getShort(idx))); + aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); break; case TINYINT: - deltaRecord.addField(idx, - (byte) Math.min(ex.evaluate(input).getByte(), accumulator.getByte(idx))); + aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); + break; + case BIGINT: + aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); break; case FLOAT: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getFloat(), accumulator.getFloat(idx))); + aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); break; case DOUBLE: - deltaRecord.addField(idx, - Math.min(ex.evaluate(input).getDouble(), accumulator.getDouble(idx))); + aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); break; - case TIMESTAMP: - Date preDate = accumulator.getDate(idx); - Date nowDate = ex.evaluate(input).getDate(); - deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate); + case DECIMAL: + aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); break; default: - break; + throw new UnsupportedOperationException(); } break; default: - break; + if (call.getAggregation() instanceof SqlUserDefinedAggFunction) { + // handle UDAF. + SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction) call.getAggregation(); + AggregateFunctionImpl fn = (AggregateFunctionImpl) udaf.function; + try { + aggregators.add((BeamSqlUdaf) fn.declaringClass.newInstance()); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } else { + throw new UnsupportedOperationException(); + } } } - return deltaRecord; + finalRecordType = BeamSqlRecordType.create(outFieldsName, outFieldsType); } - @Override - public BeamSqlRow mergeAccumulators(Iterable accumulators) { - BeamSqlRow deltaRecord = new BeamSqlRow(aggDataType); - - while (accumulators.iterator().hasNext()) { - BeamSqlRow sa = accumulators.iterator().next(); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx)); - break; - case "AVG": - case "SUM": - // for both AVG/SUM, a summary value is hold at first. - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, deltaRecord.getInteger(idx) + sa.getInteger(idx)); - break; - case BIGINT: - deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx)); - break; - case SMALLINT: - deltaRecord.addField(idx, (short) (deltaRecord.getShort(idx) + sa.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, (byte) (deltaRecord.getByte(idx) + sa.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, (float) (deltaRecord.getFloat(idx) + sa.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, deltaRecord.getDouble(idx) + sa.getDouble(idx)); - break; - default: - break; - } - break; - case "MAX": - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, Math.max(deltaRecord.getInteger(idx), sa.getInteger(idx))); - break; - case BIGINT: - deltaRecord.addField(idx, Math.max(deltaRecord.getLong(idx), sa.getLong(idx))); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) Math.max(deltaRecord.getShort(idx), sa.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, (byte) Math.max(deltaRecord.getByte(idx), sa.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, Math.max(deltaRecord.getFloat(idx), sa.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, Math.max(deltaRecord.getDouble(idx), sa.getDouble(idx))); - break; - case TIMESTAMP: - Date preDate = deltaRecord.getDate(idx); - Date nowDate = sa.getDate(idx); - deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate); - break; - default: - break; - } - break; - case "MIN": - switch (ex.getOutputType()) { - case INTEGER: - deltaRecord.addField(idx, Math.min(deltaRecord.getInteger(idx), sa.getInteger(idx))); - break; - case BIGINT: - deltaRecord.addField(idx, Math.min(deltaRecord.getLong(idx), sa.getLong(idx))); - break; - case SMALLINT: - deltaRecord.addField(idx, - (short) Math.min(deltaRecord.getShort(idx), sa.getShort(idx))); - break; - case TINYINT: - deltaRecord.addField(idx, (byte) Math.min(deltaRecord.getByte(idx), sa.getByte(idx))); - break; - case FLOAT: - deltaRecord.addField(idx, Math.min(deltaRecord.getFloat(idx), sa.getFloat(idx))); - break; - case DOUBLE: - deltaRecord.addField(idx, Math.min(deltaRecord.getDouble(idx), sa.getDouble(idx))); - break; - case TIMESTAMP: - Date preDate = deltaRecord.getDate(idx); - Date nowDate = sa.getDate(idx); - deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate); - break; - default: - break; - } - break; - default: - break; - } - } + public List createAccumulator() { + List initialAccu = new ArrayList<>(); + for (BeamSqlUdaf agg : aggregators) { + initialAccu.add(agg.init()); } - return deltaRecord; + return initialAccu; } - @Override - public BeamSqlRow extractOutput(BeamSqlRow accumulator) { - BeamSqlRow finalRecord = new BeamSqlRow(aggDataType); - for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { - BeamSqlExpression ex = aggElementExpressions.get(idx); - String aggFnName = aggFunctions.get(idx); - switch (aggFnName) { - case "COUNT": - finalRecord.addField(idx, accumulator.getLong(idx)); - break; - case "AVG": - long count = accumulator.getLong(countIndex); - switch (ex.getOutputType()) { - case INTEGER: - finalRecord.addField(idx, (int) (accumulator.getInteger(idx) / count)); - break; - case BIGINT: - finalRecord.addField(idx, accumulator.getLong(idx) / count); - break; - case SMALLINT: - finalRecord.addField(idx, (short) (accumulator.getShort(idx) / count)); - break; - case TINYINT: - finalRecord.addField(idx, (byte) (accumulator.getByte(idx) / count)); - break; - case FLOAT: - finalRecord.addField(idx, (float) (accumulator.getFloat(idx) / count)); - break; - case DOUBLE: - finalRecord.addField(idx, accumulator.getDouble(idx) / count); - break; - default: - break; - } - break; - case "SUM": - switch (ex.getOutputType()) { - case INTEGER: - finalRecord.addField(idx, accumulator.getInteger(idx)); - break; - case BIGINT: - finalRecord.addField(idx, accumulator.getLong(idx)); - break; - case SMALLINT: - finalRecord.addField(idx, accumulator.getShort(idx)); - break; - case TINYINT: - finalRecord.addField(idx, accumulator.getByte(idx)); - break; - case FLOAT: - finalRecord.addField(idx, accumulator.getFloat(idx)); - break; - case DOUBLE: - finalRecord.addField(idx, accumulator.getDouble(idx)); - break; - default: - break; - } - break; - case "MAX": - case "MIN": - switch (ex.getOutputType()) { - case INTEGER: - finalRecord.addField(idx, accumulator.getInteger(idx)); - break; - case BIGINT: - finalRecord.addField(idx, accumulator.getLong(idx)); - break; - case SMALLINT: - finalRecord.addField(idx, accumulator.getShort(idx)); - break; - case TINYINT: - finalRecord.addField(idx, accumulator.getByte(idx)); - break; - case FLOAT: - finalRecord.addField(idx, accumulator.getFloat(idx)); - break; - case DOUBLE: - finalRecord.addField(idx, accumulator.getDouble(idx)); - break; - case TIMESTAMP: - finalRecord.addField(idx, accumulator.getDate(idx)); - break; - default: - break; - } - break; - default: - break; + public List addInput(List accumulator, BeamSqlRow input) { + List deltaAcc = new ArrayList<>(); + for (int idx = 0; idx < aggregators.size(); ++idx) { + deltaAcc.add(aggregators.get(idx).add(accumulator.get(idx), + sourceFieldExps.get(idx).evaluate(input).getValue())); + } + return deltaAcc; + } + @Override + public List mergeAccumulators(Iterable> accumulators) { + List deltaAcc = new ArrayList<>(); + for (int idx = 0; idx < aggregators.size(); ++idx) { + List accs = new ArrayList<>(); + while (accumulators.iterator().hasNext()) { + accs.add(accumulators.iterator().next().get(idx)); } + deltaAcc.add(aggregators.get(idx).merge(accs)); + } + return deltaAcc; + } + @Override + public BeamSqlRow extractOutput(List accumulator) { + BeamSqlRow result = new BeamSqlRow(finalRecordType); + for (int idx = 0; idx < aggregators.size(); ++idx) { + result.addField(idx, aggregators.get(idx).result(accumulator.get(idx))); } - return finalRecord; + return result; } } } diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java new file mode 100644 index 000000000000..d5a929dd9b79 --- /dev/null +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql.transform; + +import java.math.BigDecimal; +import java.sql.Types; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; +import org.apache.beam.dsls.sql.schema.BeamSqlRow; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; +import org.apache.beam.dsls.sql.utils.CalciteUtils; + +/** + * Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG. + */ +class BeamBuiltinAggregations { + /** + * Built-in aggregation for COUNT. + */ + public static class Count extends BeamSqlUdaf { + private BeamSqlRecordType accType; + + public Count() { + accType = BeamSqlRecordType.create(Arrays.asList("__count"), Arrays.asList(Types.BIGINT)); + } + + @Override + public BeamSqlRow init() { + return new BeamSqlRow(accType, Arrays.asList(0L)); + } + + @Override + public BeamSqlRow add(BeamSqlRow accumulator, T input) { + return new BeamSqlRow(accType, Arrays.asList(accumulator.getLong(0) + 1)); + } + + @Override + public BeamSqlRow merge(Iterable accumulators) { + long v = 0L; + while (accumulators.iterator().hasNext()) { + v += accumulators.iterator().next().getLong(0); + } + return new BeamSqlRow(accType, Arrays.asList(v)); + } + + @Override + public Long result(BeamSqlRow accumulator) { + return accumulator.getLong(0); + } + } + + /** + * Built-in aggregation for MAX. + */ + public static class Max> extends BeamSqlUdaf { + private BeamSqlRecordType accType; + + public Max(int outputFieldType) { + this.accType = BeamSqlRecordType.create(Arrays.asList("__max"), + Arrays.asList(outputFieldType)); + } + + @Override + public BeamSqlRow init() { + return null; + } + + @Override + public BeamSqlRow add(BeamSqlRow accumulator, T input) { + return new BeamSqlRow(accType, + Arrays + .asList((accumulator == null || ((Comparable) accumulator.getFieldValue(0)) + .compareTo(input) < 0) + ? input : accumulator.getFieldValue(0))); + } + + @Override + public BeamSqlRow merge(Iterable accumulators) { + T mergedV = (T) accumulators.iterator().next().getFieldValue(0); + while (accumulators.iterator().hasNext()) { + T v = (T) accumulators.iterator().next().getFieldValue(0); + mergedV = mergedV.compareTo(v) > 0 ? mergedV : v; + } + return new BeamSqlRow(accType, Arrays.asList(mergedV)); + } + + @Override + public T result(BeamSqlRow accumulator) { + return (T) accumulator.getFieldValue(0); + } + } + + /** + * Built-in aggregation for MIN. + */ + public static class Min> extends BeamSqlUdaf { + private BeamSqlRecordType accType; + + public Min(int outputFieldType) { + this.accType = BeamSqlRecordType.create(Arrays.asList("__min"), + Arrays.asList(outputFieldType)); + } + + @Override + public BeamSqlRow init() { + return null; + } + + @Override + public BeamSqlRow add(BeamSqlRow accumulator, T input) { + return new BeamSqlRow(accType, + Arrays + .asList((accumulator == null || ((Comparable) accumulator.getFieldValue(0)) + .compareTo(input) > 0) + ? input : accumulator.getFieldValue(0))); + } + + @Override + public BeamSqlRow merge(Iterable accumulators) { + T mergedV = (T) accumulators.iterator().next().getFieldValue(0); + while (accumulators.iterator().hasNext()) { + T v = (T) accumulators.iterator().next().getFieldValue(0); + mergedV = mergedV.compareTo(v) < 0 ? mergedV : v; + } + return new BeamSqlRow(accType, Arrays.asList(mergedV)); + } + + @Override + public T result(BeamSqlRow accumulator) { + return (T) accumulator.getFieldValue(0); + } + } + + /** + * Built-in aggregation for SUM. + */ + public static class Sum extends BeamSqlUdaf { + private static List supportedType = Arrays.asList(Types.INTEGER, + Types.BIGINT, Types.SMALLINT, Types.TINYINT, Types.DOUBLE, + Types.FLOAT, Types.DECIMAL); + + private int outputFieldType; + private BeamSqlRecordType accType; + public Sum(int outputFieldType) { + //check input data type is supported + if (!supportedType.contains(outputFieldType)) { + throw new UnsupportedOperationException(String.format( + "data type [%s] is not supported in SUM", CalciteUtils.toCalciteType(outputFieldType))); + } + + this.outputFieldType = outputFieldType; + this.accType = BeamSqlRecordType.create(Arrays.asList("__sum"), + Arrays.asList(Types.DECIMAL)); //by default use DOUBLE to store the value. + } + + @Override + public BeamSqlRow init() { + return new BeamSqlRow(accType, Arrays.asList(new BigDecimal(0))); + } + + @Override + public BeamSqlRow add(BeamSqlRow accumulator, T input) { + return new BeamSqlRow(accType, Arrays.asList(accumulator.getBigDecimal(0) + .add(new BigDecimal(input.toString())))); + } + + @Override + public BeamSqlRow merge(Iterable accumulators) { + BigDecimal v = new BigDecimal(0); + while (accumulators.iterator().hasNext()) { + v.add(accumulators.iterator().next().getBigDecimal(0)); + } + return new BeamSqlRow(accType, Arrays.asList(v)); + } + + @Override + public T result(BeamSqlRow accumulator) { + Object result = null; + switch (outputFieldType) { + case Types.INTEGER: + result = accumulator.getBigDecimal(0).intValue(); + break; + case Types.BIGINT: + result = accumulator.getBigDecimal(0).longValue(); + break; + case Types.SMALLINT: + result = accumulator.getBigDecimal(0).shortValue(); + break; + case Types.TINYINT: + result = accumulator.getBigDecimal(0).byteValue(); + break; + case Types.DOUBLE: + result = accumulator.getBigDecimal(0).doubleValue(); + break; + case Types.FLOAT: + result = accumulator.getBigDecimal(0).floatValue(); + break; + case Types.DECIMAL: + result = accumulator.getBigDecimal(0); + break; + default: + break; + } + return (T) result; + } + + } + + /** + * Built-in aggregation for AVG. + */ + public static class Avg extends BeamSqlUdaf { + private static List supportedType = Arrays.asList(Types.INTEGER, + Types.BIGINT, Types.SMALLINT, Types.TINYINT, Types.DOUBLE, + Types.FLOAT, Types.DECIMAL); + + private int outputFieldType; + private BeamSqlRecordType accType; + public Avg(int outputFieldType) { + //check input data type is supported + if (!supportedType.contains(outputFieldType)) { + throw new UnsupportedOperationException(String.format( + "data type [%s] is not supported in AVG", CalciteUtils.toCalciteType(outputFieldType))); + } + + this.outputFieldType = outputFieldType; + this.accType = BeamSqlRecordType.create(Arrays.asList("__sum", "size"), + Arrays.asList(Types.DECIMAL, Types.BIGINT)); //by default use DOUBLE to store the value. + } + + @Override + public BeamSqlRow init() { + return new BeamSqlRow(accType, Arrays.asList(new BigDecimal(0), 0L)); + } + + @Override + public BeamSqlRow add(BeamSqlRow accumulator, T input) { + return new BeamSqlRow(accType, + Arrays.asList( + accumulator.getBigDecimal(0).add(new BigDecimal(input.toString())), + accumulator.getLong(1) + 1)); + } + + @Override + public BeamSqlRow merge(Iterable accumulators) { + BigDecimal v = new BigDecimal(0); + long s = 0; + while (accumulators.iterator().hasNext()) { + BeamSqlRow r = accumulators.iterator().next(); + v.add(r.getBigDecimal(0)); + s += r.getLong(1); + } + return new BeamSqlRow(accType, Arrays.asList(v, s)); + } + + @Override + public T result(BeamSqlRow accumulator) { + Object result = null; + BigDecimal decimalAvg = accumulator.getBigDecimal(0).divide( + new BigDecimal(accumulator.getLong(1))); + switch (outputFieldType) { + case Types.INTEGER: + result = decimalAvg.intValue(); + break; + case Types.BIGINT: + result = decimalAvg.longValue(); + break; + case Types.SMALLINT: + result = decimalAvg.shortValue(); + break; + case Types.TINYINT: + result = decimalAvg.byteValue(); + break; + case Types.DOUBLE: + result = decimalAvg.doubleValue(); + break; + case Types.FLOAT: + result = decimalAvg.floatValue(); + break; + case Types.DECIMAL: + result = decimalAvg; + break; + default: + break; + } + return (T) result; + } + + } + +} diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java index b0509ae1a330..5ecd73a4ea04 100644 --- a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java +++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java @@ -21,11 +21,13 @@ import java.util.Arrays; import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; import org.apache.beam.dsls.sql.schema.BeamSqlRow; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; +import org.junit.Ignore; import org.junit.Test; /** @@ -257,4 +259,68 @@ public void testSessionWindow() throws Exception { pipeline.run().waitUntilFinish(); } + + /** + * GROUP-BY with UDAF. + */ + @Ignore + public void testAggregationWithUDAF() throws Exception { + String sql = "SELECT f_int2, squaresum(f_int) AS `squaresum` FROM TABLE_A GROUP BY f_int2"; + + //The test case is disabled temporally as BeamSql doesn't have methods to regester UDF/UDAF, + //pending on task BEAM-2520 +// BeamSqlEnv.registerUdaf("squaresum", SquareSum.class); + PCollection result = + inputA1.apply("testAggregationWithUDAF", BeamSql.simpleQuery(sql)); + + BeamSqlRecordType resultType = BeamSqlRecordType.create(Arrays.asList("f_int2", "squaresum"), + Arrays.asList(Types.INTEGER, Types.INTEGER)); + + BeamSqlRow record = new BeamSqlRow(resultType); + record.addField("f_int2", 0); + record.addField("squaresum", 30); + + PAssert.that(result).containsInAnyOrder(record); + + pipeline.run().waitUntilFinish(); + } + + /** + * UDAF for test, which returns the sum of square. + */ + public static class SquareSum extends BeamSqlUdaf { + private int outputFieldType; + private BeamSqlRecordType accType; + + public SquareSum() { + this.outputFieldType = Types.INTEGER; + accType = BeamSqlRecordType.create(Arrays.asList("__tudaf"), Arrays.asList(outputFieldType)); + } + + // @Override + public BeamSqlRow init() { + return new BeamSqlRow(accType, Arrays.asList(0)); + } + + // @Override + public BeamSqlRow add(BeamSqlRow accumulator, Integer input) { + return new BeamSqlRow(accType, + Arrays.asList(accumulator.getInteger(0) + input * input)); + } + + // @Override + public BeamSqlRow merge(Iterable accumulators) { + int v = 0; + while (accumulators.iterator().hasNext()) { + v += accumulators.iterator().next().getInteger(0); + } + return new BeamSqlRow(accType, Arrays.asList(v)); + } + + // @Override + public Integer result(BeamSqlRow accumulator) { + return accumulator.getInteger(0); + } + + } } diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java index 388a34485ab3..2b01254d041f 100644 --- a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java +++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java @@ -117,7 +117,7 @@ public void testCountPerElementBasic() throws ParseException { //3. run aggregation functions PCollection> aggregatedStream = groupedStream.apply("aggregation", Combine.groupedValues( - new BeamAggregationTransforms.AggregationCombineFn(aggCalls, inputRowType))) + new BeamAggregationTransforms.AggregationAdaptor(aggCalls, inputRowType))) .setCoder(KvCoder.of(keyCoder, aggCoder)); //4. flat KV to a single record From 449d1fa6547cce7a9a59345cc3fa87de31111675 Mon Sep 17 00:00:00 2001 From: mingmxu Date: Tue, 27 Jun 2017 21:46:12 -0700 Subject: [PATCH 2/5] fix findbug reports --- .../beam/dsls/sql/transform/BeamBuiltinAggregations.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java index d5a929dd9b79..11097e37928b 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java @@ -184,7 +184,7 @@ public BeamSqlRow add(BeamSqlRow accumulator, T input) { public BeamSqlRow merge(Iterable accumulators) { BigDecimal v = new BigDecimal(0); while (accumulators.iterator().hasNext()) { - v.add(accumulators.iterator().next().getBigDecimal(0)); + v = v.add(accumulators.iterator().next().getBigDecimal(0)); } return new BeamSqlRow(accType, Arrays.asList(v)); } @@ -263,7 +263,7 @@ public BeamSqlRow merge(Iterable accumulators) { long s = 0; while (accumulators.iterator().hasNext()) { BeamSqlRow r = accumulators.iterator().next(); - v.add(r.getBigDecimal(0)); + v = v.add(r.getBigDecimal(0)); s += r.getLong(1); } return new BeamSqlRow(accType, Arrays.asList(v, s)); From 52013593a1855841308d42f8de09b423393b71f5 Mon Sep 17 00:00:00 2001 From: mingmxu Date: Thu, 29 Jun 2017 21:47:39 -0700 Subject: [PATCH 3/5] change BeamSqlUdaf to BeamSqlUdaf --- .../beam/dsls/sql/schema/BeamSqlUdaf.java | 28 +- .../transform/BeamAggregationTransforms.java | 242 ++++------ .../transform/BeamBuiltinAggregations.java | 453 +++++++++++------- .../dsls/sql/BeamSqlDslAggregationTest.java | 66 --- 4 files changed, 408 insertions(+), 381 deletions(-) diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java index 20e5c1089c35..9582ffaea898 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java @@ -18,6 +18,10 @@ package org.apache.beam.dsls.sql.schema; import java.io.Serializable; +import java.lang.reflect.ParameterizedType; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.transforms.Combine.CombineFn; /** @@ -28,31 +32,41 @@ * 2. The type of {@code InputT} and {@code OutputT} can only be Interger/Long/Short/Byte/Double * /Float/Date/BigDecimal, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT * /TIMESTAMP/DECIMAL;
- * 3. wrap intermediate data in a {@link BeamSqlRow}, and do not rely on elements in class;
- * 4. The intermediate value of UDAF function is stored in a {@code BeamSqlRow} object.
+ * 3. Keep intermediate data in {@code AccumT}, and do not rely on elements in class;
*/ -public abstract class BeamSqlUdaf implements Serializable { +public abstract class BeamSqlUdaf implements Serializable { public BeamSqlUdaf(){} /** * create an initial aggregation object, equals to {@link CombineFn#createAccumulator()}. */ - public abstract BeamSqlRow init(); + public abstract AccumT init(); /** * add an input value, equals to {@link CombineFn#addInput(Object, Object)}. */ - public abstract BeamSqlRow add(BeamSqlRow accumulator, InputT input); + public abstract AccumT add(AccumT accumulator, InputT input); /** * merge aggregation objects from parallel tasks, equals to * {@link CombineFn#mergeAccumulators(Iterable)}. */ - public abstract BeamSqlRow merge(Iterable accumulators); + public abstract AccumT merge(Iterable accumulators); /** * extract output value from aggregation object, equals to * {@link CombineFn#extractOutput(Object)}. */ - public abstract OutputT result(BeamSqlRow accumulator); + public abstract OutputT result(AccumT accumulator); + + /** + * get the coder for AccumT which stores the intermediate result. + * By default it's fetched from {@link CoderRegistry}. + */ + public Coder getAccumulatorCoder(CoderRegistry registry) + throws CannotProvideCoderException { + return registry.getCoder( + (Class) ((ParameterizedType) getClass() + .getGenericSuperclass()).getActualTypeArguments()[1]); + } } diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java index 1e1ac359f650..9c0b4a37ae7d 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java @@ -17,10 +17,13 @@ */ package org.apache.beam.dsls.sql.transform; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.Serializable; import java.math.BigDecimal; import java.util.ArrayList; -import java.util.Date; +import java.util.Iterator; import java.util.List; import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlExpression; import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlInputRefExpression; @@ -28,6 +31,13 @@ import org.apache.beam.dsls.sql.schema.BeamSqlRow; import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.dsls.sql.utils.CalciteUtils; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.SerializableFunction; @@ -71,9 +81,7 @@ public void processElement(ProcessContext c, BoundedWindow window) { outRecord.addField(aggFieldNames.get(idx), kvRecord.getValue().getFieldValue(idx)); } - // if (c.pane().isLast()) { c.output(outRecord); - // } } } @@ -137,7 +145,7 @@ public Instant apply(BeamSqlRow input) { * An adaptor class to invoke Calcite UDAF instances in Beam {@code CombineFn}. */ public static class AggregationAdaptor - extends CombineFn, BeamSqlRow> { + extends CombineFn { private List aggregators; private List sourceFieldExps; private BeamSqlRecordType finalRecordType; @@ -159,176 +167,128 @@ public AggregationAdaptor(List aggregationCalls, outFieldsType.add(outFieldType); switch (call.getAggregation().getName()) { - case "COUNT": - aggregators.add(new BeamBuiltinAggregations.Count()); - break; - case "MAX": - switch (call.type.getSqlTypeName()) { - case INTEGER: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - case SMALLINT: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - case TINYINT: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - case BIGINT: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - case FLOAT: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - case DOUBLE: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - case TIMESTAMP: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - case DECIMAL: - aggregators.add(new BeamBuiltinAggregations.Max(outFieldType)); - break; - default: - throw new UnsupportedOperationException(); - } - break; - case "MIN": - switch (call.type.getSqlTypeName()) { - case INTEGER: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); - break; - case SMALLINT: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); - break; - case TINYINT: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); - break; - case BIGINT: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); + case "COUNT": + aggregators.add(new BeamBuiltinAggregations.Count()); break; - case FLOAT: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); + case "MAX": + aggregators.add(BeamBuiltinAggregations.Max.create(call.type.getSqlTypeName())); break; - case DOUBLE: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); + case "MIN": + aggregators.add(BeamBuiltinAggregations.Min.create(call.type.getSqlTypeName())); break; - case TIMESTAMP: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); + case "SUM": + aggregators.add(BeamBuiltinAggregations.Sum.create(call.type.getSqlTypeName())); break; - case DECIMAL: - aggregators.add(new BeamBuiltinAggregations.Min(outFieldType)); + case "AVG": + aggregators.add(BeamBuiltinAggregations.Avg.create(call.type.getSqlTypeName())); break; default: - throw new UnsupportedOperationException(); - } - break; - case "SUM": - switch (call.type.getSqlTypeName()) { - case INTEGER: - aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); - break; - case SMALLINT: - aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); - break; - case TINYINT: - aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); - break; - case BIGINT: - aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); - break; - case FLOAT: - aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); - break; - case DOUBLE: - aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); - break; - case DECIMAL: - aggregators.add(new BeamBuiltinAggregations.Sum(outFieldType)); - break; - default: - throw new UnsupportedOperationException(); - } - break; - case "AVG": - switch (call.type.getSqlTypeName()) { - case INTEGER: - aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); - break; - case SMALLINT: - aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); - break; - case TINYINT: - aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); - break; - case BIGINT: - aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); - break; - case FLOAT: - aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); - break; - case DOUBLE: - aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); - break; - case DECIMAL: - aggregators.add(new BeamBuiltinAggregations.Avg(outFieldType)); - break; - default: - throw new UnsupportedOperationException(); - } - break; - default: - if (call.getAggregation() instanceof SqlUserDefinedAggFunction) { - // handle UDAF. - SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction) call.getAggregation(); - AggregateFunctionImpl fn = (AggregateFunctionImpl) udaf.function; - try { - aggregators.add((BeamSqlUdaf) fn.declaringClass.newInstance()); - } catch (Exception e) { - throw new IllegalStateException(e); + if (call.getAggregation() instanceof SqlUserDefinedAggFunction) { + // handle UDAF. + SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction) call.getAggregation(); + AggregateFunctionImpl fn = (AggregateFunctionImpl) udaf.function; + try { + aggregators.add((BeamSqlUdaf) fn.declaringClass.newInstance()); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } else { + throw new UnsupportedOperationException( + String.format("Aggregator [%s] is not supported", + call.getAggregation().getName())); } - } else { - throw new UnsupportedOperationException(); - } + break; } } finalRecordType = BeamSqlRecordType.create(outFieldsName, outFieldsType); } @Override - public List createAccumulator() { - List initialAccu = new ArrayList<>(); + public AggregationAccumulator createAccumulator() { + AggregationAccumulator initialAccu = new AggregationAccumulator(); for (BeamSqlUdaf agg : aggregators) { - initialAccu.add(agg.init()); + initialAccu.accumulatorElements.add(agg.init()); } return initialAccu; } @Override - public List addInput(List accumulator, BeamSqlRow input) { - List deltaAcc = new ArrayList<>(); + public AggregationAccumulator addInput(AggregationAccumulator accumulator, BeamSqlRow input) { + AggregationAccumulator deltaAcc = new AggregationAccumulator(); for (int idx = 0; idx < aggregators.size(); ++idx) { - deltaAcc.add(aggregators.get(idx).add(accumulator.get(idx), + deltaAcc.accumulatorElements.add( + aggregators.get(idx).add(accumulator.accumulatorElements.get(idx), sourceFieldExps.get(idx).evaluate(input).getValue())); } return deltaAcc; } @Override - public List mergeAccumulators(Iterable> accumulators) { - List deltaAcc = new ArrayList<>(); + public AggregationAccumulator mergeAccumulators(Iterable accumulators) { + AggregationAccumulator deltaAcc = new AggregationAccumulator(); for (int idx = 0; idx < aggregators.size(); ++idx) { - List accs = new ArrayList<>(); - while (accumulators.iterator().hasNext()) { - accs.add(accumulators.iterator().next().get(idx)); + List accs = new ArrayList<>(); + Iterator ite = accumulators.iterator(); + while (ite.hasNext()) { + accs.add(ite.next().accumulatorElements.get(idx)); } - deltaAcc.add(aggregators.get(idx).merge(accs)); + deltaAcc.accumulatorElements.add(aggregators.get(idx).merge(accs)); } return deltaAcc; } @Override - public BeamSqlRow extractOutput(List accumulator) { + public BeamSqlRow extractOutput(AggregationAccumulator accumulator) { BeamSqlRow result = new BeamSqlRow(finalRecordType); for (int idx = 0; idx < aggregators.size(); ++idx) { - result.addField(idx, aggregators.get(idx).result(accumulator.get(idx))); + result.addField(idx, aggregators.get(idx).result(accumulator.accumulatorElements.get(idx))); } return result; } + @Override + public Coder getAccumulatorCoder( + CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of()); + List aggAccuCoderList = new ArrayList<>(); + for (BeamSqlUdaf udaf : aggregators) { + aggAccuCoderList.add(udaf.getAccumulatorCoder(registry)); + } + return new AggregationAccumulatorCoder(aggAccuCoderList); + } + } + + /** + * A class to holder varied accumulator objects. + */ + public static class AggregationAccumulator{ + private List accumulatorElements = new ArrayList<>(); + } + + /** + * Coder for {@link AggregationAccumulator}. + */ + public static class AggregationAccumulatorCoder extends CustomCoder{ + private VarIntCoder sizeCoder = VarIntCoder.of(); + private List elementCoders; + + public AggregationAccumulatorCoder(List elementCoders) { + this.elementCoders = elementCoders; + } + + @Override + public void encode(AggregationAccumulator value, OutputStream outStream) + throws CoderException, IOException { + sizeCoder.encode(value.accumulatorElements.size(), outStream); + for (int idx = 0; idx < value.accumulatorElements.size(); ++idx) { + elementCoders.get(idx).encode(value.accumulatorElements.get(idx), outStream); + } + } + + @Override + public AggregationAccumulator decode(InputStream inStream) throws CoderException, IOException { + AggregationAccumulator accu = new AggregationAccumulator(); + int size = sizeCoder.decode(inStream); + for (int idx = 0; idx < size; ++idx) { + accu.accumulatorElements.add(elementCoders.get(idx).decode(inStream)); + } + return accu; + } } } diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java index 11097e37928b..4f2022a5ba66 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java @@ -18,13 +18,21 @@ package org.apache.beam.dsls.sql.transform; import java.math.BigDecimal; -import java.sql.Types; -import java.util.Arrays; -import java.util.List; -import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; -import org.apache.beam.dsls.sql.schema.BeamSqlRow; +import java.util.Date; +import java.util.Iterator; import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; -import org.apache.beam.dsls.sql.utils.CalciteUtils; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.values.KV; +import org.apache.calcite.sql.type.SqlTypeName; /** * Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG. @@ -33,275 +41,386 @@ class BeamBuiltinAggregations { /** * Built-in aggregation for COUNT. */ - public static class Count extends BeamSqlUdaf { - private BeamSqlRecordType accType; - - public Count() { - accType = BeamSqlRecordType.create(Arrays.asList("__count"), Arrays.asList(Types.BIGINT)); - } + public static final class Count extends BeamSqlUdaf { + public Count() {} @Override - public BeamSqlRow init() { - return new BeamSqlRow(accType, Arrays.asList(0L)); + public Long init() { + return 0L; } @Override - public BeamSqlRow add(BeamSqlRow accumulator, T input) { - return new BeamSqlRow(accType, Arrays.asList(accumulator.getLong(0) + 1)); + public Long add(Long accumulator, T input) { + return accumulator + 1; } @Override - public BeamSqlRow merge(Iterable accumulators) { + public Long merge(Iterable accumulators) { long v = 0L; - while (accumulators.iterator().hasNext()) { - v += accumulators.iterator().next().getLong(0); + Iterator ite = accumulators.iterator(); + while (ite.hasNext()) { + v += ite.next(); } - return new BeamSqlRow(accType, Arrays.asList(v)); + return v; } @Override - public Long result(BeamSqlRow accumulator) { - return accumulator.getLong(0); + public Long result(Long accumulator) { + return accumulator; } } /** * Built-in aggregation for MAX. */ - public static class Max> extends BeamSqlUdaf { - private BeamSqlRecordType accType; + public static final class Max> extends BeamSqlUdaf { + public static Max create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Max(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Max(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Max(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Max(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Max(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Max(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Max(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Max(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in MAX", fieldType)); + } + } - public Max(int outputFieldType) { - this.accType = BeamSqlRecordType.create(Arrays.asList("__max"), - Arrays.asList(outputFieldType)); + private SqlTypeName fieldType; + private Max(SqlTypeName fieldType) { + this.fieldType = fieldType; } @Override - public BeamSqlRow init() { + public T init() { return null; } @Override - public BeamSqlRow add(BeamSqlRow accumulator, T input) { - return new BeamSqlRow(accType, - Arrays - .asList((accumulator == null || ((Comparable) accumulator.getFieldValue(0)) - .compareTo(input) < 0) - ? input : accumulator.getFieldValue(0))); + public T add(T accumulator, T input) { + return (accumulator == null || accumulator.compareTo(input) < 0) ? input : accumulator; } @Override - public BeamSqlRow merge(Iterable accumulators) { - T mergedV = (T) accumulators.iterator().next().getFieldValue(0); - while (accumulators.iterator().hasNext()) { - T v = (T) accumulators.iterator().next().getFieldValue(0); + public T merge(Iterable accumulators) { + Iterator ite = accumulators.iterator(); + T mergedV = ite.next(); + while (ite.hasNext()) { + T v = ite.next(); mergedV = mergedV.compareTo(v) > 0 ? mergedV : v; } - return new BeamSqlRow(accType, Arrays.asList(mergedV)); + return mergedV; + } + + @Override + public T result(T accumulator) { + return accumulator; } @Override - public T result(BeamSqlRow accumulator) { - return (T) accumulator.getFieldValue(0); + public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { + switch (fieldType) { + case INTEGER: + return (Coder) VarIntCoder.of(); + case SMALLINT: + return (Coder) SerializableCoder.of(Short.class); + case TINYINT: + return (Coder) ByteCoder.of(); + case BIGINT: + return (Coder) VarLongCoder.of(); + case FLOAT: + return (Coder) SerializableCoder.of(Float.class); + case DOUBLE: + return (Coder) DoubleCoder.of(); + case TIMESTAMP: + return (Coder) SerializableCoder.of(Date.class); + case DECIMAL: + return (Coder) BigDecimalCoder.of(); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in MAX", fieldType)); + } } } /** * Built-in aggregation for MIN. */ - public static class Min> extends BeamSqlUdaf { - private BeamSqlRecordType accType; + public static final class Min> extends BeamSqlUdaf { + public static Min create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Min(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Min(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Min(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Min(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Min(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Min(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Min(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Min(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in MIN", fieldType)); + } + } - public Min(int outputFieldType) { - this.accType = BeamSqlRecordType.create(Arrays.asList("__min"), - Arrays.asList(outputFieldType)); + private SqlTypeName fieldType; + private Min(SqlTypeName fieldType) { + this.fieldType = fieldType; } @Override - public BeamSqlRow init() { + public T init() { return null; } @Override - public BeamSqlRow add(BeamSqlRow accumulator, T input) { - return new BeamSqlRow(accType, - Arrays - .asList((accumulator == null || ((Comparable) accumulator.getFieldValue(0)) - .compareTo(input) > 0) - ? input : accumulator.getFieldValue(0))); + public T add(T accumulator, T input) { + return (accumulator == null || accumulator.compareTo(input) > 0) ? input : accumulator; } @Override - public BeamSqlRow merge(Iterable accumulators) { - T mergedV = (T) accumulators.iterator().next().getFieldValue(0); - while (accumulators.iterator().hasNext()) { - T v = (T) accumulators.iterator().next().getFieldValue(0); + public T merge(Iterable accumulators) { + Iterator ite = accumulators.iterator(); + T mergedV = ite.next(); + while (ite.hasNext()) { + T v = ite.next(); mergedV = mergedV.compareTo(v) < 0 ? mergedV : v; } - return new BeamSqlRow(accType, Arrays.asList(mergedV)); + return mergedV; } @Override - public T result(BeamSqlRow accumulator) { - return (T) accumulator.getFieldValue(0); + public T result(T accumulator) { + return accumulator; + } + + @Override + public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { + switch (fieldType) { + case INTEGER: + return (Coder) VarIntCoder.of(); + case SMALLINT: + return (Coder) SerializableCoder.of(Short.class); + case TINYINT: + return (Coder) ByteCoder.of(); + case BIGINT: + return (Coder) VarLongCoder.of(); + case FLOAT: + return (Coder) SerializableCoder.of(Float.class); + case DOUBLE: + return (Coder) DoubleCoder.of(); + case TIMESTAMP: + return (Coder) SerializableCoder.of(Date.class); + case DECIMAL: + return (Coder) BigDecimalCoder.of(); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in MIN", fieldType)); + } } } /** * Built-in aggregation for SUM. */ - public static class Sum extends BeamSqlUdaf { - private static List supportedType = Arrays.asList(Types.INTEGER, - Types.BIGINT, Types.SMALLINT, Types.TINYINT, Types.DOUBLE, - Types.FLOAT, Types.DECIMAL); - - private int outputFieldType; - private BeamSqlRecordType accType; - public Sum(int outputFieldType) { - //check input data type is supported - if (!supportedType.contains(outputFieldType)) { - throw new UnsupportedOperationException(String.format( - "data type [%s] is not supported in SUM", CalciteUtils.toCalciteType(outputFieldType))); + public static final class Sum extends BeamSqlUdaf { + public static Sum create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Sum(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Sum(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Sum(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Sum(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Sum(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Sum(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Sum(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Sum(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in SUM", fieldType)); } - - this.outputFieldType = outputFieldType; - this.accType = BeamSqlRecordType.create(Arrays.asList("__sum"), - Arrays.asList(Types.DECIMAL)); //by default use DOUBLE to store the value. } + private SqlTypeName fieldType; + private Sum(SqlTypeName fieldType) { + this.fieldType = fieldType; + } + @Override - public BeamSqlRow init() { - return new BeamSqlRow(accType, Arrays.asList(new BigDecimal(0))); + public BigDecimal init() { + return new BigDecimal(0); } @Override - public BeamSqlRow add(BeamSqlRow accumulator, T input) { - return new BeamSqlRow(accType, Arrays.asList(accumulator.getBigDecimal(0) - .add(new BigDecimal(input.toString())))); + public BigDecimal add(BigDecimal accumulator, T input) { + return accumulator.add(new BigDecimal(input.toString())); } @Override - public BeamSqlRow merge(Iterable accumulators) { + public BigDecimal merge(Iterable accumulators) { BigDecimal v = new BigDecimal(0); - while (accumulators.iterator().hasNext()) { - v = v.add(accumulators.iterator().next().getBigDecimal(0)); + Iterator ite = accumulators.iterator(); + while (ite.hasNext()) { + v = v.add(ite.next()); } - return new BeamSqlRow(accType, Arrays.asList(v)); + return v; } @Override - public T result(BeamSqlRow accumulator) { + public T result(BigDecimal accumulator) { Object result = null; - switch (outputFieldType) { - case Types.INTEGER: - result = accumulator.getBigDecimal(0).intValue(); - break; - case Types.BIGINT: - result = accumulator.getBigDecimal(0).longValue(); - break; - case Types.SMALLINT: - result = accumulator.getBigDecimal(0).shortValue(); - break; - case Types.TINYINT: - result = accumulator.getBigDecimal(0).byteValue(); - break; - case Types.DOUBLE: - result = accumulator.getBigDecimal(0).doubleValue(); - break; - case Types.FLOAT: - result = accumulator.getBigDecimal(0).floatValue(); - break; - case Types.DECIMAL: - result = accumulator.getBigDecimal(0); - break; - default: - break; + switch (fieldType) { + case INTEGER: + result = accumulator.intValue(); + break; + case BIGINT: + result = accumulator.longValue(); + break; + case SMALLINT: + result = accumulator.shortValue(); + break; + case TINYINT: + result = accumulator.byteValue(); + break; + case DOUBLE: + result = accumulator.doubleValue(); + break; + case FLOAT: + result = accumulator.floatValue(); + break; + case DECIMAL: + result = accumulator; + break; + default: + break; } return (T) result; } - } /** * Built-in aggregation for AVG. */ - public static class Avg extends BeamSqlUdaf { - private static List supportedType = Arrays.asList(Types.INTEGER, - Types.BIGINT, Types.SMALLINT, Types.TINYINT, Types.DOUBLE, - Types.FLOAT, Types.DECIMAL); - - private int outputFieldType; - private BeamSqlRecordType accType; - public Avg(int outputFieldType) { - //check input data type is supported - if (!supportedType.contains(outputFieldType)) { - throw new UnsupportedOperationException(String.format( - "data type [%s] is not supported in AVG", CalciteUtils.toCalciteType(outputFieldType))); + public static final class Avg extends BeamSqlUdaf, T> { + public static Avg create(SqlTypeName fieldType) { + switch (fieldType) { + case INTEGER: + return new BeamBuiltinAggregations.Avg(fieldType); + case SMALLINT: + return new BeamBuiltinAggregations.Avg(fieldType); + case TINYINT: + return new BeamBuiltinAggregations.Avg(fieldType); + case BIGINT: + return new BeamBuiltinAggregations.Avg(fieldType); + case FLOAT: + return new BeamBuiltinAggregations.Avg(fieldType); + case DOUBLE: + return new BeamBuiltinAggregations.Avg(fieldType); + case TIMESTAMP: + return new BeamBuiltinAggregations.Avg(fieldType); + case DECIMAL: + return new BeamBuiltinAggregations.Avg(fieldType); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not support in AVG", fieldType)); } - - this.outputFieldType = outputFieldType; - this.accType = BeamSqlRecordType.create(Arrays.asList("__sum", "size"), - Arrays.asList(Types.DECIMAL, Types.BIGINT)); //by default use DOUBLE to store the value. } + private SqlTypeName fieldType; + private Avg(SqlTypeName fieldType) { + this.fieldType = fieldType; + } + @Override - public BeamSqlRow init() { - return new BeamSqlRow(accType, Arrays.asList(new BigDecimal(0), 0L)); + public KV init() { + return KV.of(new BigDecimal(0), 0L); } @Override - public BeamSqlRow add(BeamSqlRow accumulator, T input) { - return new BeamSqlRow(accType, - Arrays.asList( - accumulator.getBigDecimal(0).add(new BigDecimal(input.toString())), - accumulator.getLong(1) + 1)); + public KV add(KV accumulator, T input) { + return KV.of( + accumulator.getKey().add(new BigDecimal(input.toString())), + accumulator.getValue() + 1); } @Override - public BeamSqlRow merge(Iterable accumulators) { + public KV merge(Iterable> accumulators) { BigDecimal v = new BigDecimal(0); long s = 0; - while (accumulators.iterator().hasNext()) { - BeamSqlRow r = accumulators.iterator().next(); - v = v.add(r.getBigDecimal(0)); - s += r.getLong(1); + Iterator> ite = accumulators.iterator(); + while (ite.hasNext()) { + KV r = ite.next(); + v = v.add(r.getKey()); + s += r.getValue(); } - return new BeamSqlRow(accType, Arrays.asList(v, s)); + return KV.of(v, s); } @Override - public T result(BeamSqlRow accumulator) { + public T result(KV accumulator) { + BigDecimal decimalAvg = accumulator.getKey().divide( + new BigDecimal(accumulator.getValue())); Object result = null; - BigDecimal decimalAvg = accumulator.getBigDecimal(0).divide( - new BigDecimal(accumulator.getLong(1))); - switch (outputFieldType) { - case Types.INTEGER: - result = decimalAvg.intValue(); - break; - case Types.BIGINT: - result = decimalAvg.longValue(); - break; - case Types.SMALLINT: - result = decimalAvg.shortValue(); - break; - case Types.TINYINT: - result = decimalAvg.byteValue(); - break; - case Types.DOUBLE: - result = decimalAvg.doubleValue(); - break; - case Types.FLOAT: - result = decimalAvg.floatValue(); - break; - case Types.DECIMAL: - result = decimalAvg; - break; - default: - break; + switch (fieldType) { + case INTEGER: + result = decimalAvg.intValue(); + break; + case BIGINT: + result = decimalAvg.longValue(); + break; + case SMALLINT: + result = decimalAvg.shortValue(); + break; + case TINYINT: + result = decimalAvg.byteValue(); + break; + case DOUBLE: + result = decimalAvg.doubleValue(); + break; + case FLOAT: + result = decimalAvg.floatValue(); + break; + case DECIMAL: + result = decimalAvg; + break; + default: + break; } return (T) result; } + @Override + public Coder> getAccumulatorCoder(CoderRegistry registry) + throws CannotProvideCoderException { + return KvCoder.of(BigDecimalCoder.of(), VarLongCoder.of()); + } } } diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java index 5ecd73a4ea04..b0509ae1a330 100644 --- a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java +++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslAggregationTest.java @@ -21,13 +21,11 @@ import java.util.Arrays; import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; import org.apache.beam.dsls.sql.schema.BeamSqlRow; -import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; -import org.junit.Ignore; import org.junit.Test; /** @@ -259,68 +257,4 @@ public void testSessionWindow() throws Exception { pipeline.run().waitUntilFinish(); } - - /** - * GROUP-BY with UDAF. - */ - @Ignore - public void testAggregationWithUDAF() throws Exception { - String sql = "SELECT f_int2, squaresum(f_int) AS `squaresum` FROM TABLE_A GROUP BY f_int2"; - - //The test case is disabled temporally as BeamSql doesn't have methods to regester UDF/UDAF, - //pending on task BEAM-2520 -// BeamSqlEnv.registerUdaf("squaresum", SquareSum.class); - PCollection result = - inputA1.apply("testAggregationWithUDAF", BeamSql.simpleQuery(sql)); - - BeamSqlRecordType resultType = BeamSqlRecordType.create(Arrays.asList("f_int2", "squaresum"), - Arrays.asList(Types.INTEGER, Types.INTEGER)); - - BeamSqlRow record = new BeamSqlRow(resultType); - record.addField("f_int2", 0); - record.addField("squaresum", 30); - - PAssert.that(result).containsInAnyOrder(record); - - pipeline.run().waitUntilFinish(); - } - - /** - * UDAF for test, which returns the sum of square. - */ - public static class SquareSum extends BeamSqlUdaf { - private int outputFieldType; - private BeamSqlRecordType accType; - - public SquareSum() { - this.outputFieldType = Types.INTEGER; - accType = BeamSqlRecordType.create(Arrays.asList("__tudaf"), Arrays.asList(outputFieldType)); - } - - // @Override - public BeamSqlRow init() { - return new BeamSqlRow(accType, Arrays.asList(0)); - } - - // @Override - public BeamSqlRow add(BeamSqlRow accumulator, Integer input) { - return new BeamSqlRow(accType, - Arrays.asList(accumulator.getInteger(0) + input * input)); - } - - // @Override - public BeamSqlRow merge(Iterable accumulators) { - int v = 0; - while (accumulators.iterator().hasNext()) { - v += accumulators.iterator().next().getInteger(0); - } - return new BeamSqlRow(accType, Arrays.asList(v)); - } - - // @Override - public Integer result(BeamSqlRow accumulator) { - return accumulator.getInteger(0); - } - - } } From c81d52f3bbea2c1787d78f73e9a47f03d657b5e9 Mon Sep 17 00:00:00 2001 From: mingmxu Date: Fri, 30 Jun 2017 12:35:20 -0700 Subject: [PATCH 4/5] cleanup and update Coder behaviour in BeamSqlUdaf --- .../beam/dsls/sql/schema/BeamSqlUdaf.java | 13 +++- .../transform/BeamBuiltinAggregations.java | 74 ++++++++----------- 2 files changed, 42 insertions(+), 45 deletions(-) diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java index 9582ffaea898..5d28781c223d 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java @@ -19,9 +19,13 @@ import java.io.Serializable; import java.lang.reflect.ParameterizedType; +import java.math.BigDecimal; +import java.util.Date; +import org.apache.beam.sdk.coders.BigDecimalCoder; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.Combine.CombineFn; /** @@ -61,10 +65,17 @@ public BeamSqlUdaf(){} /** * get the coder for AccumT which stores the intermediate result. - * By default it's fetched from {@link CoderRegistry}. + * By default it's fetched from {@link CoderRegistry}, and Beam SQL field types are included, + * like Integer/Long/Short/Byte/Float/Double/BigDecimal/Date. */ public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { + //Register coder for Short/Float/BigDecimal/Date + registry.registerCoderForClass(Short.class, SerializableCoder.of(Short.class)); + registry.registerCoderForClass(Float.class, SerializableCoder.of(Float.class)); + registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of()); + registry.registerCoderForClass(Date.class, SerializableCoder.of(Date.class)); + return registry.getCoder( (Class) ((ParameterizedType) getClass() .getGenericSuperclass()).getActualTypeArguments()[1]); diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java index 4f2022a5ba66..fab26667e2e9 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java @@ -98,7 +98,7 @@ public static Max create(SqlTypeName fieldType) { } } - private SqlTypeName fieldType; + private final SqlTypeName fieldType; private Max(SqlTypeName fieldType) { this.fieldType = fieldType; } @@ -131,27 +131,7 @@ public T result(T accumulator) { @Override public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { - switch (fieldType) { - case INTEGER: - return (Coder) VarIntCoder.of(); - case SMALLINT: - return (Coder) SerializableCoder.of(Short.class); - case TINYINT: - return (Coder) ByteCoder.of(); - case BIGINT: - return (Coder) VarLongCoder.of(); - case FLOAT: - return (Coder) SerializableCoder.of(Float.class); - case DOUBLE: - return (Coder) DoubleCoder.of(); - case TIMESTAMP: - return (Coder) SerializableCoder.of(Date.class); - case DECIMAL: - return (Coder) BigDecimalCoder.of(); - default: - throw new UnsupportedOperationException( - String.format("[%s] is not support in MAX", fieldType)); - } + return BeamBuiltinAggregations.getSqlTypeCoder(fieldType); } } @@ -183,7 +163,7 @@ public static Min create(SqlTypeName fieldType) { } } - private SqlTypeName fieldType; + private final SqlTypeName fieldType; private Min(SqlTypeName fieldType) { this.fieldType = fieldType; } @@ -216,27 +196,7 @@ public T result(T accumulator) { @Override public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { - switch (fieldType) { - case INTEGER: - return (Coder) VarIntCoder.of(); - case SMALLINT: - return (Coder) SerializableCoder.of(Short.class); - case TINYINT: - return (Coder) ByteCoder.of(); - case BIGINT: - return (Coder) VarLongCoder.of(); - case FLOAT: - return (Coder) SerializableCoder.of(Float.class); - case DOUBLE: - return (Coder) DoubleCoder.of(); - case TIMESTAMP: - return (Coder) SerializableCoder.of(Date.class); - case DECIMAL: - return (Coder) BigDecimalCoder.of(); - default: - throw new UnsupportedOperationException( - String.format("[%s] is not support in MIN", fieldType)); - } + return BeamBuiltinAggregations.getSqlTypeCoder(fieldType); } } @@ -423,4 +383,30 @@ public Coder> getAccumulatorCoder(CoderRegistry registry) } } + /** + * Find {@link Coder} for Beam SQL field types. + */ + private static Coder getSqlTypeCoder(SqlTypeName sqlType) { + switch (sqlType) { + case INTEGER: + return VarIntCoder.of(); + case SMALLINT: + return SerializableCoder.of(Short.class); + case TINYINT: + return ByteCoder.of(); + case BIGINT: + return VarLongCoder.of(); + case FLOAT: + return SerializableCoder.of(Float.class); + case DOUBLE: + return DoubleCoder.of(); + case TIMESTAMP: + return SerializableCoder.of(Date.class); + case DECIMAL: + return BigDecimalCoder.of(); + default: + throw new UnsupportedOperationException( + String.format("Cannot find a Coder for data type [%s]", sqlType)); + } + } } From 4e8e4ed941c94d45a417917ad836f46e4c9f26b0 Mon Sep 17 00:00:00 2001 From: mingmxu Date: Fri, 30 Jun 2017 13:19:13 -0700 Subject: [PATCH 5/5] cleanup Coders in BeamSqlUdaf --- .../apache/beam/dsls/sql/schema/BeamSqlUdaf.java | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java index 5d28781c223d..9582ffaea898 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java @@ -19,13 +19,9 @@ import java.io.Serializable; import java.lang.reflect.ParameterizedType; -import java.math.BigDecimal; -import java.util.Date; -import org.apache.beam.sdk.coders.BigDecimalCoder; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.Combine.CombineFn; /** @@ -65,17 +61,10 @@ public BeamSqlUdaf(){} /** * get the coder for AccumT which stores the intermediate result. - * By default it's fetched from {@link CoderRegistry}, and Beam SQL field types are included, - * like Integer/Long/Short/Byte/Float/Double/BigDecimal/Date. + * By default it's fetched from {@link CoderRegistry}. */ public Coder getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException { - //Register coder for Short/Float/BigDecimal/Date - registry.registerCoderForClass(Short.class, SerializableCoder.of(Short.class)); - registry.registerCoderForClass(Float.class, SerializableCoder.of(Float.class)); - registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of()); - registry.registerCoderForClass(Date.class, SerializableCoder.of(Date.class)); - return registry.getCoder( (Class) ((ParameterizedType) getClass() .getGenericSuperclass()).getActualTypeArguments()[1]);