From d533a78f5dad3f1ef085a19c8297204a1c99fd6d Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Wed, 13 Mar 2019 23:53:23 +0800 Subject: [PATCH] [FLINK-11788][table-planner-blink] Support Code Generation for RexNode 1. Introduce ExprCodeGenerator to generate code for RexNodes 2. Introduce SqlOperatorGens to generate code for SQL scalar operators 3. Introduce GenerateUtils to generate code for general purpose, e.g. generateLiteral, generateFieldAccess This closes #7982 --- .../functions/sql/FlinkSqlOperatorTable.java | 275 +++ .../sql/ProctimeMaterializeSqlFunction.java | 63 + .../functions/sql/ProctimeSqlFunction.java | 52 + .../sql/StreamRecordTimestampSqlFunction.java | 55 + .../flink/table/api/TableEnvironment.scala | 2 + .../flink/table/calcite/FlinkLocalRef.scala | 48 + .../table/calcite/FlinkTypeFactory.scala | 101 +- .../flink/table/calcite/FlinkTypeSystem.scala | 15 +- .../flink/table/codegen/CodeGenUtils.scala | 493 ++-- .../table/codegen/CodeGeneratorContext.scala | 108 +- .../table/codegen/ExprCodeGenerator.scala | 721 ++++++ .../table/codegen/FunctionCodeGenerator.scala | 162 ++ .../flink/table/codegen/GenerateUtils.scala | 751 ++++++ .../table/codegen/GeneratedExpression.scala | 45 +- .../table/codegen/SortCodeGenerator.scala | 2 +- .../table/codegen/calls/BuiltInMethods.scala | 42 + .../table/codegen/calls/CallGenerator.scala | 35 + .../table/codegen/calls/ConstantCallGen.scala | 37 + .../calls/CurrentTimePointCallGen.scala | 56 + .../codegen/calls/ScalarOperatorGens.scala | 2017 +++++++++++++++++ .../plan/schema/GenericRelDataType.scala | 59 + .../table/typeutils/TypeCheckUtils.scala | 42 + .../flink/table/typeutils/TypeCoercion.scala | 159 ++ .../table/expressions/ArrayTypeTest.scala | 215 ++ .../expressions/CompositeAccessTest.scala | 141 ++ .../table/expressions/DecimalTypeTest.scala | 222 ++ .../flink/table/expressions/LiteralTest.scala | 159 ++ .../flink/table/expressions/MapTypeTest.scala | 193 ++ .../table/expressions/MathFunctionsTest.scala | 694 ++++++ .../flink/table/expressions/RowTypeTest.scala | 88 + .../expressions/ScalarOperatorsTest.scala | 125 + .../table/expressions/SqlExpressionTest.scala | 314 +++ .../expressions/utils/ArrayTypeTestBase.scala | 63 + .../utils/CompositeTypeTestBase.scala | 92 + .../utils/ExpressionTestBase.scala | 199 ++ .../expressions/utils/MapTypeTestBase.scala | 75 + .../expressions/utils/RowTypeTestBase.scala | 67 + .../utils/ScalarOperatorsTestBase.scala | 75 + .../utils/ScalarTypesTestBase.scala | 130 ++ .../validation/ArrayTypeValidationTest.scala | 58 + .../CompositeAccessValidationTest.scala | 38 + .../validation/MapTypeValidationTest.scala | 46 + .../validation/RowTypeValidationTest.scala | 41 + .../ScalarFunctionsValidationTest.scala | 136 ++ .../flink/table/util/DateTimeTestUtil.scala | 38 + .../dataformat/DataFormatConverters.java | 78 +- .../flink/table/dataformat/Decimal.java | 40 + .../table/dataformat/LazyBinaryFormat.java | 1 + .../runtime/functions/DateTimeUtils.java | 130 +- .../runtime/functions/ThreadLocalCache.java | 68 + .../runtime/util/collections/ByteHashSet.java | 151 ++ .../util/collections/DoubleHashSet.java | 133 ++ .../util/collections/FloatHashSet.java | 133 ++ .../runtime/util/collections/IntHashSet.java | 151 ++ .../runtime/util/collections/LongHashSet.java | 151 ++ .../util/collections/ObjectHashSet.java | 53 + .../util/collections/OptimizableHashSet.java | 156 ++ .../util/collections/ShortHashSet.java | 151 ++ .../flink/table/type/InternalTypeUtils.java | 36 + .../org/apache/flink/table/type/RowType.java | 3 +- .../flink/table/type/TypeConverters.java | 3 + .../flink/table/util/MurmurHashUtil.java | 11 +- 62 files changed, 9742 insertions(+), 256 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/FlinkSqlOperatorTable.java create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeMaterializeSqlFunction.java create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/StreamRecordTimestampSqlFunction.java create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLocalRef.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ExprCodeGenerator.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/FunctionCodeGenerator.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GenerateUtils.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ConstantCallGen.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperatorGens.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/GenericRelDataType.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCoercion.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ArrayTypeTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/LiteralTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MapTypeTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MathFunctionsTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/RowTypeTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ArrayTypeTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/MapTypeTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/RowTypeTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarOperatorsTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ArrayTypeValidationTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/CompositeAccessValidationTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/MapTypeValidationTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/RowTypeValidationTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/DateTimeTestUtil.scala create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ThreadLocalCache.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/DoubleHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/FloatHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/IntHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/LongHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ObjectHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/OptimizableHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ShortHashSet.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/InternalTypeUtils.java diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/FlinkSqlOperatorTable.java new file mode 100644 index 0000000000000..19af644be0fc8 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/FlinkSqlOperatorTable.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlBinaryOperator; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; + +/** + * Operator table that contains only Flink-specific functions and operators. + */ +public class FlinkSqlOperatorTable extends ReflectiveSqlOperatorTable { + + /** + * The table of contains Flink-specific operators. + */ + private static FlinkSqlOperatorTable instance; + + /** + * Returns the Flink operator table, creating it if necessary. + */ + public static synchronized FlinkSqlOperatorTable instance() { + if (instance == null) { + // Creates and initializes the standard operator table. + // Uses two-phase construction, because we can't initialize the + // table until the constructor of the sub-class has completed. + instance = new FlinkSqlOperatorTable(); + instance.init(); + } + return instance; + } + + // ----------------------------------------------------------------------------- + + private static final SqlReturnTypeInference FLINK_QUOTIENT_NULLABLE = new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + RelDataType type1 = opBinding.getOperandType(0); + RelDataType type2 = opBinding.getOperandType(1); + if (SqlTypeUtil.isDecimal(type1) || SqlTypeUtil.isDecimal(type2)) { + return ReturnTypes.QUOTIENT_NULLABLE.inferReturnType(opBinding); + } else { + RelDataType doubleType = opBinding.getTypeFactory().createSqlType(SqlTypeName.DOUBLE); + if (type1.isNullable() || type2.isNullable()) { + return opBinding.getTypeFactory().createTypeWithNullability(doubleType, true); + } else { + return doubleType; + } + } + } + }; + + // ----------------------------------------------------------------------------- + // Flink specific built-in scalar SQL functions + // ----------------------------------------------------------------------------- + + /** + * Arithmetic division operator, '/'. + */ + public static final SqlBinaryOperator DIVIDE = new SqlBinaryOperator( + "/", + SqlKind.DIVIDE, + 60, + true, + FLINK_QUOTIENT_NULLABLE, + InferTypes.FIRST_KNOWN, + OperandTypes.DIVISION_OPERATOR); + + /** Function used to access a processing time attribute. */ + public static final SqlFunction PROCTIME = new ProctimeSqlFunction(); + + /** Function used to materialize a processing time attribute. */ + public static final SqlFunction PROCTIME_MATERIALIZE = new ProctimeMaterializeSqlFunction(); + + /** Function to access the timestamp of a StreamRecord. */ + public static final SqlFunction STREAMRECORD_TIMESTAMP = new StreamRecordTimestampSqlFunction(); + + // ----------------------------------------------------------------------------- + // Window SQL functions + // ----------------------------------------------------------------------------- + + // TODO: add window functions here + + // ----------------------------------------------------------------------------- + // operators extend from Calcite + // ----------------------------------------------------------------------------- + + // SET OPERATORS + public static final SqlOperator UNION = SqlStdOperatorTable.UNION; + public static final SqlOperator UNION_ALL = SqlStdOperatorTable.UNION_ALL; + public static final SqlOperator EXCEPT = SqlStdOperatorTable.EXCEPT; + public static final SqlOperator EXCEPT_ALL = SqlStdOperatorTable.EXCEPT_ALL; + public static final SqlOperator INTERSECT = SqlStdOperatorTable.INTERSECT; + public static final SqlOperator INTERSECT_ALL = SqlStdOperatorTable.INTERSECT_ALL; + + // BINARY OPERATORS + public static final SqlOperator AND = SqlStdOperatorTable.AND; + public static final SqlOperator AS = SqlStdOperatorTable.AS; + public static final SqlOperator CONCAT = SqlStdOperatorTable.CONCAT; + public static final SqlOperator DIVIDE_INTEGER = SqlStdOperatorTable.DIVIDE_INTEGER; + public static final SqlOperator DOT = SqlStdOperatorTable.DOT; + public static final SqlOperator EQUALS = SqlStdOperatorTable.EQUALS; + public static final SqlOperator GREATER_THAN = SqlStdOperatorTable.GREATER_THAN; + public static final SqlOperator IS_DISTINCT_FROM = SqlStdOperatorTable.IS_DISTINCT_FROM; + public static final SqlOperator IS_NOT_DISTINCT_FROM = SqlStdOperatorTable.IS_NOT_DISTINCT_FROM; + public static final SqlOperator GREATER_THAN_OR_EQUAL = SqlStdOperatorTable.GREATER_THAN_OR_EQUAL; + public static final SqlOperator LESS_THAN = SqlStdOperatorTable.LESS_THAN; + public static final SqlOperator LESS_THAN_OR_EQUAL = SqlStdOperatorTable.LESS_THAN_OR_EQUAL; + public static final SqlOperator MINUS = SqlStdOperatorTable.MINUS; + public static final SqlOperator MINUS_DATE = SqlStdOperatorTable.MINUS_DATE; + public static final SqlOperator MULTIPLY = SqlStdOperatorTable.MULTIPLY; + public static final SqlOperator NOT_EQUALS = SqlStdOperatorTable.NOT_EQUALS; + public static final SqlOperator OR = SqlStdOperatorTable.OR; + public static final SqlOperator PLUS = SqlStdOperatorTable.PLUS; + public static final SqlOperator DATETIME_PLUS = SqlStdOperatorTable.DATETIME_PLUS; + + // POSTFIX OPERATORS + public static final SqlOperator DESC = SqlStdOperatorTable.DESC; + public static final SqlOperator NULLS_FIRST = SqlStdOperatorTable.NULLS_FIRST; + public static final SqlOperator NULLS_LAST = SqlStdOperatorTable.NULLS_LAST; + public static final SqlOperator IS_NOT_NULL = SqlStdOperatorTable.IS_NOT_NULL; + public static final SqlOperator IS_NULL = SqlStdOperatorTable.IS_NULL; + public static final SqlOperator IS_NOT_TRUE = SqlStdOperatorTable.IS_NOT_TRUE; + public static final SqlOperator IS_TRUE = SqlStdOperatorTable.IS_TRUE; + public static final SqlOperator IS_NOT_FALSE = SqlStdOperatorTable.IS_NOT_FALSE; + public static final SqlOperator IS_FALSE = SqlStdOperatorTable.IS_FALSE; + public static final SqlOperator IS_NOT_UNKNOWN = SqlStdOperatorTable.IS_NOT_UNKNOWN; + public static final SqlOperator IS_UNKNOWN = SqlStdOperatorTable.IS_UNKNOWN; + + // PREFIX OPERATORS + public static final SqlOperator NOT = SqlStdOperatorTable.NOT; + public static final SqlOperator UNARY_MINUS = SqlStdOperatorTable.UNARY_MINUS; + public static final SqlOperator UNARY_PLUS = SqlStdOperatorTable.UNARY_PLUS; + + // GROUPING FUNCTIONS + public static final SqlFunction GROUP_ID = SqlStdOperatorTable.GROUP_ID; + public static final SqlFunction GROUPING = SqlStdOperatorTable.GROUPING; + public static final SqlFunction GROUPING_ID = SqlStdOperatorTable.GROUPING_ID; + + // AGGREGATE OPERATORS + public static final SqlAggFunction SUM = SqlStdOperatorTable.SUM; + public static final SqlAggFunction SUM0 = SqlStdOperatorTable.SUM0; + public static final SqlAggFunction COUNT = SqlStdOperatorTable.COUNT; + public static final SqlAggFunction APPROX_COUNT_DISTINCT = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; + public static final SqlAggFunction COLLECT = SqlStdOperatorTable.COLLECT; + public static final SqlAggFunction MIN = SqlStdOperatorTable.MIN; + public static final SqlAggFunction MAX = SqlStdOperatorTable.MAX; + public static final SqlAggFunction AVG = SqlStdOperatorTable.AVG; + public static final SqlAggFunction STDDEV = SqlStdOperatorTable.STDDEV; + public static final SqlAggFunction STDDEV_POP = SqlStdOperatorTable.STDDEV_POP; + public static final SqlAggFunction STDDEV_SAMP = SqlStdOperatorTable.STDDEV_SAMP; + public static final SqlAggFunction VARIANCE = SqlStdOperatorTable.VARIANCE; + public static final SqlAggFunction VAR_POP = SqlStdOperatorTable.VAR_POP; + public static final SqlAggFunction VAR_SAMP = SqlStdOperatorTable.VAR_SAMP; + + // ARRAY OPERATORS + public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR; + public static final SqlOperator ELEMENT = SqlStdOperatorTable.ELEMENT; + + // MAP OPERATORS + public static final SqlOperator MAP_VALUE_CONSTRUCTOR = SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR; + + // ARRAY MAP SHARED OPERATORS + public static final SqlOperator ITEM = SqlStdOperatorTable.ITEM; + public static final SqlOperator CARDINALITY = SqlStdOperatorTable.CARDINALITY; + + // SPECIAL OPERATORS + public static final SqlOperator MULTISET_VALUE = SqlStdOperatorTable.MULTISET_VALUE; + public static final SqlOperator ROW = SqlStdOperatorTable.ROW; + public static final SqlOperator OVERLAPS = SqlStdOperatorTable.OVERLAPS; + public static final SqlOperator LITERAL_CHAIN = SqlStdOperatorTable.LITERAL_CHAIN; + public static final SqlOperator BETWEEN = SqlStdOperatorTable.BETWEEN; + public static final SqlOperator SYMMETRIC_BETWEEN = SqlStdOperatorTable.SYMMETRIC_BETWEEN; + public static final SqlOperator NOT_BETWEEN = SqlStdOperatorTable.NOT_BETWEEN; + public static final SqlOperator SYMMETRIC_NOT_BETWEEN = SqlStdOperatorTable.SYMMETRIC_NOT_BETWEEN; + public static final SqlOperator NOT_LIKE = SqlStdOperatorTable.NOT_LIKE; + public static final SqlOperator LIKE = SqlStdOperatorTable.LIKE; + public static final SqlOperator NOT_SIMILAR_TO = SqlStdOperatorTable.NOT_SIMILAR_TO; + public static final SqlOperator SIMILAR_TO = SqlStdOperatorTable.SIMILAR_TO; + public static final SqlOperator CASE = SqlStdOperatorTable.CASE; + public static final SqlOperator REINTERPRET = SqlStdOperatorTable.REINTERPRET; + public static final SqlOperator EXTRACT = SqlStdOperatorTable.EXTRACT; + public static final SqlOperator IN = SqlStdOperatorTable.IN; + public static final SqlOperator NOT_IN = SqlStdOperatorTable.NOT_IN; + + // FUNCTIONS + public static final SqlFunction OVERLAY = SqlStdOperatorTable.OVERLAY; + public static final SqlFunction TRIM = SqlStdOperatorTable.TRIM; + public static final SqlFunction POSITION = SqlStdOperatorTable.POSITION; + public static final SqlFunction CHAR_LENGTH = SqlStdOperatorTable.CHAR_LENGTH; + public static final SqlFunction CHARACTER_LENGTH = SqlStdOperatorTable.CHARACTER_LENGTH; + public static final SqlFunction UPPER = SqlStdOperatorTable.UPPER; + public static final SqlFunction LOWER = SqlStdOperatorTable.LOWER; + public static final SqlFunction INITCAP = SqlStdOperatorTable.INITCAP; + public static final SqlFunction POWER = SqlStdOperatorTable.POWER; + public static final SqlFunction SQRT = SqlStdOperatorTable.SQRT; + public static final SqlFunction MOD = SqlStdOperatorTable.MOD; + public static final SqlFunction LN = SqlStdOperatorTable.LN; + public static final SqlFunction LOG10 = SqlStdOperatorTable.LOG10; + public static final SqlFunction ABS = SqlStdOperatorTable.ABS; + public static final SqlFunction EXP = SqlStdOperatorTable.EXP; + public static final SqlFunction NULLIF = SqlStdOperatorTable.NULLIF; + public static final SqlFunction COALESCE = SqlStdOperatorTable.COALESCE; + public static final SqlFunction FLOOR = SqlStdOperatorTable.FLOOR; + public static final SqlFunction CEIL = SqlStdOperatorTable.CEIL; + public static final SqlFunction LOCALTIME = SqlStdOperatorTable.LOCALTIME; + public static final SqlFunction LOCALTIMESTAMP = SqlStdOperatorTable.LOCALTIMESTAMP; + public static final SqlFunction CURRENT_TIME = SqlStdOperatorTable.CURRENT_TIME; + public static final SqlFunction CURRENT_TIMESTAMP = SqlStdOperatorTable.CURRENT_TIMESTAMP; + public static final SqlFunction CURRENT_DATE = SqlStdOperatorTable.CURRENT_DATE; + public static final SqlFunction CAST = SqlStdOperatorTable.CAST; + public static final SqlFunction QUARTER = SqlStdOperatorTable.QUARTER; + public static final SqlOperator SCALAR_QUERY = SqlStdOperatorTable.SCALAR_QUERY; + public static final SqlOperator EXISTS = SqlStdOperatorTable.EXISTS; + public static final SqlFunction SIN = SqlStdOperatorTable.SIN; + public static final SqlFunction COS = SqlStdOperatorTable.COS; + public static final SqlFunction TAN = SqlStdOperatorTable.TAN; + public static final SqlFunction COT = SqlStdOperatorTable.COT; + public static final SqlFunction ASIN = SqlStdOperatorTable.ASIN; + public static final SqlFunction ACOS = SqlStdOperatorTable.ACOS; + public static final SqlFunction ATAN = SqlStdOperatorTable.ATAN; + public static final SqlFunction ATAN2 = SqlStdOperatorTable.ATAN2; + public static final SqlFunction DEGREES = SqlStdOperatorTable.DEGREES; + public static final SqlFunction RADIANS = SqlStdOperatorTable.RADIANS; + public static final SqlFunction SIGN = SqlStdOperatorTable.SIGN; + public static final SqlFunction PI = SqlStdOperatorTable.PI; + public static final SqlFunction RAND = SqlStdOperatorTable.RAND; + public static final SqlFunction RAND_INTEGER = SqlStdOperatorTable.RAND_INTEGER; + public static final SqlFunction TIMESTAMP_ADD = SqlStdOperatorTable.TIMESTAMP_ADD; + public static final SqlFunction TIMESTAMP_DIFF = SqlStdOperatorTable.TIMESTAMP_DIFF; + + // MATCH_RECOGNIZE + public static final SqlFunction FIRST = SqlStdOperatorTable.FIRST; + public static final SqlFunction LAST = SqlStdOperatorTable.LAST; + public static final SqlFunction PREV = SqlStdOperatorTable.PREV; + public static final SqlFunction NEXT = SqlStdOperatorTable.NEXT; + public static final SqlFunction CLASSIFIER = SqlStdOperatorTable.CLASSIFIER; + public static final SqlOperator FINAL = SqlStdOperatorTable.FINAL; + public static final SqlOperator RUNNING = SqlStdOperatorTable.RUNNING; + + // OVER FUNCTIONS + public static final SqlAggFunction RANK = SqlStdOperatorTable.RANK; + public static final SqlAggFunction DENSE_RANK = SqlStdOperatorTable.DENSE_RANK; + public static final SqlAggFunction ROW_NUMBER = SqlStdOperatorTable.ROW_NUMBER; + public static final SqlAggFunction LEAD = SqlStdOperatorTable.LEAD; + public static final SqlAggFunction LAG = SqlStdOperatorTable.LAG; +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeMaterializeSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeMaterializeSqlFunction.java new file mode 100644 index 0000000000000..17b9ea3356c82 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeMaterializeSqlFunction.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlMonotonicity; + +/** + * Function that materializes a processing time attribute. + * After materialization the result can be used in regular arithmetical calculations. + */ +public class ProctimeMaterializeSqlFunction extends SqlFunction { + + public ProctimeMaterializeSqlFunction() { + super( + "PROCTIME_MATERIALIZE", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.TIMESTAMP), + InferTypes.RETURN_TYPE, + OperandTypes.family(SqlTypeFamily.TIMESTAMP), + SqlFunctionCategory.SYSTEM); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.FUNCTION; + } + + @Override + public SqlMonotonicity getMonotonicity(SqlOperatorBinding call) { + return SqlMonotonicity.INCREASING; + } + + @Override + public boolean isDeterministic() { + return false; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java new file mode 100644 index 0000000000000..45827deafcb9c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.flink.table.calcite.FlinkTypeFactory; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelProtoDataType; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; + +/** + * Function used to access a proctime attribute. + */ +public class ProctimeSqlFunction extends SqlFunction { + public ProctimeSqlFunction() { + super( + "PROCTIME", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(new ProctimeRelProtoDataType()), + null, + OperandTypes.NILADIC, + SqlFunctionCategory.TIMEDATE); + } + + private static class ProctimeRelProtoDataType implements RelProtoDataType { + @Override + public RelDataType apply(RelDataTypeFactory factory) { + return ((FlinkTypeFactory) factory).createRowtimeIndicatorType(); + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/StreamRecordTimestampSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/StreamRecordTimestampSqlFunction.java new file mode 100644 index 0000000000000..9a794257c4710 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/StreamRecordTimestampSqlFunction.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Function to access the timestamp of a StreamRecord. + */ +public class StreamRecordTimestampSqlFunction extends SqlFunction { + + public StreamRecordTimestampSqlFunction() { + super( + "STREAMRECORD_TIMESTAMP", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.BIGINT), + InferTypes.RETURN_TYPE, + OperandTypes.family(SqlTypeFamily.NUMERIC), + SqlFunctionCategory.SYSTEM); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.FUNCTION; + } + + @Override + public boolean isDeterministic() { + return true; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala index 97cf376eed868..4411f6f234741 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala @@ -32,6 +32,7 @@ import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.{RowTypeInfo, _} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkRelBuilder, FlinkTypeFactory, FlinkTypeSystem} +import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable import org.apache.flink.table.plan.cost.FlinkCostFactory import org.apache.flink.table.plan.schema.RelTable import org.apache.flink.types.Row @@ -63,6 +64,7 @@ abstract class TableEnvironment(val config: TableConfig) { .costFactory(new FlinkCostFactory) .typeSystem(new FlinkTypeSystem) .sqlToRelConverterConfig(getSqlToRelConverterConfig) + .operatorTable(FlinkSqlOperatorTable.instance()) // TODO: introduce ExpressionReducer after codegen // set the executor to evaluate constant expressions // .executor(new ExpressionReducer(config)) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLocalRef.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLocalRef.scala new file mode 100644 index 0000000000000..62831dd629802 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLocalRef.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.calcite + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rex.RexLocalRef +import org.apache.flink.table.`type`.InternalType + +/** + * Special reference which represent a local filed, such as aggregate buffers or constants. + * We are stored as class members, so the field can be referenced directly. + * We should use an unique name to locate the field. + * + * See [[org.apache.flink.table.codegen.ExprCodeGenerator.visitLocalRef()]] + */ +case class RexAggLocalVariable( + fieldTerm: String, + nullTerm: String, + dataType: RelDataType, + internalType: InternalType) + extends RexLocalRef(0, dataType) + +/** + * Special reference which represent a distinct key input filed, + * We use the name to locate the distinct key field. + * + * See [[org.apache.flink.table.codegen.ExprCodeGenerator.visitLocalRef()]] + */ +case class RexDistinctKeyVariable( + keyTerm: String, + dataType: RelDataType, + internalType: InternalType) + extends RexLocalRef(0, dataType) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala index 2b683119dced8..0d525c7359fb4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala @@ -18,15 +18,19 @@ package org.apache.flink.table.calcite -import org.apache.flink.table.`type`.{ArrayType, DecimalType, InternalType, InternalTypes, MapType, RowType} +import org.apache.flink.table.`type`._ import org.apache.flink.table.api.{TableException, TableSchema} -import org.apache.flink.table.plan.schema.{ArrayRelDataType, MapRelDataType, RowRelDataType, RowSchema, TimeIndicatorRelDataType} - +import org.apache.flink.table.plan.schema.{GenericRelDataType, _} +import org.apache.calcite.avatica.util.TimeUnit import org.apache.calcite.jdbc.JavaTypeFactoryImpl import org.apache.calcite.rel.`type`._ +import org.apache.calcite.sql.SqlIntervalQualifier import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.`type`.{BasicSqlType, SqlTypeName} +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.util.ConversionUtil +import java.nio.charset.Charset import java.util import scala.collection.JavaConverters._ @@ -64,12 +68,25 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp case InternalTypes.DATE => createSqlType(DATE) case InternalTypes.TIME => createSqlType(TIME) case InternalTypes.TIMESTAMP => createSqlType(TIMESTAMP) + case InternalTypes.PROCTIME_INDICATOR => createProctimeIndicatorType() + case InternalTypes.ROWTIME_INDICATOR => createRowtimeIndicatorType() + + // interval types + case InternalTypes.INTERVAL_MONTHS => + createSqlIntervalType( + new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)) + case InternalTypes.INTERVAL_MILLIS => + createSqlIntervalType( + new SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.SECOND, SqlParserPos.ZERO)) case InternalTypes.BINARY => createSqlType(VARBINARY) case InternalTypes.CHAR => throw new TableException("Character type is not supported.") + case decimal: DecimalType => + createSqlType(DECIMAL, decimal.precision(), decimal.scale()) + case rowType: RowType => new RowRelDataType(rowType, isNullable, this) case arrayType: ArrayType => new ArrayRelDataType(arrayType, @@ -81,6 +98,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp createTypeFromInternalType(mapType.getValueType, isNullable = true), isNullable) + case generic: GenericType[_] => + new GenericRelDataType(generic, isNullable, getTypeSystem) + case _@t => throw new TableException(s"Type is not supported: $t") } @@ -91,6 +111,19 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp createTypeWithNullability(relType, isNullable) } + /** + * Creates a indicator type for processing-time, but with similar properties as SQL timestamp. + */ + def createProctimeIndicatorType(): RelDataType = { + val originalType = createTypeFromInternalType(InternalTypes.TIMESTAMP, isNullable = false) + canonize( + new TimeIndicatorRelDataType( + getTypeSystem, + originalType.asInstanceOf[BasicSqlType], + isEventTime = false) + ) + } + /** * Creates a indicator type for event-time, but with similar properties as SQL timestamp. */ @@ -209,6 +242,28 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } } + + override def createArrayType(elementType: RelDataType, maxCardinality: Long): RelDataType = { + val arrayType = InternalTypes.createArrayType(FlinkTypeFactory.toInternalType(elementType)) + val relType = new ArrayRelDataType( + arrayType, + elementType, + isNullable = false) + canonize(relType) + } + + override def createMapType(keyType: RelDataType, valueType: RelDataType): RelDataType = { + val internalKeyType = FlinkTypeFactory.toInternalType(keyType) + val internalValueType = FlinkTypeFactory.toInternalType(valueType) + val internalMapType = InternalTypes.createMapType(internalKeyType, internalValueType) + val relType = new MapRelDataType( + internalMapType, + keyType, + valueType, + isNullable = false) + canonize(relType) + } + override def createSqlType(typeName: SqlTypeName): RelDataType = { if (typeName == DECIMAL) { // if we got here, the precision and scale are not specified, here we @@ -231,7 +286,22 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } // change nullability - val newType = super.createTypeWithNullability(relDataType, isNullable) + val newType = relDataType match { + case array: ArrayRelDataType => + new ArrayRelDataType(array.arrayType, array.getComponentType, isNullable) + + case map: MapRelDataType => + new MapRelDataType(map.mapType, map.keyType, map.valueType, isNullable) + + case generic: GenericRelDataType => + new GenericRelDataType(generic.genericType, isNullable, typeSystem) + + case timeIndicator: TimeIndicatorRelDataType => + timeIndicator + + case _ => + super.createTypeWithNullability(relDataType, isNullable) + } canonize(newType) } @@ -273,6 +343,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } } + override def getDefaultCharset: Charset = { + Charset.forName(ConversionUtil.NATIVE_UTF16_CHARSET_NAME) + } } object FlinkTypeFactory { @@ -291,22 +364,32 @@ object FlinkTypeFactory { case FLOAT => InternalTypes.FLOAT case DOUBLE => InternalTypes.DOUBLE case VARCHAR | CHAR => InternalTypes.STRING - case DECIMAL => throw new RuntimeException("Not support yet.") + case VARBINARY | BINARY => InternalTypes.BINARY + case DECIMAL => InternalTypes.createDecimalType(relDataType.getPrecision, relDataType.getScale) + + // time indicators + case TIMESTAMP if relDataType.isInstanceOf[TimeIndicatorRelDataType] => + val indicator = relDataType.asInstanceOf[TimeIndicatorRelDataType] + if (indicator.isEventTime) { + InternalTypes.ROWTIME_INDICATOR + } else { + InternalTypes.PROCTIME_INDICATOR + } // temporal types case DATE => InternalTypes.DATE case TIME => InternalTypes.TIME case TIMESTAMP => InternalTypes.TIMESTAMP - - case VARBINARY => InternalTypes.BINARY + case typeName if YEAR_INTERVAL_TYPES.contains(typeName) => InternalTypes.INTERVAL_MONTHS + case typeName if DAY_INTERVAL_TYPES.contains(typeName) => InternalTypes.INTERVAL_MILLIS case NULL => throw new TableException( "Type NULL is not supported. Null values must have a supported type.") // symbol for special flags e.g. TRIM's BOTH, LEADING, TRAILING - // are represented as integer - case SYMBOL => InternalTypes.INT + // are represented as Enum + case SYMBOL => InternalTypes.createGenericType(classOf[Enum[_]]) case ROW if relDataType.isInstanceOf[RowRelDataType] => val compositeRelDataType = relDataType.asInstanceOf[RowRelDataType] diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala index 99d8cab4f39aa..6b6e6d1952001 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala @@ -20,19 +20,18 @@ package org.apache.flink.table.calcite import org.apache.calcite.rel.`type`.RelDataTypeSystemImpl import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.table.`type`.DecimalType /** * Custom type system for Flink. */ class FlinkTypeSystem extends RelDataTypeSystemImpl { - // we cannot use Int.MaxValue because of an overflow in Calcite's type inference logic - // half should be enough for all use cases - override def getMaxNumericScale: Int = Int.MaxValue / 2 + // set the maximum precision of a NUMERIC or DECIMAL type to DecimalType.MAX_PRECISION. + override def getMaxNumericPrecision: Int = DecimalType.MAX_PRECISION - // we cannot use Int.MaxValue because of an overflow in Calcite's type inference logic - // half should be enough for all use cases - override def getMaxNumericPrecision: Int = Int.MaxValue / 2 + // the max scale can't be greater than precision + override def getMaxNumericScale: Int = DecimalType.MAX_PRECISION override def getDefaultPrecision(typeName: SqlTypeName): Int = typeName match { @@ -48,4 +47,8 @@ class FlinkTypeSystem extends RelDataTypeSystemImpl { super.getDefaultPrecision(typeName) } + // when union a number of CHAR types of different lengths, we should cast to a VARCHAR + // this fixes the problem of CASE WHEN with different length string literals but get wrong + // result with additional space suffix + override def shouldConvertRaggedUnionTypesToVarying(): Boolean = true } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala index d07815fad5928..5d4d3f33896ff 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala @@ -18,35 +18,54 @@ package org.apache.flink.table.codegen -import org.apache.flink.api.common.ExecutionConfig -import org.apache.flink.api.common.typeinfo.{AtomicType => AtomicTypeInfo} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.core.memory.MemorySegment import org.apache.flink.table.`type`._ -import org.apache.flink.table.calcite.FlinkPlannerImpl -import org.apache.flink.table.dataformat._ +import org.apache.flink.table.dataformat.{Decimal, _} import org.apache.flink.table.typeutils.TypeCheckUtils import java.lang.reflect.Method import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable - object CodeGenUtils { // ------------------------------- DEFAULT TERMS ------------------------------------------ val DEFAULT_TIMEZONE_TERM = "timeZone" + val DEFAULT_INPUT1_TERM = "in1" + + val DEFAULT_INPUT2_TERM = "in2" + + val DEFAULT_COLLECTOR_TERM = "c" + + val DEFAULT_OUT_RECORD_TERM = "out" + + val DEFAULT_OPERATOR_COLLECTOR_TERM = "output" + + val DEFAULT_OUT_RECORD_WRITER_TERM = "outWriter" + + val DEFAULT_CONTEXT_TERM = "ctx" + // -------------------------- CANONICAL CLASS NAMES --------------------------------------- val BINARY_ROW: String = className[BinaryRow] + val BINARY_ARRAY: String = className[BinaryArray] + val BINARY_GENERIC: String = className[BinaryGeneric[_]] + val BINARY_STRING: String = className[BinaryString] + + val BINARY_MAP: String = className[BinaryMap] + val BASE_ROW: String = className[BaseRow] + val GENERIC_ROW: String = className[GenericRow] + + val DECIMAL: String = className[Decimal] + val SEGMENT: String = className[MemorySegment] // ---------------------------------------------------------------------------------------- @@ -68,15 +87,6 @@ object CodeGenUtils { */ def className[T](implicit m: Manifest[T]): String = m.runtimeClass.getCanonicalName - def needCopyForType(t: InternalType): Boolean = t match { - case InternalTypes.STRING => true - case _: ArrayType => true - case _: MapType => true - case _: RowType => true - case _: GenericType[_] => true - case _ => false - } - // when casting we first need to unbox Primitives, for example, // float a = 1.0f; // byte b = (byte) a; @@ -120,7 +130,6 @@ object CodeGenUtils { case InternalTypes.BINARY => "byte[]" case _: DecimalType => className[Decimal] - // BINARY is also an ArrayType and uses BinaryArray internally too case _: ArrayType => className[BinaryArray] case _: MapType => className[BinaryMap] case _: RowType => className[BaseRow] @@ -195,6 +204,16 @@ object CodeGenUtils { throw new CodeGenException("Boolean expression type expected.") } + def requireTemporal(genExpr: GeneratedExpression): Unit = + if (!TypeCheckUtils.isTemporal(genExpr.resultType)) { + throw new CodeGenException("Temporal expression type expected.") + } + + def requireTimeInterval(genExpr: GeneratedExpression): Unit = + if (!TypeCheckUtils.isTimeInterval(genExpr.resultType)) { + throw new CodeGenException("Interval expression type expected.") + } + def requireArray(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isArray(genExpr.resultType)) { throw new CodeGenException("Array expression type expected.") @@ -210,227 +229,267 @@ object CodeGenUtils { throw new CodeGenException("Integer expression type expected.") } - // --------------------------- Generate Utils --------------------------------------- - - def generateOutputRecordStatement( - t: InternalType, - clazz: Class[_], - outRecordTerm: String, - outRecordWriterTerm: Option[String] = None): String = { - t match { - case rt: RowType if clazz == classOf[BinaryRow] => - val writerTerm = outRecordWriterTerm.getOrElse( - throw new CodeGenException("No writer is specified when writing BinaryRow record.") - ) - val binaryRowWriter = className[BinaryRowWriter] - val typeTerm = clazz.getCanonicalName - s""" - |final $typeTerm $outRecordTerm = new $typeTerm(${rt.getArity}); - |final $binaryRowWriter $writerTerm = new $binaryRowWriter($outRecordTerm); - |""".stripMargin.trim - case rt: RowType if classOf[ObjectArrayRow].isAssignableFrom(clazz) => - val typeTerm = clazz.getCanonicalName - s"final $typeTerm $outRecordTerm = new $typeTerm(${rt.getArity});" - case _: RowType if clazz == classOf[JoinedRow] => - val typeTerm = clazz.getCanonicalName - s"final $typeTerm $outRecordTerm = new $typeTerm();" - case _ => - val typeTerm = boxedTypeTermForType(t) - s"final $typeTerm $outRecordTerm = new $typeTerm();" - } - } + // -------------------------------------------------------------------------------- + // DataFormat Operations + // -------------------------------------------------------------------------------- + + // -------------------------- BaseRow Read Access ------------------------------- def baseRowFieldReadAccess( - ctx: CodeGeneratorContext, pos: Int, rowTerm: String, fieldType: InternalType) : String = - baseRowFieldReadAccess(ctx, pos.toString, rowTerm, fieldType) + ctx: CodeGeneratorContext, + index: Int, + rowTerm: String, + fieldType: InternalType) : String = + baseRowFieldReadAccess(ctx, index.toString, rowTerm, fieldType) def baseRowFieldReadAccess( - ctx: CodeGeneratorContext, pos: String, rowTerm: String, fieldType: InternalType) : String = + ctx: CodeGeneratorContext, + indexTerm: String, + rowTerm: String, + fieldType: InternalType) : String = fieldType match { - case InternalTypes.INT => s"$rowTerm.getInt($pos)" - case InternalTypes.LONG => s"$rowTerm.getLong($pos)" - case InternalTypes.SHORT => s"$rowTerm.getShort($pos)" - case InternalTypes.BYTE => s"$rowTerm.getByte($pos)" - case InternalTypes.FLOAT => s"$rowTerm.getFloat($pos)" - case InternalTypes.DOUBLE => s"$rowTerm.getDouble($pos)" - case InternalTypes.BOOLEAN => s"$rowTerm.getBoolean($pos)" - case InternalTypes.STRING => s"$rowTerm.getString($pos)" - case InternalTypes.BINARY => s"$rowTerm.getBinary($pos)" - case dt: DecimalType => s"$rowTerm.getDecimal($pos, ${dt.precision()}, ${dt.scale()})" - case InternalTypes.CHAR => s"$rowTerm.getChar($pos)" - case _: TimestampType => s"$rowTerm.getLong($pos)" - case _: DateType => s"$rowTerm.getInt($pos)" - case InternalTypes.TIME => s"$rowTerm.getInt($pos)" - case _: ArrayType => s"$rowTerm.getArray($pos)" - case _: MapType => s"$rowTerm.getMap($pos)" - case rt: RowType => s"$rowTerm.getRow($pos, ${rt.getArity})" - case _: GenericType[_] => s"$rowTerm.getGeneric($pos)" + // primitive types + case InternalTypes.BOOLEAN => s"$rowTerm.getBoolean($indexTerm)" + case InternalTypes.BYTE => s"$rowTerm.getByte($indexTerm)" + case InternalTypes.CHAR => s"$rowTerm.getChar($indexTerm)" + case InternalTypes.SHORT => s"$rowTerm.getShort($indexTerm)" + case InternalTypes.INT => s"$rowTerm.getInt($indexTerm)" + case InternalTypes.LONG => s"$rowTerm.getLong($indexTerm)" + case InternalTypes.FLOAT => s"$rowTerm.getFloat($indexTerm)" + case InternalTypes.DOUBLE => s"$rowTerm.getDouble($indexTerm)" + case InternalTypes.STRING => s"$rowTerm.getString($indexTerm)" + case InternalTypes.BINARY => s"$rowTerm.getBinary($indexTerm)" + case dt: DecimalType => s"$rowTerm.getDecimal($indexTerm, ${dt.precision()}, ${dt.scale()})" + + // temporal types + case _: DateType => s"$rowTerm.getInt($indexTerm)" + case InternalTypes.TIME => s"$rowTerm.getInt($indexTerm)" + case _: TimestampType => s"$rowTerm.getLong($indexTerm)" + + // complex types + case _: ArrayType => s"$rowTerm.getArray($indexTerm)" + case _: MapType => s"$rowTerm.getMap($indexTerm)" + case rt: RowType => s"$rowTerm.getRow($indexTerm, ${rt.getArity})" + + case _: GenericType[_] => s"$rowTerm.getGeneric($indexTerm)" } - /** - * Generates code for comparing two field. - */ - def genCompare( - ctx: CodeGeneratorContext, - t: InternalType, - nullsIsLast: Boolean, - c1: String, - c2: String): String = t match { - case InternalTypes.BOOLEAN => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" - case _: PrimitiveType | _: DateType | _: TimeType | _: TimestampType => - s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" - case InternalTypes.BINARY => - val sortUtil = classOf[org.apache.flink.table.runtime.sort.SortUtil].getCanonicalName - s"$sortUtil.compareBinary($c1, $c2)" - case at: ArrayType => - val compareFunc = newName("compareArray") - val compareCode = genArrayCompare( - ctx, - FlinkPlannerImpl.getNullDefaultOrder(true), at, "a", "b") - val funcCode: String = - s""" - public int $compareFunc($BINARY_ARRAY a, $BINARY_ARRAY b) { - $compareCode - return 0; - } - """ - ctx.addReusableMember(funcCode) - s"$compareFunc($c1, $c2)" - case rowType: RowType => - val orders = rowType.getFieldTypes.map(_ => true) - val comparisons = genRowCompare( - ctx, - rowType.getFieldTypes.indices.toArray, - rowType.getFieldTypes, - orders, - FlinkPlannerImpl.getNullDefaultOrders(orders), - "a", - "b") - val compareFunc = newName("compareRow") - val funcCode: String = - s""" - public int $compareFunc($BASE_ROW a, $BASE_ROW b) { - $comparisons - return 0; + // -------------------------- BaseRow Set Field ------------------------------- + + def baseRowSetField( + ctx: CodeGeneratorContext, + rowClass: Class[_ <: BaseRow], + rowTerm: String, + indexTerm: String, + fieldExpr: GeneratedExpression, + binaryRowWriterTerm: Option[String]): String = { + + val fieldType = fieldExpr.resultType + val fieldTerm = fieldExpr.resultTerm + + if (rowClass == classOf[BinaryRow]) { + binaryRowWriterTerm match { + case Some(writer) => + // use writer to set field + val writeField = binaryWriterWriteField(ctx, indexTerm, fieldTerm, writer, fieldType) + if (ctx.nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | ${binaryWriterWriteNull(indexTerm, writer, fieldType)}; + |} else { + | $writeField; + |} + """.stripMargin + } else { + s""" + |${fieldExpr.code} + |$writeField; + """.stripMargin } - """ - ctx.addReusableMember(funcCode) - s"$compareFunc($c1, $c2)" - case gt: GenericType[_] => - val ser = ctx.addReusableObject(gt.getSerializer, "serializer") - val comp = ctx.addReusableObject( - gt.getTypeInfo.asInstanceOf[AtomicTypeInfo[_]].createComparator(true, new ExecutionConfig), - "comparator") - s""" - |$comp.compare( - | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($c1, $ser), - | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($c2, $ser) - |) - """.stripMargin - case other if other.isInstanceOf[AtomicType] => s"$c1.compareTo($c2)" - } - /** - * Generates code for comparing array. - */ - def genArrayCompare( - ctx: CodeGeneratorContext, nullsIsLast: Boolean, t: ArrayType, a: String, b: String) - : String = { - val nullIsLastRet = if (nullsIsLast) 1 else -1 - val elementType = t.getElementType - val fieldA = newName("fieldA") - val isNullA = newName("isNullA") - val lengthA = newName("lengthA") - val fieldB = newName("fieldB") - val isNullB = newName("isNullB") - val lengthB = newName("lengthB") - val minLength = newName("minLength") - val i = newName("i") - val comp = newName("comp") - val typeTerm = primitiveTypeTermForType(elementType) - s""" - int $lengthA = a.numElements(); - int $lengthB = b.numElements(); - int $minLength = ($lengthA > $lengthB) ? $lengthB : $lengthA; - for (int $i = 0; $i < $minLength; $i++) { - boolean $isNullA = a.isNullAt($i); - boolean $isNullB = b.isNullAt($i); - if ($isNullA && $isNullB) { - // Continue to compare the next element - } else if ($isNullA) { - return $nullIsLastRet; - } else if ($isNullB) { - return ${-nullIsLastRet}; + case None => + // directly set field to BinaryRow, this depends on all the fields are fixed length + val writeField = binaryRowFieldSetAccess(indexTerm, rowTerm, fieldType, fieldTerm) + if (ctx.nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | ${binaryRowSetNull(indexTerm, rowTerm, fieldType)}; + |} else { + | $writeField; + |} + """.stripMargin } else { - $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, i, a, elementType)}; - $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, i, b, elementType)}; - int $comp = ${genCompare(ctx, elementType, nullsIsLast, fieldA, fieldB)}; - if ($comp != 0) { - return $comp; - } + s""" + |${fieldExpr.code} + |$writeField; + """.stripMargin } - } - - if ($lengthA < $lengthB) { - return -1; - } else if ($lengthA > $lengthB) { - return 1; - } - """ + } + } else if (rowClass == classOf[GenericRow] || rowClass == classOf[BoxedWrapperRow]) { + val writeField = if (rowClass == classOf[GenericRow]) { + s"$rowTerm.setField($indexTerm, $fieldTerm);" + } else { + boxedWrapperRowFieldSetAccess(rowTerm, indexTerm, fieldTerm, fieldType) + } + if (ctx.nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | $rowTerm.setNullAt($indexTerm); + |} else { + | $writeField; + |} + """.stripMargin + } else { + s""" + |${fieldExpr.code} + |$writeField; + """.stripMargin + } + } else { + throw new UnsupportedOperationException("Not support set field for " + rowClass) + } } - /** - * Generates code for comparing row keys. - */ - def genRowCompare( - ctx: CodeGeneratorContext, - keys: Array[Int], - keyTypes: Array[InternalType], - orders: Array[Boolean], - nullsIsLast: Array[Boolean], - row1: String, - row2: String): String = { + // -------------------------- BinaryRow Set Field ------------------------------- + + def binaryRowSetNull(index: Int, rowTerm: String, t: InternalType): String = + binaryRowSetNull(index.toString, rowTerm, t) + + def binaryRowSetNull(indexTerm: String, rowTerm: String, t: InternalType): String = t match { + case d: DecimalType if !Decimal.isCompact(d.precision()) => + s"$rowTerm.setDecimal($indexTerm, null, ${d.precision()}, ${d.scale()})" + case _ => s"$rowTerm.setNullAt($indexTerm)" + } - val compares = new mutable.ArrayBuffer[String] + def binaryRowFieldSetAccess( + index: Int, + binaryRowTerm: String, + fieldType: InternalType, + fieldValTerm: String): String = + binaryRowFieldSetAccess(index.toString, binaryRowTerm, fieldType, fieldValTerm) + + def binaryRowFieldSetAccess( + index: String, + binaryRowTerm: String, + fieldType: InternalType, + fieldValTerm: String): String = + fieldType match { + case InternalTypes.INT => s"$binaryRowTerm.setInt($index, $fieldValTerm)" + case InternalTypes.LONG => s"$binaryRowTerm.setLong($index, $fieldValTerm)" + case InternalTypes.SHORT => s"$binaryRowTerm.setShort($index, $fieldValTerm)" + case InternalTypes.BYTE => s"$binaryRowTerm.setByte($index, $fieldValTerm)" + case InternalTypes.FLOAT => s"$binaryRowTerm.setFloat($index, $fieldValTerm)" + case InternalTypes.DOUBLE => s"$binaryRowTerm.setDouble($index, $fieldValTerm)" + case InternalTypes.BOOLEAN => s"$binaryRowTerm.setBoolean($index, $fieldValTerm)" + case InternalTypes.CHAR => s"$binaryRowTerm.setChar($index, $fieldValTerm)" + case _: DateType => s"$binaryRowTerm.setInt($index, $fieldValTerm)" + case InternalTypes.TIME => s"$binaryRowTerm.setInt($index, $fieldValTerm)" + case _: TimestampType => s"$binaryRowTerm.setLong($index, $fieldValTerm)" + case d: DecimalType => + s"$binaryRowTerm.setDecimal($index, $fieldValTerm, ${d.precision()}, ${d.scale()})" + case _ => + throw new CodeGenException("Fail to find binary row field setter method of InternalType " + + fieldType + ".") + } - for (i <- keys.indices) { - val index = keys(i) + // -------------------------- BoxedWrapperRow Set Field ------------------------------- - val symbol = if (orders(i)) "" else "-" + def boxedWrapperRowFieldSetAccess( + rowTerm: String, + indexTerm: String, + fieldTerm: String, + fieldType: InternalType): String = + fieldType match { + case InternalTypes.INT => s"$rowTerm.setInt($indexTerm, $fieldTerm)" + case InternalTypes.LONG => s"$rowTerm.setLong($indexTerm, $fieldTerm)" + case InternalTypes.SHORT => s"$rowTerm.setShort($indexTerm, $fieldTerm)" + case InternalTypes.BYTE => s"$rowTerm.setByte($indexTerm, $fieldTerm)" + case InternalTypes.FLOAT => s"$rowTerm.setFloat($indexTerm, $fieldTerm)" + case InternalTypes.DOUBLE => s"$rowTerm.setDouble($indexTerm, $fieldTerm)" + case InternalTypes.BOOLEAN => s"$rowTerm.setBoolean($indexTerm, $fieldTerm)" + case InternalTypes.CHAR => s"$rowTerm.setChar($indexTerm, $fieldTerm)" + case _: DateType => s"$rowTerm.setInt($indexTerm, $fieldTerm)" + case InternalTypes.TIME => s"$rowTerm.setInt($indexTerm, $fieldTerm)" + case _: TimestampType => s"$rowTerm.setLong($indexTerm, $fieldTerm)" + case _ => s"$rowTerm.setNonPrimitiveValue($indexTerm, $fieldTerm)" + } - val nullIsLastRet = if (nullsIsLast(i)) 1 else -1 + // -------------------------- BinaryArray Set Access ------------------------------- + + def binaryArraySetNull( + index: Int, + arrayTerm: String, + elementType: InternalType): String = elementType match { + case InternalTypes.BOOLEAN => s"$arrayTerm.setNullBoolean($index)" + case InternalTypes.BYTE => s"$arrayTerm.setNullByte($index)" + case InternalTypes.CHAR => s"$arrayTerm.setNullChar($index)" + case InternalTypes.SHORT => s"$arrayTerm.setNullShort($index)" + case InternalTypes.INT => s"$arrayTerm.setNullInt($index)" + case InternalTypes.LONG => s"$arrayTerm.setNullLong($index)" + case InternalTypes.FLOAT => s"$arrayTerm.setNullFloat($index)" + case InternalTypes.DOUBLE => s"$arrayTerm.setNullDouble($index)" + case InternalTypes.TIME => s"$arrayTerm.setNullInt($index)" + case _: DateType => s"$arrayTerm.setNullInt($index)" + case _: TimestampType => s"$arrayTerm.setNullLong($index)" + case _ => s"$arrayTerm.setNullLong($index)" + } - val t = keyTypes(i) + // -------------------------- BinaryWriter Write ------------------------------- - val typeTerm = primitiveTypeTermForType(t) - val fieldA = newName("fieldA") - val isNullA = newName("isNullA") - val fieldB = newName("fieldB") - val isNullB = newName("isNullB") - val comp = newName("comp") + def binaryWriterWriteNull(index: Int, writerTerm: String, t: InternalType): String = + binaryWriterWriteNull(index.toString, writerTerm, t) - val code = - s""" - |boolean $isNullA = $row1.isNullAt($index); - |boolean $isNullB = $row2.isNullAt($index); - |if ($isNullA && $isNullB) { - | // Continue to compare the next element - |} else if ($isNullA) { - | return $nullIsLastRet; - |} else if ($isNullB) { - | return ${-nullIsLastRet}; - |} else { - | $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, index, row1, t)}; - | $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, index, row2, t)}; - | int $comp = ${genCompare(ctx, t, nullsIsLast(i), fieldA, fieldB)}; - | if ($comp != 0) { - | return $symbol$comp; - | } - |} - """.stripMargin - compares += code - } - compares.mkString + def binaryWriterWriteNull( + indexTerm: String, + writerTerm: String, + t: InternalType): String = t match { + case d: DecimalType if !Decimal.isCompact(d.precision()) => + s"$writerTerm.writeDecimal($indexTerm, null, ${d.precision()})" + case _ => s"$writerTerm.setNullAt($indexTerm)" } + + def binaryWriterWriteField( + ctx: CodeGeneratorContext, + index: Int, + fieldValTerm: String, + writerTerm: String, + fieldType: InternalType): String = + binaryWriterWriteField(ctx, index.toString, fieldValTerm, writerTerm, fieldType) + + def binaryWriterWriteField( + ctx: CodeGeneratorContext, + indexTerm: String, + fieldValTerm: String, + writerTerm: String, + fieldType: InternalType): String = + fieldType match { + case InternalTypes.INT => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" + case InternalTypes.LONG => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)" + case InternalTypes.SHORT => s"$writerTerm.writeShort($indexTerm, $fieldValTerm)" + case InternalTypes.BYTE => s"$writerTerm.writeByte($indexTerm, $fieldValTerm)" + case InternalTypes.FLOAT => s"$writerTerm.writeFloat($indexTerm, $fieldValTerm)" + case InternalTypes.DOUBLE => s"$writerTerm.writeDouble($indexTerm, $fieldValTerm)" + case InternalTypes.BOOLEAN => s"$writerTerm.writeBoolean($indexTerm, $fieldValTerm)" + case InternalTypes.BINARY => s"$writerTerm.writeBinary($indexTerm, $fieldValTerm)" + case InternalTypes.STRING => s"$writerTerm.writeString($indexTerm, $fieldValTerm)" + case d: DecimalType => + s"$writerTerm.writeDecimal($indexTerm, $fieldValTerm, ${d.precision()})" + case InternalTypes.CHAR => s"$writerTerm.writeChar($indexTerm, $fieldValTerm)" + case _: DateType => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" + case InternalTypes.TIME => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" + case _: TimestampType => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)" + + // complex types + case _: ArrayType => s"$writerTerm.writeArray($indexTerm, $fieldValTerm)" + case _: MapType => s"$writerTerm.writeMap($indexTerm, $fieldValTerm)" + case _: RowType => + val serializerTerm = ctx.addReusableTypeSerializer(fieldType) + s"$writerTerm.writeRow($indexTerm, $fieldValTerm, $serializerTerm)" + + case _: GenericType[_] => s"$writerTerm.writeGeneric($indexTerm, $fieldValTerm)" + } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala index 98f1eb8eef51d..f611aec28647f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala @@ -18,15 +18,20 @@ package org.apache.flink.table.codegen -import org.apache.calcite.avatica.util.DateTimeUtils +import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.common.functions.{Function, RuntimeContext} -import org.apache.flink.table.`type`.{InternalType, InternalTypes, RowType} +import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.table.`type`.{InternalType, InternalTypes, RowType, TypeConverters} import org.apache.flink.table.api.TableConfig import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.GenerateUtils.generateRecordStatement import org.apache.flink.table.dataformat.GenericRow import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction} +import org.apache.flink.table.runtime.util.collections._ import org.apache.flink.util.InstantiationUtil +import org.apache.calcite.avatica.util.DateTimeUtils + import scala.collection.mutable /** @@ -92,12 +97,10 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { // string_constant -> reused_term private val reusableStringConstants: mutable.Map[String, String] = mutable.Map[String, String]() - // map of local variable statements. It will be placed in method if method code not excess - // max code length, otherwise will be placed in member area of the class. The statements - // are maintained for multiple methods, so that it's a map from method_name to variables. - // - // method_name -> local_variable_statements - private val reusableLocalVariableStatements = mutable.Map[String, mutable.LinkedHashSet[String]]() + // map of type serializer that will be added only once + // InternalType -> reused_term + private val reusableTypeSerializers: mutable.Map[InternalType, String] = + mutable.Map[InternalType, String]() /** * The current method name for [[reusableLocalVariableStatements]]. You can start a new @@ -105,6 +108,14 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { */ private var currentMethodNameForLocalVariables = "DEFAULT" + // map of local variable statements. It will be placed in method if method code not excess + // max code length, otherwise will be placed in member area of the class. The statements + // are maintained for multiple methods, so that it's a map from method_name to variables. + // + // method_name -> local_variable_statements + private val reusableLocalVariableStatements = mutable.Map[String, mutable.LinkedHashSet[String]]( + (currentMethodNameForLocalVariables, mutable.LinkedHashSet[String]())) + // --------------------------------------------------------------------------------- // Getter // --------------------------------------------------------------------------------- @@ -112,7 +123,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { def getReusableInputUnboxingExprs(inputTerm: String, index: Int): Option[GeneratedExpression] = reusableInputUnboxingExprs.get((inputTerm, index)) - def getNullCheck: Boolean = tableConfig.getNullCheck + def nullCheck: Boolean = tableConfig.getNullCheck // --------------------------------------------------------------------------------- // Local Variables for Code Split @@ -137,7 +148,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { * @param fieldTypeTerm the field type term * @return a new generated unique field name */ - def newReusableLocalVariable(fieldTypeTerm: String, fieldName: String): String = { + def addReusableLocalVariable(fieldTypeTerm: String, fieldName: String): String = { val fieldTerm = newName(fieldName) reusableLocalVariableStatements .getOrElse(currentMethodNameForLocalVariables, mutable.LinkedHashSet[String]()) @@ -154,7 +165,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { * left is field type term and right is field name * @return the new generated unique field names for each variable pairs */ - def newReusableLocalFields(fieldTypeAndNames: (String, String)*): Seq[String] = { + def addReusableLocalVariables(fieldTypeAndNames: (String, String)*): Seq[String] = { val fieldTerms = newNames(fieldTypeAndNames.map(_._2): _*) fieldTypeAndNames.map(_._1).zip(fieldTerms).foreach { case (fieldTypeTerm, fieldTerm) => reusableLocalVariableStatements @@ -296,6 +307,11 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { reusableMemberStatements.add(memberStatement) } + /** + * Adds a reusable init statement which will be placed in constructor. + */ + def addReusableInitStatement(s: String): Unit = reusableInitStatements.add(s) + /** * Adds a reusable per record statement */ @@ -338,7 +354,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { clazz: Class[_], outRecordTerm: String, outRecordWriterTerm: Option[String] = None): Unit = { - val statement = generateOutputRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm) + val statement = generateRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm) reusableMemberStatements.add(statement) } @@ -352,6 +368,42 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { rowTerm) } + /** + * Adds a reusable internal hash set to the member area of the generated class. + */ + def addReusableHashSet(elements: Seq[GeneratedExpression], elementType: InternalType): String = { + val fieldTerm = newName("set") + + val setTypeTerm = elementType match { + case InternalTypes.BYTE => className[ByteHashSet] + case InternalTypes.SHORT => className[ShortHashSet] + case InternalTypes.INT => className[IntHashSet] + case InternalTypes.LONG => className[LongHashSet] + case InternalTypes.FLOAT => className[FloatHashSet] + case InternalTypes.DOUBLE => className[DoubleHashSet] + case _ => className[ObjectHashSet[_]] + } + + addReusableMember( + s"final $setTypeTerm $fieldTerm = new $setTypeTerm(${elements.size})") + + elements.foreach { element => + val content = + s""" + |${element.code} + |if (${element.nullTerm}) { + | $fieldTerm.addNull(); + |} else { + | $fieldTerm.add(${element.resultTerm}); + |} + |""".stripMargin + reusableInitStatements.add(content) + } + reusableInitStatements.add(s"$fieldTerm.optimize();") + + fieldTerm + } + /** * Adds a reusable timestamp to the beginning of the SAM of the generated class. */ @@ -458,7 +510,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { |""".stripMargin val fieldInit = seedExpr match { - case Some(s) if getNullCheck => + case Some(s) if nullCheck => s""" |${s.code} |if (!${s.nullTerm}) { @@ -550,6 +602,28 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { fieldTerm } + /** + * Adds a reusable [[TypeSerializer]] to the member area of the generated class. + * + * @param t the internal type which used to generate internal type serializer + * @return member variable term + */ + def addReusableTypeSerializer(t: InternalType): String = { + // if type serializer has been used before, we can reuse the code that + // has already been generated + reusableTypeSerializers.get(t) match { + case Some(term) => term + + case None => + val term = newName("typeSerializer") + val ser = TypeConverters.createInternalTypeInfoFromInternalType(t) + .createSerializer(new ExecutionConfig) + addReusableObjectInternal(ser, term, ser.getClass.getCanonicalName) + reusableTypeSerializers(t) = term + term + } + } + /** * Adds a reusable static SLF4J Logger to the member area of the generated class. */ @@ -607,7 +681,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { val field = newName("str") val stmt = s""" - |private final $BINARY_STRING $field = $BINARY_STRING.fromString("$value");" + |private final $BINARY_STRING $field = $BINARY_STRING.fromString("$value"); """.stripMargin reusableMemberStatements.add(stmt) reusableStringConstants(value) = field @@ -686,3 +760,9 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { fieldTerm } } + +object CodeGeneratorContext { + def apply(config: TableConfig): CodeGeneratorContext = { + new CodeGeneratorContext(config) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ExprCodeGenerator.scala new file mode 100644 index 0000000000000..14a326c579139 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ExprCodeGenerator.scala @@ -0,0 +1,721 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen + +import org.apache.calcite.rex._ +import org.apache.calcite.sql.SqlOperator +import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName} +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.table.`type`._ +import org.apache.flink.table.api.TableException +import org.apache.flink.table.calcite.{FlinkTypeFactory, RexAggLocalVariable, RexDistinctKeyVariable} +import org.apache.flink.table.codegen.CodeGenUtils.{requireTemporal, requireTimeInterval, _} +import org.apache.flink.table.codegen.GenerateUtils._ +import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} +import org.apache.flink.table.codegen.calls.ScalarOperatorGens._ +import org.apache.flink.table.dataformat._ +import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable._ +import org.apache.flink.table.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval} +import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TypeCheckUtils} + +import scala.collection.JavaConversions._ + +/** + * This code generator is mainly responsible for generating codes for a given calcite [[RexNode]]. + * It can also generate type conversion codes for the result converter. + */ +class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) + extends RexVisitor[GeneratedExpression] { + + // check if nullCheck is enabled when inputs can be null + if (nullableInput && !ctx.nullCheck) { + throw new CodeGenException("Null check must be enabled if entire rows can be null.") + } + + /** + * term of the [[ProcessFunction]]'s context, can be changed when needed + */ + var contextTerm = "ctx" + + /** + * information of the first input + */ + var input1Type: InternalType = _ + var input1Term: String = _ + var input1FieldMapping: Option[Array[Int]] = None + + /** + * information of the optional second input + */ + var input2Type: Option[InternalType] = None + var input2Term: Option[String] = None + var input2FieldMapping: Option[Array[Int]] = None + + /** + * Bind the input information, should be called before generating expression. + */ + def bindInput( + inputType: InternalType, + inputTerm: String = DEFAULT_INPUT1_TERM, + inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = { + input1Type = inputType + input1Term = inputTerm + input1FieldMapping = inputFieldMapping + this + } + + /** + * In some cases, the expression will have two inputs (e.g. join condition and udtf). We should + * bind second input information before use. + */ + def bindSecondInput( + inputType: InternalType, + inputTerm: String = DEFAULT_INPUT2_TERM, + inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = { + input2Type = Some(inputType) + input2Term = Some(inputTerm) + input2FieldMapping = inputFieldMapping + this + } + + private lazy val input1Mapping: Array[Int] = input1FieldMapping match { + case Some(mapping) => mapping + case _ => fieldIndices(input1Type) + } + + private lazy val input2Mapping: Array[Int] = input2FieldMapping match { + case Some(mapping) => mapping + case _ => input2Type match { + case Some(input) => fieldIndices(input) + case _ => Array[Int]() + } + } + + private def fieldIndices(t: InternalType): Array[Int] = t match { + case rt: RowType => (0 until rt.getArity).toArray + case _ => Array(0) + } + + /** + * Generates an expression from a RexNode. If objects or variables can be reused, they will be + * added to reusable code sections internally. + * + * @param rex Calcite row expression + * @return instance of GeneratedExpression + */ + def generateExpression(rex: RexNode): GeneratedExpression = { + rex.accept(this) + } + + /** + * Generates an expression that converts the first input (and second input) into the given type. + * If two inputs are converted, the second input is appended. If objects or variables can + * be reused, they will be added to reusable code sections internally. The evaluation result + * will be stored in the variable outRecordTerm. + * + * @param returnType conversion target type. Inputs and output must have the same arity. + * @param outRecordTerm the result term + * @param outRecordWriterTerm the result writer term + * @param reusedOutRow If objects or variables can be reused, they will be added to reusable + * code sections internally. + * @return instance of GeneratedExpression + */ + def generateConverterResultExpression( + returnType: RowType, + returnTypeClazz: Class[_ <: BaseRow], + outRecordTerm: String = DEFAULT_OUT_RECORD_TERM, + outRecordWriterTerm: String = DEFAULT_OUT_RECORD_WRITER_TERM, + reusedOutRow: Boolean = true, + fieldCopy: Boolean = false, + rowtimeExpression: Option[RexNode] = None) + : GeneratedExpression = { + val input1AccessExprs = input1Mapping.map { + case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER | + TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER if rowtimeExpression.isDefined => + // generate rowtime attribute from expression + generateExpression(rowtimeExpression.get) + case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER | + TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER => + throw new TableException("Rowtime extraction expression missing. Please report a bug.") + case TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER => + // attribute is proctime indicator. + // we use a null literal and generate a timestamp when we need it. + generateNullLiteral(InternalTypes.PROCTIME_INDICATOR, ctx.nullCheck) + case TimeIndicatorTypeInfo.PROCTIME_BATCH_MARKER => + // attribute is proctime field in a batch query. + // it is initialized with the current time. + generateCurrentTimestamp(ctx) + case idx => + // get type of result field + generateInputAccess( + ctx, + input1Type, + input1Term, + idx, + nullableInput, + fieldCopy) + } + + val input2AccessExprs = input2Type match { + case Some(ti) => + input2Mapping.map(idx => generateInputAccess( + ctx, + ti, + input2Term.get, + idx, + nullableInput, + ctx.nullCheck) + ).toSeq + case None => Seq() // add nothing + } + + generateResultExpression( + input1AccessExprs ++ input2AccessExprs, + returnType, + returnTypeClazz, + outRow = outRecordTerm, + outRowWriter = Some(outRecordWriterTerm), + reusedOutRow = reusedOutRow) + } + + /** + * Generates an expression from a sequence of other expressions. The evaluation result + * may be stored in the variable outRecordTerm. + * + * @param fieldExprs field expressions to be converted + * @param returnType conversion target type. Type must have the same arity than fieldExprs. + * @param outRow the result term + * @param outRowWriter the result writer term for BinaryRow. + * @param reusedOutRow If objects or variables can be reused, they will be added to reusable + * code sections internally. + * @param outRowAlreadyExists Don't need addReusableRecord if out row already exists. + * @return instance of GeneratedExpression + */ + def generateResultExpression( + fieldExprs: Seq[GeneratedExpression], + returnType: RowType, + returnTypeClazz: Class[_ <: BaseRow], + outRow: String = DEFAULT_OUT_RECORD_TERM, + outRowWriter: Option[String] = Some(DEFAULT_OUT_RECORD_WRITER_TERM), + reusedOutRow: Boolean = true, + outRowAlreadyExists: Boolean = false): GeneratedExpression = { + val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap + generateResultExpression(fieldExprs, fieldExprIdxToOutputRowPosMap, returnType, + returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists) + } + + /** + * Generates an expression from a sequence of other expressions. The evaluation result + * may be stored in the variable outRecordTerm. + * + * @param fieldExprs field expressions to be converted + * @param fieldExprIdxToOutputRowPosMap Mapping index of fieldExpr in `fieldExprs` + * to position of output row. + * @param returnType conversion target type. Type must have the same arity than fieldExprs. + * @param outRow the result term + * @param outRowWriter the result writer term for BinaryRow. + * @param reusedOutRow If objects or variables can be reused, they will be added to reusable + * code sections internally. + * @param outRowAlreadyExists Don't need addReusableRecord if out row already exists. + * @return instance of GeneratedExpression + */ + def generateResultExpression( + fieldExprs: Seq[GeneratedExpression], + fieldExprIdxToOutputRowPosMap: Map[Int, Int], + returnType: RowType, + returnTypeClazz: Class[_ <: BaseRow], + outRow: String, + outRowWriter: Option[String], + reusedOutRow: Boolean, + outRowAlreadyExists: Boolean) + : GeneratedExpression = { + // initial type check + if (returnType.getArity != fieldExprs.length) { + throw new CodeGenException( + s"Arity [${returnType.getArity}] of result type [$returnType] does not match " + + s"number [${fieldExprs.length}] of expressions [$fieldExprs].") + } + if (fieldExprIdxToOutputRowPosMap.size != fieldExprs.length) { + throw new CodeGenException( + s"Size [${returnType.getArity}] of fieldExprIdxToOutputRowPosMap does not match " + + s"number [${fieldExprs.length}] of expressions [$fieldExprs].") + } + // type check + fieldExprs.zipWithIndex foreach { + // timestamp type(Include TimeIndicator) and generic type can compatible with each other. + case (fieldExpr, i) + if fieldExpr.resultType.isInstanceOf[GenericType[_]] || + fieldExpr.resultType.isInstanceOf[TimestampType] => + if (returnType.getTypeAt(i).getClass != fieldExpr.resultType.getClass + && !returnType.getTypeAt(i).isInstanceOf[GenericType[_]]) { + throw new CodeGenException( + s"Incompatible types of expression and result type, Expression[$fieldExpr] type is " + + s"[${fieldExpr.resultType}], result type is [${returnType.getTypeAt(i)}]") + } + case (fieldExpr, i) if fieldExpr.resultType != returnType.getTypeAt(i) => + throw new CodeGenException( + s"Incompatible types of expression and result type. Expression[$fieldExpr] type is " + + s"[${fieldExpr.resultType}], result type is [${returnType.getTypeAt(i)}]") + case _ => // ok + } + + val setFieldsCode = fieldExprs.zipWithIndex.map { case (fieldExpr, index) => + val pos = fieldExprIdxToOutputRowPosMap.getOrElse(index, + throw new CodeGenException(s"Illegal field expr index: $index")) + baseRowSetField(ctx, returnTypeClazz, outRow, pos.toString, fieldExpr, outRowWriter) + }.mkString("\n") + + val outRowInitCode = if (!outRowAlreadyExists) { + val initCode = generateRecordStatement(returnType, returnTypeClazz, outRow, outRowWriter) + if (reusedOutRow) { + ctx.addReusableMember(initCode) + NO_CODE + } else { + initCode + } + } else { + NO_CODE + } + + val code = if (returnTypeClazz == classOf[BinaryRow] && outRowWriter.isDefined) { + val writer = outRowWriter.get + val resetWriter = if (ctx.nullCheck) s"$writer.reset();" else s"$writer.resetCursor();" + val completeWriter: String = s"$writer.complete();" + s""" + |$outRowInitCode + |$resetWriter + |$setFieldsCode + |$completeWriter + """.stripMargin + } else { + s""" + |$outRowInitCode + |$setFieldsCode + """.stripMargin + } + GeneratedExpression(outRow, NEVER_NULL, code, returnType) + } + + override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = { + // if inputRef index is within size of input1 we work with input1, input2 otherwise + val input = if (inputRef.getIndex < InternalTypeUtils.getArity(input1Type)) { + (input1Type, input1Term) + } else { + (input2Type.getOrElse(throw new CodeGenException("Invalid input access.")), + input2Term.getOrElse(throw new CodeGenException("Invalid input access."))) + } + + val index = if (input._2 == input1Term) { + inputRef.getIndex + } else { + inputRef.getIndex - InternalTypeUtils.getArity(input1Type) + } + + generateInputAccess(ctx, input._1, input._2, index, nullableInput, ctx.nullCheck) + } + + override def visitTableInputRef(rexTableInputRef: RexTableInputRef): GeneratedExpression = + visitInputRef(rexTableInputRef) + + override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = { + val refExpr = rexFieldAccess.getReferenceExpr.accept(this) + val index = rexFieldAccess.getField.getIndex + val fieldAccessExpr = generateFieldAccess( + ctx, + refExpr.resultType, + refExpr.resultTerm, + index) + + val resultTypeTerm = primitiveTypeTermForType(fieldAccessExpr.resultType) + val defaultValue = primitiveDefaultValue(fieldAccessExpr.resultType) + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "result"), + ("boolean", "isNull")) + + val resultCode = if (ctx.nullCheck) { + s""" + |${refExpr.code} + |if (${refExpr.nullTerm}) { + | $resultTerm = $defaultValue; + | $nullTerm = true; + |} + |else { + | ${fieldAccessExpr.code} + | $resultTerm = ${fieldAccessExpr.resultTerm}; + | $nullTerm = ${fieldAccessExpr.nullTerm}; + |} + |""".stripMargin + } else { + s""" + |${refExpr.code} + |${fieldAccessExpr.code} + |$resultTerm = ${fieldAccessExpr.resultTerm}; + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, resultCode, fieldAccessExpr.resultType) + } + + override def visitLiteral(literal: RexLiteral): GeneratedExpression = { + val resultType = FlinkTypeFactory.toInternalType(literal.getType) + val value = literal.getValue3 + generateLiteral(ctx, resultType, value) + } + + override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = { + GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1Type) + } + + override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = localRef match { + case local: RexAggLocalVariable => + GeneratedExpression(local.fieldTerm, local.nullTerm, NO_CODE, local.internalType) + case value: RexDistinctKeyVariable => + val inputExpr = ctx.getReusableInputUnboxingExprs(input1Term, 0) match { + case Some(expr) => expr + case None => + val pType = primitiveTypeTermForType(value.internalType) + val defaultValue = primitiveDefaultValue(value.internalType) + val resultTerm = newName("field") + val nullTerm = newName("isNull") + val code = + s""" + |$pType $resultTerm = $defaultValue; + |boolean $nullTerm = true; + |if ($input1Term != null) { + | $nullTerm = false; + | $resultTerm = ($pType) $input1Term; + |} + """.stripMargin + val expr = GeneratedExpression(resultTerm, nullTerm, code, value.internalType) + ctx.addReusableInputUnboxingExprs(input1Term, 0, expr) + expr + } + // hide the generated code as it will be executed only once + GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, NO_CODE, inputExpr.resultType) + case _ => throw new CodeGenException("Local variables are not supported yet.") + } + + override def visitRangeRef(rangeRef: RexRangeRef): GeneratedExpression = + throw new CodeGenException("Range references are not supported yet.") + + override def visitDynamicParam(dynamicParam: RexDynamicParam): GeneratedExpression = + throw new CodeGenException("Dynamic parameter references are not supported yet.") + + override def visitCall(call: RexCall): GeneratedExpression = { + + val resultType = FlinkTypeFactory.toInternalType(call.getType) + + // convert operands and help giving untyped NULL literals a type + val operands = call.getOperands.zipWithIndex.map { + + // this helps e.g. for AS(null) + // we might need to extend this logic in case some rules do not create typed NULLs + case (operandLiteral: RexLiteral, 0) if + operandLiteral.getType.getSqlTypeName == SqlTypeName.NULL && + call.getOperator.getReturnTypeInference == ReturnTypes.ARG0 => + generateNullLiteral(resultType, ctx.nullCheck) + + case (o@_, _) => o.accept(this) + } + + generateCallExpression(ctx, call.getOperator, operands, resultType) + } + + override def visitOver(over: RexOver): GeneratedExpression = + throw new CodeGenException("Aggregate functions over windows are not supported yet.") + + override def visitSubQuery(subQuery: RexSubQuery): GeneratedExpression = + throw new CodeGenException("Subqueries are not supported yet.") + + override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = + throw new CodeGenException("Pattern field references are not supported yet.") + + // ---------------------------------------------------------------------------------------- + + private def generateCallExpression( + ctx: CodeGeneratorContext, + operator: SqlOperator, + operands: Seq[GeneratedExpression], + resultType: InternalType): GeneratedExpression = { + operator match { + // arithmetic + case PLUS if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "+", resultType, left, right) + + case PLUS | DATETIME_PLUS if isTemporal(resultType) => + val left = operands.head + val right = operands(1) + requireTemporal(left) + requireTemporal(right) + generateTemporalPlusMinus(ctx, plus = true, resultType, left, right) + + case MINUS if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "-", resultType, left, right) + + case MINUS | MINUS_DATE if isTemporal(resultType) => + val left = operands.head + val right = operands(1) + requireTemporal(left) + requireTemporal(right) + generateTemporalPlusMinus(ctx, plus = false, resultType, left, right) + + case MULTIPLY if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "*", resultType, left, right) + + case MULTIPLY if isTimeInterval(resultType) => + val left = operands.head + val right = operands(1) + requireTimeInterval(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "*", resultType, left, right) + + case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "/", resultType, left, right) + + case MOD if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "%", resultType, left, right) + + case UNARY_MINUS if isNumeric(resultType) => + val operand = operands.head + requireNumeric(operand) + generateUnaryArithmeticOperator(ctx, "-", resultType, operand) + + case UNARY_MINUS if isTimeInterval(resultType) => + val operand = operands.head + requireTimeInterval(operand) + generateUnaryIntervalPlusMinus(ctx, plus = false, operand) + + case UNARY_PLUS if isNumeric(resultType) => + val operand = operands.head + requireNumeric(operand) + generateUnaryArithmeticOperator(ctx, "+", resultType, operand) + + case UNARY_PLUS if isTimeInterval(resultType) => + val operand = operands.head + requireTimeInterval(operand) + generateUnaryIntervalPlusMinus(ctx, plus = true, operand) + + // comparison + case EQUALS => + val left = operands.head + val right = operands(1) + generateEquals(ctx, left, right) + + case NOT_EQUALS => + val left = operands.head + val right = operands(1) + generateNotEquals(ctx, left, right) + + case GREATER_THAN => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, ">", left, right) + + case GREATER_THAN_OR_EQUAL => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, ">=", left, right) + + case LESS_THAN => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, "<", left, right) + + case LESS_THAN_OR_EQUAL => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, "<=", left, right) + + case IS_NULL => + val operand = operands.head + generateIsNull(ctx, operand) + + case IS_NOT_NULL => + val operand = operands.head + generateIsNotNull(ctx, operand) + + // logic + case AND => + operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) => + requireBoolean(left) + requireBoolean(right) + generateAnd(ctx, left, right) + } + + case OR => + operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) => + requireBoolean(left) + requireBoolean(right) + generateOr(ctx, left, right) + } + + case NOT => + val operand = operands.head + requireBoolean(operand) + generateNot(ctx, operand) + + case CASE => + generateIfElse(ctx, operands, resultType) + + case IS_TRUE => + val operand = operands.head + requireBoolean(operand) + generateIsTrue(operand) + + case IS_NOT_TRUE => + val operand = operands.head + requireBoolean(operand) + generateIsNotTrue(operand) + + case IS_FALSE => + val operand = operands.head + requireBoolean(operand) + generateIsFalse(operand) + + case IS_NOT_FALSE => + val operand = operands.head + requireBoolean(operand) + generateIsNotFalse(operand) + + case IN => + val left = operands.head + val right = operands.tail + generateIn(ctx, left, right) + + case NOT_IN => + val left = operands.head + val right = operands.tail + generateNot(ctx, generateIn(ctx, left, right)) + + // casting + case CAST => + val operand = operands.head + generateCast(ctx, operand, resultType) + + // Reinterpret + case REINTERPRET => + val operand = operands.head + generateReinterpret(ctx, operand, resultType) + + // as / renaming + case AS => + operands.head + + // rows + case ROW => + generateRow(ctx, resultType, operands) + + // arrays + case ARRAY_VALUE_CONSTRUCTOR => + generateArray(ctx, resultType, operands) + + // maps + case MAP_VALUE_CONSTRUCTOR => + generateMap(ctx, resultType, operands) + + case ITEM => + operands.head.resultType match { + case t: InternalType if TypeCheckUtils.isArray(t) => + val array = operands.head + val index = operands(1) + requireInteger(index) + generateArrayElementAt(ctx, array, index) + + case t: InternalType if TypeCheckUtils.isMap(t) => + val key = operands(1) + generateMapGet(ctx, operands.head, key) + + case _ => throw new CodeGenException("Expect an array or a map.") + } + + case CARDINALITY => + operands.head.resultType match { + case t: InternalType if TypeCheckUtils.isArray(t) => + val array = operands.head + generateArrayCardinality(ctx, array) + + case t: InternalType if TypeCheckUtils.isMap(t) => + val map = operands.head + generateMapCardinality(ctx, map) + + case _ => throw new CodeGenException("Expect an array or a map.") + } + + case ELEMENT => + val array = operands.head + requireArray(array) + generateArrayElement(ctx, array) + + case DOT => + generateDot(ctx, operands) + + case PROCTIME => + // attribute is proctime indicator. + // We use a null literal and generate a timestamp when we need it. + generateNullLiteral(InternalTypes.PROCTIME_INDICATOR, ctx.nullCheck) + + case PROCTIME_MATERIALIZE => + generateProctimeTimestamp(ctx, contextTerm) + + case STREAMRECORD_TIMESTAMP => + generateRowtimeAccess(ctx, contextTerm) + + // advanced scalar functions + case sqlOperator: SqlOperator => + val explainCall = s"$sqlOperator(${operands.map(_.resultType).mkString(", ")})" + // TODO: support BinaryStringCallGen and FunctionGenerator + throw new CodeGenException(s"Unsupported call: $explainCall \n" + + s"If you think this function should be supported, " + + s"you can create an issue and start a discussion for it.") + + // unknown or invalid + case call@_ => + val explainCall = s"$call(${operands.map(_.resultType).mkString(", ")})" + throw new CodeGenException(s"Unsupported call: $explainCall") + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/FunctionCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/FunctionCodeGenerator.scala new file mode 100644 index 0000000000000..f123b48a29aeb --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/FunctionCodeGenerator.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.codegen + +import org.apache.flink.api.common.functions._ +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.functions.async.{AsyncFunction, RichAsyncFunction} +import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.generated.GeneratedFunction + +/** + * A code generator for generating Flink [[org.apache.flink.api.common.functions.Function]]s. + * Including [[MapFunction]], [[FlatMapFunction]], [[FlatJoinFunction]], [[ProcessFunction]], and + * the corresponding rich version of the functions. + */ +object FunctionCodeGenerator { + + /** + * Generates a [[org.apache.flink.api.common.functions.Function]] that can be passed to Java + * compiler. + * + * @param ctx The context of the code generator + * @param name Class name of the Function. Must not be unique but has to be a valid Java class + * identifier. + * @param clazz Flink Function to be generated. + * @param bodyCode code contents of the SAM (Single Abstract Method). Inputs, collector, or + * output record can be accessed via the given term methods. + * @param returnType expected return type + * @param input1Type the first input type + * @param input1Term the first input term + * @param input2Type the second input type, optional. + * @param input2Term the second input term. + * @param collectorTerm the collector term + * @param contextTerm the context term + * @tparam F Flink Function to be generated. + * @return instance of GeneratedFunction + */ + def generateFunction[F <: Function]( + ctx: CodeGeneratorContext, + name: String, + clazz: Class[F], + bodyCode: String, + returnType: InternalType, + input1Type: InternalType, + input1Term: String = DEFAULT_INPUT1_TERM, + input2Type: Option[InternalType] = None, + input2Term: Option[String] = Some(DEFAULT_INPUT2_TERM), + collectorTerm: String = DEFAULT_COLLECTOR_TERM, + contextTerm: String = DEFAULT_CONTEXT_TERM) + : GeneratedFunction[F] = { + val funcName = newName(name) + val inputTypeTerm = boxedTypeTermForType(input1Type) + + // Janino does not support generics, that's why we need + // manual casting here + val samHeader = + // FlatMapFunction + if (clazz == classOf[FlatMapFunction[_, _]]) { + val baseClass = classOf[RichFlatMapFunction[_, _]] + (baseClass, + s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + // MapFunction + else if (clazz == classOf[MapFunction[_, _]]) { + val baseClass = classOf[RichMapFunction[_, _]] + (baseClass, + "Object map(Object _in1)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + // FlatJoinFunction + else if (clazz == classOf[FlatJoinFunction[_, _, _]]) { + val baseClass = classOf[RichFlatJoinFunction[_, _, _]] + val inputTypeTerm2 = boxedTypeTermForType(input2Type.getOrElse( + throw new CodeGenException("Input 2 for FlatJoinFunction should not be null"))) + (baseClass, + s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;", + s"$inputTypeTerm2 ${input2Term.get} = ($inputTypeTerm2) _in2;")) + } + + // ProcessFunction + else if (clazz == classOf[ProcessFunction[_, _]]) { + val baseClass = classOf[ProcessFunction[_, _]] + (baseClass, + s"void processElement(Object _in1, " + + s"org.apache.flink.streaming.api.functions.ProcessFunction.Context $contextTerm," + + s"org.apache.flink.util.Collector $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + // AsyncFunction + else if (clazz == classOf[AsyncFunction[_, _]]) { + val baseClass = classOf[RichAsyncFunction[_, _]] + (baseClass, + s"void asyncInvoke(Object _in1, " + + s"org.apache.flink.streaming.api.functions.async.ResultFuture $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + else { + // TODO more functions + throw new CodeGenException("Unsupported Function.") + } + + val funcCode = + j""" + public class $funcName + extends ${samHeader._1.getCanonicalName} { + + ${ctx.reuseMemberCode()} + + public $funcName(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + ${ctx.reuseConstructorCode(funcName)} + + @Override + public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception { + ${ctx.reuseOpenCode()} + } + + @Override + public ${samHeader._2} throws Exception { + ${samHeader._3.mkString("\n")} + ${ctx.reusePerRecordCode()} + ${ctx.reuseLocalVariableCode()} + ${ctx.reuseInputUnboxingCode()} + $bodyCode + } + + @Override + public void close() throws Exception { + ${ctx.reuseCloseCode()} + } + } + """.stripMargin + + new GeneratedFunction(funcName, funcCode, ctx.references.toArray) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GenerateUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GenerateUtils.scala new file mode 100644 index 0000000000000..28fe506dddeff --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GenerateUtils.scala @@ -0,0 +1,751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.{AtomicType => AtomicTypeInfo} +import org.apache.flink.table.`type`._ +import org.apache.flink.table.calcite.FlinkPlannerImpl +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE} +import org.apache.flink.table.codegen.calls.CurrentTimePointCallGen +import org.apache.flink.table.dataformat._ +import org.apache.flink.table.typeutils.TypeCheckUtils.{isReference, isTemporal} + +import org.apache.calcite.avatica.util.ByteString +import org.apache.commons.lang3.StringEscapeUtils + +import java.math.{BigDecimal => JBigDecimal} + +import scala.collection.mutable + +/** + * Utilities to generate code for general purpose. + */ +object GenerateUtils { + + // ---------------------------------------------------------------------------------------- + // basic call generate utils + // ---------------------------------------------------------------------------------------- + + /** + * Generates a call with a single result statement. + */ + def generateCallIfArgsNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + operands: Seq[GeneratedExpression], + resultNullable: Boolean = false) + (call: Seq[String] => String): GeneratedExpression = { + generateCallWithStmtIfArgsNotNull(ctx, returnType, operands, resultNullable) { + args => ("", call(args)) + } + } + + /** + * Generates a call with auxiliary statements and result expression. + */ + def generateCallWithStmtIfArgsNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + operands: Seq[GeneratedExpression], + resultNullable: Boolean = false) + (call: Seq[String] => (String, String)): GeneratedExpression = { + val resultTypeTerm = primitiveTypeTermForType(returnType) + val nullTerm = ctx.addReusableLocalVariable("boolean", "isNull") + val resultTerm = ctx.addReusableLocalVariable(resultTypeTerm, "result") + val defaultValue = primitiveDefaultValue(returnType) + val isResultNullable = resultNullable || (isReference(returnType) && !isTemporal(returnType)) + val nullTermCode = if (ctx.nullCheck && isResultNullable) { + s"$nullTerm = ($resultTerm == null);" + } else { + "" + } + + val (stmt, result) = call(operands.map(_.resultTerm)) + + val resultCode = if (ctx.nullCheck && operands.nonEmpty) { + s""" + |${operands.map(_.code).mkString("\n")} + |$nullTerm = ${operands.map(_.nullTerm).mkString(" || ")}; + |$resultTerm = $defaultValue; + |if (!$nullTerm) { + | $stmt + | $resultTerm = $result; + | $nullTermCode + |} + |""".stripMargin + } else if (ctx.nullCheck && operands.isEmpty) { + s""" + |${operands.map(_.code).mkString("\n")} + |$nullTerm = false; + |$stmt + |$resultTerm = $result; + |$nullTermCode + |""".stripMargin + } else { + s""" + |$nullTerm = false; + |${operands.map(_.code).mkString("\n")} + |$stmt + |$resultTerm = $result; + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, resultCode, returnType) + } + + /** + * Generates a string result call with a single result statement. + * This will convert the String result to BinaryString. + */ + def generateStringResultCallIfArgsNotNull( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression]) + (call: Seq[String] => String): GeneratedExpression = { + generateCallIfArgsNotNull(ctx, InternalTypes.STRING, operands) { + args => s"$BINARY_STRING.fromString(${call(args)})" + } + } + + + /** + * Generates a string result call with auxiliary statements and result expression. + * This will convert the String result to BinaryString. + */ + def generateStringResultCallWithStmtIfArgsNotNull( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression]) + (call: Seq[String] => (String, String)): GeneratedExpression = { + generateCallWithStmtIfArgsNotNull(ctx, InternalTypes.STRING, operands) { + args => + val (stmt, result) = call(args) + (stmt, s"$BINARY_STRING.fromString($result)") + } + } + + // --------------------------- General Generate Utils ---------------------------------- + + /** + * Generates a record declaration statement. The record can be any type of BaseRow or + * other types. + * @param t the record type + * @param clazz the specified class of the type (only used when RowType) + * @param recordTerm the record term to be declared + * @param recordWriterTerm the record writer term (only used when BinaryRow type) + * @return the record declaration statement + */ + def generateRecordStatement( + t: InternalType, + clazz: Class[_], + recordTerm: String, + recordWriterTerm: Option[String] = None): String = { + t match { + case rt: RowType if clazz == classOf[BinaryRow] => + val writerTerm = recordWriterTerm.getOrElse( + throw new CodeGenException("No writer is specified when writing BinaryRow record.") + ) + val binaryRowWriter = className[BinaryRowWriter] + val typeTerm = clazz.getCanonicalName + s""" + |final $typeTerm $recordTerm = new $typeTerm(${rt.getArity}); + |final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm); + |""".stripMargin.trim + case rt: RowType if classOf[ObjectArrayRow].isAssignableFrom(clazz) => + val typeTerm = clazz.getCanonicalName + s"final $typeTerm $recordTerm = new $typeTerm(${rt.getArity});" + case _: RowType if clazz == classOf[JoinedRow] => + val typeTerm = clazz.getCanonicalName + s"final $typeTerm $recordTerm = new $typeTerm();" + case _ => + val typeTerm = boxedTypeTermForType(t) + s"final $typeTerm $recordTerm = new $typeTerm();" + } + } + + def generateNullLiteral( + resultType: InternalType, + nullCheck: Boolean): GeneratedExpression = { + val defaultValue = primitiveDefaultValue(resultType) + val resultTypeTerm = primitiveTypeTermForType(resultType) + if (nullCheck) { + GeneratedExpression( + s"(($resultTypeTerm) $defaultValue)", + ALWAYS_NULL, + NO_CODE, + resultType, + literalValue = Some(null)) // the literal is null + } else { + throw new CodeGenException("Null literals are not allowed if nullCheck is disabled.") + } + } + + def generateNonNullLiteral( + literalType: InternalType, + literalCode: String, + literalValue: Any): GeneratedExpression = { + val resultTypeTerm = primitiveTypeTermForType(literalType) + GeneratedExpression( + s"(($resultTypeTerm) $literalCode)", + NEVER_NULL, + NO_CODE, + literalType, + literalValue = Some(literalValue)) + } + + def generateLiteral( + ctx: CodeGeneratorContext, + literalType: InternalType, + literalValue: Any): GeneratedExpression = { + if (literalValue == null) { + return generateNullLiteral(literalType, ctx.nullCheck) + } + // non-null values + literalType match { + + case InternalTypes.BOOLEAN => + generateNonNullLiteral(literalType, literalValue.toString, literalValue) + + case InternalTypes.BYTE => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral(literalType, decimal.byteValue().toString, decimal.byteValue()) + + case InternalTypes.SHORT => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral(literalType, decimal.shortValue().toString, decimal.shortValue()) + + case InternalTypes.INT => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral(literalType, decimal.intValue().toString, decimal.intValue()) + + case InternalTypes.LONG => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral( + literalType, decimal.longValue().toString + "L", decimal.longValue()) + + case InternalTypes.FLOAT => + val floatValue = literalValue.asInstanceOf[JBigDecimal].floatValue() + floatValue match { + case Float.NaN => generateNonNullLiteral( + literalType, "java.lang.Float.NaN", Float.NaN) + case Float.NegativeInfinity => + generateNonNullLiteral( + literalType, + "java.lang.Float.NEGATIVE_INFINITY", + Float.NegativeInfinity) + case Float.PositiveInfinity => generateNonNullLiteral( + literalType, + "java.lang.Float.POSITIVE_INFINITY", + Float.PositiveInfinity) + case _ => generateNonNullLiteral( + literalType, floatValue.toString + "f", floatValue) + } + + case InternalTypes.DOUBLE => + val doubleValue = literalValue.asInstanceOf[JBigDecimal].doubleValue() + doubleValue match { + case Double.NaN => generateNonNullLiteral( + literalType, "java.lang.Double.NaN", Double.NaN) + case Double.NegativeInfinity => + generateNonNullLiteral( + literalType, + "java.lang.Double.NEGATIVE_INFINITY", + Double.NegativeInfinity) + case Double.PositiveInfinity => + generateNonNullLiteral( + literalType, + "java.lang.Double.POSITIVE_INFINITY", + Double.PositiveInfinity) + case _ => generateNonNullLiteral( + literalType, doubleValue.toString + "d", doubleValue) + } + case decimal: DecimalType => + val precision = decimal.precision() + val scale = decimal.scale() + val fieldTerm = newName("decimal") + val decimalClass = className[Decimal] + val fieldDecimal = + s""" + |$decimalClass $fieldTerm = + | $decimalClass.castFrom("${literalValue.toString}", $precision, $scale); + |""".stripMargin + ctx.addReusableMember(fieldDecimal) + val value = Decimal.fromBigDecimal( + literalValue.asInstanceOf[JBigDecimal], precision, scale) + generateNonNullLiteral(literalType, fieldTerm, value) + + case InternalTypes.STRING => + val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString) + val field = ctx.addReusableStringConstants(escapedValue) + generateNonNullLiteral(literalType, field, BinaryString.fromString(escapedValue)) + + case InternalTypes.BINARY => + val bytesVal = literalValue.asInstanceOf[ByteString].getBytes + val fieldTerm = ctx.addReusableObject( + bytesVal, "binary", bytesVal.getClass.getCanonicalName) + generateNonNullLiteral(literalType, fieldTerm, bytesVal) + + case InternalTypes.DATE => + generateNonNullLiteral(literalType, literalValue.toString, literalValue) + + case InternalTypes.TIME => + generateNonNullLiteral(literalType, literalValue.toString, literalValue) + + case InternalTypes.TIMESTAMP => + // Hack + // Currently, in RexLiteral/SqlLiteral(Calcite), TimestampString has no time zone. + // TimeString, DateString TimestampString are treated as UTC time/(unix time) + // when they are converted/formatted/validated + // Here, we adjust millis before Calcite solve TimeZone perfectly + val millis = literalValue.asInstanceOf[Long] + val adjustedValue = millis - ctx.tableConfig.getTimeZone.getOffset(millis) + generateNonNullLiteral(literalType, adjustedValue.toString + "L", adjustedValue) + + case InternalTypes.INTERVAL_MONTHS => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + if (decimal.isValidInt) { + generateNonNullLiteral(literalType, decimal.intValue().toString, decimal.intValue()) + } else { + throw new CodeGenException( + s"Decimal '$decimal' can not be converted to interval of months.") + } + + case InternalTypes.INTERVAL_MILLIS => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + if (decimal.isValidLong) { + generateNonNullLiteral( + literalType, + decimal.longValue().toString + "L", + decimal.longValue()) + } else { + throw new CodeGenException( + s"Decimal '$decimal' can not be converted to interval of milliseconds.") + } + + // Symbol type for special flags e.g. TRIM's BOTH, LEADING, TRAILING + case symbol: GenericType[_] if symbol.getTypeClass.isAssignableFrom(classOf[Enum[_]]) => + generateSymbol(literalValue.asInstanceOf[Enum[_]]) + + case t@_ => + throw new CodeGenException(s"Type not supported: $t") + } + } + + def generateSymbol(enum: Enum[_]): GeneratedExpression = { + GeneratedExpression( + qualifyEnum(enum), + NEVER_NULL, + NO_CODE, + new GenericType(enum.getDeclaringClass), + literalValue = Some(enum)) + } + + /** + * Generates access to a non-null field that does not require unboxing logic. + * + * @param fieldType type of field + * @param fieldTerm expression term of field (already unboxed) + * @return internal unboxed field representation + */ + private[flink] def generateNonNullField( + fieldType: InternalType, + fieldTerm: String) + : GeneratedExpression = { + val resultTypeTerm = primitiveTypeTermForType(fieldType) + GeneratedExpression(s"(($resultTypeTerm) $fieldTerm)", NEVER_NULL, NO_CODE, fieldType) + } + + def generateProctimeTimestamp( + ctx: CodeGeneratorContext, + contextTerm: String): GeneratedExpression = { + val resultTerm = ctx.addReusableLocalVariable("long", "result") + val resultCode = + s""" + |$resultTerm = $contextTerm.timerService().currentProcessingTime(); + |""".stripMargin.trim + // the proctime has been materialized, so it's TIMESTAMP now, not PROCTIME_INDICATOR + GeneratedExpression(resultTerm, NEVER_NULL, resultCode, InternalTypes.TIMESTAMP) + } + + def generateCurrentTimestamp( + ctx: CodeGeneratorContext): GeneratedExpression = { + new CurrentTimePointCallGen(false).generate(ctx, Seq(), InternalTypes.TIMESTAMP) + } + + def generateRowtimeAccess( + ctx: CodeGeneratorContext, + contextTerm: String): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + ("Long", "result"), + ("boolean", "isNull")) + + val accessCode = + s""" + |$resultTerm = $contextTerm.timestamp(); + |if ($resultTerm == null) { + | throw new RuntimeException("Rowtime timestamp is null. Please make sure that a " + + | "proper TimestampAssigner is defined and the stream environment uses the EventTime " + + | "time characteristic."); + |} + |$nullTerm = false; + """.stripMargin.trim + + GeneratedExpression(resultTerm, nullTerm, accessCode, InternalTypes.ROWTIME_INDICATOR) + } + + /** + * Generates access to a field of the input. + * @param ctx code generator context which maintains various code statements. + * @param inputType input type + * @param inputTerm input term + * @param index the field index to access + * @param nullableInput whether the input is nullable + * @param deepCopy whether to copy the accessed field (usually needed when buffered) + */ + def generateInputAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int, + nullableInput: Boolean, + deepCopy: Boolean = false): GeneratedExpression = { + // if input has been used before, we can reuse the code that + // has already been generated + val inputExpr = ctx.getReusableInputUnboxingExprs(inputTerm, index) match { + // input access and unboxing has already been generated + case Some(expr) => expr + + // generate input access and unboxing if necessary + case None => + val expr = if (nullableInput) { + generateNullableInputFieldAccess(ctx, inputType, inputTerm, index, deepCopy) + } else { + generateFieldAccess(ctx, inputType, inputTerm, index, deepCopy) + } + + ctx.addReusableInputUnboxingExprs(inputTerm, index, expr) + expr + } + // hide the generated code as it will be executed only once + GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType) + } + + def generateNullableInputFieldAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int, + deepCopy: Boolean = false): GeneratedExpression = { + + val fieldType = inputType match { + case ct: RowType => ct.getFieldTypes()(index) + case _ => inputType + } + val resultTypeTerm = primitiveTypeTermForType(fieldType) + val defaultValue = primitiveDefaultValue(fieldType) + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "result"), + ("boolean", "isNull")) + + val fieldAccessExpr = generateFieldAccess( + ctx, inputType, inputTerm, index, deepCopy) + + val inputCheckCode = + s""" + |$resultTerm = $defaultValue; + |$nullTerm = true; + |if ($inputTerm != null) { + | ${fieldAccessExpr.code} + | $resultTerm = ${fieldAccessExpr.resultTerm}; + | $nullTerm = ${fieldAccessExpr.nullTerm}; + |} + |""".stripMargin.trim + + GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType) + } + + /** + * Converts the external boxed format to an internal mostly primitive field representation. + * Wrapper types can autoboxed to their corresponding primitive type (Integer -> int). + * + * @param ctx code generator context which maintains various code statements. + * @param fieldType type of field + * @param fieldTerm expression term of field to be unboxed + * @return internal unboxed field representation + */ + def generateInputFieldUnboxing( + ctx: CodeGeneratorContext, + fieldType: InternalType, + fieldTerm: String): GeneratedExpression = { + + val resultTypeTerm = primitiveTypeTermForType(fieldType) + val defaultValue = primitiveDefaultValue(fieldType) + + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "result"), + ("boolean", "isNull")) + + val wrappedCode = if (ctx.nullCheck) { + s""" + |$nullTerm = $fieldTerm == null; + |$resultTerm = $defaultValue; + |if (!$nullTerm) { + | $resultTerm = $fieldTerm; + |} + |""".stripMargin.trim + } else { + s""" + |$resultTerm = $fieldTerm; + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType) + } + + /** + * Generates field access code expression. The different between this method and + * [[generateFieldAccess(ctx, inputType, inputTerm, index)]] is that this method + * accepts an additional `deepCopy` parameter. When deepCopy is set to true, the returned + * result will be copied. + * + * NOTE: Please set `deepCopy` to true when the result will be buffered. + */ + def generateFieldAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int, + deepCopy: Boolean): GeneratedExpression = { + val expr = generateFieldAccess(ctx, inputType, inputTerm, index) + if (deepCopy) { + expr.deepCopy(ctx) + } else { + expr + } + } + + def generateFieldAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int): GeneratedExpression = + inputType match { + case ct: RowType => + val fieldType = ct.getFieldTypes()(index) + val resultTypeTerm = primitiveTypeTermForType(fieldType) + val defaultValue = primitiveDefaultValue(fieldType) + val readCode = baseRowFieldReadAccess(ctx, index.toString, inputTerm, fieldType) + val Seq(fieldTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "field"), + ("boolean", "isNull")) + + val inputCode = if (ctx.nullCheck) { + s""" + |$nullTerm = $inputTerm.isNullAt($index); + |$fieldTerm = $defaultValue; + |if (!$nullTerm) { + | $fieldTerm = $readCode; + |} + """.stripMargin.trim + } else { + s""" + |$nullTerm = false; + |$fieldTerm = $readCode; + """.stripMargin + } + GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType) + + case _ => + val fieldTypeTerm = boxedTypeTermForType(inputType) + val inputCode = s"($fieldTypeTerm) $inputTerm" + generateInputFieldUnboxing(ctx, inputType, inputCode) + } + + + /** + * Generates code for comparing two field. + */ + def generateCompare( + ctx: CodeGeneratorContext, + t: InternalType, + nullsIsLast: Boolean, + leftTerm: String, + rightTerm: String): String = t match { + case InternalTypes.BOOLEAN => s"($leftTerm == $rightTerm ? 0 : ($leftTerm ? 1 : -1))" + case _: PrimitiveType | _: DateType | _: TimeType | _: TimestampType => + s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)" + case InternalTypes.BINARY => + val sortUtil = classOf[org.apache.flink.table.runtime.sort.SortUtil].getCanonicalName + s"$sortUtil.compareBinary($leftTerm, $rightTerm)" + case at: ArrayType => + val compareFunc = newName("compareArray") + val compareCode = generateArrayCompare( + ctx, + FlinkPlannerImpl.getNullDefaultOrder(true), at, "a", "b") + val funcCode: String = + s""" + public int $compareFunc($BINARY_ARRAY a, $BINARY_ARRAY b) { + $compareCode + return 0; + } + """ + ctx.addReusableMember(funcCode) + s"$compareFunc($leftTerm, $rightTerm)" + case rowType: RowType => + val orders = rowType.getFieldTypes.map(_ => true) + val comparisons = generateRowCompare( + ctx, + rowType.getFieldTypes.indices.toArray, + rowType.getFieldTypes, + orders, + FlinkPlannerImpl.getNullDefaultOrders(orders), + "a", + "b") + val compareFunc = newName("compareRow") + val funcCode: String = + s""" + public int $compareFunc($BASE_ROW a, $BASE_ROW b) { + $comparisons + return 0; + } + """ + ctx.addReusableMember(funcCode) + s"$compareFunc($leftTerm, $rightTerm)" + case gt: GenericType[_] => + val ser = ctx.addReusableObject(gt.getSerializer, "serializer") + val comp = ctx.addReusableObject( + gt.getTypeInfo.asInstanceOf[AtomicTypeInfo[_]].createComparator(true, new ExecutionConfig), + "comparator") + s""" + |$comp.compare( + | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($leftTerm, $ser), + | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($rightTerm, $ser) + |) + """.stripMargin + case other if other.isInstanceOf[AtomicType] => s"$leftTerm.compareTo($rightTerm)" + } + + /** + * Generates code for comparing array. + */ + def generateArrayCompare( + ctx: CodeGeneratorContext, + nullsIsLast: Boolean, + arrayType: ArrayType, + leftTerm: String, + rightTerm: String) + : String = { + val nullIsLastRet = if (nullsIsLast) 1 else -1 + val elementType = arrayType.getElementType + val fieldA = newName("fieldA") + val isNullA = newName("isNullA") + val lengthA = newName("lengthA") + val fieldB = newName("fieldB") + val isNullB = newName("isNullB") + val lengthB = newName("lengthB") + val minLength = newName("minLength") + val i = newName("i") + val comp = newName("comp") + val typeTerm = primitiveTypeTermForType(elementType) + s""" + int $lengthA = a.numElements(); + int $lengthB = b.numElements(); + int $minLength = ($lengthA > $lengthB) ? $lengthB : $lengthA; + for (int $i = 0; $i < $minLength; $i++) { + boolean $isNullA = a.isNullAt($i); + boolean $isNullB = b.isNullAt($i); + if ($isNullA && $isNullB) { + // Continue to compare the next element + } else if ($isNullA) { + return $nullIsLastRet; + } else if ($isNullB) { + return ${-nullIsLastRet}; + } else { + $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, i, leftTerm, elementType)}; + $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, i, rightTerm, elementType)}; + int $comp = ${generateCompare(ctx, elementType, nullsIsLast, fieldA, fieldB)}; + if ($comp != 0) { + return $comp; + } + } + } + + if ($lengthA < $lengthB) { + return -1; + } else if ($lengthA > $lengthB) { + return 1; + } + """ + } + + /** + * Generates code for comparing row keys. + */ + def generateRowCompare( + ctx: CodeGeneratorContext, + keys: Array[Int], + keyTypes: Array[InternalType], + orders: Array[Boolean], + nullsIsLast: Array[Boolean], + leftTerm: String, + rightTerm: String): String = { + + val compares = new mutable.ArrayBuffer[String] + + for (i <- keys.indices) { + val index = keys(i) + + val symbol = if (orders(i)) "" else "-" + + val nullIsLastRet = if (nullsIsLast(i)) 1 else -1 + + val t = keyTypes(i) + + val typeTerm = primitiveTypeTermForType(t) + val fieldA = newName("fieldA") + val isNullA = newName("isNullA") + val fieldB = newName("fieldB") + val isNullB = newName("isNullB") + val comp = newName("comp") + + val code = + s""" + |boolean $isNullA = $leftTerm.isNullAt($index); + |boolean $isNullB = $rightTerm.isNullAt($index); + |if ($isNullA && $isNullB) { + | // Continue to compare the next element + |} else if ($isNullA) { + | return $nullIsLastRet; + |} else if ($isNullB) { + | return ${-nullIsLastRet}; + |} else { + | $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, index, leftTerm, t)}; + | $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, index, rightTerm, t)}; + | int $comp = ${generateCompare(ctx, t, nullsIsLast(i), fieldA, fieldB)}; + | if ($comp != 0) { + | return $symbol$comp; + | } + |} + """.stripMargin + compares += code + } + compares.mkString + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala index 35a3999ad9dc5..f8285c4b5abdc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala @@ -19,6 +19,8 @@ package org.apache.flink.table.codegen import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForType, newName} +import org.apache.flink.table.typeutils.TypeCheckUtils /** * Describes a generated expression. @@ -27,16 +29,51 @@ import org.apache.flink.table.`type`.InternalType * @param nullTerm boolean term that indicates if expression is null * @param code code necessary to produce resultTerm and nullTerm * @param resultType type of the resultTerm - * @param literal flag to indicate a constant expression do not reference input and can thus - * be used in the member area (e.g. as constructor parameter of a reusable - * instance) + * @param literalValue None if the expression is not literal. Otherwise it represent the + * original object of the literal. */ case class GeneratedExpression( resultTerm: String, nullTerm: String, code: String, resultType: InternalType, - literal: Boolean = false) + literalValue: Option[Any] = None) { + + /** + * Indicates a constant expression do not reference input and can thus be used + * in the member area (e.g. as constructor parameter of a reusable instance) + * + * @return true if the expression is literal + */ + def literal: Boolean = literalValue.isDefined + + /** + * Deep copy the generated expression. + * + * NOTE: Please use this method when the result will be buffered. + * This method makes sure a new object/data is created when the type is mutable. + */ + def deepCopy(ctx: CodeGeneratorContext): GeneratedExpression = { + // only copy when type is mutable + if (TypeCheckUtils.isMutable(resultType)) { + val newResultTerm = newName("field") + // if the type need copy, it must be a boxed type + val typeTerm = boxedTypeTermForType(resultType) + val serTerm = ctx.addReusableTypeSerializer(resultType) + val newCode = + s""" + |$code + |$typeTerm $newResultTerm = $resultTerm; + |if (!$nullTerm) { + | $newResultTerm = ($typeTerm) ($serTerm.copy($newResultTerm)); + |} + """.stripMargin + GeneratedExpression(newResultTerm, nullTerm, newCode, resultType, literalValue) + } else { + this + } + } +} object GeneratedExpression { val ALWAYS_NULL = "true" diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala index 2e3df9577bf0b..0f6b2d230c09d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala @@ -383,7 +383,7 @@ class SortCodeGenerator( val baseClass = classOf[RecordComparator] val ctx = new CodeGeneratorContext(conf) - val compareCode = CodeGenUtils.genRowCompare( + val compareCode = GenerateUtils.generateRowCompare( ctx, keys, keyTypes, orders, nullsIsLast, "o1", "o2") val code = diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala new file mode 100644 index 0000000000000..ce181ab44b57f --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.calcite.linq4j.tree.Types +import org.apache.flink.table.runtime.functions.DateTimeUtils + +import java.util.TimeZone + +object BuiltInMethods { + + val STRING_TO_TIMESTAMP = Types.lookupMethod( + classOf[DateTimeUtils], + "strToTimestamp", + classOf[String], classOf[TimeZone]) + + val UNIX_TIME_TO_STRING = Types.lookupMethod( + classOf[DateTimeUtils], + "timeToString", + classOf[Int]) + + val TIMESTAMP_TO_STRING = Types.lookupMethod( + classOf[DateTimeUtils], + "timestampToString", + classOf[Long], classOf[Int], classOf[TimeZone]) +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala new file mode 100644 index 0000000000000..a35bc5d798c9c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression} + +/** + * Generator to generate a call expression. It is usually used when the generation + * depends on some other parameters or the generation is too complex to be a util method. + */ +trait CallGenerator { + + def generate( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + returnType: InternalType): GeneratedExpression + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ConstantCallGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ConstantCallGen.scala new file mode 100644 index 0000000000000..d51674a66236a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ConstantCallGen.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.codegen.GenerateUtils.generateNonNullLiteral + +/** + * Generates a function call which returns a constant. + */ +class ConstantCallGen(constantCode: String, constantValue: Any) extends CallGenerator { + + override def generate( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + returnType: InternalType): GeneratedExpression = { + generateNonNullLiteral(returnType, constantCode, constantValue) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala new file mode 100644 index 0000000000000..6b81d9de92401 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`.{InternalType, InternalTypes} +import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.codegen.GenerateUtils.generateNonNullField + +/** + * Generates function call to determine current time point (as date/time/timestamp) in + * local timezone or not. + */ +class CurrentTimePointCallGen(local: Boolean) extends CallGenerator { + + override def generate( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + returnType: InternalType): GeneratedExpression = returnType match { + case InternalTypes.TIME if local => + val time = ctx.addReusableLocalTime() + generateNonNullField(returnType, time) + + case InternalTypes.TIMESTAMP if local => + val timestamp = ctx.addReusableLocalTimestamp() + generateNonNullField(returnType, timestamp) + + case InternalTypes.DATE => + val date = ctx.addReusableDate() + generateNonNullField(returnType, date) + + case InternalTypes.TIME => + val time = ctx.addReusableTime() + generateNonNullField(returnType, time) + + case InternalTypes.TIMESTAMP => + val timestamp = ctx.addReusableTimestamp() + generateNonNullField(returnType, timestamp) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperatorGens.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperatorGens.scala new file mode 100644 index 0000000000000..b9b5ae07591d3 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperatorGens.scala @@ -0,0 +1,2017 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`._ +import org.apache.flink.table.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull, binaryWriterWriteField, binaryWriterWriteNull, _} +import org.apache.flink.table.codegen.GenerateUtils._ +import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE, ALWAYS_NULL} +import org.apache.flink.table.codegen.{CodeGenException, CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.dataformat._ +import org.apache.flink.table.typeutils.TypeCheckUtils._ +import org.apache.flink.table.typeutils.{TypeCheckUtils, TypeCoercion} +import org.apache.flink.util.Preconditions.checkArgument + +import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY +import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange} +import org.apache.calcite.util.BuiltInMethod + +import java.lang.{StringBuilder => JStringBuilder} +import java.nio.charset.StandardCharsets + +/** + * Utilities to generate SQL scalar operators, e.g. arithmetic operator, + * compare operator, equal operator, etc. + */ +object ScalarOperatorGens { + + // ---------------------------------------------------------------------------------------- + // scalar operators generate utils + // ---------------------------------------------------------------------------------------- + + /** + * Generates a binary arithmetic operator, e.g. + - * / % + */ + def generateBinaryArithmeticOperator( + ctx: CodeGeneratorContext, + operator: String, + resultType: InternalType, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + + resultType match { + case dt: DecimalType => + return generateDecimalBinaryArithmeticOperator(ctx, operator, dt, left, right) + case _ => + } + + val leftCasting = operator match { + case "%" => + if (left.resultType == right.resultType) { + numericCasting(left.resultType, resultType) + } else { + val castedType = if (isDecimal(left.resultType)) { + InternalTypes.LONG + } else { + left.resultType + } + numericCasting(left.resultType, castedType) + } + case _ => numericCasting(left.resultType, resultType) + } + + val rightCasting = numericCasting(right.resultType, resultType) + val resultTypeTerm = primitiveTypeTermForType(resultType) + + generateOperatorIfNotNull(ctx, resultType, left, right) { + (leftTerm, rightTerm) => + s"($resultTypeTerm) (${leftCasting(leftTerm)} $operator ${rightCasting(rightTerm)})" + } + } + + /** + * Generates a binary arithmetic operator for Decimal, e.g. + - * / % + */ + private def generateDecimalBinaryArithmeticOperator( + ctx: CodeGeneratorContext, + operator: String, + resultType: DecimalType, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + + // do not cast a decimal operand to resultType, which may change its value. + // use it as is during calculation. + def castToDec(t: InternalType): String => String = t match { + case _: DecimalType => (operandTerm: String) => s"$operandTerm" + case _ => numericCasting(t, resultType) + } + val methods = Map( + "+" -> "add", + "-" -> "subtract", + "*" -> "multiply", + "/" -> "divide", + "%" -> "mod") + + generateOperatorIfNotNull(ctx, resultType, left, right) { + (leftTerm, rightTerm) => { + val method = methods(operator) + val leftCasted = castToDec(left.resultType)(leftTerm) + val rightCasted = castToDec(right.resultType)(rightTerm) + val precision = resultType.precision() + val scale = resultType.scale() + s"$DECIMAL.$method($leftCasted, $rightCasted, $precision, $scale)" + } + } + } + + /** + * Generates an unary arithmetic operator, e.g. -num + */ + def generateUnaryArithmeticOperator( + ctx: CodeGeneratorContext, + operator: String, + resultType: InternalType, + operand: GeneratedExpression) + : GeneratedExpression = { + generateUnaryOperatorIfNotNull(ctx, resultType, operand) { + operandTerm => + if (isDecimal(operand.resultType) && operator == "-") { + s"$operandTerm.negate()" + } else if (isDecimal(operand.resultType) && operator == "+") { + s"$operandTerm" + } else { + s"$operator($operandTerm)" + } + } + } + + + def generateTemporalPlusMinus( + ctx: CodeGeneratorContext, + plus: Boolean, + resultType: InternalType, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + + val op = if (plus) "+" else "-" + + (left.resultType, right.resultType) match { + // arithmetic of time point and time interval + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.INTERVAL_MONTHS) | + (InternalTypes.INTERVAL_MILLIS, InternalTypes.INTERVAL_MILLIS) => + generateBinaryArithmeticOperator(ctx, op, left.resultType, left, right) + + case (InternalTypes.DATE, InternalTypes.INTERVAL_MILLIS) => + generateOperatorIfNotNull(ctx, InternalTypes.DATE, left, right) { + (l, r) => s"$l $op ((int) ($r / ${MILLIS_PER_DAY}L))" + } + + case (InternalTypes.DATE, InternalTypes.INTERVAL_MONTHS) => + generateOperatorIfNotNull(ctx, InternalTypes.DATE, left, right) { + (l, r) => s"${qualifyMethod(BuiltInMethod.ADD_MONTHS.method)}($l, $op($r))" + } + + case (InternalTypes.TIME, InternalTypes.INTERVAL_MILLIS) => + generateOperatorIfNotNull(ctx, InternalTypes.TIME, left, right) { + (l, r) => s"$l $op ((int) ($r))" + } + + case (InternalTypes.TIMESTAMP, InternalTypes.INTERVAL_MILLIS) => + generateOperatorIfNotNull(ctx, InternalTypes.TIMESTAMP, left, right) { + (l, r) => s"$l $op $r" + } + + case (InternalTypes.TIMESTAMP, InternalTypes.INTERVAL_MONTHS) => + generateOperatorIfNotNull(ctx, InternalTypes.TIMESTAMP, left, right) { + (l, r) => s"${qualifyMethod(BuiltInMethod.ADD_MONTHS.method)}($l, $op($r))" + } + + // minus arithmetic of time points (i.e. for TIMESTAMPDIFF) + case (InternalTypes.TIMESTAMP | InternalTypes.TIME | InternalTypes.DATE, + InternalTypes.TIMESTAMP | InternalTypes.TIME | InternalTypes.DATE) if !plus => + resultType match { + case InternalTypes.INTERVAL_MONTHS => + generateOperatorIfNotNull(ctx, resultType, left, right) { + (ll, rr) => (left.resultType, right.resultType) match { + case (InternalTypes.TIMESTAMP, InternalTypes.DATE) => + s"${qualifyMethod(BuiltInMethod.SUBTRACT_MONTHS.method)}" + + s"($ll, $rr * ${MILLIS_PER_DAY}L)" + case (InternalTypes.DATE, InternalTypes.TIMESTAMP) => + s"${qualifyMethod(BuiltInMethod.SUBTRACT_MONTHS.method)}" + + s"($ll * ${MILLIS_PER_DAY}L, $rr)" + case _ => + s"${qualifyMethod(BuiltInMethod.SUBTRACT_MONTHS.method)}($ll, $rr)" + } + } + + case InternalTypes.INTERVAL_MILLIS => + generateOperatorIfNotNull(ctx, resultType, left, right) { + (ll, rr) => (left.resultType, right.resultType) match { + case (InternalTypes.TIMESTAMP, InternalTypes.TIMESTAMP) => + s"$ll $op $rr" + case (InternalTypes.DATE, InternalTypes.DATE) => + s"($ll * ${MILLIS_PER_DAY}L) $op ($rr * ${MILLIS_PER_DAY}L)" + case (InternalTypes.TIMESTAMP, InternalTypes.DATE) => + s"$ll $op ($rr * ${MILLIS_PER_DAY}L)" + case (InternalTypes.DATE, InternalTypes.TIMESTAMP) => + s"($ll * ${MILLIS_PER_DAY}L) $op $rr" + } + } + } + + case _ => + throw new CodeGenException("Unsupported temporal arithmetic.") + } + } + + def generateUnaryIntervalPlusMinus( + ctx: CodeGeneratorContext, + plus: Boolean, + operand: GeneratedExpression) + : GeneratedExpression = { + val operator = if (plus) "+" else "-" + generateUnaryArithmeticOperator(ctx, operator, operand.resultType, operand) + } + + // ---------------------------------------------------------------------------------------- + // scalar expression generate utils + // ---------------------------------------------------------------------------------------- + + /** + * Generates IN expression using a HashSet + */ + def generateIn( + ctx: CodeGeneratorContext, + needle: GeneratedExpression, + haystack: Seq[GeneratedExpression]) + : GeneratedExpression = { + + // add elements to hash set if they are constant + if (haystack.forall(_.literal)) { + + // determine common numeric type + val widerType = TypeCoercion.widerTypeOf( + needle.resultType, + haystack.head.resultType) + + // we need to normalize the values for the hash set + val castNumeric = widerType match { + case Some(t) => (value: GeneratedExpression) => + numericCasting(value.resultType, t)(value.resultTerm) + case None => (value: GeneratedExpression) => value.resultTerm + } + + val resultType = widerType match { + case Some(t) => t + case None => needle.resultType + } + + val elements = haystack.map { element => + element.copy( + castNumeric(element), // cast element to wider type + element.nullTerm, + element.code, + resultType) + } + val setTerm = ctx.addReusableHashSet(elements, resultType) + + val castedNeedle = needle.copy( + castNumeric(needle), // cast needle to wider type + needle.nullTerm, + needle.code, + resultType) + + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultTypeTerm = primitiveTypeTermForType(InternalTypes.BOOLEAN) + val defaultValue = primitiveDefaultValue(InternalTypes.BOOLEAN) + + val operatorCode = if (ctx.nullCheck) { + s""" + |${castedNeedle.code} + |$resultTypeTerm $resultTerm = $defaultValue; + |boolean $nullTerm = true; + |if (!${castedNeedle.nullTerm}) { + | $resultTerm = $setTerm.contains(${castedNeedle.resultTerm}); + | $nullTerm = !$resultTerm && $setTerm.containsNull(); + |} + |""".stripMargin.trim + } + else { + s""" + |${castedNeedle.code} + |$resultTypeTerm $resultTerm = $setTerm.contains(${castedNeedle.resultTerm}); + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, InternalTypes.BOOLEAN) + } else { + // we use a chain of ORs for a set that contains non-constant elements + haystack + .map(generateEquals(ctx, needle, _)) + .reduce((left, right) => + generateOr(ctx, left, right) + ) + } + } + + def generateEquals( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + if (left.resultType == InternalTypes.STRING && right.resultType == InternalTypes.STRING) { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + (leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" + } + } + // numeric types + else if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + generateComparison(ctx, "==", left, right) + } + // temporal types + else if (isTemporal(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "==", left, right) + } + // array types + else if (isArray(left.resultType) && left.resultType == right.resultType) { + generateArrayComparison(ctx, left, right) + } + // map types + else if (isMap(left.resultType) && left.resultType == right.resultType) { + generateMapComparison(ctx, left, right) + } + // comparable types of same type + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "==", left, right) + } + // support date/time/timestamp equalTo string. + // for performance, we cast literal string to literal time. + else if (isTimePoint(left.resultType) && right.resultType == InternalTypes.STRING) { + if (right.literal) { + generateEquals(ctx, left, generateCastStringLiteralToDateTime(ctx, right, left.resultType)) + } else { + generateEquals(ctx, left, generateCast(ctx, right, left.resultType)) + } + } + else if (isTimePoint(right.resultType) && left.resultType == InternalTypes.STRING) { + if (left.literal) { + generateEquals( + ctx, + generateCastStringLiteralToDateTime(ctx, left, right.resultType), + right) + } else { + generateEquals(ctx, generateCast(ctx, left, right.resultType), right) + } + } + // non comparable types + else { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + if (isReference(left)) { + (leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" + } + else if (isReference(right)) { + (leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } + } + } + } + + def generateNotEquals( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + if (left.resultType == InternalTypes.STRING && right.resultType == InternalTypes.STRING) { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + (leftTerm, rightTerm) => s"!$leftTerm.equals($rightTerm)" + } + } + // numeric types + else if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + generateComparison(ctx, "!=", left, right) + } + // temporal types + else if (isTemporal(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "!=", left, right) + } + // array types + else if (isArray(left.resultType) && left.resultType == right.resultType) { + val equalsExpr = generateEquals(ctx, left, right) + GeneratedExpression( + s"(!${equalsExpr.resultTerm})", equalsExpr.nullTerm, equalsExpr.code, InternalTypes.BOOLEAN) + } + // map types + else if (isMap(left.resultType) && left.resultType == right.resultType) { + val equalsExpr = generateEquals(ctx, left, right) + GeneratedExpression( + s"(!${equalsExpr.resultTerm})", equalsExpr.nullTerm, equalsExpr.code, InternalTypes.BOOLEAN) + } + // comparable types + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "!=", left, right) + } + // non-comparable types + else { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + if (isReference(left)) { + (leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))" + } + else if (isReference(right)) { + (leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } + } + } + } + + /** + * Generates comparison code for numeric types and comparable types of same type. + */ + def generateComparison( + ctx: CodeGeneratorContext, + operator: String, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + // either side is decimal + if (isDecimal(left.resultType) || isDecimal(right.resultType)) { + (leftTerm, rightTerm) => { + s"${className[Decimal]}.compare($leftTerm, $rightTerm) $operator 0" + } + } + // both sides are numeric + else if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + } + // both sides are temporal of same type + else if (isTemporal(left.resultType) && left.resultType == right.resultType) { + (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + } + // both sides are boolean + else if (isBoolean(left.resultType) && left.resultType == right.resultType) { + operator match { + case "==" | "!=" => (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + case ">" | "<" | "<=" | ">=" => + (leftTerm, rightTerm) => + s"java.lang.Boolean.compare($leftTerm, $rightTerm) $operator 0" + case _ => throw new CodeGenException(s"Unsupported boolean comparison '$operator'.") + } + } + // both sides are binary type + else if (isBinary(left.resultType) && left.resultType == right.resultType) { + (leftTerm, rightTerm) => + s"java.util.Arrays.equals($leftTerm, $rightTerm)" + } + // both sides are same comparable type + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + (leftTerm, rightTerm) => + s"(($leftTerm == null) ? (($rightTerm == null) ? 0 : -1) : (($rightTerm == null) ? " + + s"1 : ($leftTerm.compareTo($rightTerm)))) $operator 0" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } + } + } + + def generateIsNull( + ctx: CodeGeneratorContext, + operand: GeneratedExpression): GeneratedExpression = { + if (ctx.nullCheck) { + GeneratedExpression(operand.nullTerm, NEVER_NULL, operand.code, InternalTypes.BOOLEAN) + } + else if (!ctx.nullCheck && isReference(operand)) { + val resultTerm = newName("isNull") + val operatorCode = + s""" + |${operand.code} + |boolean $resultTerm = ${operand.resultTerm} == null; + |""".stripMargin + GeneratedExpression(resultTerm, NEVER_NULL, operatorCode, InternalTypes.BOOLEAN) + } + else { + GeneratedExpression("false", NEVER_NULL, operand.code, InternalTypes.BOOLEAN) + } + } + + def generateIsNotNull( + ctx: CodeGeneratorContext, + operand: GeneratedExpression): GeneratedExpression = { + if (ctx.nullCheck) { + val resultTerm = newName("result") + val operatorCode = + s""" + |${operand.code} + |boolean $resultTerm = !${operand.nullTerm}; + |""".stripMargin.trim + GeneratedExpression(resultTerm, NEVER_NULL, operatorCode, InternalTypes.BOOLEAN) + } + else if (!ctx.nullCheck && isReference(operand)) { + val resultTerm = newName("result") + val operatorCode = + s""" + |${operand.code} + |boolean $resultTerm = ${operand.resultTerm} != null; + |""".stripMargin.trim + GeneratedExpression(resultTerm, NEVER_NULL, operatorCode, InternalTypes.BOOLEAN) + } + else { + GeneratedExpression("true", NEVER_NULL, operand.code, InternalTypes.BOOLEAN) + } + } + + def generateAnd( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + + val operatorCode = if (ctx.nullCheck) { + // Three-valued logic: + // no Unknown -> Two-valued logic + // True && Unknown -> Unknown + // False && Unknown -> False + // Unknown && True -> Unknown + // Unknown && False -> False + // Unknown && Unknown -> Unknown + s""" + |${left.code} + | + |boolean $resultTerm = false; + |boolean $nullTerm = false; + |if (!${left.nullTerm} && !${left.resultTerm}) { + | // left expr is false, skip right expr + |} else { + | ${right.code} + | + | if (!${left.nullTerm} && !${right.nullTerm}) { + | $resultTerm = ${left.resultTerm} && ${right.resultTerm}; + | $nullTerm = false; + | } + | else if (!${left.nullTerm} && ${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else if (!${left.nullTerm} && !${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = false; + | $nullTerm = false; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && ${right.resultTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && !${right.resultTerm}) { + | $resultTerm = false; + | $nullTerm = false; + | } + | else { + | $resultTerm = false; + | $nullTerm = true; + | } + |} + """.stripMargin.trim + } + else { + s""" + |${left.code} + |boolean $resultTerm = false; + |if (${left.resultTerm}) { + | ${right.code} + | $resultTerm = ${right.resultTerm}; + |} + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, InternalTypes.BOOLEAN) + } + + def generateOr( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + + val operatorCode = if (ctx.nullCheck) { + // Three-valued logic: + // no Unknown -> Two-valued logic + // True || Unknown -> True + // False || Unknown -> Unknown + // Unknown || True -> True + // Unknown || False -> Unknown + // Unknown || Unknown -> Unknown + s""" + |${left.code} + | + |boolean $resultTerm = true; + |boolean $nullTerm = false; + |if (!${left.nullTerm} && ${left.resultTerm}) { + | // left expr is true, skip right expr + |} else { + | ${right.code} + | + | if (!${left.nullTerm} && !${right.nullTerm}) { + | $resultTerm = ${left.resultTerm} || ${right.resultTerm}; + | $nullTerm = false; + | } + | else if (!${left.nullTerm} && ${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = true; + | $nullTerm = false; + | } + | else if (!${left.nullTerm} && !${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && ${right.resultTerm}) { + | $resultTerm = true; + | $nullTerm = false; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && !${right.resultTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else { + | $resultTerm = false; + | $nullTerm = true; + | } + |} + |""".stripMargin.trim + } + else { + s""" + |${left.code} + |boolean $resultTerm = true; + |if (!${left.resultTerm}) { + | ${right.code} + | $resultTerm = ${right.resultTerm}; + |} + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, InternalTypes.BOOLEAN) + } + + def generateNot( + ctx: CodeGeneratorContext, + operand: GeneratedExpression) + : GeneratedExpression = { + // Three-valued logic: + // no Unknown -> Two-valued logic + // Unknown -> Unknown + generateUnaryOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, operand) { + operandTerm => s"!($operandTerm)" + } + } + + def generateIsTrue(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + operand.resultTerm, // unknown is always false by default + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateIsNotTrue(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + s"(!${operand.resultTerm})", // unknown is always false by default + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateIsFalse(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + s"(!${operand.resultTerm} && !${operand.nullTerm})", + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateIsNotFalse(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + s"(${operand.resultTerm} || ${operand.nullTerm})", + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateReinterpret( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + targetType: InternalType) + : GeneratedExpression = (operand.resultType, targetType) match { + + case (fromTp, toTp) if fromTp == toTp => + operand + + // internal reinterpretation of temporal types + // Date -> Integer + // Time -> Integer + // Timestamp -> Long + // Integer -> Date + // Integer -> Time + // Long -> Timestamp + // Integer -> Interval Months + // Long -> Interval Millis + // Interval Months -> Integer + // Interval Millis -> Long + // Date -> Long + // Time -> Long + // Interval Months -> Long + case (InternalTypes.DATE, InternalTypes.INT) | + (InternalTypes.TIME, InternalTypes.INT) | + (_: TimestampType, InternalTypes.LONG) | + (InternalTypes.INT, InternalTypes.DATE) | + (InternalTypes.INT, InternalTypes.TIME) | + (InternalTypes.LONG, _: TimestampType) | + (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) | + (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) | + (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) | + (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) | + (InternalTypes.DATE, InternalTypes.LONG) | + (InternalTypes.TIME, InternalTypes.LONG) | + (InternalTypes.INTERVAL_MONTHS, InternalTypes.LONG) => + internalExprCasting(operand, targetType) + + case (from, to) => + throw new CodeGenException(s"Unsupported reinterpret from '$from' to '$to'.") + } + + def generateCast( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + targetType: InternalType) + : GeneratedExpression = (operand.resultType, targetType) match { + + // special case: cast from TimeIndicatorTypeInfo to SqlTimeTypeInfo + case (InternalTypes.PROCTIME_INDICATOR, InternalTypes.TIMESTAMP) | + (InternalTypes.ROWTIME_INDICATOR, InternalTypes.TIMESTAMP) | + (InternalTypes.TIMESTAMP, InternalTypes.PROCTIME_INDICATOR) | + (InternalTypes.TIMESTAMP, InternalTypes.ROWTIME_INDICATOR) => + operand.copy(resultType = InternalTypes.TIMESTAMP) // just replace the DataType + + // identity casting + case (fromTp, toTp) if fromTp == toTp => + operand + + // Date/Time/Timestamp -> String + case (left, InternalTypes.STRING) if TypeCheckUtils.isTimePoint(left) => + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + operandTerm => + val zoneTerm = ctx.addReusableTimeZone() + s"${internalToStringCode(left, operandTerm.head, zoneTerm)}" + } + + // Interval Months -> String + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.STRING) => + val method = qualifyMethod(BuiltInMethod.INTERVAL_YEAR_MONTH_TO_STRING.method) + val timeUnitRange = qualifyEnum(TimeUnitRange.YEAR_TO_MONTH) + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s"$method(${terms.head}, $timeUnitRange)" + } + + // Interval Millis -> String + case (InternalTypes.INTERVAL_MILLIS, InternalTypes.STRING) => + val method = qualifyMethod(BuiltInMethod.INTERVAL_DAY_TIME_TO_STRING.method) + val timeUnitRange = qualifyEnum(TimeUnitRange.DAY_TO_SECOND) + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s"$method(${terms.head}, $timeUnitRange, 3)" // milli second precision + } + + // Array -> String + case (at: ArrayType, InternalTypes.STRING) => + generateCastArrayToString(ctx, operand, at) + + // Byte array -> String UTF-8 + case (InternalTypes.BINARY, InternalTypes.STRING) => + val charset = classOf[StandardCharsets].getCanonicalName + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s"(new String(${terms.head}.toByteArray(), $charset.UTF_8))" + } + + + // Map -> String + case (mt: MapType, InternalTypes.STRING) => + generateCastMapToString(ctx, operand, mt) + + // composite type -> String + case (brt: RowType, InternalTypes.STRING) => + generateCastBaseRowToString(ctx, operand, brt) + + case (g: GenericType[_], InternalTypes.STRING) => + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => + val converter = DataFormatConverters.getConverterForTypeInfo(g.getTypeInfo) + val converterTerm = ctx.addReusableObject(converter, "converter") + s""" "" + $converterTerm.toExternal(${terms.head})""" + } + + // * (not Date/Time/Timestamp) -> String + // TODO: GenericType with Date/Time/Timestamp -> String would call toString implicitly + case (_, InternalTypes.STRING) => + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s""" "" + ${terms.head}""" + } + + // * -> Character + case (_, InternalTypes.CHAR) => + throw new CodeGenException("Character type not supported.") + + // String -> Boolean + case (InternalTypes.STRING, InternalTypes.BOOLEAN) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => s"$operandTerm.toBooleanSQL()" + } + + // String -> NUMERIC TYPE (not Character) + case (InternalTypes.STRING, _) + if TypeCheckUtils.isNumeric(targetType) => + targetType match { + case dt: DecimalType => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm.toDecimal(${dt.precision}, ${dt.scale})" + } + case _ => + val methodName = targetType match { + case InternalTypes.BYTE => "toByte" + case InternalTypes.SHORT => "toShort" + case InternalTypes.INT => "toInt" + case InternalTypes.LONG => "toLong" + case InternalTypes.DOUBLE => "toDouble" + case InternalTypes.FLOAT => "toFloat" + case _ => null + } + assert(methodName != null, "Unexpected data type.") + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => s"($operandTerm.trim().$methodName())" + } + } + + // String -> Date + case (InternalTypes.STRING, InternalTypes.DATE) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => + s"${qualifyMethod(BuiltInMethod.STRING_TO_DATE.method)}($operandTerm.toString())" + } + + // String -> Time + case (InternalTypes.STRING, InternalTypes.TIME) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => + s"${qualifyMethod(BuiltInMethod.STRING_TO_TIME.method)}($operandTerm.toString())" + } + + // String -> Timestamp + case (InternalTypes.STRING, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => + val zoneTerm = ctx.addReusableTimeZone() + s"""${qualifyMethod(BuiltInMethods.STRING_TO_TIMESTAMP)}($operandTerm.toString(), + | $zoneTerm)""".stripMargin + } + + // String -> binary + case (InternalTypes.STRING, InternalTypes.BINARY) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm.getBytes()" + } + + // Note: SQL2003 $6.12 - casting is not allowed between boolean and numeric types. + // Calcite does not allow it either. + + // Boolean -> BigDecimal + case (InternalTypes.BOOLEAN, dt: DecimalType) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castFrom($operandTerm, ${dt.precision}, ${dt.scale})" + } + + // Boolean -> NUMERIC TYPE + case (InternalTypes.BOOLEAN, _) if TypeCheckUtils.isNumeric(targetType) => + val targetTypeTerm = primitiveTypeTermForType(targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"($targetTypeTerm) ($operandTerm ? 1 : 0)" + } + + // BigDecimal -> Boolean + case (_: DecimalType, InternalTypes.BOOLEAN) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castToBoolean($operandTerm)" + } + + // BigDecimal -> Timestamp + case (_: DecimalType, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castToTimestamp($operandTerm)" + } + + // NUMERIC TYPE -> Boolean + case (left, InternalTypes.BOOLEAN) if isNumeric(left) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm != 0" + } + + // between NUMERIC TYPE | Decimal + case (left, right) if isNumeric(left) && isNumeric(right) => + val operandCasting = numericCasting(operand.resultType, targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"${operandCasting(operandTerm)}" + } + + // Date -> Timestamp + case (InternalTypes.DATE, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => + s"$operandTerm * ${classOf[DateTimeUtils].getCanonicalName}.MILLIS_PER_DAY" + } + + // Timestamp -> Date + case (InternalTypes.TIMESTAMP, InternalTypes.DATE) => + val targetTypeTerm = primitiveTypeTermForType(targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => + s"($targetTypeTerm) ($operandTerm / " + + s"${classOf[DateTimeUtils].getCanonicalName}.MILLIS_PER_DAY)" + } + + // Time -> Timestamp + case (InternalTypes.TIME, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm" + } + + // Timestamp -> Time + case (InternalTypes.TIMESTAMP, InternalTypes.TIME) => + val targetTypeTerm = primitiveTypeTermForType(targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => + s"($targetTypeTerm) ($operandTerm % " + + s"${classOf[DateTimeUtils].getCanonicalName}.MILLIS_PER_DAY)" + } + + // Timestamp -> Decimal + case (InternalTypes.TIMESTAMP, dt: DecimalType) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castFrom" + + s"(((double) ($operandTerm / 1000.0)), ${dt.precision}, ${dt.scale})" + } + + // Tinyint -> Timestamp + // Smallint -> Timestamp + // Int -> Timestamp + // Bigint -> Timestamp + case (InternalTypes.BYTE, InternalTypes.TIMESTAMP) | + (InternalTypes.SHORT,InternalTypes.TIMESTAMP) | + (InternalTypes.INT, InternalTypes.TIMESTAMP) | + (InternalTypes.LONG, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"(((long) $operandTerm) * 1000)" + } + + // Float -> Timestamp + // Double -> Timestamp + case (InternalTypes.FLOAT, InternalTypes.TIMESTAMP) | + (InternalTypes.DOUBLE, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((long) ($operandTerm * 1000))" + } + + // Timestamp -> Tinyint + case (InternalTypes.TIMESTAMP, InternalTypes.BYTE) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((byte) ($operandTerm / 1000))" + } + + // Timestamp -> Smallint + case (InternalTypes.TIMESTAMP, InternalTypes.SHORT) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((short) ($operandTerm / 1000))" + } + + // Timestamp -> Int + case (InternalTypes.TIMESTAMP, InternalTypes.INT) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((int) ($operandTerm / 1000))" + } + + // Timestamp -> BigInt + case (InternalTypes.TIMESTAMP, InternalTypes.LONG) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((long) ($operandTerm / 1000))" + } + + // Timestamp -> Float + case (InternalTypes.TIMESTAMP, InternalTypes.FLOAT) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((float) ($operandTerm / 1000.0))" + } + + // Timestamp -> Double + case (InternalTypes.TIMESTAMP, InternalTypes.DOUBLE) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((double) ($operandTerm / 1000.0))" + } + + // internal temporal casting + // Date -> Integer + // Time -> Integer + // Integer -> Date + // Integer -> Time + // Integer -> Interval Months + // Long -> Interval Millis + // Interval Months -> Integer + // Interval Millis -> Long + case (InternalTypes.DATE, InternalTypes.INT) | + (InternalTypes.TIME, InternalTypes.INT) | + (InternalTypes.INT, InternalTypes.DATE) | + (InternalTypes.INT, InternalTypes.TIME) | + (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) | + (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) | + (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) | + (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) => + internalExprCasting(operand, targetType) + + // internal reinterpretation of temporal types + // Date, Time, Interval Months -> Long + case (InternalTypes.DATE, InternalTypes.LONG) + | (InternalTypes.TIME, InternalTypes.LONG) + | (InternalTypes.INTERVAL_MONTHS, InternalTypes.LONG) => + internalExprCasting(operand, targetType) + + case (from, to) => + throw new CodeGenException(s"Unsupported cast from '$from' to '$to'.") + } + + def generateIfElse( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + resultType: InternalType, + i: Int = 0) + : GeneratedExpression = { + // else part + if (i == operands.size - 1) { + generateCast(ctx, operands(i), resultType) + } + else { + // check that the condition is boolean + // we do not check for null instead we use the default value + // thus null is false + requireBoolean(operands(i)) + val condition = operands(i) + val trueAction = generateCast(ctx, operands(i + 1), resultType) + val falseAction = generateIfElse(ctx, operands, resultType, i + 2) + + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultTypeTerm = primitiveTypeTermForType(resultType) + + val operatorCode = if (ctx.nullCheck) { + s""" + |${condition.code} + |$resultTypeTerm $resultTerm; + |boolean $nullTerm; + |if (${condition.resultTerm}) { + | ${trueAction.code} + | $resultTerm = ${trueAction.resultTerm}; + | $nullTerm = ${trueAction.nullTerm}; + |} + |else { + | ${falseAction.code} + | $resultTerm = ${falseAction.resultTerm}; + | $nullTerm = ${falseAction.nullTerm}; + |} + |""".stripMargin.trim + } + else { + s""" + |${condition.code} + |$resultTypeTerm $resultTerm; + |if (${condition.resultTerm}) { + | ${trueAction.code} + | $resultTerm = ${trueAction.resultTerm}; + |} + |else { + | ${falseAction.code} + | $resultTerm = ${falseAction.resultTerm}; + |} + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, resultType) + } + } + + def generateDot( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression]): GeneratedExpression = { + + // due to https://issues.apache.org/jira/browse/CALCITE-2162, expression such as + // "array[1].a.b" won't work now. + if (operands.size > 2) { + throw new CodeGenException( + "A DOT operator with more than 2 operands is not supported yet.") + } + + checkArgument(operands(1).literal) + checkArgument(operands(1).resultType == InternalTypes.STRING) + checkArgument(operands.head.resultType.isInstanceOf[RowType]) + + val fieldName = operands(1).literalValue.get.toString + val fieldIdx = operands + .head + .resultType + .asInstanceOf[RowType] + .getFieldIndex(fieldName) + + val access = generateFieldAccess( + ctx, + operands.head.resultType, + operands.head.resultTerm, + fieldIdx) + + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultTypeTerm = primitiveTypeTermForType(access.resultType) + val defaultValue = primitiveDefaultValue(access.resultType) + + val resultCode = if (ctx.nullCheck) { + s""" + |${operands.map(_.code).mkString("\n")} + |$resultTypeTerm $resultTerm; + |boolean $nullTerm; + |if (${operands.map(_.nullTerm).mkString(" || ")}) { + | $resultTerm = $defaultValue; + | $nullTerm = true; + |} + |else { + | ${access.code} + | $resultTerm = ${access.resultTerm}; + | $nullTerm = ${access.nullTerm}; + |} + |""".stripMargin + } else { + s""" + |${operands.map(_.code).mkString("\n")} + |${access.code} + |$resultTypeTerm $resultTerm = ${access.resultTerm}; + |""".stripMargin + } + + + GeneratedExpression( + resultTerm, + nullTerm, + resultCode, + access.resultType + ) + } + + // ---------------------------------------------------------------------------------------- + // value construction and accessing generate utils + // ---------------------------------------------------------------------------------------- + + def generateRow( + ctx: CodeGeneratorContext, + resultType: InternalType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + checkArgument(resultType.isInstanceOf[RowType]) + val rowType = resultType.asInstanceOf[RowType] + val fieldTypes = rowType.getFieldTypes + val isLiteral = elements.forall(e => e.literal) + val isPrimitive = fieldTypes.forall(f => f.isInstanceOf[PrimitiveType]) + + if (isLiteral) { + // generate literal row + generateLiteralRow(ctx, rowType, elements) + } else { + if (isPrimitive) { + // generate primitive row + val mapped = elements.zipWithIndex.map { case (element, idx) => + if (element.literal) { + element + } else { + val tpe = fieldTypes(idx) + val resultTerm = primitiveDefaultValue(tpe) + GeneratedExpression(resultTerm, ALWAYS_NULL, NO_CODE, tpe, Some(null)) + } + } + val row = generateLiteralRow(ctx, rowType, mapped) + val code = elements.zipWithIndex.map { case (element, idx) => + val tpe = fieldTypes(idx) + if (element.literal) { + "" + } else if(ctx.nullCheck) { + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryRowSetNull(idx, row.resultTerm, tpe)}; + |} else { + | ${binaryRowFieldSetAccess(idx, row.resultTerm, tpe, element.resultTerm)}; + |} + """.stripMargin + } else { + s""" + |${element.code} + |${binaryRowFieldSetAccess(idx, row.resultTerm, tpe, element.resultTerm)}; + """.stripMargin + } + }.mkString("\n") + GeneratedExpression(row.resultTerm, NEVER_NULL, code, rowType) + } else { + // generate general row + generateNonLiteralRow(ctx, rowType, elements) + } + } + } + + private def generateLiteralRow( + ctx: CodeGeneratorContext, + rowType: RowType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + checkArgument(elements.forall(e => e.literal)) + val expr = generateNonLiteralRow(ctx, rowType, elements) + ctx.addReusableInitStatement(expr.code) + GeneratedExpression(expr.resultTerm, GeneratedExpression.NEVER_NULL, NO_CODE, rowType) + } + + private def generateNonLiteralRow( + ctx: CodeGeneratorContext, + rowType: RowType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + val rowTerm = newName("row") + val writerTerm = newName("writer") + val writerCls = className[BinaryRowWriter] + + val writeCode = elements.zipWithIndex.map { + case (element, idx) => + val tpe = rowType.getTypeAt(idx) + if (ctx.nullCheck) { + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryWriterWriteNull(idx, writerTerm, tpe)}; + |} else { + | ${binaryWriterWriteField(ctx, idx, element.resultTerm, writerTerm, tpe)}; + |} + """.stripMargin + } else { + s""" + |${element.code} + |${binaryWriterWriteField(ctx, idx, element.resultTerm, writerTerm, tpe)}; + """.stripMargin + } + }.mkString("\n") + + val code = + s""" + |$writerTerm.reset(); + |$writeCode + |$writerTerm.complete(); + """.stripMargin + + ctx.addReusableMember(s"$BINARY_ROW $rowTerm = new $BINARY_ROW(${rowType.getArity});") + ctx.addReusableMember(s"$writerCls $writerTerm = new $writerCls($rowTerm);") + GeneratedExpression(rowTerm, GeneratedExpression.NEVER_NULL, code, rowType) + } + + def generateArray( + ctx: CodeGeneratorContext, + resultType: InternalType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + checkArgument(resultType.isInstanceOf[ArrayType]) + val arrayType = resultType.asInstanceOf[ArrayType] + val elementType = arrayType.getElementType + val isLiteral = elements.forall(e => e.literal) + val isPrimitive = elementType.isInstanceOf[PrimitiveType] + + if (isLiteral) { + // generate literal array + generateLiteralArray(ctx, arrayType, elements) + } else { + if (isPrimitive) { + // generate primitive array + val mapped = elements.map { element => + if (element.literal) { + element + } else { + val resultTerm = primitiveDefaultValue(elementType) + GeneratedExpression(resultTerm, ALWAYS_NULL, NO_CODE, elementType, Some(null)) + } + } + val array = generateLiteralArray(ctx, arrayType, mapped) + val code = generatePrimitiveArrayUpdateCode(ctx, array.resultTerm, elementType, elements) + GeneratedExpression(array.resultTerm, GeneratedExpression.NEVER_NULL, code, arrayType) + } else { + // generate general array + generateNonLiteralArray(ctx, arrayType, elements) + } + } + } + + private def generatePrimitiveArrayUpdateCode( + ctx: CodeGeneratorContext, + arrayTerm: String, + elementType: InternalType, + elements: Seq[GeneratedExpression]): String = { + elements.zipWithIndex.map { case (element, idx) => + if (element.literal) { + "" + } else if (ctx.nullCheck) { + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryArraySetNull(idx, arrayTerm, elementType)}; + |} else { + | ${binaryRowFieldSetAccess( + idx, arrayTerm, elementType, element.resultTerm)}; + |} + """.stripMargin + } else { + s""" + |${element.code} + |${binaryRowFieldSetAccess( + idx, arrayTerm, elementType, element.resultTerm)}; + """.stripMargin + } + }.mkString("\n") + } + + private def generateLiteralArray( + ctx: CodeGeneratorContext, + arrayType: ArrayType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + checkArgument(elements.forall(e => e.literal)) + val expr = generateNonLiteralArray(ctx, arrayType, elements) + ctx.addReusableInitStatement(expr.code) + GeneratedExpression(expr.resultTerm, GeneratedExpression.NEVER_NULL, NO_CODE, arrayType) + } + + private def generateNonLiteralArray( + ctx: CodeGeneratorContext, + arrayType: ArrayType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + val elementType = arrayType.getElementType + val arrayTerm = newName("array") + val writerTerm = newName("writer") + val writerCls = className[BinaryArrayWriter] + val elementSize = BinaryArray.calculateFixLengthPartSize(elementType) + + val writeCode = elements.zipWithIndex.map { + case (element, idx) => + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryArraySetNull(idx, writerTerm, elementType)}; + |} else { + | ${binaryWriterWriteField(ctx, idx, element.resultTerm, writerTerm, elementType)}; + |} + """.stripMargin + }.mkString("\n") + + val code = + s""" + |$writerTerm.reset(); + |$writeCode + |$writerTerm.complete(); + """.stripMargin + + val memberStmt = + s""" + |$BINARY_ARRAY $arrayTerm = new $BINARY_ARRAY(); + |$writerCls $writerTerm = new $writerCls($arrayTerm, ${elements.length}, $elementSize); + """.stripMargin + + ctx.addReusableMember(memberStmt) + GeneratedExpression(arrayTerm, GeneratedExpression.NEVER_NULL, code, arrayType) + } + + def generateArrayElementAt( + ctx: CodeGeneratorContext, + array: GeneratedExpression, + index: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val componentInfo = array.resultType.asInstanceOf[ArrayType].getElementType + val resultTypeTerm = primitiveTypeTermForType(componentInfo) + val defaultTerm = primitiveDefaultValue(componentInfo) + + val idxStr = s"${index.resultTerm} - 1" + val arrayIsNull = s"${array.resultTerm}.isNullAt($idxStr)" + val arrayGet = + baseRowFieldReadAccess(ctx, idxStr, array.resultTerm, componentInfo) + + val arrayAccessCode = + s""" + |${array.code} + |${index.code} + |boolean $nullTerm = ${array.nullTerm} || ${index.nullTerm} || $arrayIsNull; + |$resultTypeTerm $resultTerm = $nullTerm ? $defaultTerm : $arrayGet; + |""".stripMargin + + GeneratedExpression(resultTerm, nullTerm, arrayAccessCode, componentInfo) + } + + def generateArrayElement( + ctx: CodeGeneratorContext, + array: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultType = array.resultType.asInstanceOf[ArrayType].getElementType + val resultTypeTerm = primitiveTypeTermForType(resultType) + val defaultValue = primitiveDefaultValue(resultType) + + val arrayLengthCode = s"${array.nullTerm} ? 0 : ${array.resultTerm}.numElements()" + + val arrayGet = baseRowFieldReadAccess(ctx, 0, array.resultTerm, resultType) + val arrayAccessCode = + s""" + |${array.code} + |boolean $nullTerm; + |$resultTypeTerm $resultTerm; + |switch ($arrayLengthCode) { + | case 0: + | $nullTerm = true; + | $resultTerm = $defaultValue; + | break; + | case 1: + | $nullTerm = ${array.resultTerm}.isNullAt(0); + | $resultTerm = $nullTerm ? $defaultValue : $arrayGet; + | break; + | default: + | throw new RuntimeException("Array has more than one element."); + |} + |""".stripMargin + + GeneratedExpression(resultTerm, nullTerm, arrayAccessCode, resultType) + } + + def generateArrayCardinality( + ctx: CodeGeneratorContext, + array: GeneratedExpression) + : GeneratedExpression = { + generateUnaryOperatorIfNotNull(ctx, InternalTypes.INT, array) { + _ => s"${array.resultTerm}.numElements()" + } + } + + def generateMap( + ctx: CodeGeneratorContext, + resultType: InternalType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + checkArgument(resultType.isInstanceOf[MapType]) + val mapType = resultType.asInstanceOf[MapType] + val mapTerm = newName("map") + + // prepare map key array + val keyElements = elements.grouped(2).map { case Seq(key, _) => key }.toSeq + val keyType = mapType.getKeyType + val keyExpr = generateArray(ctx, InternalTypes.createArrayType(keyType), keyElements) + val isKeyFixLength = keyType.isInstanceOf[PrimitiveType] + + // prepare map value array + val valueElements = elements.grouped(2).map { case Seq(_, value) => value }.toSeq + val valueType = mapType.getValueType + val valueExpr = generateArray(ctx, InternalTypes.createArrayType(valueType), valueElements) + val isValueFixLength = valueType.isInstanceOf[PrimitiveType] + + // construct binary map + ctx.addReusableMember(s"$BINARY_MAP $mapTerm = null;") + + val code = if (isKeyFixLength && isValueFixLength) { + // the key and value are fixed length, initialize and reuse the map in constructor + val init = s"$mapTerm = $BINARY_MAP.valueOf(${keyExpr.resultTerm}, ${valueExpr.resultTerm});" + ctx.addReusableInitStatement(init) + // there are some non-literal primitive fields need to update + val keyArrayTerm = newName("keyArray") + val valueArrayTerm = newName("valueArray") + val keyUpdate = generatePrimitiveArrayUpdateCode( + ctx, keyArrayTerm, keyType, keyElements) + val valueUpdate = generatePrimitiveArrayUpdateCode( + ctx, valueArrayTerm, valueType, valueElements) + s""" + |$BINARY_ARRAY $keyArrayTerm = $mapTerm.keyArray(); + |$keyUpdate + |$BINARY_ARRAY $valueArrayTerm = $mapTerm.valueArray(); + |$valueUpdate + """.stripMargin + } else { + // the key or value is not fixed length, re-create the map on every update + s""" + |${keyExpr.code} + |${valueExpr.code} + |$mapTerm = $BINARY_MAP.valueOf(${keyExpr.resultTerm}, ${valueExpr.resultTerm}); + """.stripMargin + } + GeneratedExpression(mapTerm, NEVER_NULL, code, resultType) + } + + def generateMapGet( + ctx: CodeGeneratorContext, + map: GeneratedExpression, + key: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val tmpKey = newName("key") + val length = newName("length") + val keys = newName("keys") + val values = newName("values") + val index = newName("index") + val found = newName("found") + + val mapType = map.resultType.asInstanceOf[MapType] + val keyType = mapType.getKeyType + val valueType = mapType.getValueType + + // use primitive for key as key is not null + val keyTypeTerm = primitiveTypeTermForType(keyType) + val valueTypeTerm = primitiveTypeTermForType(valueType) + val valueDefault = primitiveDefaultValue(valueType) + + val mapTerm = map.resultTerm + + val equal = generateEquals(ctx, key, GeneratedExpression(tmpKey, NEVER_NULL, NO_CODE, keyType)) + val code = + s""" + |final int $length = $mapTerm.numElements(); + |final $BINARY_ARRAY $keys = $mapTerm.keyArray(); + |final $BINARY_ARRAY $values = $mapTerm.valueArray(); + | + |int $index = 0; + |boolean $found = false; + |if (${key.nullTerm}) { + | while ($index < $length && !$found) { + | if ($keys.isNullAt($index)) { + | $found = true; + | } else { + | $index++; + | } + | } + |} else { + | while ($index < $length && !$found) { + | final $keyTypeTerm $tmpKey = ${baseRowFieldReadAccess(ctx, index, keys, keyType)}; + | ${equal.code} + | if (${equal.resultTerm}) { + | $found = true; + | } else { + | $index++; + | } + | } + |} + | + |if (!$found || $values.isNullAt($index)) { + | $nullTerm = true; + |} else { + | $resultTerm = ${baseRowFieldReadAccess(ctx, index, values, valueType)}; + |} + """.stripMargin + + val accessCode = + s""" + |${map.code} + |${key.code} + |boolean $nullTerm = (${map.nullTerm} || ${key.nullTerm}); + |$valueTypeTerm $resultTerm = $valueDefault; + |if (!$nullTerm) { + | $code + |} + """.stripMargin + + GeneratedExpression(resultTerm, nullTerm, accessCode, valueType) + } + + def generateMapCardinality( + ctx: CodeGeneratorContext, + map: GeneratedExpression): GeneratedExpression = { + generateUnaryOperatorIfNotNull(ctx, InternalTypes.INT, map) { + _ => s"${map.resultTerm}.numElements()" + } + } + + // ---------------------------------------------------------------------------------------- + // private generate utils + // ---------------------------------------------------------------------------------------- + + private def generateCastStringLiteralToDateTime( + ctx: CodeGeneratorContext, + stringLiteral: GeneratedExpression, + expectType: InternalType): GeneratedExpression = { + checkArgument(stringLiteral.literal) + val rightTerm = stringLiteral.resultTerm + val typeTerm = primitiveTypeTermForType(expectType) + val defaultTerm = primitiveDefaultValue(expectType) + val term = newName("stringToTime") + val zoneTerm = ctx.addReusableTimeZone() + val code = stringToInternalCode(expectType, rightTerm, zoneTerm) + val stmt = s"$typeTerm $term = ${stringLiteral.nullTerm} ? $defaultTerm : $code;" + ctx.addReusableMember(stmt) + stringLiteral.copy(resultType = expectType, resultTerm = term) + } + + private def generateCastArrayToString( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + at: ArrayType): GeneratedExpression = + generateStringResultCallWithStmtIfArgsNotNull(ctx, Seq(operand)) { + terms => + val builderCls = classOf[JStringBuilder].getCanonicalName + val builderTerm = newName("builder") + ctx.addReusableMember(s"""$builderCls $builderTerm = new $builderCls();""") + + val arrayTerm = terms.head + + val indexTerm = newName("i") + val numTerm = newName("num") + + val elementType = at.getElementType + val elementCls = primitiveTypeTermForType(elementType) + val elementTerm = newName("element") + val elementNullTerm = newName("isNull") + val elementCode = + s""" + |$elementCls $elementTerm = ${primitiveDefaultValue(elementType)}; + |boolean $elementNullTerm = $arrayTerm.isNullAt($indexTerm); + |if (!$elementNullTerm) { + | $elementTerm = ($elementCls) ${ + baseRowFieldReadAccess(ctx, indexTerm, arrayTerm, elementType)}; + |} + """.stripMargin + val elementExpr = GeneratedExpression( + elementTerm, elementNullTerm, elementCode, elementType) + val castExpr = generateCast(ctx, elementExpr, InternalTypes.STRING) + + val stmt = + s""" + |$builderTerm.setLength(0); + |$builderTerm.append("["); + |int $numTerm = $arrayTerm.numElements(); + |for (int $indexTerm = 0; $indexTerm < $numTerm; $indexTerm++) { + | if ($indexTerm != 0) { + | $builderTerm.append(", "); + | } + | + | ${castExpr.code} + | if (${castExpr.nullTerm}) { + | $builderTerm.append("null"); + | } else { + | $builderTerm.append(${castExpr.resultTerm}); + | } + |} + |$builderTerm.append("]"); + """.stripMargin + (stmt, s"$builderTerm.toString()") + } + + private def generateCastMapToString( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + mt: MapType): GeneratedExpression = + generateStringResultCallWithStmtIfArgsNotNull(ctx, Seq(operand)) { + terms => + val resultTerm = newName("toStringResult") + + val builderCls = classOf[JStringBuilder].getCanonicalName + val builderTerm = newName("builder") + ctx.addReusableMember(s"$builderCls $builderTerm = new $builderCls();") + + val binaryMapTerm = terms.head + val arrayCls = classOf[BinaryArray].getCanonicalName + val keyArrayTerm = newName("keyArray") + val valueArrayTerm = newName("valueArray") + + val indexTerm = newName("i") + val numTerm = newName("num") + + val keyType = mt.getKeyType + val keyCls = primitiveTypeTermForType(keyType) + val keyTerm = newName("key") + val keyNullTerm = newName("isNull") + val keyCode = + s""" + |$keyCls $keyTerm = ${primitiveDefaultValue(keyType)}; + |boolean $keyNullTerm = $keyArrayTerm.isNullAt($indexTerm); + |if (!$keyNullTerm) { + | $keyTerm = ($keyCls) ${ + baseRowFieldReadAccess(ctx, indexTerm, keyArrayTerm, keyType)}; + |} + """.stripMargin + val keyExpr = GeneratedExpression(keyTerm, keyNullTerm, keyCode, keyType) + val keyCastExpr = generateCast(ctx, keyExpr, InternalTypes.STRING) + + val valueType = mt.getValueType + val valueCls = primitiveTypeTermForType(valueType) + val valueTerm = newName("value") + val valueNullTerm = newName("isNull") + val valueCode = + s""" + |$valueCls $valueTerm = ${primitiveDefaultValue(valueType)}; + |boolean $valueNullTerm = $valueArrayTerm.isNullAt($indexTerm); + |if (!$valueNullTerm) { + | $valueTerm = ($valueCls) ${ + baseRowFieldReadAccess(ctx, indexTerm, valueArrayTerm, valueType)}; + |} + """.stripMargin + val valueExpr = GeneratedExpression(valueTerm, valueNullTerm, valueCode, valueType) + val valueCastExpr = generateCast(ctx, valueExpr, InternalTypes.STRING) + + val stmt = + s""" + |String $resultTerm; + |$arrayCls $keyArrayTerm = $binaryMapTerm.keyArray(); + |$arrayCls $valueArrayTerm = $binaryMapTerm.valueArray(); + | + |$builderTerm.setLength(0); + |$builderTerm.append("{"); + | + |int $numTerm = $binaryMapTerm.numElements(); + |for (int $indexTerm = 0; $indexTerm < $numTerm; $indexTerm++) { + | if ($indexTerm != 0) { + | $builderTerm.append(", "); + | } + | + | ${keyCastExpr.code} + | if (${keyCastExpr.nullTerm}) { + | $builderTerm.append("null"); + | } else { + | $builderTerm.append(${keyCastExpr.resultTerm}); + | } + | $builderTerm.append("="); + | + | ${valueCastExpr.code} + | if (${valueCastExpr.nullTerm}) { + | $builderTerm.append("null"); + | } else { + | $builderTerm.append(${valueCastExpr.resultTerm}); + | } + |} + |$builderTerm.append("}"); + | + |$resultTerm = $builderTerm.toString(); + """.stripMargin + (stmt, resultTerm) + } + + private def generateCastBaseRowToString( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + brt: RowType): GeneratedExpression = + generateStringResultCallWithStmtIfArgsNotNull(ctx, Seq(operand)) { + terms => + val builderCls = classOf[JStringBuilder].getCanonicalName + val builderTerm = newName("builder") + ctx.addReusableMember(s"""$builderCls $builderTerm = new $builderCls();""") + + val rowTerm = terms.head + + val appendCode = brt.getFieldTypes.zipWithIndex.map { + case (elementType, idx) => + val elementCls = primitiveTypeTermForType(elementType) + val elementTerm = newName("element") + val elementExpr = GeneratedExpression( + elementTerm, s"$rowTerm.isNullAt($idx)", + s"$elementCls $elementTerm = ($elementCls) ${baseRowFieldReadAccess( + ctx, idx, rowTerm, elementType)};", elementType) + val castExpr = generateCast(ctx, elementExpr, InternalTypes.STRING) + s""" + |${if (idx != 0) s"""$builderTerm.append(",");""" else ""} + |${castExpr.code} + |if (${castExpr.nullTerm}) { + | $builderTerm.append("null"); + |} else { + | $builderTerm.append(${castExpr.resultTerm}); + |} + """.stripMargin + }.mkString("\n") + + val stmt = + s""" + |$builderTerm.setLength(0); + |$builderTerm.append("("); + |$appendCode + |$builderTerm.append(")"); + """.stripMargin + (stmt, s"$builderTerm.toString()") + } + + private def generateArrayComparison( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = + generateCallWithStmtIfArgsNotNull(ctx, InternalTypes.BOOLEAN, Seq(left, right)) { + args => + val leftTerm = args.head + val rightTerm = args(1) + + val resultTerm = newName("compareResult") + val binaryArrayCls = classOf[BinaryArray].getCanonicalName + + val elementType = left.resultType.asInstanceOf[ArrayType].getElementType + val elementCls = primitiveTypeTermForType(elementType) + val elementDefault = primitiveDefaultValue(elementType) + + val leftElementTerm = newName("leftElement") + val leftElementNullTerm = newName("leftElementIsNull") + val leftElementExpr = + GeneratedExpression(leftElementTerm, leftElementNullTerm, "", elementType) + + val rightElementTerm = newName("rightElement") + val rightElementNullTerm = newName("rightElementIsNull") + val rightElementExpr = + GeneratedExpression(rightElementTerm, rightElementNullTerm, "", elementType) + + val indexTerm = newName("index") + val elementEqualsExpr = generateEquals(ctx, leftElementExpr, rightElementExpr) + + val stmt = + s""" + |boolean $resultTerm; + |if ($leftTerm instanceof $binaryArrayCls && $rightTerm instanceof $binaryArrayCls) { + | $resultTerm = $leftTerm.equals($rightTerm); + |} else { + | if ($leftTerm.numElements() == $rightTerm.numElements()) { + | $resultTerm = true; + | for (int $indexTerm = 0; $indexTerm < $leftTerm.numElements(); $indexTerm++) { + | $elementCls $leftElementTerm = $elementDefault; + | boolean $leftElementNullTerm = $leftTerm.isNullAt($indexTerm); + | if (!$leftElementNullTerm) { + | $leftElementTerm = + | ${baseRowFieldReadAccess(ctx, indexTerm, leftTerm, elementType)}; + | } + | + | $elementCls $rightElementTerm = $elementDefault; + | boolean $rightElementNullTerm = $rightTerm.isNullAt($indexTerm); + | if (!$rightElementNullTerm) { + | $rightElementTerm = + | ${baseRowFieldReadAccess(ctx, indexTerm, rightTerm, elementType)}; + | } + | + | ${elementEqualsExpr.code} + | if (!${elementEqualsExpr.resultTerm}) { + | $resultTerm = false; + | break; + | } + | } + | } else { + | $resultTerm = false; + | } + |} + """.stripMargin + (stmt, resultTerm) + } + + private def generateMapComparison( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = + generateCallWithStmtIfArgsNotNull(ctx, InternalTypes.BOOLEAN, Seq(left, right)) { + args => + val leftTerm = args.head + val rightTerm = args(1) + val resultTerm = newName("compareResult") + val stmt = s"boolean $resultTerm = $leftTerm.equals($rightTerm);" + (stmt, resultTerm) + } + + // ------------------------------------------------------------------------------------------ + + private def generateUnaryOperatorIfNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + operand: GeneratedExpression, + resultNullable: Boolean = false) + (expr: String => String): GeneratedExpression = { + generateCallIfArgsNotNull(ctx, returnType, Seq(operand), resultNullable) { + args => expr(args.head) + } + } + + private def generateOperatorIfNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + left: GeneratedExpression, + right: GeneratedExpression, + resultNullable: Boolean = false) + (expr: (String, String) => String) + : GeneratedExpression = { + generateCallIfArgsNotNull(ctx, returnType, Seq(left, right), resultNullable) { + args => expr(args.head, args(1)) + } + } + + // ---------------------------------------------------------------------------------------------- + + private def internalExprCasting( + expr: GeneratedExpression, + targetType: InternalType) + : GeneratedExpression = { + expr.copy(resultType = targetType) + } + + private def numericCasting( + operandType: InternalType, + resultType: InternalType): String => String = { + + val resultTypeTerm = primitiveTypeTermForType(resultType) + + def decToPrimMethod(targetType: InternalType): String = targetType match { + case InternalTypes.BYTE => "castToByte" + case InternalTypes.SHORT => "castToShort" + case InternalTypes.INT => "castToInt" + case InternalTypes.LONG => "castToLong" + case InternalTypes.FLOAT => "castToFloat" + case InternalTypes.DOUBLE => "castToDouble" + case InternalTypes.BOOLEAN => "castToBoolean" + case _ => throw new CodeGenException(s"Unsupported decimal casting type: '$targetType'") + } + + // no casting necessary + if (operandType == resultType) { + operandTerm => s"$operandTerm" + } + // decimal to decimal, may have different precision/scale + else if (isDecimal(resultType) && isDecimal(operandType)) { + val dt = resultType.asInstanceOf[DecimalType] + operandTerm => + s"$DECIMAL.castToDecimal($operandTerm, ${dt.precision()}, ${dt.scale()})" + } + // non_decimal_numeric to decimal + else if (isDecimal(resultType) && isNumeric(operandType)) { + val dt = resultType.asInstanceOf[DecimalType] + operandTerm => + s"$DECIMAL.castFrom($operandTerm, ${dt.precision()}, ${dt.scale()})" + } + // decimal to non_decimal_numeric + else if (isNumeric(resultType) && isDecimal(operandType) ) { + operandTerm => + s"$DECIMAL.${decToPrimMethod(resultType)}($operandTerm)" + } + // numeric to numeric + // TODO: Create a wrapper layer that handles type conversion between numeric. + else if (isNumeric(operandType) && isNumeric(resultType)) { + val resultTypeValue = resultTypeTerm + "Value()" + val boxedTypeTerm = boxedTypeTermForType(operandType) + operandTerm => + s"(new $boxedTypeTerm($operandTerm)).$resultTypeValue" + } + // result type is time interval and operand type is integer + else if (isTimeInterval(resultType) && isInteger(operandType)){ + operandTerm => s"(($resultTypeTerm) $operandTerm)" + } + else { + throw new CodeGenException(s"Unsupported casting from $operandType to $resultType.") + } + } + + private def stringToInternalCode( + targetType: InternalType, + operandTerm: String, + zoneTerm: String): String = + targetType match { + case InternalTypes.DATE => + s"${qualifyMethod(BuiltInMethod.STRING_TO_DATE.method)}($operandTerm.toString())" + case InternalTypes.TIME => + s"${qualifyMethod(BuiltInMethod.STRING_TO_TIME.method)}($operandTerm.toString())" + case InternalTypes.TIMESTAMP => + s"""${qualifyMethod(BuiltInMethods.STRING_TO_TIMESTAMP)}($operandTerm.toString(), + | $zoneTerm)""".stripMargin + case _ => throw new UnsupportedOperationException + } + + private def internalToStringCode( + fromType: InternalType, + operandTerm: String, + zoneTerm: String): String = + fromType match { + case InternalTypes.DATE => + s"${qualifyMethod(BuiltInMethod.UNIX_DATE_TO_STRING.method)}($operandTerm)" + case InternalTypes.TIME => + s"${qualifyMethod(BuiltInMethods.UNIX_TIME_TO_STRING)}($operandTerm)" + case _: TimestampType => // including rowtime indicator + s"${qualifyMethod(BuiltInMethods.TIMESTAMP_TO_STRING)}($operandTerm, 3, $zoneTerm)" + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/GenericRelDataType.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/GenericRelDataType.scala new file mode 100644 index 0000000000000..b35967a7429ed --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/GenericRelDataType.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.schema + +import org.apache.calcite.rel.`type`.RelDataTypeSystem +import org.apache.calcite.sql.`type`.{BasicSqlType, SqlTypeName} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.`type`.GenericType + +/** + * Generic type for encapsulating Flink's [[TypeInformation]]. + * + * @param genericType InternalType to encapsulate + * @param nullable flag if type can be nullable + * @param typeSystem Flink's type system + */ +class GenericRelDataType( + val genericType: GenericType[_], + val nullable: Boolean, + typeSystem: RelDataTypeSystem) + extends BasicSqlType( + typeSystem, + SqlTypeName.ANY) { + + isNullable = nullable + + override def toString = s"ANY($genericType)" + + def canEqual(other: Any): Boolean = other.isInstanceOf[GenericRelDataType] + + override def equals(other: Any): Boolean = other match { + case that: GenericRelDataType => + super.equals(that) && + (that canEqual this) && + genericType == that.genericType && + nullable == that.nullable + case _ => false + } + + override def hashCode(): Int = { + genericType.hashCode() + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala index 41676a84ac220..fa8c1faae622d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.typeutils import org.apache.flink.table.`type`._ +import org.apache.flink.table.codegen.GeneratedExpression object TypeCheckUtils { @@ -29,6 +30,26 @@ object TypeCheckUtils { case _ => false } + def isTemporal(dataType: InternalType): Boolean = + isTimePoint(dataType) || isTimeInterval(dataType) + + def isTimePoint(dataType: InternalType): Boolean = dataType match { + case InternalTypes.INTERVAL_MILLIS | InternalTypes.INTERVAL_MONTHS => false + case _: TimeType | _: DateType | _: TimestampType => true + case _ => false + } + + def isRowTime(dataType: InternalType): Boolean = + dataType == InternalTypes.ROWTIME_INDICATOR + + def isProcTime(dataType: InternalType): Boolean = + dataType == InternalTypes.PROCTIME_INDICATOR + + def isTimeInterval(dataType: InternalType): Boolean = dataType match { + case InternalTypes.INTERVAL_MILLIS | InternalTypes.INTERVAL_MONTHS => true + case _ => false + } + def isString(dataType: InternalType): Boolean = dataType == InternalTypes.STRING def isBinary(dataType: InternalType): Boolean = dataType == InternalTypes.BINARY @@ -51,4 +72,25 @@ object TypeCheckUtils { !dataType.isInstanceOf[RowType] && !isArray(dataType) + def isMutable(dataType: InternalType): Boolean = dataType match { + // the internal representation of String is BinaryString which is mutable + case InternalTypes.STRING => true + case _: ArrayType | _: MapType | _: RowType | _: GenericType[_] => true + case _ => false + } + + def isReference(t: InternalType): Boolean = t match { + case InternalTypes.INT + | InternalTypes.LONG + | InternalTypes.SHORT + | InternalTypes.BYTE + | InternalTypes.FLOAT + | InternalTypes.DOUBLE + | InternalTypes.BOOLEAN + | InternalTypes.CHAR => false + case _ => true + } + + def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCoercion.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCoercion.scala new file mode 100644 index 0000000000000..e559dd2a4bbee --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCoercion.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils + +import org.apache.flink.table.`type`.{DecimalType, InternalType, InternalTypes, TimestampType} +import org.apache.flink.table.typeutils.TypeCheckUtils._ + +/** + * Utilities for type conversions. + */ +object TypeCoercion { + + val numericWideningPrecedence: IndexedSeq[InternalType] = + IndexedSeq( + InternalTypes.BYTE, + InternalTypes.SHORT, + InternalTypes.INT, + InternalTypes.LONG, + InternalTypes.FLOAT, + InternalTypes.DOUBLE) + + def widerTypeOf(tp1: InternalType, tp2: InternalType): Option[InternalType] = { + (tp1, tp2) match { + case (ti1, ti2) if ti1 == ti2 => Some(ti1) + + case (_, InternalTypes.STRING) => Some(InternalTypes.STRING) + case (InternalTypes.STRING, _) => Some(InternalTypes.STRING) + + case (_, dt: DecimalType) => Some(dt) + case (dt: DecimalType, _) => Some(dt) + + case (a, b) if isTimePoint(a) && isTimeInterval(b) => Some(a) + case (a, b) if isTimeInterval(a) && isTimePoint(b) => Some(b) + + case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => + val higherIndex = numericWideningPrecedence.lastIndexWhere(t => t == tp1 || t == tp2) + Some(numericWideningPrecedence(higherIndex)) + + case _ => None + } + } + + /** + * Test if we can do cast safely without lose of type. + */ + def canSafelyCast(from: InternalType, to: InternalType): Boolean = (from, to) match { + case (_, InternalTypes.STRING) => true + + case (a, _: DecimalType) if isNumeric(a) => true + + case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => + if (numericWideningPrecedence.indexOf(from) < numericWideningPrecedence.indexOf(to)) { + true + } else { + false + } + + case _ => false + } + + /** + * All the supported cast types in flink-table. + * + * Note: No distinction between explicit and implicit conversions + * Note: This is a subset of SqlTypeAssignmentRule + * Note: This may lose type during the cast. + */ + def canCast(from: InternalType, to: InternalType): Boolean = (from, to) match { + case (fromTp, toTp) if fromTp == toTp => true + + case (_, InternalTypes.STRING) => true + + case (_, InternalTypes.CHAR) => false // Character type not supported. + + case (InternalTypes.STRING, b) if isNumeric(b) => true + case (InternalTypes.STRING, InternalTypes.BOOLEAN) => true + case (InternalTypes.STRING, _: DecimalType) => true + case (InternalTypes.STRING, InternalTypes.DATE) => true + case (InternalTypes.STRING, InternalTypes.TIME) => true + case (InternalTypes.STRING, _: TimestampType) => true + + case (InternalTypes.BOOLEAN, b) if isNumeric(b) => true + case (InternalTypes.BOOLEAN, _: DecimalType) => true + case (a, InternalTypes.BOOLEAN) if isNumeric(a) => true + case (_: DecimalType, InternalTypes.BOOLEAN) => true + + case (a, b) if isNumeric(a) && isNumeric(b) => true + case (a, _: DecimalType) if isNumeric(a) => true + case (_: DecimalType, b) if isNumeric(b) => true + case (_: DecimalType, _: DecimalType) => true + case (InternalTypes.INT, InternalTypes.DATE) => true + case (InternalTypes.INT, InternalTypes.TIME) => true + case (InternalTypes.BYTE, _: TimestampType) => true + case (InternalTypes.SHORT, _: TimestampType) => true + case (InternalTypes.INT, _: TimestampType) => true + case (InternalTypes.LONG, _: TimestampType) => true + case (InternalTypes.DOUBLE, _: TimestampType) => true + case (InternalTypes.FLOAT, _: TimestampType) => true + case (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) => true + case (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) => true + + case (InternalTypes.DATE, InternalTypes.TIME) => false + case (InternalTypes.TIME, InternalTypes.DATE) => false + case (a, b) if isTimePoint(a) && isTimePoint(b) => true + case (InternalTypes.DATE, InternalTypes.INT) => true + case (InternalTypes.TIME, InternalTypes.INT) => true + case (_: TimestampType, InternalTypes.BYTE) => true + case (_: TimestampType, InternalTypes.INT) => true + case (_: TimestampType, InternalTypes.SHORT) => true + case (_: TimestampType, InternalTypes.LONG) => true + case (_: TimestampType, InternalTypes.DOUBLE) => true + case (_: TimestampType, InternalTypes.FLOAT) => true + + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) => true + case (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) => true + + case _ => false + } + + /** + * All the supported reinterpret types in flink-table. + */ + def canReinterpret(from: InternalType, to: InternalType): Boolean = (from, to) match { + case (fromTp, toTp) if fromTp == toTp => true + + case (InternalTypes.DATE, InternalTypes.INT) => true + case (InternalTypes.TIME, InternalTypes.INT) => true + case (_: TimestampType, InternalTypes.LONG) => true + case (InternalTypes.INT, InternalTypes.DATE) => true + case (InternalTypes.INT, InternalTypes.TIME) => true + case (InternalTypes.LONG, _: TimestampType) => true + case (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) => true + case (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) => true + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) => true + case (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) => true + + case (InternalTypes.DATE, InternalTypes.LONG) => true + case (InternalTypes.TIME, InternalTypes.LONG) => true + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.LONG) => true + + case _ => false + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ArrayTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ArrayTypeTest.scala new file mode 100644 index 0000000000000..8e8e7994291e3 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ArrayTypeTest.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.ArrayTypeTestBase +import org.junit.Test + +class ArrayTypeTest extends ArrayTypeTestBase { + + @Test + def testArrayLiterals(): Unit = { + // primitive literals + testSqlApi( + "ARRAY[1, 2, 3]", + "[1, 2, 3]") + + testSqlApi( + "ARRAY[TRUE, TRUE, TRUE]", + "[true, true, true]") + + testSqlApi( + "ARRAY[ARRAY[ARRAY[1], ARRAY[1]]]", + "[[[1], [1]]]") + + testSqlApi( + "ARRAY[1 + 1, 3 * 3]", + "[2, 9]") + + testSqlApi( + "ARRAY[NULLIF(1,1), 1]", + "[null, 1]") + + testSqlApi( + "ARRAY[ARRAY[NULLIF(1,1), 1]]", + "[[null, 1]]") + + testSqlApi( + "ARRAY[DATE '1985-04-11', DATE '2018-07-26']", + "[1985-04-11, 2018-07-26]") + + testSqlApi( + "ARRAY[TIME '14:15:16', TIME '17:18:19']", + "[14:15:16, 17:18:19]") + + testSqlApi( + "ARRAY[TIMESTAMP '1985-04-11 14:15:16', TIMESTAMP '2018-07-26 17:18:19']", + "[1985-04-11 14:15:16.000, 2018-07-26 17:18:19.000]") + + testSqlApi( + "ARRAY[CAST(2.0002 AS DECIMAL(10,4)), CAST(2.0003 AS DECIMAL(10,4))]", + "[2.0002, 2.0003]") + + testSqlApi( + "ARRAY[ARRAY[TRUE]]", + "[[true]]") + + testSqlApi( + "ARRAY[ARRAY[1, 2, 3], ARRAY[3, 2, 1]]", + "[[1, 2, 3], [3, 2, 1]]") + + // implicit type cast only works on SQL APIs. + testSqlApi( + "ARRAY[CAST(1 AS DOUBLE), CAST(2 AS FLOAT)]", + "[1.0, 2.0]") + } + + @Test + def testArrayField(): Unit = { + testSqlApi( + "ARRAY[f0, f1]", + "[null, 42]") + + testSqlApi( + "ARRAY[f0, f1]", + "[null, 42]") + + testSqlApi( + "f2", + "[1, 2, 3]") + + testSqlApi( + "f3", + "[1984-03-12, 1984-02-10]") + + testSqlApi( + "f5", + "[[1, 2, 3], null]") + + testSqlApi( + "f6", + "[1, null, null, 4]") + + testSqlApi( + "f2", + "[1, 2, 3]") + + testSqlApi( + "f2[1]", + "1") + + testSqlApi( + "f3[1]", + "1984-03-12") + + testSqlApi( + "f3[2]", + "1984-02-10") + + testSqlApi( + "f5[1][2]", + "2") + + testSqlApi( + "f5[2][2]", + "null") + + testSqlApi( + "f4[2][2]", + "null") + + testSqlApi( + "f11[1]", + "1") + } + + @Test + def testArrayOperations(): Unit = { + // cardinality + testSqlApi( + "CARDINALITY(f2)", + "3") + + testSqlApi( + "CARDINALITY(f4)", + "null") + + testSqlApi( + "CARDINALITY(f11)", + "1") + + // element + testSqlApi( + "ELEMENT(f9)", + "1") + + testSqlApi( + "ELEMENT(f8)", + "4.0") + + testSqlApi( + "ELEMENT(f10)", + "null") + + testSqlApi( + "ELEMENT(f4)", + "null") + + testSqlApi( + "ELEMENT(f11)", + "1") + + // comparison + testSqlApi( + "f2 = f5[1]", + "true") + + testSqlApi( + "f6 = ARRAY[1, 2, 3]", + "false") + + testSqlApi( + "f2 <> f5[1]", + "false") + + testSqlApi( + "f2 = f7", + "false") + + testSqlApi( + "f2 <> f7", + "true") + + testSqlApi( + "f11 = f11", + "true") + + testSqlApi( + "f11 = f9", + "true") + + testSqlApi( + "f11 <> f11", + "false") + + testSqlApi( + "f11 <> f9", + "false") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala new file mode 100644 index 0000000000000..c77132b78fbbf --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.CompositeTypeTestBase +import org.junit.Test + +class CompositeAccessTest extends CompositeTypeTestBase { + + @Test + def testGetField(): Unit = { + + // single field by string key + testSqlApi( + "testTable.f0.intField", + "42") + testSqlApi("f0.intField", "42") + + testSqlApi("testTable.f0.stringField", "Bob") + testSqlApi("f0.stringField", "Bob") + + testSqlApi("testTable.f0.booleanField", "true") + testSqlApi("f0.booleanField", "true") + + // nested single field + testSqlApi( + "testTable.f1.objectField.intField", + "25") + testSqlApi("f1.objectField.intField", "25") + + testSqlApi("testTable.f1.objectField.stringField", "Timo") + testSqlApi("f1.objectField.stringField", "Timo") + + testSqlApi("testTable.f1.objectField.booleanField", "false") + testSqlApi("f1.objectField.booleanField", "false") + + testSqlApi( + "testTable.f2._1", + "a") + testSqlApi("f2._1", "a") + + testSqlApi("testTable.f3.f1", "b") + testSqlApi("f3.f1", "b") + + testSqlApi("testTable.f4.myString", "Hello") + testSqlApi("f4.myString", "Hello") + + testSqlApi("testTable.f5", "13") + testSqlApi("f5", "13") + + testSqlApi( + "testTable.f7._1", + "true") + + // composite field return type + testSqlApi("testTable.f6", "MyCaseClass2(null)") + testSqlApi("f6", "MyCaseClass2(null)") + + // MyCaseClass is converted to BaseRow + // so the result of "toString" does'nt contain MyCaseClass prefix + testSqlApi( + "testTable.f1.objectField", + "(25,Timo,false)") + testSqlApi("f1.objectField", "(25,Timo,false)") + + testSqlApi( + "testTable.f0", + "(42,Bob,true)") + testSqlApi("f0", "(42,Bob,true)") + + // flattening (test base only returns first column) + testSqlApi( + "testTable.f1.objectField.*", + "25") + testSqlApi("f1.objectField.*", "25") + + testSqlApi( + "testTable.f0.*", + "42") + testSqlApi("f0.*", "42") + + // array of composites + testSqlApi( + "f8[1]._1", + "true" + ) + + testSqlApi( + "f8[1]._2", + "23" + ) + + testSqlApi( + "f9[2]._1", + "null" + ) + + testSqlApi( + "f10[1].stringField", + "Bob" + ) + + testSqlApi( + "f11[1].myString", + "Hello" + ) + + testSqlApi( + "f11[2]", + "null" + ) + + testSqlApi( + "f12[1].arrayField[1].stringField", + "Alice" + ) + + testSqlApi( + "f13[1].objectField.stringField", + "Bob" + ) + } +} + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala new file mode 100644 index 0000000000000..06608fc2a5593 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.expressions.utils.ExpressionTestBase +import org.apache.flink.table.typeutils.DecimalTypeInfo +import org.apache.flink.types.Row +import org.junit.Test + +class DecimalTypeTest extends ExpressionTestBase { + + @Test + def testDecimalLiterals(): Unit = { + // implicit double + testSqlApi( + "11.2", + "11.2") + + // implicit double + testSqlApi( + "0.7623533651719233", + "0.7623533651719233") + + // explicit decimal (with precision of 19) + testSqlApi( + "1234567891234567891", + "1234567891234567891") + } + + @Test + def testDecimalBorders(): Unit = { + testSqlApi( + Double.MaxValue.toString, + Double.MaxValue.toString) + + testSqlApi( + Double.MinValue.toString, + Double.MinValue.toString) + + testSqlApi( + s"CAST(${Double.MinValue} AS FLOAT)", + Float.NegativeInfinity.toString) + + testSqlApi( + s"CAST(${Byte.MinValue} AS TINYINT)", + Byte.MinValue.toString) + + testSqlApi( + s"CAST(${Byte.MinValue} AS TINYINT) - CAST(1 AS TINYINT)", + Byte.MaxValue.toString) + + testSqlApi( + s"CAST(${Short.MinValue} AS SMALLINT)", + Short.MinValue.toString) + + testSqlApi( + s"CAST(${Int.MinValue} AS INT) - 1", + Int.MaxValue.toString) + + testSqlApi( + s"CAST(${Long.MinValue} AS BIGINT)", + Long.MinValue.toString) + } + + @Test + def testDecimalCasting(): Unit = { + // from String + testSqlApi( + "CAST('123456789123456789123456789' AS DECIMAL(27, 0))", + "123456789123456789123456789") + + // from double + testSqlApi( + "CAST(f3 AS DECIMAL)", + "4") + + testSqlApi( + "CAST(f3 AS DECIMAL(10,2))", + "4.20" + ) + + // to double + testSqlApi( + "CAST(f0 AS DOUBLE)", + "1.2345678912345679E8") + + // to int + testSqlApi( + "CAST(f4 AS INT)", + "123456789") + + // to long + testSqlApi( + "CAST(f4 AS BIGINT)", + "123456789") + } + + @Test + def testDecimalArithmetic(): Unit = { + + // note: calcite type inference: + // Decimal+ExactNumeric => Decimal + // Decimal+Double => Double. + + // implicit cast to decimal + testSqlApi( + "f1 + 12", + "123456789123456789123456801") + + // implicit cast to decimal + testSqlApi( + "12 + f1", + "123456789123456789123456801") + + testSqlApi( + "f1 + 12.3", + "123456789123456789123456801.3" + ) + + testSqlApi( + "12.3 + f1", + "123456789123456789123456801.3") + + testSqlApi( + "f1 + f1", + "246913578246913578246913578") + + testSqlApi( + "f1 - f1", + "0") + + testSqlApi( + "f1 / f1", + "1.00000000") + + testSqlApi( + "MOD(f1, f1)", + "0") + + testSqlApi( + "-f0", + "-123456789.123456789123456789") + } + + @Test + def testDecimalComparison(): Unit = { + testSqlApi( + "f1 < 12", + "false") + + testSqlApi( + "f1 > 12", + "true") + + testSqlApi( + "f1 = 12", + "false") + + testSqlApi( + "f5 = 0", + "true") + + testSqlApi( + "f1 = CAST('123456789123456789123456789' AS DECIMAL(30, 0))", + "true") + + testSqlApi( + "f1 <> CAST('123456789123456789123456789' AS DECIMAL(30, 0))", + "false") + + testSqlApi( + "f4 < f0", + "true") + + // TODO add all tests if FLINK-4070 is fixed + testSqlApi( + "12 < f1", + "true") + } + + // ---------------------------------------------------------------------------------------------- + + override def testData: Row = { + val testData = new Row(6) + testData.setField(0, Decimal.castFrom("123456789.123456789123456789", 30, 18)) + testData.setField(1, Decimal.castFrom("123456789123456789123456789", 30, 0)) + testData.setField(2, 42) + testData.setField(3, 4.2) + testData.setField(4, Decimal.castFrom("123456789", 10, 0)) + testData.setField(5, Decimal.castFrom("0.000", 10, 3)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + DecimalTypeInfo.of(30, 18), + DecimalTypeInfo.of(30, 0), + Types.INT, + Types.DOUBLE, + DecimalTypeInfo.of(10, 0), + DecimalTypeInfo.of(10, 3)) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/LiteralTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/LiteralTest.scala new file mode 100644 index 0000000000000..f1c4c25527552 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/LiteralTest.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.api.common.typeinfo.{TypeInformation, Types} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.expressions.utils.ExpressionTestBase +import org.apache.flink.types.Row +import org.junit.{Ignore, Test} + +class LiteralTest extends ExpressionTestBase { + + @Test + def testFieldWithBooleanPrefix(): Unit = { + + testSqlApi( + "trUeX", + "trUeX_value" + ) + + testSqlApi( + "FALSE_A", + "FALSE_A_value" + ) + + testSqlApi( + "FALSE_AB", + "FALSE_AB_value" + ) + + testSqlApi( + "trUe", + "true" + ) + + testSqlApi( + "FALSE", + "false" + ) + } + + @Ignore("TODO: FLINK-11898") + @Test + def testNonAsciiLiteral(): Unit = { + testSqlApi( + "f4 LIKE '%测试%'", + "true") + + testSqlApi( + "'Абвгде' || '谢谢'", + "Абвгде谢谢") + } + + @Ignore("TODO: FLINK-11898") + @Test + def testDoubleQuote(): Unit = { + val hello = "\"\"" + testSqlApi( + s"concat('a', '$hello')", + s"a and $hello") + } + + @Test + def testStringLiterals(): Unit = { + + // these tests use Java/Scala escaping for non-quoting unicode characters + + testSqlApi( + "'>\n<'", + ">\n<") + + testSqlApi( + "'>\u263A<'", + ">\u263A<") + + testSqlApi( + "'>\u263A<'", + ">\u263A<") + + testSqlApi( + "'>\\<'", + ">\\<") + + testSqlApi( + "'>''<'", + ">'<") + + testSqlApi( + "' '", + " ") + + testSqlApi( + "''", + "") + + testSqlApi( + "'>foo([\\w]+)<'", + ">foo([\\w]+)<") + + testSqlApi( + "'>\\''\n<'", + ">\\'\n<") + + testSqlApi( + "'It''s me.'", + "It's me.") + + // these test use SQL for describing unicode characters + + testSqlApi( + "U&'>\\263A<'", // default escape backslash + ">\u263A<") + + testSqlApi( + "U&'>#263A<' UESCAPE '#'", // custom escape '#' + ">\u263A<") + + testSqlApi( + """'>\\<'""", + ">\\\\<") + } + + override def testData: Row = { + val testData = new Row(4) + testData.setField(0, "trUeX_value") + testData.setField(1, "FALSE_A_value") + testData.setField(2, "FALSE_AB_value") + testData.setField(3, "这是个测试字符串") + testData + } + + override def typeInfo : RowTypeInfo = { + new RowTypeInfo( + Array( + Types.STRING, + Types.STRING, + Types.STRING, + Types.STRING + ).asInstanceOf[Array[TypeInformation[_]]], + Array("trUeX", "FALSE_A", "FALSE_AB", "f4") + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MapTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MapTypeTest.scala new file mode 100644 index 0000000000000..a8d2729dde778 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MapTypeTest.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.MapTypeTestBase +import org.junit.Test + +class MapTypeTest extends MapTypeTestBase { + + @Test + def testItem(): Unit = { + testSqlApi("f0['map is null']", "null") + testSqlApi("f1['map is empty']", "null") + testSqlApi("f2['b']", "13") + testSqlApi("f3[1]", "null") + testSqlApi("f3[12]", "a") + } + + @Test + def testMapLiteral(): Unit = { + // primitive literals + testSqlApi( + "MAP[1, 1]", + "{1=1}") + + testSqlApi( + "map[TRUE, TRUE]", + "{true=true}") + + testSqlApi( + "MAP[MAP[1, 2], MAP[3, 4]]", + "{{1=2}={3=4}}") + + testSqlApi( + "map[1 + 2, 3 * 3, 3 - 6, 4 - 2]", + "{3=9, -3=2}") + + testSqlApi( + "map[1, NULLIF(1,1)]", + "{1=null}") + + // explicit conversion + testSqlApi( + "MAP[1, CAST(2 AS BIGINT), 3, CAST(4 AS BIGINT)]", + "{1=2, 3=4}") + + testSqlApi( + "MAP[DATE '1985-04-11', TIME '14:15:16', DATE '2018-07-26', TIME '17:18:19']", + "{1985-04-11=14:15:16, 2018-07-26=17:18:19}") + + testSqlApi( + "MAP[TIME '14:15:16', TIMESTAMP '1985-04-11 14:15:16', " + + "TIME '17:18:19', TIMESTAMP '2018-07-26 17:18:19']", + "{14:15:16=1985-04-11 14:15:16.000, 17:18:19=2018-07-26 17:18:19.000}") + + testSqlApi( + "MAP[CAST(2.0002 AS DECIMAL(5, 4)), CAST(2.0003 AS DECIMAL(5, 4))]", + "{2.0002=2.0003}") + + // implicit type cast only works on SQL API + testSqlApi( + "MAP['k1', CAST(1 AS DOUBLE), 'k2', CAST(2 AS FLOAT)]", + "{k1=1.0, k2=2.0}") + } + + @Test + def testMapField(): Unit = { + testSqlApi( + "MAP[f4, f5]", + "{foo=12}") + + testSqlApi( + "MAP[f4, f1]", + "{foo={}}") + + testSqlApi( + "MAP[f2, f3]", + "{{a=12, b=13}={12=a, 13=b}}") + + testSqlApi( + "MAP[f1['a'], f5]", + "{null=12}") + + testSqlApi( + "f1", + "{}") + + testSqlApi( + "f2", + "{a=12, b=13}") + + testSqlApi( + "f2['a']", + "12") + + testSqlApi( + "f3[12]", + "a") + + testSqlApi( + "MAP[f4, f3]['foo'][13]", + "b") + } + + @Test + def testMapOperations(): Unit = { + + // comparison + testSqlApi( + "f1 = f2", + "false") + + testSqlApi( + "f3 = f7", + "true") + + testSqlApi( + "f5 = f2['a']", + "true") + + testSqlApi( + "f8 = f9", + "true") + + testSqlApi( + "f10 = f11", + "true") + + testSqlApi( + "f8 <> f9", + "false") + + testSqlApi( + "f10 <> f11", + "false") + + testSqlApi( + "f0['map is null']", + "null") + + testSqlApi( + "f1['map is empty']", + "null") + + testSqlApi( + "f2['b']", + "13") + + testSqlApi( + "f3[1]", + "null") + + testSqlApi( + "f3[12]", + "a") + + testSqlApi( + "CARDINALITY(f3)", + "2") + + testSqlApi( + "f2['a'] IS NOT NULL", + "true") + + testSqlApi( + "f2['a'] IS NULL", + "false") + + testSqlApi( + "f2['c'] IS NOT NULL", + "false") + + testSqlApi( + "f2['c'] IS NULL", + "true") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MathFunctionsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MathFunctionsTest.scala new file mode 100644 index 0000000000000..ae603ab2b1000 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MathFunctionsTest.scala @@ -0,0 +1,694 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.ScalarTypesTestBase +import org.junit.{Ignore, Test} + +@Ignore("TODO: [FLINK-11898] the math functions will be supported in the future") +class MathFunctionsTest extends ScalarTypesTestBase { + + // ---------------------------------------------------------------------------------------------- + // Math functions + // ---------------------------------------------------------------------------------------------- + + @Test + def testMod(): Unit = { + testSqlApi( + "MOD(f4, f7)", + "2") + + testSqlApi( + "MOD(f4, 3)", + "2") + + testSqlApi( + "MOD(44, 3)", + "2") + } + + @Test + def testExp(): Unit = { + testSqlApi( + "EXP(f2)", + math.exp(42.toByte).toString) + + testSqlApi( + "EXP(f3)", + math.exp(43.toShort).toString) + + testSqlApi( + "EXP(f4)", + math.exp(44.toLong).toString) + + testSqlApi( + "EXP(f5)", + math.exp(4.5.toFloat).toString) + + testSqlApi( + "EXP(f6)", + math.exp(4.6).toString) + + testSqlApi( + "EXP(f7)", + math.exp(3).toString) + + testSqlApi( + "EXP(3)", + math.exp(3).toString) + } + + @Test + def testLog10(): Unit = { + testSqlApi( + "LOG10(f2)", + math.log10(42.toByte).toString) + + testSqlApi( + "LOG10(f3)", + math.log10(43.toShort).toString) + + testSqlApi( + "LOG10(f4)", + math.log10(44.toLong).toString) + + testSqlApi( + "LOG10(f5)", + math.log10(4.5.toFloat).toString) + + testSqlApi( + "LOG10(f6)", + math.log10(4.6).toString) + + testSqlApi( + "LOG10(f32)", + math.log10(-1).toString) + + testSqlApi( + "LOG10(f27)", + math.log10(0).toString) + } + + @Test + def testPower(): Unit = { + // f7: int , f4: long, f6: double + testSqlApi( + "POWER(f2, f7)", + math.pow(42.toByte, 3).toString) + + testSqlApi( + "POWER(f3, f6)", + math.pow(43.toShort, 4.6D).toString) + + testSqlApi( + "POWER(f4, f5)", + math.pow(44.toLong, 4.5.toFloat).toString) + + testSqlApi( + "POWER(f4, f5)", + math.pow(44.toLong, 4.5.toFloat).toString) + + // f5: float + testSqlApi( + "power(f5, f5)", + math.pow(4.5F, 4.5F).toString) + + testSqlApi( + "power(f5, f6)", + math.pow(4.5F, 4.6D).toString) + + testSqlApi( + "power(f5, f7)", + math.pow(4.5F, 3).toString) + + testSqlApi( + "power(f5, f4)", + math.pow(4.5F, 44L).toString) + + // f22: bigDecimal + // TODO delete casting in SQL when CALCITE-1467 is fixed + testSqlApi( + "power(CAST(f22 AS DOUBLE), f5)", + math.pow(2, 4.5F).toString) + + testSqlApi( + "power(CAST(f22 AS DOUBLE), f6)", + math.pow(2, 4.6D).toString) + + testSqlApi( + "power(CAST(f22 AS DOUBLE), f7)", + math.pow(2, 3).toString) + + testSqlApi( + "power(CAST(f22 AS DOUBLE), f4)", + math.pow(2, 44L).toString) + + testSqlApi( + "power(f6, f22)", + math.pow(4.6D, 2).toString) + } + + @Test + def testSqrt(): Unit = { + testSqlApi( + "SQRT(f6)", + math.sqrt(4.6D).toString) + + testSqlApi( + "SQRT(f7)", + math.sqrt(3).toString) + + testSqlApi( + "SQRT(f4)", + math.sqrt(44L).toString) + + testSqlApi( + "SQRT(CAST(f22 AS DOUBLE))", + math.sqrt(2.0).toString) + + testSqlApi( + "SQRT(f5)", + math.pow(4.5F, 0.5).toString) + + testSqlApi( + "SQRT(25)", + "5.0") + + testSqlApi( + "POWER(CAST(2.2 AS DOUBLE), CAST(0.5 AS DOUBLE))", // TODO fix FLINK-4621 + math.sqrt(2.2).toString) + } + + @Test + def testLn(): Unit = { + testSqlApi( + "LN(f2)", + math.log(42.toByte).toString) + + testSqlApi( + "LN(f3)", + math.log(43.toShort).toString) + + testSqlApi( + "LN(f4)", + math.log(44.toLong).toString) + + testSqlApi( + "LN(f5)", + math.log(4.5.toFloat).toString) + + testSqlApi( + "LN(f6)", + math.log(4.6).toString) + + testSqlApi( + "LN(f32)", + math.log(-1).toString) + + testSqlApi( + "LN(f27)", + math.log(0).toString) + } + + @Test + def testAbs(): Unit = { + testSqlApi( + "ABS(f2)", + "42") + + testSqlApi( + "ABS(f3)", + "43") + + testSqlApi( + "ABS(f4)", + "44") + + testSqlApi( + "ABS(f5)", + "4.5") + + testSqlApi( + "ABS(f6)", + "4.6") + + testSqlApi( + "ABS(f9)", + "42") + + testSqlApi( + "ABS(f10)", + "43") + + testSqlApi( + "ABS(f11)", + "44") + + testSqlApi( + "ABS(f12)", + "4.5") + + testSqlApi( + "ABS(f13)", + "4.6") + + testSqlApi( + "ABS(f15)", + "1231.1231231321321321111") + } + + @Test + def testArithmeticFloorCeil(): Unit = { + testSqlApi( + "FLOOR(f5)", + "4.0") + + testSqlApi( + "CEIL(f5)", + "5.0") + + testSqlApi( + "FLOOR(f3)", + "43") + + testSqlApi( + "CEIL(f3)", + "43") + + testSqlApi( + "FLOOR(f15)", + "-1232") + + testSqlApi( + "CEIL(f15)", + "-1231") + } + + @Test + def testSin(): Unit = { + testSqlApi( + "SIN(f2)", + math.sin(42.toByte).toString) + + testSqlApi( + "SIN(f3)", + math.sin(43.toShort).toString) + + testSqlApi( + "SIN(f4)", + math.sin(44.toLong).toString) + + testSqlApi( + "SIN(f5)", + math.sin(4.5.toFloat).toString) + + testSqlApi( + "SIN(f6)", + math.sin(4.6).toString) + + testSqlApi( + "SIN(f15)", + math.sin(-1231.1231231321321321111).toString) + } + + @Test + def testCos(): Unit = { + testSqlApi( + "COS(f2)", + math.cos(42.toByte).toString) + + testSqlApi( + "COS(f3)", + math.cos(43.toShort).toString) + + testSqlApi( + "COS(f4)", + math.cos(44.toLong).toString) + + testSqlApi( + "COS(f5)", + math.cos(4.5.toFloat).toString) + + testSqlApi( + "COS(f6)", + math.cos(4.6).toString) + + testSqlApi( + "COS(f15)", + math.cos(-1231.1231231321321321111).toString) + } + + @Test + def testTan(): Unit = { + testSqlApi( + "TAN(f2)", + math.tan(42.toByte).toString) + + testSqlApi( + "TAN(f3)", + math.tan(43.toShort).toString) + + testSqlApi( + "TAN(f4)", + math.tan(44.toLong).toString) + + testSqlApi( + "TAN(f5)", + math.tan(4.5.toFloat).toString) + + testSqlApi( + "TAN(f6)", + math.tan(4.6).toString) + + testSqlApi( + "TAN(f15)", + math.tan(-1231.1231231321321321111).toString) + } + + @Test + def testCot(): Unit = { + testSqlApi( + "COT(f2)", + (1.0d / math.tan(42.toByte)).toString) + + testSqlApi( + "COT(f3)", + (1.0d / math.tan(43.toShort)).toString) + + testSqlApi( + "COT(f4)", + (1.0d / math.tan(44.toLong)).toString) + + testSqlApi( + "COT(f5)", + (1.0d / math.tan(4.5.toFloat)).toString) + + testSqlApi( + "COT(f6)", + (1.0d / math.tan(4.6)).toString) + + testSqlApi( + "COT(f15)", + (1.0d / math.tan(-1231.1231231321321321111)).toString) + } + + @Test + def testAsin(): Unit = { + testSqlApi( + "ASIN(f25)", + math.asin(0.42.toByte).toString) + + testSqlApi( + "ASIN(f26)", + math.asin(0.toShort).toString) + + testSqlApi( + "ASIN(f27)", + math.asin(0.toLong).toString) + + testSqlApi( + "ASIN(f28)", + math.asin(0.45.toFloat).toString) + + testSqlApi( + "ASIN(f29)", + math.asin(0.46).toString) + + testSqlApi( + "ASIN(f30)", + math.asin(1).toString) + + testSqlApi( + "ASIN(f31)", + math.asin(-0.1231231321321321111).toString) + } + + @Test + def testAcos(): Unit = { + testSqlApi( + "ACOS(f25)", + math.acos(0.42.toByte).toString) + + testSqlApi( + "ACOS(f26)", + math.acos(0.toShort).toString) + + testSqlApi( + "ACOS(f27)", + math.acos(0.toLong).toString) + + testSqlApi( + "ACOS(f28)", + math.acos(0.45.toFloat).toString) + + testSqlApi( + "ACOS(f29)", + math.acos(0.46).toString) + + testSqlApi( + "ACOS(f30)", + math.acos(1).toString) + + testSqlApi( + "ACOS(f31)", + math.acos(-0.1231231321321321111).toString) + } + + @Test + def testAtan(): Unit = { + testSqlApi( + "ATAN(f25)", + math.atan(0.42.toByte).toString) + + testSqlApi( + "ATAN(f26)", + math.atan(0.toShort).toString) + + testSqlApi( + "ATAN(f27)", + math.atan(0.toLong).toString) + + testSqlApi( + "ATAN(f28)", + math.atan(0.45.toFloat).toString) + + testSqlApi( + "ATAN(f29)", + math.atan(0.46).toString) + + testSqlApi( + "ATAN(f30)", + math.atan(1).toString) + + testSqlApi( + "ATAN(f31)", + math.atan(-0.1231231321321321111).toString) + } + + @Test + def testDegrees(): Unit = { + testSqlApi( + "DEGREES(f2)", + math.toDegrees(42.toByte).toString) + + testSqlApi( + "DEGREES(f3)", + math.toDegrees(43.toShort).toString) + + testSqlApi( + "DEGREES(f4)", + math.toDegrees(44.toLong).toString) + + testSqlApi( + "DEGREES(f5)", + math.toDegrees(4.5.toFloat).toString) + + testSqlApi( + "DEGREES(f6)", + math.toDegrees(4.6).toString) + + testSqlApi( + "DEGREES(f15)", + math.toDegrees(-1231.1231231321321321111).toString) + } + + @Test + def testRadians(): Unit = { + testSqlApi( + "RADIANS(f2)", + math.toRadians(42.toByte).toString) + + testSqlApi( + "RADIANS(f3)", + math.toRadians(43.toShort).toString) + + testSqlApi( + "RADIANS(f4)", + math.toRadians(44.toLong).toString) + + testSqlApi( + "RADIANS(f5)", + math.toRadians(4.5.toFloat).toString) + + testSqlApi( + "RADIANS(f6)", + math.toRadians(4.6).toString) + + testSqlApi( + "RADIANS(f15)", + math.toRadians(-1231.1231231321321321111).toString) + } + + @Test + def testSign(): Unit = { + testSqlApi( + "SIGN(f4)", + 1.toString) + + testSqlApi( + "SIGN(f6)", + 1.0.toString) + + testSqlApi( + "SIGN(f15)", + "-1.0000000000000000000") // calcite: SIGN(Decimal(p,s)) => Decimal(p,s) + } + + @Test + def testRound(): Unit = { + testSqlApi( + "ROUND(f29, f30)", + 0.5.toString) + + testSqlApi( + "ROUND(f31, f7)", + "-0.123") + + testSqlApi( + "ROUND(f4, f32)", + 40.toString) + } + + @Test + def testPi(): Unit = { + testSqlApi( + "pi()", + math.Pi.toString) + } + + @Test + def testRandAndRandInteger(): Unit = { + val random1 = new java.util.Random(1) + testSqlApi( + "RAND(1)", + random1.nextDouble().toString) + + val random2 = new java.util.Random(3) + testSqlApi( + "RAND(f7)", + random2.nextDouble().toString) + + val random3 = new java.util.Random(1) + testSqlApi( + "RAND_INTEGER(1, 10)", + random3.nextInt(10).toString) + + val random4 = new java.util.Random(3) + testSqlApi( + "RAND_INTEGER(f7, CAST(f4 AS INT))", + random4.nextInt(44).toString) + } + + @Test + def testE(): Unit = { + testSqlApi( + "E()", + math.E.toString) + + testSqlApi( + "e()", + math.E.toString) + } + + @Test + def testLog(): Unit = { + testSqlApi( + "LOG(f6)", + "1.5260563034950492" + ) + + testSqlApi( + "LOG(f6-f6 + 10, f6-f6+100)", + "2.0" + ) + + testSqlApi( + "LOG(f6+20)", + "3.202746442938317" + ) + + testSqlApi( + "LOG(10)", + "2.302585092994046" + ) + + testSqlApi( + "LOG(10, 100)", + "2.0" + ) + + testSqlApi( + "log(f32, f32)", + (math.log(-1)/math.log(-1)).toString) + + testSqlApi( + "log(f27, f32)", + (math.log(0)/math.log(0)).toString) + } + + @Test + def testLog2(): Unit = { + testSqlApi( + "log2(f2)", + (math.log(42.toByte)/math.log(2.toByte)).toString) + + testSqlApi( + "log2(f3)", + (math.log(43.toShort)/math.log(2.toShort)).toString) + + testSqlApi( + "log2(f4)", + (math.log(44.toLong)/math.log(2.toLong)).toString) + + testSqlApi( + "log2(f5)", + (math.log(4.5.toFloat)/math.log(2.toFloat)).toString) + + testSqlApi( + "log2(f6)", + (math.log(4.6)/math.log(2)).toString) + + testSqlApi( + "log2(f32)", + (math.log(-1)/math.log(2)).toString) + + testSqlApi( + "log2(f27)", + (math.log(0)/math.log(2)).toString) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/RowTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/RowTypeTest.scala new file mode 100644 index 0000000000000..45b748277f126 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/RowTypeTest.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.RowTypeTestBase +import org.junit.Test + +class RowTypeTest extends RowTypeTestBase { + + @Test + def testRowLiteral(): Unit = { + + // primitive literal + testSqlApi( + "ROW(1, 'foo', true)", + "(1,foo,true)") + + // special literal + testSqlApi( + "ROW(DATE '1985-04-11', TIME '14:15:16', TIMESTAMP '1985-04-11 14:15:16', " + + "CAST(0.1 AS DECIMAL(2, 1)), ARRAY[1, 2, 3], MAP['foo', 'bar'], row(1, true))", + "(1985-04-11,14:15:16,1985-04-11 14:15:16.000,0.1,[1, 2, 3],{foo=bar},(1,true))") + + testSqlApi( + "ROW(1 + 1, 2 * 3, NULLIF(1, 1))", + "(2,6,null)" + ) + + testSqlApi("(1, 'foo', true)", "(1,foo,true)") + } + + @Test + def testRowField(): Unit = { + testSqlApi( + "(f0, f1)", + "(null,1)" + ) + + testSqlApi( + "f2", + "(2,foo,true)" + ) + + testSqlApi( + "(f2, f5)", + "((2,foo,true),(foo,null))" + ) + + testSqlApi( + "f4", + "(1984-03-12,0.00000000,[1, 2, 3])" + ) + + testSqlApi( + "(f1, 'foo',true)", + "(1,foo,true)" + ) + } + + @Test + def testRowOperations(): Unit = { + testSqlApi( + "f5.f0", + "foo" + ) + + testSqlApi( + "f3.f1.f2", + "true" + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala new file mode 100644 index 0000000000000..cd4d875c4b0e7 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.ScalarOperatorsTestBase +import org.junit.Test + +class ScalarOperatorsTest extends ScalarOperatorsTestBase { + + @Test + def testIn(): Unit = { + testSqlApi( + "f2 IN (1, 2, 42)", + "true" + ) + + testSqlApi( + "CAST(f0 AS DECIMAL) IN (42.0, 2.00, 3.01, 1.000000)", // SQL would downcast otherwise + "true" + ) + + testSqlApi( + "f10 IN ('This is a test String.', 'String', 'Hello world', 'Comment#1')", + "true" + ) + + testSqlApi( + "f14 IN ('This is a test String.', 'String', 'Hello world')", + "null" + ) + + testSqlApi( + "f15 IN (DATE '1996-11-10')", + "true" + ) + + testSqlApi( + "f15 IN (DATE '1996-11-10', DATE '1996-11-11')", + "true" + ) + + testSqlApi( + "f7 IN (f16, f17)", + "true" + ) + } + + @Test + def testOtherExpressions(): Unit = { + + // nested field null type + testSqlApi("CASE WHEN f13.f1 IS NULL THEN 'a' ELSE 'b' END", "a") + testSqlApi("CASE WHEN f13.f1 IS NOT NULL THEN 'a' ELSE 'b' END", "b") + testSqlApi("f13 IS NULL", "false") + testSqlApi("f13 IS NOT NULL", "true") + testSqlApi("f13.f0 IS NULL", "false") + testSqlApi("f13.f0 IS NOT NULL", "true") + testSqlApi("f13.f1 IS NULL", "true") + testSqlApi("f13.f1 IS NOT NULL", "false") + + // boolean literals + testSqlApi( + "true", + "true") + + testSqlApi( + "fAlse", + "false") + + testSqlApi( + "tRuE", + "true") + + // null + testSqlApi("CAST(NULL AS INT)", "null") + testSqlApi( + "CAST(NULL AS VARCHAR) = ''", + "null") + + // case when + testSqlApi("CASE 11 WHEN 1 THEN 'a' ELSE 'b' END", "b") + testSqlApi("CASE 2 WHEN 1 THEN 'a' ELSE 'b' END", "b") + testSqlApi( + "CASE 1 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "1 or 2") + testSqlApi( + "CASE 2 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "1 or 2") + testSqlApi( + "CASE 3 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "3") + testSqlApi( + "CASE 4 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "none of the above") + testSqlApi("CASE WHEN 'a'='a' THEN 1 END", "1") + testSqlApi("CASE 2 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "bcd") + testSqlApi("CASE 1 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "a") + testSqlApi("CASE 1 WHEN 1 THEN cast('a' as varchar(1)) WHEN 2 THEN " + + "cast('bcd' as varchar(3)) END", "a") + testSqlApi("CASE f2 WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END", "11") + testSqlApi("CASE f7 WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END", "null") + testSqlApi("CASE 42 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "null") + testSqlApi("CASE 1 WHEN 1 THEN true WHEN 2 THEN false ELSE NULL END", "true") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala new file mode 100644 index 0000000000000..08213f45f5f19 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.expressions.utils.ExpressionTestBase +import org.apache.flink.types.Row +import org.junit.{Ignore, Test} + +/** + * Tests all SQL expressions that are currently supported according to the documentation. + * This tests should be kept in sync with the documentation to reduce confusion due to the + * large amount of SQL functions. + * + * The tests do not test every parameter combination of a function. + * They are rather a function existence test and simple functional test. + * + * The tests are split up and ordered like the sections in the documentation. + */ +@Ignore("TODO: [FLINK-11898] the scalar functions will be supported in the future") +class SqlExpressionTest extends ExpressionTestBase { + + @Test + def testComparisonFunctions(): Unit = { + testSqlApi("1 = 1", "true") + testSqlApi("1 <> 1", "false") + testSqlApi("5 > 2", "true") + testSqlApi("2 >= 2", "true") + testSqlApi("5 < 2", "false") + testSqlApi("2 <= 2", "true") + testSqlApi("1 IS NULL", "false") + testSqlApi("1 IS NOT NULL", "true") + testSqlApi("NULLIF(1,1) IS DISTINCT FROM NULLIF(1,1)", "false") + testSqlApi("NULLIF(1,1) IS NOT DISTINCT FROM NULLIF(1,1)", "true") + testSqlApi("NULLIF(1,1) IS NOT DISTINCT FROM NULLIF(1,1)", "true") + testSqlApi("12 BETWEEN 11 AND 13", "true") + testSqlApi("12 BETWEEN ASYMMETRIC 13 AND 11", "false") + testSqlApi("12 BETWEEN SYMMETRIC 13 AND 11", "true") + testSqlApi("12 NOT BETWEEN 11 AND 13", "false") + testSqlApi("12 NOT BETWEEN ASYMMETRIC 13 AND 11", "true") + testSqlApi("12 NOT BETWEEN SYMMETRIC 13 AND 11", "false") + testSqlApi("'TEST' LIKE '%EST'", "true") + //testSqlApi("'%EST' LIKE '.%EST' ESCAPE '.'", "true") // TODO + testSqlApi("'TEST' NOT LIKE '%EST'", "false") + //testSqlApi("'%EST' NOT LIKE '.%EST' ESCAPE '.'", "false") // TODO + testSqlApi("'TEST' SIMILAR TO '.EST'", "true") + //testSqlApi("'TEST' SIMILAR TO ':.EST' ESCAPE ':'", "true") // TODO + testSqlApi("'TEST' NOT SIMILAR TO '.EST'", "false") + //testSqlApi("'TEST' NOT SIMILAR TO ':.EST' ESCAPE ':'", "false") // TODO + testSqlApi("'TEST' IN ('west', 'TEST', 'rest')", "true") + testSqlApi("'TEST' IN ('west', 'rest')", "false") + testSqlApi("'TEST' NOT IN ('west', 'TEST', 'rest')", "false") + testSqlApi("'TEST' NOT IN ('west', 'rest')", "true") + + // sub-query functions are not listed here + } + + @Test + def testLogicalFunctions(): Unit = { + testSqlApi("TRUE OR FALSE", "true") + testSqlApi("TRUE AND FALSE", "false") + testSqlApi("NOT TRUE", "false") + testSqlApi("TRUE IS FALSE", "false") + testSqlApi("TRUE IS NOT FALSE", "true") + testSqlApi("TRUE IS TRUE", "true") + testSqlApi("TRUE IS NOT TRUE", "false") + testSqlApi("NULLIF(TRUE,TRUE) IS UNKNOWN", "true") + testSqlApi("NULLIF(TRUE,TRUE) IS NOT UNKNOWN", "false") + } + + @Test + def testArithmeticFunctions(): Unit = { + testSqlApi("+5", "5") + testSqlApi("-5", "-5") + testSqlApi("5+5", "10") + testSqlApi("5-5", "0") + testSqlApi("5*5", "25") + testSqlApi("5/5", "1.0") + testSqlApi("POWER(5, 5)", "3125.0") + testSqlApi("ABS(-5)", "5") + testSqlApi("MOD(-26, 5)", "-1") + testSqlApi("SQRT(4)", "2.0") + testSqlApi("LN(1)", "0.0") + testSqlApi("LOG10(1)", "0.0") + testSqlApi("EXP(0)", "1.0") + testSqlApi("CEIL(2.5)", "3") + testSqlApi("CEILING(2.5)", "3") + testSqlApi("FLOOR(2.5)", "2") + testSqlApi("SIN(2.5)", "0.5984721441039564") + testSqlApi("SINH(2.5)", "6.0502044810397875") + testSqlApi("COS(2.5)", "-0.8011436155469337") + testSqlApi("TAN(2.5)", "-0.7470222972386603") + testSqlApi("COT(2.5)", "-1.3386481283041514") + testSqlApi("ASIN(0.5)", "0.5235987755982989") + testSqlApi("ACOS(0.5)", "1.0471975511965979") + testSqlApi("ATAN(0.5)", "0.4636476090008061") + testSqlApi("ATAN2(0.5, 0.5)", "0.7853981633974483") + testSqlApi("COSH(2.5)", "6.132289479663686") + testSqlApi("TANH(2.5)", "0.9866142981514303") + testSqlApi("DEGREES(0.5)", "28.64788975654116") + testSqlApi("RADIANS(0.5)", "0.008726646259971648") + testSqlApi("SIGN(-1.1)", "-1.0") // calcite: SIGN(Decimal(p,s)) => Decimal(p,s) + testSqlApi("ROUND(-12.345, 2)", "-12.35") + testSqlApi("PI()", "3.141592653589793") + testSqlApi("E()", "2.718281828459045") + } + + @Test + def testDivideFunctions(): Unit = { + + //slash + + // Decimal(2,1) / Decimal(2,1) => Decimal(8,6) + testSqlApi("1.0/8.0", "0.125000") + testSqlApi("2.0/3.0", "0.666667") + + // Integer => Decimal(10, 0) + // Decimal(10,0) / Decimal(2,1) => Decimal(17,6) + testSqlApi("-2/3.0", "-0.666667") + + // Decimal(2,1) / Decimal(10,0) => Decimal(23,12) + testSqlApi("2.0/(-3)", "-0.666666666667") + testSqlApi("-7.9/2", "-3.950000000000") + + //div function + testSqlApi("div(7, 2)", "3") + testSqlApi("div(7.9, 2.009)", "3") + testSqlApi("div(7, -2.009)", "-3") + testSqlApi("div(-7.9, 2)", "-3") + } + + @Test + def testStringFunctions(): Unit = { + testSqlApi("'test' || 'string'", "teststring") + testSqlApi("CHAR_LENGTH('string')", "6") + testSqlApi("CHARACTER_LENGTH('string')", "6") + testSqlApi("UPPER('string')", "STRING") + testSqlApi("LOWER('STRING')", "string") + testSqlApi("POSITION('STR' IN 'STRING')", "1") + testSqlApi("TRIM(BOTH ' STRING ')", "STRING") + testSqlApi("TRIM(LEADING 'x' FROM 'xxxxSTRINGxxxx')", "STRINGxxxx") + testSqlApi("TRIM(TRAILING 'x' FROM 'xxxxSTRINGxxxx')", "xxxxSTRING") + testSqlApi( + "OVERLAY('This is a old string' PLACING 'new' FROM 11 FOR 3)", + "This is a new string") + testSqlApi("SUBSTRING('hello world', 2)", "ello world") + testSqlApi("SUBSTRING('hello world', 2, 3)", "ell") + testSqlApi("SUBSTRING('hello world', 2, 300)", "ello world") + testSqlApi("SUBSTR('hello world', 2, 3)", "ell") + testSqlApi("SUBSTR('hello world', 2)", "ello world") + testSqlApi("SUBSTR('hello world', 2, 300)", "ello world") + testSqlApi("SUBSTR('hello world', 0, 3)", "hel") + testSqlApi("INITCAP('hello world')", "Hello World") + testSqlApi("REGEXP_REPLACE('foobar', 'oo|ar', '')", "fb") + testSqlApi("REGEXP_EXTRACT('foothebar', 'foo(.*?)(bar)', 2)", "bar") + testSqlApi( + "REPEAT('This is a test String.', 2)", + "This is a test String.This is a test String.") + testSqlApi("REPLACE('hello world', 'world', 'flink')", "hello flink") + } + + @Test + def testConditionalFunctions(): Unit = { + testSqlApi("CASE 2 WHEN 1, 2 THEN 2 ELSE 3 END", "2") + testSqlApi("CASE WHEN 1 = 2 THEN 2 WHEN 1 = 1 THEN 3 ELSE 3 END", "3") + testSqlApi("NULLIF(1, 1)", "null") + testSqlApi("COALESCE(NULL, 5)", "5") + testSqlApi("COALESCE(keyvalue('', ';', ':', 'isB2C'), '5')", "5") + testSqlApi("COALESCE(json_value('xx', '$x'), '5')", "5") + } + + @Test + def testTypeConversionFunctions(): Unit = { + testSqlApi("CAST(2 AS DOUBLE)", "2.0") + } + + @Test + def testValueConstructorFunctions(): Unit = { + testSqlApi("ROW('hello world', 12)", "hello world,12") + testSqlApi("('hello world', 12)", "hello world,12") + testSqlApi("('foo', ('bar', 12))", "foo,bar,12") + testSqlApi("ARRAY[TRUE, FALSE][2]", "false") + testSqlApi("ARRAY[TRUE, TRUE]", "[true, true]") + testSqlApi("MAP['k1', 'v1', 'k2', 'v2']['k2']", "v2") + testSqlApi("MAP['k1', CAST(true AS VARCHAR(256)), 'k2', 'foo']['k1']", "true") + } + + @Test + def testDateTimeFunctions(): Unit = { + testSqlApi("DATE '1990-10-14'", "1990-10-14") + testSqlApi("TIME '12:12:12'", "12:12:12") + testSqlApi("TIMESTAMP '1990-10-14 12:12:12.123'", "1990-10-14 12:12:12.123") + testSqlApi("INTERVAL '10 00:00:00.004' DAY TO SECOND", "+10 00:00:00.004") + testSqlApi("INTERVAL '10 00:12' DAY TO MINUTE", "+10 00:12:00.000") + testSqlApi("INTERVAL '2-10' YEAR TO MONTH", "+2-10") + testSqlApi("EXTRACT(DAY FROM DATE '1990-12-01')", "1") + testSqlApi("EXTRACT(DAY FROM INTERVAL '19 12:10:10.123' DAY TO SECOND(3))", "19") + testSqlApi("FLOOR(TIME '12:44:31' TO MINUTE)", "12:44:00") + testSqlApi("CEIL(TIME '12:44:31' TO MINUTE)", "12:45:00") + testSqlApi("QUARTER(DATE '2016-04-12')", "2") + testSqlApi( + "(TIME '2:55:00', INTERVAL '1' HOUR) OVERLAPS (TIME '3:30:00', INTERVAL '2' HOUR)", + "true") + } + + @Test + def testArrayFunctions(): Unit = { + testSqlApi("CARDINALITY(ARRAY[TRUE, TRUE, FALSE])", "3") + testSqlApi("ELEMENT(ARRAY['HELLO WORLD'])", "HELLO WORLD") + } + + @Test + def testHashFunctions(): Unit = { + testSqlApi("MD5('')", "d41d8cd98f00b204e9800998ecf8427e") + testSqlApi("MD5('test')", "098f6bcd4621d373cade4e832627b4f6") + + testSqlApi("SHA1('')", "da39a3ee5e6b4b0d3255bfef95601890afd80709") + testSqlApi("SHA1('test')", "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3") + + testSqlApi("SHA224('')", "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f") + testSqlApi("SHA2('', 224)", "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f") + + testSqlApi("SHA224('test')", "90a3ed9e32b2aaf4c61c410eb925426119e1a9dc53d4286ade99a809") + testSqlApi("SHA2('test', 224)", "90a3ed9e32b2aaf4c61c410eb925426119e1a9dc53d4286ade99a809") + + testSqlApi("SHA256('')", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + testSqlApi("SHA2('', 256)", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + + testSqlApi("SHA256('test')", "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08") + testSqlApi("SHA2('test', 256)", + "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08") + + testSqlApi("SHA384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc" + + "7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b") + testSqlApi("SHA2('', 384)", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0" + + "cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b") + + testSqlApi("SHA384('test')", "768412320f7b0aa5812fce428dc4706b3cae50e02a64caa16a782249bfe8efc" + + "4b7ef1ccb126255d196047dfedf17a0a9") + testSqlApi("SHA2('test', 384)", "768412320f7b0aa5812fce428dc4706b3cae50e02a64caa16a782249bfe8" + + "efc4b7ef1ccb126255d196047dfedf17a0a9") + + testSqlApi("SHA512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d" + + "0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") + testSqlApi("SHA2('',512)", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce4" + + "7d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") + + testSqlApi("SHA512('test')", "ee26b0dd4af7e749aa1a8ee3c10ae9923f618980772e473f8819a5d4940e0db" + + "27ac185f8a0e1d5f84f88bc887fd67b143732c304cc5fa9ad8e6f57f50028a8ff") + testSqlApi("SHA2('test',512)", "ee26b0dd4af7e749aa1a8ee3c10ae9923f618980772e473f8819a5d4940e0" + + "db27ac185f8a0e1d5f84f88bc887fd67b143732c304cc5fa9ad8e6f57f50028a8ff") + + testSqlApi("MD5(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA1(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA224(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA256(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA384(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA512(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA2(CAST(NULL AS VARCHAR), 256)", "null") + } + + @Test + def testNullableCases(): Unit = { + testSqlApi( + "BITAND(cast(NUll as bigInt), cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "BITNOT(cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "BITOR(cast(NUll as bigInt), cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "BITXOR(cast(NUll as bigInt), cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "TO_BASE64(FROM_BASE64(cast(NUll as varchar)))", + nullable + ) + + testSqlApi( + "FROM_BASE64(cast(NUll as varchar))", + nullable + ) + } + + override def testData: Row = new Row(0) + + override def typeInfo: RowTypeInfo = new RowTypeInfo() +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ArrayTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ArrayTypeTestBase.scala new file mode 100644 index 0000000000000..eb41325490eef --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ArrayTypeTestBase.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, PrimitiveArrayTypeInfo, Types} +import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.table.util.DateTimeTestUtil._ +import org.apache.flink.types.Row + +abstract class ArrayTypeTestBase extends ExpressionTestBase { + + case class MyCaseClass(string: String, int: Int) + + override def testData: Row = { + val testData = new Row(12) + testData.setField(0, null) + testData.setField(1, 42) + testData.setField(2, Array(1, 2, 3)) + testData.setField(3, Array(UTCDate("1984-03-12"), UTCDate("1984-02-10"))) + testData.setField(4, null) + testData.setField(5, Array(Array(1, 2, 3), null)) + testData.setField(6, Array[Integer](1, null, null, 4)) + testData.setField(7, Array(1, 2, 3, 4)) + testData.setField(8, Array(4.0)) + testData.setField(9, Array[Integer](1)) + testData.setField(10, Array[Integer]()) + testData.setField(11, Array[Integer](1)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.INT, + Types.INT, + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + ObjectArrayTypeInfo.getInfoFor(Types.SQL_DATE), + ObjectArrayTypeInfo.getInfoFor(ObjectArrayTypeInfo.getInfoFor(Types.INT)), + ObjectArrayTypeInfo.getInfoFor(PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO), + ObjectArrayTypeInfo.getInfoFor(Types.INT), + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, + ObjectArrayTypeInfo.getInfoFor(Types.INT), + ObjectArrayTypeInfo.getInfoFor(Types.INT), + BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala new file mode 100644 index 0000000000000..cc90dc0ce1c77 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo, TypeExtractor} +import org.apache.flink.api.scala.createTypeInformation +import org.apache.flink.table.expressions.utils.CompositeTypeTestBase.{MyCaseClass, MyCaseClass2, MyCaseClass3, MyPojo} +import org.apache.flink.types.Row + +class CompositeTypeTestBase extends ExpressionTestBase { + + override def testData: Row = { + val testData = new Row(14) + testData.setField(0, MyCaseClass(42, "Bob", booleanField = true)) + testData.setField(1, MyCaseClass2(MyCaseClass(25, "Timo", booleanField = false))) + testData.setField(2, ("a", "b")) + testData.setField(3, new org.apache.flink.api.java.tuple.Tuple2[String, String]("a", "b")) + testData.setField(4, new MyPojo()) + testData.setField(5, 13) + testData.setField(6, MyCaseClass2(null)) + testData.setField(7, Tuple1(true)) + testData.setField(8, Array(Tuple2(true, 23), Tuple2(false, 12))) + testData.setField(9, Array(Tuple1(true), null)) + testData.setField(10, Array(MyCaseClass(42, "Bob", booleanField = true))) + testData.setField(11, Array(new MyPojo(), null)) + testData.setField(12, Array(MyCaseClass3(Array(MyCaseClass(42, "Alice", booleanField = true))))) + testData.setField(13, Array(MyCaseClass2(MyCaseClass(42, "Bob", booleanField = true)))) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + createTypeInformation[MyCaseClass], + createTypeInformation[MyCaseClass2], + createTypeInformation[(String, String)], + new TupleTypeInfo(Types.STRING, Types.STRING), + TypeExtractor.createTypeInfo(classOf[MyPojo]), + Types.INT, + TypeExtractor.createTypeInfo(classOf[MyCaseClass2]), + createTypeInformation[Tuple1[Boolean]], + createTypeInformation[Array[Tuple2[Boolean, Int]]], + createTypeInformation[Array[Tuple1[Boolean]]], + createTypeInformation[Array[MyCaseClass]], + createTypeInformation[Array[MyPojo]], + createTypeInformation[Array[MyCaseClass3]], + createTypeInformation[Array[MyCaseClass2]] + ) + } +} + +object CompositeTypeTestBase { + case class MyCaseClass(intField: Int, stringField: String, booleanField: Boolean) + + case class MyCaseClass2(objectField: MyCaseClass) + + case class MyCaseClass3(arrayField: Array[MyCaseClass]) + + class MyPojo { + private var myInt: Int = 0 + private var myString: String = "Hello" + + def getMyInt: Int = myInt + + def setMyInt(value: Int): Unit = { + myInt = value + } + + def getMyString: String = myString + + def setMyString(value: String): Unit = { + myString = myString + } + } +} + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala new file mode 100644 index 0000000000000..3b02567720fd6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.TaskInfo +import org.apache.flink.api.common.functions.util.RuntimeUDFContext +import org.apache.flink.api.common.functions.{MapFunction, RichFunction, RichMapFunction} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.table.`type`.{InternalTypes, RowType, TypeConverters} +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.api.java.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkPlannerImpl +import org.apache.flink.table.codegen.{CodeGeneratorContext, ExprCodeGenerator, FunctionCodeGenerator} +import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, DataFormatConverters} +import org.apache.flink.types.Row +import org.junit.Assert.{assertEquals, fail} +import org.junit.{After, Before} +import org.apache.calcite.plan.Convention +import org.apache.calcite.plan.hep.{HepPlanner, HepProgramBuilder} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.logical.{LogicalCalc, LogicalTableScan} +import org.apache.calcite.rel.rules._ +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR +import org.apache.calcite.tools.Programs +import org.apache.flink.table.plan.optimize.program.{FlinkHepProgram, FlinkOptimizeContext} + +import java.util.Collections + +import scala.collection.mutable + +abstract class ExpressionTestBase { + + val config = new TableConfig() + + // (originalExpr, optimizedExpr, expectedResult) + private val testExprs = mutable.ArrayBuffer[(String, RexNode, String)]() + private val env = StreamExecutionEnvironment.createLocalEnvironment(4) + private val tEnv = StreamTableEnvironment.create(env, config) + private val relBuilder = tEnv.getRelBuilder + private val planner = new FlinkPlannerImpl( + tEnv.getFrameworkConfig, + tEnv.getPlanner, + tEnv.getTypeFactory) + + + // setup test utils + private val tableName = "testTable" + protected val nullable = "null" + protected val notNullable = "not null" + + @Before + def prepare(): Unit = { + val ds = env.fromCollection(Collections.emptyList[Row](), typeInfo) + tEnv.registerDataStream(tableName, ds) + + // prepare RelBuilder + relBuilder.scan(tableName) + + // reset test exprs + testExprs.clear() + } + + @After + def evaluateExprs(): Unit = { + val ctx = CodeGeneratorContext(config) + val inputType = TypeConverters.createInternalTypeFromTypeInfo(typeInfo) + val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false).bindInput(inputType) + + // cast expressions to String + val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._2, VARCHAR)) + + // generate code + val resultType = new RowType(Seq.fill(testExprs.size)(InternalTypes.STRING): _*) + + val exprs = stringTestExprs.map(exprGenerator.generateExpression) + val genExpr = exprGenerator.generateResultExpression(exprs, resultType, classOf[BinaryRow]) + + val bodyCode = + s""" + |${genExpr.code} + |return ${genExpr.resultTerm}; + """.stripMargin + + val genFunc = FunctionCodeGenerator.generateFunction[MapFunction[BaseRow, BinaryRow]]( + ctx, + "TestFunction", + classOf[MapFunction[BaseRow, BinaryRow]], + bodyCode, + resultType, + inputType) + + val mapper = genFunc.newInstance(getClass.getClassLoader) + + val isRichFunction = mapper.isInstanceOf[RichFunction] + + // call setRuntimeContext method and open method for RichFunction + if (isRichFunction) { + val richMapper = mapper.asInstanceOf[RichMapFunction[_, _]] + val t = new RuntimeUDFContext( + new TaskInfo("ExpressionTest", 1, 0, 1, 1), + null, + env.getConfig, + Collections.emptyMap(), + Collections.emptyMap(), + null) + richMapper.setRuntimeContext(t) + richMapper.open(new Configuration()) + } + + val converter = DataFormatConverters + .getConverterForTypeInfo(typeInfo) + .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, Row]] + val testRow = converter.toInternal(testData) + val result = mapper.map(testRow) + + // call close method for RichFunction + if (isRichFunction) { + mapper.asInstanceOf[RichMapFunction[_, _]].close() + } + + // compare + testExprs + .zipWithIndex + .foreach { + case ((originalExpr, optimizedExpr, expected), index) => + + // adapt string result + val actual = if(!result.asInstanceOf[BinaryRow].isNullAt(index)) { + result.asInstanceOf[BinaryRow].getString(index).toString + } else { + null + } + + assertEquals( + s"Wrong result for: [$originalExpr] optimized to: [$optimizedExpr]", + expected, + if (actual == null) "null" else actual) + } + + } + + private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = { + // create RelNode from SQL expression + val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName") + val validated = planner.validate(parsed) + val converted = planner.rel(validated).rel + + val builder = new HepProgramBuilder() + builder.addRuleInstance(ProjectToCalcRule.INSTANCE) + val hep = new HepPlanner(builder.build()) + hep.setRoot(converted) + val optimized = hep.findBestExp() + + // throw exception if plan contains more than a calc + if (!optimized.getInput(0).isInstanceOf[LogicalTableScan]) { + fail("Expression is converted into more than a Calc operation. Use a different test method.") + } + + testExprs += ((sqlExpr, extractRexNode(optimized), expected)) + } + + private def extractRexNode(node: RelNode): RexNode = { + val calcProgram = node + .asInstanceOf[LogicalCalc] + .getProgram + calcProgram.expandLocalRef(calcProgram.getProjectList.get(0)) + } + + def testSqlApi( + sqlExpr: String, + expected: String) + : Unit = { + addSqlTestExpr(sqlExpr, expected) + } + + def testData: Row + + def typeInfo: RowTypeInfo + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/MapTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/MapTypeTestBase.scala new file mode 100644 index 0000000000000..57a3e749a36d5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/MapTypeTestBase.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import com.google.common.collect.ImmutableMap +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, Types} +import org.apache.flink.api.java.typeutils.{MapTypeInfo, RowTypeInfo} +import org.apache.flink.types.Row + +import java.util.{HashMap => JHashMap} + +abstract class MapTypeTestBase extends ExpressionTestBase { + + override def testData: Row = { + val map1 = new JHashMap[String, Int]() + map1.put("a", 12) + map1.put("b", 13) + val map2 = new JHashMap[Int, String]() + map2.put(12, "a") + map2.put(13, "b") + val map3 = new JHashMap[Long, Int]() + map3.put(10L, 1) + map3.put(20L, 2) + val map4 = new JHashMap[Int, Array[Int]]() + map4.put(1, Array(10, 100)) + map4.put(2, Array(20, 200)) + val testData = new Row(12) + testData.setField(0, null) + testData.setField(1, new JHashMap[String, Int]()) + testData.setField(2, map1) + testData.setField(3, map2) + testData.setField(4, "foo") + testData.setField(5, 12) + testData.setField(6, Array(1.2, 1.3)) + testData.setField(7, ImmutableMap.of(12, "a", 13, "b")) + testData.setField(8, map3) + testData.setField(9, map3) + testData.setField(10, map4) + testData.setField(11, map4) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + new MapTypeInfo(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), + Types.STRING, + Types.INT, + Types.PRIMITIVE_ARRAY(Types.DOUBLE), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, Types.PRIMITIVE_ARRAY(Types.INT)), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, Types.PRIMITIVE_ARRAY(Types.INT)) + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/RowTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/RowTypeTestBase.scala new file mode 100644 index 0000000000000..ecee5ed1ab8b6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/RowTypeTestBase.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.typeutils.DecimalTypeInfo +import org.apache.flink.table.util.DateTimeTestUtil.UTCDate +import org.apache.flink.types.Row + +abstract class RowTypeTestBase extends ExpressionTestBase { + + override def testData: Row = { + val row = new Row(3) + row.setField(0, 2) + row.setField(1, "foo") + row.setField(2, true) + val nestedRow = new Row(2) + nestedRow.setField(0, 3) + nestedRow.setField(1, row) + val specialTypeRow = new Row(3) + specialTypeRow.setField(0, UTCDate("1984-03-12")) + specialTypeRow.setField(1, Decimal.castFrom("0.00000000", 9, 8)) + specialTypeRow.setField(2, Array[java.lang.Integer](1, 2, 3)) + val testData = new Row(7) + testData.setField(0, null) + testData.setField(1, 1) + testData.setField(2, row) + testData.setField(3, nestedRow) + testData.setField(4, specialTypeRow) + testData.setField(5, Row.of("foo", null)) + testData.setField(6, Row.of(null, null)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.STRING, + Types.INT, + Types.ROW(Types.INT, Types.STRING, Types.BOOLEAN), + Types.ROW(Types.INT, Types.ROW(Types.INT, Types.STRING, Types.BOOLEAN)), + Types.ROW( + Types.SQL_DATE, + DecimalTypeInfo.of(9, 8), + ObjectArrayTypeInfo.getInfoFor(Types.INT)), + Types.ROW(Types.STRING, Types.BOOLEAN), + Types.ROW(Types.STRING, Types.STRING) + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarOperatorsTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarOperatorsTestBase.scala new file mode 100644 index 0000000000000..503dadcdad1c3 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarOperatorsTestBase.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.typeutils.DecimalTypeInfo +import org.apache.flink.table.util.DateTimeTestUtil._ +import org.apache.flink.types.Row + +abstract class ScalarOperatorsTestBase extends ExpressionTestBase { + + override def testData: Row = { + val testData = new Row(18) + testData.setField(0, 1: Byte) + testData.setField(1, 1: Short) + testData.setField(2, 1) + testData.setField(3, 1L) + testData.setField(4, 1.0f) + testData.setField(5, 1.0d) + testData.setField(6, true) + testData.setField(7, 0.0d) + testData.setField(8, 5) + testData.setField(9, 10) + testData.setField(10, "String") + testData.setField(11, false) + testData.setField(12, null) + testData.setField(13, Row.of("foo", null)) + testData.setField(14, null) + testData.setField(15, UTCDate("1996-11-10")) + testData.setField(16, Decimal.castFrom("0.00000000", 19, 8)) + testData.setField(17, Decimal.castFrom("10.0", 19, 1)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.BYTE, + Types.SHORT, + Types.INT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.BOOLEAN, + Types.DOUBLE, + Types.INT, + Types.INT, + Types.STRING, + Types.BOOLEAN, + Types.BOOLEAN, + Types.ROW(Types.STRING, Types.STRING), + Types.STRING, + Types.SQL_DATE, + DecimalTypeInfo.of(19, 8), + DecimalTypeInfo.of(19, 1) + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala new file mode 100644 index 0000000000000..4c8a3d1fff7eb --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, Types} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.typeutils.{DecimalTypeInfo, TimeIntervalTypeInfo} +import org.apache.flink.table.util.DateTimeTestUtil._ +import org.apache.flink.types.Row + +abstract class ScalarTypesTestBase extends ExpressionTestBase { + + override def testData: Row = { + val testData = new Row(46) + testData.setField(0, "This is a test String.") + testData.setField(1, true) + testData.setField(2, 42.toByte) + testData.setField(3, 43.toShort) + testData.setField(4, 44.toLong) + testData.setField(5, 4.5.toFloat) + testData.setField(6, 4.6) + testData.setField(7, 3) + testData.setField(8, " This is a test String. ") + testData.setField(9, -42.toByte) + testData.setField(10, -43.toShort) + testData.setField(11, -44.toLong) + testData.setField(12, -4.5.toFloat) + testData.setField(13, -4.6) + testData.setField(14, -3) + testData.setField(15, Decimal.castFrom("-1231.1231231321321321111", 38, 19)) + testData.setField(16, UTCDate("1996-11-10")) + testData.setField(17, UTCTime("06:55:44")) + testData.setField(18, UTCTimestamp("1996-11-10 06:55:44.333")) + testData.setField(19, 1467012213000L) // +16979 07:23:33.000 + testData.setField(20, 25) // +2-01 + testData.setField(21, null) + testData.setField(22, Decimal.castFrom("2", 38, 19)) + testData.setField(23, "%This is a test String.") + testData.setField(24, "*_This is a test String.") + testData.setField(25, 0.42.toByte) + testData.setField(26, 0.toShort) + testData.setField(27, 0.toLong) + testData.setField(28, 0.45.toFloat) + testData.setField(29, 0.46) + testData.setField(30, 1) + testData.setField(31, Decimal.castFrom("-0.1231231321321321111", 38, 0)) + testData.setField(32, -1) + testData.setField(33, null) + testData.setField(34, Decimal.castFrom("1514356320000", 38, 19)) + testData.setField(35, "a") + testData.setField(36, "b") + testData.setField(37, Array[Byte](1, 2, 3, 4)) + testData.setField(38, "AQIDBA==") + testData.setField(39, "1世3") + testData.setField(40, null) + testData.setField(41, null) + testData.setField(42, 256.toLong) + testData.setField(43, -1.toLong) + testData.setField(44, 256) + testData.setField(45, UTCTimestamp("1996-11-10 06:55:44.333").toString) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.STRING, + Types.BOOLEAN, + Types.BYTE, + Types.SHORT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.BYTE, + Types.SHORT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.INT, + DecimalTypeInfo.of(38, 19), + Types.SQL_DATE, + Types.SQL_TIME, + Types.SQL_TIMESTAMP, + TimeIntervalTypeInfo.INTERVAL_MILLIS, + TimeIntervalTypeInfo.INTERVAL_MONTHS, + Types.BOOLEAN, + DecimalTypeInfo.of(38, 19), + Types.STRING, + Types.STRING, + Types.BYTE, + Types.SHORT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.INT, + DecimalTypeInfo.of(38, 19), + Types.INT, + Types.STRING, + DecimalTypeInfo.of(19, 0), + Types.STRING, + Types.STRING, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + Types.STRING, + Types.STRING, + Types.STRING, + DecimalTypeInfo.of(38, 19), + Types.LONG, + Types.LONG, + Types.INT, + Types.STRING) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ArrayTypeValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ArrayTypeValidationTest.scala new file mode 100644 index 0000000000000..951952532c4fa --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ArrayTypeValidationTest.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.utils.ArrayTypeTestBase +import org.junit.Test + +class ArrayTypeValidationTest extends ArrayTypeTestBase { + + @Test(expected = classOf[ValidationException]) + def testImplicitTypeCastArraySql(): Unit = { + testSqlApi("ARRAY['string', 12]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testEmptyArraySql(): Unit = { + testSqlApi("ARRAY[]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testNullArraySql(): Unit = { + testSqlApi("ARRAY[NULL]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testDifferentTypesArraySql(): Unit = { + testSqlApi("ARRAY[1, TRUE]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testElementNonArraySql(): Unit = { + testSqlApi( + "ELEMENT(f0)", + "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testCardinalityOnNonArraySql(): Unit = { + testSqlApi("CARDINALITY(f0)", "FAIL") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/CompositeAccessValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/CompositeAccessValidationTest.scala new file mode 100644 index 0000000000000..cc0b171fa0e63 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/CompositeAccessValidationTest.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.utils.CompositeTypeTestBase +import org.junit.Test + +class CompositeAccessValidationTest extends CompositeTypeTestBase { + + @Test(expected = classOf[ValidationException]) + def testWrongSqlFieldFull(): Unit = { + testSqlApi("testTable.f5.test", "13") + } + + @Test(expected = classOf[ValidationException]) + def testWrongSqlField(): Unit = { + testSqlApi("f5.test", "13") + } +} + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/MapTypeValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/MapTypeValidationTest.scala new file mode 100644 index 0000000000000..7e960549b488e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/MapTypeValidationTest.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.utils.MapTypeTestBase +import org.junit.Test + +class MapTypeValidationTest extends MapTypeTestBase { + + @Test(expected = classOf[ValidationException]) + def testWrongKeyType(): Unit = { + testSqlApi("f2[12]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testUnsupportedComparisonType(): Unit = { + testSqlApi("f6 <> f2", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testEmptyMap(): Unit = { + testSqlApi("MAP[]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testUnsupportedMapImplicitTypeCastSql(): Unit = { + testSqlApi("MAP['k1', 'string', 'k2', 12]", "FAIL") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/RowTypeValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/RowTypeValidationTest.scala new file mode 100644 index 0000000000000..1a7123c09b5e6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/RowTypeValidationTest.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.{SqlParserException, ValidationException} +import org.apache.flink.table.expressions.utils.RowTypeTestBase +import org.junit.Test + +class RowTypeValidationTest extends RowTypeTestBase { + + @Test(expected = classOf[SqlParserException]) + def testEmptyRowType(): Unit = { + testSqlApi("Row()", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testNullRowType(): Unit = { + testSqlApi("Row(NULL)", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testSqlRowIllegalAccess(): Unit = { + testSqlApi("f5.f2", "FAIL") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala new file mode 100644 index 0000000000000..d2a0997a3565d --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.calcite.avatica.util.TimeUnit +import org.apache.flink.table.api.{SqlParserException, ValidationException} +import org.apache.flink.table.expressions.utils.ScalarTypesTestBase +import org.junit.{Ignore, Test} + +class ScalarFunctionsValidationTest extends ScalarTypesTestBase { + + // ---------------------------------------------------------------------------------------------- + // Math functions + // ---------------------------------------------------------------------------------------------- + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[IllegalArgumentException]) + def testInvalidLog1(): Unit = { + // invalid arithmetic argument + testSqlApi( + "LOG(1, 100)", + "FAIL" + ) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[IllegalArgumentException]) + def testInvalidLog2(): Unit ={ + // invalid arithmetic argument + testSqlApi( + "LOG(-1)", + "FAIL" + ) + } + + @Test(expected = classOf[ValidationException]) + def testInvalidBin1(): Unit = { + testSqlApi("BIN(f12)", "101010") // float type + } + + @Test(expected = classOf[ValidationException]) + def testInvalidBin2(): Unit = { + testSqlApi("BIN(f15)", "101010") // BigDecimal type + } + + @Test(expected = classOf[ValidationException]) + def testInvalidBin3(): Unit = { + testSqlApi("BIN(f16)", "101010") // Date type + } + + + // ---------------------------------------------------------------------------------------------- + // Temporal functions + // ---------------------------------------------------------------------------------------------- + + @Test(expected = classOf[SqlParserException]) + def testTimestampAddWithWrongTimestampInterval(): Unit ={ + testSqlApi("TIMESTAMPADD(XXX, 1, timestamp '2016-02-24'))", "2016-06-16") + } + + @Test(expected = classOf[SqlParserException]) + def testTimestampAddWithWrongTimestampFormat(): Unit ={ + testSqlApi("TIMESTAMPADD(YEAR, 1, timestamp '2016-02-24'))", "2016-06-16") + } + + @Test(expected = classOf[ValidationException]) + def testTimestampAddWithWrongQuantity(): Unit ={ + testSqlApi("TIMESTAMPADD(YEAR, 1.0, timestamp '2016-02-24 12:42:25')", "2016-06-16") + } + + // ---------------------------------------------------------------------------------------------- + // Sub-query functions + // ---------------------------------------------------------------------------------------------- + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testDOWWithTimeWhichIsUnsupported(): Unit = { + testSqlApi("EXTRACT(DOW FROM TIME '12:42:25')", "0") + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testDOYWithTimeWhichIsUnsupported(): Unit = { + testSqlApi("EXTRACT(DOY FROM TIME '12:42:25')", "0") + } + + private def testExtractFromTimeZeroResult(unit: TimeUnit): Unit = { + testSqlApi("EXTRACT(" + unit + " FROM TIME '00:00:00')", "0") + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testMillenniumWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.MILLENNIUM) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testCenturyWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.CENTURY) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testYearWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.YEAR) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testMonthWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.MONTH) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testDayWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.DAY) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/DateTimeTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/DateTimeTestUtil.scala new file mode 100644 index 0000000000000..05b3f0ecff3ba --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/DateTimeTestUtil.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.util + +import org.apache.calcite.avatica.util.DateTimeUtils + +import java.sql.{Date, Time, Timestamp} + +object DateTimeTestUtil { + + def UTCDate(s: String): Date = { + new Date(DateTimeUtils.dateStringToUnixDate(s) * DateTimeUtils.MILLIS_PER_DAY) + } + + def UTCTime(s: String): Time = { + new Time(DateTimeUtils.timeStringToUnixDate(s).longValue()) + } + + def UTCTimestamp(s: String): Timestamp = { + new Timestamp(DateTimeUtils.timestampStringToUnixDate(s)) + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java index e3f371b2cb2c6..af71d56500175 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java @@ -46,6 +46,7 @@ import org.apache.flink.table.typeutils.DecimalTypeInfo; import org.apache.flink.types.Row; +import java.io.Serializable; import java.lang.reflect.Array; import java.math.BigDecimal; import java.sql.Date; @@ -111,6 +112,7 @@ public class DataFormatConverters { * lost its specific Java format. Only TypeInformation retains all its * Java format information. */ + @SuppressWarnings("unchecked") public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeInfo) { DataFormatConverter converter = TYPE_INFO_TO_CONVERTER.get(typeInfo); if (converter != null) { @@ -158,7 +160,9 @@ public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeIn * @param Internal data format. * @param External data format. */ - public abstract static class DataFormatConverter { + public abstract static class DataFormatConverter implements Serializable { + + private static final long serialVersionUID = 1L; /** * Converts a external(Java) data format to its internal equivalent while automatically handling nulls. @@ -203,6 +207,8 @@ public final External toExternal(BaseRow row, int column) { */ public abstract static class IdentityConverter extends DataFormatConverter { + private static final long serialVersionUID = 6146619729108124872L; + @Override T toInternalImpl(T value) { return value; @@ -219,6 +225,8 @@ T toExternalImpl(T value) { */ public static class BooleanConverter extends IdentityConverter { + private static final long serialVersionUID = 3618373319753553272L; + public static final BooleanConverter INSTANCE = new BooleanConverter(); private BooleanConverter() {} @@ -234,6 +242,8 @@ Boolean toExternalImpl(BaseRow row, int column) { */ public static class ByteConverter extends IdentityConverter { + private static final long serialVersionUID = 1880134895918999433L; + public static final ByteConverter INSTANCE = new ByteConverter(); private ByteConverter() {} @@ -249,6 +259,8 @@ Byte toExternalImpl(BaseRow row, int column) { */ public static class ShortConverter extends IdentityConverter { + private static final long serialVersionUID = 8055034507232206636L; + public static final ShortConverter INSTANCE = new ShortConverter(); private ShortConverter() {} @@ -264,6 +276,8 @@ Short toExternalImpl(BaseRow row, int column) { */ public static class IntConverter extends IdentityConverter { + private static final long serialVersionUID = -7749307898273403416L; + public static final IntConverter INSTANCE = new IntConverter(); private IntConverter() {} @@ -279,6 +293,8 @@ Integer toExternalImpl(BaseRow row, int column) { */ public static class LongConverter extends IdentityConverter { + private static final long serialVersionUID = 7373868336730797650L; + public static final LongConverter INSTANCE = new LongConverter(); private LongConverter() {} @@ -294,6 +310,8 @@ Long toExternalImpl(BaseRow row, int column) { */ public static class FloatConverter extends IdentityConverter { + private static final long serialVersionUID = -1119035126939832966L; + public static final FloatConverter INSTANCE = new FloatConverter(); private FloatConverter() {} @@ -309,6 +327,8 @@ Float toExternalImpl(BaseRow row, int column) { */ public static class DoubleConverter extends IdentityConverter { + private static final long serialVersionUID = 2801171640313215040L; + public static final DoubleConverter INSTANCE = new DoubleConverter(); private DoubleConverter() {} @@ -324,6 +344,8 @@ Double toExternalImpl(BaseRow row, int column) { */ public static class CharConverter extends IdentityConverter { + private static final long serialVersionUID = -7631466361315237011L; + public static final CharConverter INSTANCE = new CharConverter(); private CharConverter() {} @@ -339,6 +361,8 @@ Character toExternalImpl(BaseRow row, int column) { */ public static class BinaryStringConverter extends IdentityConverter { + private static final long serialVersionUID = 5565684451615599206L; + public static final BinaryStringConverter INSTANCE = new BinaryStringConverter(); private BinaryStringConverter() {} @@ -354,6 +378,8 @@ BinaryString toExternalImpl(BaseRow row, int column) { */ public static class BinaryArrayConverter extends IdentityConverter { + private static final long serialVersionUID = -7790350668043604641L; + public static final BinaryArrayConverter INSTANCE = new BinaryArrayConverter(); private BinaryArrayConverter() {} @@ -369,6 +395,8 @@ BinaryArray toExternalImpl(BaseRow row, int column) { */ public static class BinaryMapConverter extends IdentityConverter { + private static final long serialVersionUID = -9114231688474126815L; + public static final BinaryMapConverter INSTANCE = new BinaryMapConverter(); private BinaryMapConverter() {} @@ -384,6 +412,8 @@ BinaryMap toExternalImpl(BaseRow row, int column) { */ public static class DecimalConverter extends IdentityConverter { + private static final long serialVersionUID = 3825744951173809617L; + private final int precision; private final int scale; @@ -403,6 +433,8 @@ Decimal toExternalImpl(BaseRow row, int column) { */ public static class BinaryGenericConverter extends IdentityConverter { + private static final long serialVersionUID = 1436229503920584273L; + public static final BinaryGenericConverter INSTANCE = new BinaryGenericConverter(); private BinaryGenericConverter() {} @@ -418,6 +450,8 @@ BinaryGeneric toExternalImpl(BaseRow row, int column) { */ public static class StringConverter extends DataFormatConverter { + private static final long serialVersionUID = 4713165079099282774L; + public static final StringConverter INSTANCE = new StringConverter(); private StringConverter() {} @@ -443,6 +477,8 @@ String toExternalImpl(BaseRow row, int column) { */ public static class BigDecimalConverter extends DataFormatConverter { + private static final long serialVersionUID = -6586239704060565834L; + private final int precision; private final int scale; @@ -472,6 +508,8 @@ BigDecimal toExternalImpl(BaseRow row, int column) { */ public static class GenericConverter extends DataFormatConverter, T> { + private static final long serialVersionUID = -3611718364918053384L; + private final TypeSerializer serializer; public GenericConverter(TypeSerializer serializer) { @@ -499,6 +537,8 @@ T toExternalImpl(BaseRow row, int column) { */ public static class DateConverter extends DataFormatConverter { + private static final long serialVersionUID = 1343457113582411650L; + public static final DateConverter INSTANCE = new DateConverter(); private DateConverter() {} @@ -524,6 +564,8 @@ Date toExternalImpl(BaseRow row, int column) { */ public static class TimeConverter extends DataFormatConverter { + private static final long serialVersionUID = -8061475784916442483L; + public static final TimeConverter INSTANCE = new TimeConverter(); private TimeConverter() {} @@ -549,6 +591,8 @@ Time toExternalImpl(BaseRow row, int column) { */ public static class TimestampConverter extends DataFormatConverter { + private static final long serialVersionUID = -779956524906131757L; + public static final TimestampConverter INSTANCE = new TimestampConverter(); private TimestampConverter() {} @@ -574,6 +618,8 @@ Timestamp toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveIntArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = 1780941126232395638L; + public static final PrimitiveIntArrayConverter INSTANCE = new PrimitiveIntArrayConverter(); private PrimitiveIntArrayConverter() {} @@ -599,6 +645,8 @@ int[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveBooleanArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -4037693692440282141L; + public static final PrimitiveBooleanArrayConverter INSTANCE = new PrimitiveBooleanArrayConverter(); private PrimitiveBooleanArrayConverter() {} @@ -624,6 +672,8 @@ boolean[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveByteArrayConverter extends IdentityConverter { + private static final long serialVersionUID = -2007960927801689921L; + public static final PrimitiveByteArrayConverter INSTANCE = new PrimitiveByteArrayConverter(); private PrimitiveByteArrayConverter() {} @@ -639,6 +689,8 @@ byte[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveShortArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -1343184089311186834L; + public static final PrimitiveShortArrayConverter INSTANCE = new PrimitiveShortArrayConverter(); private PrimitiveShortArrayConverter() {} @@ -664,6 +716,8 @@ short[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveLongArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = 4061982985342526078L; + public static final PrimitiveLongArrayConverter INSTANCE = new PrimitiveLongArrayConverter(); private PrimitiveLongArrayConverter() {} @@ -689,6 +743,8 @@ long[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveFloatArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -3237695040861141459L; + public static final PrimitiveFloatArrayConverter INSTANCE = new PrimitiveFloatArrayConverter(); private PrimitiveFloatArrayConverter() {} @@ -714,6 +770,8 @@ float[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveDoubleArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = 6333670535356315691L; + public static final PrimitiveDoubleArrayConverter INSTANCE = new PrimitiveDoubleArrayConverter(); private PrimitiveDoubleArrayConverter() {} @@ -739,6 +797,8 @@ BinaryArray toInternalImpl(double[] value) { */ public static class PrimitiveCharArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -5438377988505771316L; + public static final PrimitiveCharArrayConverter INSTANCE = new PrimitiveCharArrayConverter(); private PrimitiveCharArrayConverter() {} @@ -764,6 +824,8 @@ char[] toExternalImpl(BaseRow row, int column) { */ public static class ObjectArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -7434682160639380078L; + private final Class arrayClass; private final InternalType elementType; private final DataFormatConverter elementConverter; @@ -825,6 +887,8 @@ private static T[] binaryArrayToJavaArray(BinaryArray value, InternalType e */ public static class MapConverter extends DataFormatConverter { + private static final long serialVersionUID = -916429669828309919L; + private final InternalType keyType; private final InternalType valueType; @@ -898,6 +962,8 @@ Map toExternalImpl(BaseRow row, int column) { */ public abstract static class AbstractBaseRowConverter extends DataFormatConverter { + private static final long serialVersionUID = 4365740929854771618L; + protected final DataFormatConverter[] converters; public AbstractBaseRowConverter(CompositeType t) { @@ -918,6 +984,8 @@ E toExternalImpl(BaseRow row, int column) { */ public static class BaseRowConverter extends IdentityConverter { + private static final long serialVersionUID = -4470307402371540680L; + public static final BaseRowConverter INSTANCE = new BaseRowConverter(); private BaseRowConverter() {} @@ -933,6 +1001,8 @@ BaseRow toExternalImpl(BaseRow row, int column) { */ public static class PojoConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = 6821541780176167135L; + private final PojoTypeInfo t; private final PojoField[] fields; @@ -979,6 +1049,8 @@ T toExternalImpl(BaseRow value) { */ public static class RowConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = -56553502075225785L; + private final RowTypeInfo t; public RowConverter(RowTypeInfo t) { @@ -1010,6 +1082,8 @@ Row toExternalImpl(BaseRow value) { */ public static class TupleConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = 2794892691010934194L; + private final TupleTypeInfo t; public TupleConverter(TupleTypeInfo t) { @@ -1046,6 +1120,8 @@ Tuple toExternalImpl(BaseRow value) { */ public static class CaseClassConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = -966598627968372952L; + private final TupleTypeInfoBase t; private final TupleSerializerBase serializer; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java index aeaa41a0e1306..f6aa4bbff4bf8 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java @@ -339,6 +339,30 @@ public static long castToIntegral(Decimal dec) { return bd.longValue(); } + public static long castToLong(Decimal dec) { + return castToIntegral(dec); + } + + public static int castToInt(Decimal dec) { + return (int) castToIntegral(dec); + } + + public static short castToShort(Decimal dec) { + return (short) castToIntegral(dec); + } + + public static byte castToByte(Decimal dec) { + return (byte) castToIntegral(dec); + } + + public static float castToFloat(Decimal dec) { + return (float) dec.doubleValue(); + } + + public static double castToDouble(Decimal dec) { + return dec.doubleValue(); + } + public static Decimal castToDecimal(Decimal dec, int precision, int scale) { return fromBigDecimal(dec.toBigDecimal(), precision, scale); } @@ -383,6 +407,10 @@ public static Decimal sign(Decimal b0) { } } + public static int compare(Decimal b1, Decimal b2){ + return b1.compareTo(b2); + } + public static int compare(Decimal b1, long n2) { if (!b1.isCompact()) { return b1.decimalVal.compareTo(BigDecimal.valueOf(n2)); @@ -400,6 +428,18 @@ public static int compare(Decimal b1, long n2) { } } + public static int compare(Decimal b1, double n2) { + return Double.compare(b1.doubleValue(), n2); + } + + public static int compare(long n1, Decimal b2) { + return -compare(b2, n1); + } + + public static int compare(double n1, Decimal b2) { + return -compare(b2, n1); + } + /** * SQL ROUND operator applied to BigDecimal values. */ diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java index ba7478f6d6d2f..b50e0f1a0c5ac 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java @@ -57,6 +57,7 @@ public LazyBinaryFormat(T javaObject) { } public T getJavaObject() { + // TODO: ensure deserialize ? return javaObject; } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java index c86f821b16823..d796706015cd2 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java @@ -17,7 +17,15 @@ package org.apache.flink.table.runtime.functions; +import org.apache.flink.api.java.tuple.Tuple2; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.sql.Timestamp; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; import java.util.TimeZone; /** @@ -27,7 +35,16 @@ */ public class DateTimeUtils { - public static final TimeZone LOCAL_TZ = TimeZone.getDefault(); + private static final Logger LOG = LoggerFactory.getLogger(DateTimeUtils.class); + + private static final TimeZone LOCAL_TZ = TimeZone.getDefault(); + + private static final String[] DEFAULT_DATETIME_FORMATS = new String[]{ + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.S", + "yyyy-MM-dd HH:mm:ss.SS", + "yyyy-MM-dd HH:mm:ss.SSS" + }; /** * The number of milliseconds in a day. @@ -35,7 +52,21 @@ public class DateTimeUtils { *

This is the modulo 'mask' used when converting * TIMESTAMP values to DATE and TIME values. */ - public static final long MILLIS_PER_DAY = 86400000L; // = 24 * 60 * 60 * 1000 + private static final long MILLIS_PER_DAY = 86400000L; // = 24 * 60 * 60 * 1000 + + /** + * A ThreadLocal cache map for SimpleDateFormat, because SimpleDateFormat is not thread-safe. + * (format, timezone) => formatter + */ + private static final ThreadLocalCache, SimpleDateFormat> FORMATTER_CACHE = + new ThreadLocalCache, SimpleDateFormat>(64) { + @Override + public SimpleDateFormat getNewInstance(Tuple2 key) { + SimpleDateFormat sdf = new SimpleDateFormat(key.f0); + sdf.setTimeZone(key.f1); + return sdf; + } + }; /** Converts the Java type used for UDF parameters of SQL TIME type * ({@link java.sql.Time}) to internal representation (int). @@ -93,4 +124,99 @@ public static java.sql.Timestamp internalToTimestamp(long v) { return new java.sql.Timestamp(v - LOCAL_TZ.getOffset(v)); } + /** + * Parse date time string to timestamp based on the given time zone and + * "yyyy-MM-dd HH:mm:ss" format. Returns null if parsing failed. + * + * @param dateText the date time string + * @param tz the time zone + */ + public static Long strToTimestamp(String dateText, TimeZone tz) { + return strToTimestamp(dateText, DEFAULT_DATETIME_FORMATS[0], tz); + } + + /** + * Parse date time string to timestamp based on the given time zone and format. + * Returns null if parsing failed. + * + * @param dateText the date time string + * @param format date time string format + * @param tz the time zone + */ + public static Long strToTimestamp(String dateText, String format, TimeZone tz) { + SimpleDateFormat formatter = FORMATTER_CACHE.get(Tuple2.of(format, tz)); + try { + return formatter.parse(dateText).getTime(); + } catch (ParseException e) { + return null; + } + } + + /** + * Format a timestamp as specific. + * @param ts the timestamp to format. + * @param format the string formatter. + * @param tz the time zone + */ + public static String dateFormat(long ts, String format, TimeZone tz) { + SimpleDateFormat formatter = FORMATTER_CACHE.get(Tuple2.of(format, tz)); + Date dateTime = new Date(ts); + return formatter.format(dateTime); + } + + /** + * Convert a timestamp to string. + * @param ts the timestamp to convert. + * @param precision the milli second precision to preserve + * @param tz the time zone + */ + public static String timestampToString(long ts, int precision, TimeZone tz) { + int p = (precision <= 3 && precision >= 0) ? precision : 3; + String format = DEFAULT_DATETIME_FORMATS[p]; + return dateFormat(ts, format, tz); + } + + /** Helper for CAST({time} AS VARCHAR(n)). */ + public static String timeToString(int time) { + final StringBuilder buf = new StringBuilder(8); + timeToString(buf, time, 0); // set milli second precision to 0 + return buf.toString(); + } + + private static void timeToString(StringBuilder buf, int time, int precision) { + while (time < 0) { + time += MILLIS_PER_DAY; + } + int h = time / 3600000; + int time2 = time % 3600000; + int m = time2 / 60000; + int time3 = time2 % 60000; + int s = time3 / 1000; + int ms = time3 % 1000; + int2(buf, h); + buf.append(':'); + int2(buf, m); + buf.append(':'); + int2(buf, s); + if (precision > 0) { + buf.append('.'); + while (precision > 0) { + buf.append((char) ('0' + (ms / 100))); + ms = ms % 100; + ms = ms * 10; + + // keep consistent with Timestamp.toString() + if (ms == 0) { + break; + } + + --precision; + } + } + } + + private static void int2(StringBuilder buf, int i) { + buf.append((char) ('0' + (i / 10) % 10)); + buf.append((char) ('0' + i % 10)); + } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ThreadLocalCache.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ThreadLocalCache.java new file mode 100644 index 0000000000000..7425c0270b2a4 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ThreadLocalCache.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Provides a ThreadLocal cache with a maximum cache size per thread. + * Values must not be null. + */ +public abstract class ThreadLocalCache { + + private final ThreadLocal> cache = new ThreadLocal<>(); + private final int maxSizePerThread; + + protected ThreadLocalCache(int maxSizePerThread) { + this.maxSizePerThread = maxSizePerThread; + } + + public V get(K key) { + BoundedMap map = cache.get(); + if (map == null) { + map = new BoundedMap<>(maxSizePerThread); + cache.set(map); + } + V value = map.get(key); + if (value == null) { + value = getNewInstance(key); + map.put(key, value); + } + return value; + } + + public abstract V getNewInstance(K key); + + private static class BoundedMap extends LinkedHashMap { + + private static final long serialVersionUID = -211630219014422361L; + + private final int maxSize; + + private BoundedMap(int maxSize) { + this.maxSize = maxSize; + } + + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return this.size() > maxSize; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java new file mode 100644 index 0000000000000..bc84943ff1cb4 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Byte hash set. + */ +public class ByteHashSet extends OptimizableHashSet { + + private byte[] key; + + private byte min = Byte.MAX_VALUE; + private byte max = Byte.MIN_VALUE; + + public ByteHashSet(final int expected, final float f) { + super(expected, f); + this.key = new byte[this.n + 1]; + } + + public ByteHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public ByteHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final byte k) { + if (k == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + byte[] key = this.key; + int pos; + byte curr; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) != 0) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final byte k) { + if (isDense) { + return k >= min && k <= max && used[k - min]; + } else { + if (k == 0) { + return this.containsZero; + } else { + byte[] key = this.key; + byte curr; + int pos; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) == 0) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + byte[] key = this.key; + int mask = newN - 1; + byte[] newKey = new byte[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while(key[i] == 0); + + if (newKey[pos = MurmurHashUtil.fmix(key[i]) & mask] != 0) { + while (newKey[pos = pos + 1 & mask] != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + int range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[max - min + 1]; + for (byte v : key) { + if (v != 0) { + used[v - min] = true; + } + } + if (containsZero) { + used[-min] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/DoubleHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/DoubleHashSet.java new file mode 100644 index 0000000000000..0f1c056dd0649 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/DoubleHashSet.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Double hash set. + */ +public class DoubleHashSet extends OptimizableHashSet { + + private double[] key; + + public DoubleHashSet(final int expected, final float f) { + super(expected, f); + this.key = new double[this.n + 1]; + } + + public DoubleHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public DoubleHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + /** + * See {@link Double#equals(Object)}. + */ + public boolean add(final double k) { + long longKey = Double.doubleToLongBits(k); + if (longKey == 0L) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + double[] key = this.key; + int pos; + long curr; + if ((curr = Double.doubleToLongBits(key[pos = (int) MurmurHashUtil.fmix(longKey) & this.mask])) != 0L) { + if (curr == longKey) { + return false; + } + + while ((curr = Double.doubleToLongBits(key[pos = pos + 1 & this.mask])) != 0L) { + if (curr == longKey) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + return true; + } + + /** + * See {@link Double#equals(Object)}. + */ + public boolean contains(final double k) { + long longKey = Double.doubleToLongBits(k); + if (longKey == 0L) { + return this.containsZero; + } else { + double[] key = this.key; + long curr; + int pos; + if ((curr = Double.doubleToLongBits(key[pos = (int) MurmurHashUtil.fmix(longKey) & this.mask])) == 0L) { + return false; + } else if (longKey == curr) { + return true; + } else { + while ((curr = Double.doubleToLongBits(key[pos = pos + 1 & this.mask])) != 0L) { + if (longKey == curr) { + return true; + } + } + + return false; + } + } + } + + private void rehash(final int newN) { + double[] key = this.key; + int mask = newN - 1; + double[] newKey = new double[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (Double.doubleToLongBits(key[i]) == 0L); + + if (Double.doubleToLongBits(newKey[pos = + (int) MurmurHashUtil.fmix(Double.doubleToLongBits(key[i])) & mask]) != 0L) { + while (Double.doubleToLongBits(newKey[pos = pos + 1 & mask]) != 0L) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/FloatHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/FloatHashSet.java new file mode 100644 index 0000000000000..8d7e26e4c6aed --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/FloatHashSet.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Float hash set. + */ +public class FloatHashSet extends OptimizableHashSet { + + private float[] key; + + public FloatHashSet(final int expected, final float f) { + super(expected, f); + this.key = new float[this.n + 1]; + } + + public FloatHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public FloatHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + /** + * See {@link Float#equals(Object)}. + */ + public boolean add(final float k) { + int intKey = Float.floatToIntBits(k); + if (intKey == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + float[] key = this.key; + int pos; + int curr; + if ((curr = Float.floatToIntBits(key[pos = MurmurHashUtil.fmix(intKey) & this.mask])) != 0) { + if (curr == intKey) { + return false; + } + + while ((curr = Float.floatToIntBits(key[pos = pos + 1 & this.mask])) != 0) { + if (curr == intKey) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + return true; + } + + /** + * See {@link Float#equals(Object)}. + */ + public boolean contains(final float k) { + int intKey = Float.floatToIntBits(k); + if (intKey == 0) { + return this.containsZero; + } else { + float[] key = this.key; + int curr; + int pos; + if ((curr = Float.floatToIntBits(key[pos = MurmurHashUtil.fmix(intKey) & this.mask])) == 0) { + return false; + } else if (intKey == curr) { + return true; + } else { + while ((curr = Float.floatToIntBits(key[pos = pos + 1 & this.mask])) != 0) { + if (intKey == curr) { + return true; + } + } + + return false; + } + } + } + + private void rehash(final int newN) { + float[] key = this.key; + int mask = newN - 1; + float[] newKey = new float[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (Float.floatToIntBits(key[i]) == 0); + + if (Float.floatToIntBits(newKey[pos = + MurmurHashUtil.fmix(Float.floatToIntBits(key[i])) & mask]) != 0) { + while (Float.floatToIntBits(newKey[pos = pos + 1 & mask]) != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/IntHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/IntHashSet.java new file mode 100644 index 0000000000000..1159917cdba5b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/IntHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Int hash set. + */ +public class IntHashSet extends OptimizableHashSet { + + private int[] key; + + private int min = Integer.MAX_VALUE; + private int max = Integer.MIN_VALUE; + + public IntHashSet(final int expected, final float f) { + super(expected, f); + this.key = new int[this.n + 1]; + } + + public IntHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public IntHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final int k) { + if (k == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + int[] key = this.key; + int pos; + int curr; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) != 0) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final int k) { + if (isDense) { + return k >= min && k <= max && used[k - min]; + } else { + if (k == 0) { + return this.containsZero; + } else { + int[] key = this.key; + int curr; + int pos; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) == 0) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + int[] key = this.key; + int mask = newN - 1; + int[] newKey = new int[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (key[i] == 0); + + if (newKey[pos = MurmurHashUtil.fmix(key[i]) & mask] != 0) { + while (newKey[pos = pos + 1 & mask] != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + int range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[max - min + 1]; + for (int v : key) { + if (v != 0) { + used[v - min] = true; + } + } + if (containsZero) { + used[-min] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/LongHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/LongHashSet.java new file mode 100644 index 0000000000000..37a8cd0244963 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/LongHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Long hash set. + */ +public class LongHashSet extends OptimizableHashSet { + + private long[] key; + + private long min = Long.MAX_VALUE; + private long max = Long.MIN_VALUE; + + public LongHashSet(final int expected, final float f) { + super(expected, f); + this.key = new long[this.n + 1]; + } + + public LongHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public LongHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final long k) { + if (k == 0L) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + long[] key = this.key; + int pos; + long curr; + if ((curr = key[pos = (int) MurmurHashUtil.fmix(k) & this.mask]) != 0L) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0L) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final long k) { + if (isDense) { + return k >= min && k <= max && used[(int) (k - min)]; + } else { + if (k == 0L) { + return this.containsZero; + } else { + long[] key = this.key; + long curr; + int pos; + if ((curr = key[pos = (int) MurmurHashUtil.fmix(k) & this.mask]) == 0L) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0L) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + long[] key = this.key; + int mask = newN - 1; + long[] newKey = new long[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (key[i] == 0L); + + if (newKey[pos = (int) MurmurHashUtil.fmix(key[i]) & mask] != 0L) { + while (newKey[pos = pos + 1 & mask] != 0L) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + long range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[(int) (max - min + 1)]; + for (long v : key) { + if (v != 0) { + used[(int) (v - min)] = true; + } + } + if (containsZero) { + used[(int) (-min)] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ObjectHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ObjectHashSet.java new file mode 100644 index 0000000000000..2a85b22cf0cf3 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ObjectHashSet.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import java.util.HashSet; + +/** + * Wrap {@link HashSet} with hashSet interface. + */ +public class ObjectHashSet extends OptimizableHashSet { + + private HashSet set; + + public ObjectHashSet(final int expected, final float f) { + super(expected, f); + this.set = new HashSet<>(expected, f); + } + + public ObjectHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public ObjectHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(T t) { + return set.add(t); + } + + public boolean contains(final T t) { + return set.contains(t); + } + + @Override + public void optimize() { + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/OptimizableHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/OptimizableHashSet.java new file mode 100644 index 0000000000000..66b55232e9527 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/OptimizableHashSet.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * A type-specific hash set with with a fast, small-footprint implementation. + * Refer to the implementation of fastutil. + * + *

Instances of this class use a hash table to represent a set. The table is + * enlarged as needed by doubling its size when new entries are created. + * + *

The difference with fastutil is that if the range of the maximum and minimum values is + * small or the data is dense, a Dense array will be used to greatly improve the access speed. + */ +public abstract class OptimizableHashSet { + + /** The initial default size of a hash table. */ + public static final int DEFAULT_INITIAL_SIZE = 16; + + /** The default load factor of a hash table. */ + public static final float DEFAULT_LOAD_FACTOR = 0.75f; + + /** + * Decide whether to convert to dense mode if it does not require more memory or + * could fit within L1 cache. + */ + public static final int DENSE_THRESHOLD = 8192; + + /** The acceptable load factor. */ + protected final float f; + + /** The mask for wrapping a position counter. */ + protected int mask; + + /** The current table size. */ + protected int n; + + /** Threshold after which we rehash. It must be the table size times {@link #f}. */ + protected int maxFill; + + /** Is this set has a null key. */ + protected boolean containsNull; + + /** Is this set has a zero key. */ + protected boolean containsZero; + + /** Number of entries in the set. */ + protected int size; + + /** Is now dense mode. */ + protected boolean isDense = false; + + /** Used array for dense mode. */ + protected boolean[] used; + + public OptimizableHashSet(final int expected, final float f) { + checkArgument(f > 0 && f <= 1); + checkArgument(expected >= 0); + this.f = f; + this.n = OptimizableHashSet.arraySize(expected, f); + this.mask = this.n - 1; + this.maxFill = OptimizableHashSet.maxFill(this.n, f); + } + + /** + * Add a null key. + */ + public void addNull() { + this.containsNull = true; + } + + /** + * Is there a null key. + */ + public boolean containsNull() { + return containsNull; + } + + protected int realSize() { + return this.containsZero ? this.size - 1 : this.size; + } + + /** + * Decide whether to convert to dense mode. + */ + public abstract void optimize(); + + /** + * Return the least power of two greater than or equal to the specified value. + * + *

Note that this function will return 1 when the argument is 0. + * + * @param x a long integer smaller than or equal to 262. + * @return the least power of two greater than or equal to the specified value. + */ + public static long nextPowerOfTwo(long x) { + if (x == 0L) { + return 1L; + } else { + --x; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + return (x | x >> 32) + 1L; + } + } + + /** + * Returns the maximum number of entries that can be filled before rehashing. + * + * @param n the size of the backing array. + * @param f the load factor. + * @return the maximum number of entries before rehashing. + */ + public static int maxFill(int n, float f) { + return Math.min((int) Math.ceil((double) ((float) n * f)), n - 1); + } + + /** + * Returns the least power of two smaller than or equal to 230 and larger than + * or equal to Math.ceil( expected / f ). + * + * @param expected the expected number of elements in a hash table. + * @param f the load factor. + * @return the minimum possible size for a backing array. + * @throws IllegalArgumentException if the necessary size is larger than 230. + */ + public static int arraySize(int expected, float f) { + long s = Math.max(2L, nextPowerOfTwo((long) Math.ceil((double) ((float) expected / f)))); + if (s > (Integer.MAX_VALUE / 2 + 1)) { + throw new IllegalArgumentException( + "Too large (" + expected + " expected elements with load factor " + f + ")"); + } else { + return (int) s; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ShortHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ShortHashSet.java new file mode 100644 index 0000000000000..54e42f9c22340 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ShortHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Short hash set. + */ +public class ShortHashSet extends OptimizableHashSet { + + private short[] key; + + private short min = Short.MAX_VALUE; + private short max = Short.MIN_VALUE; + + public ShortHashSet(final int expected, final float f) { + super(expected, f); + this.key = new short[this.n + 1]; + } + + public ShortHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public ShortHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final short k) { + if (k == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + short[] key = this.key; + int pos; + short curr; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) != 0) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final short k) { + if (isDense) { + return k >= min && k <= max && used[k - min]; + } else { + if (k == 0) { + return this.containsZero; + } else { + short[] key = this.key; + short curr; + int pos; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) == 0) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + short[] key = this.key; + int mask = newN - 1; + short[] newKey = new short[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (key[i] == 0); + + if (newKey[pos = MurmurHashUtil.fmix(key[i]) & mask] != 0) { + while (newKey[pos = pos + 1 & mask] != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + int range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[max - min + 1]; + for (short v : key) { + if (v != 0) { + used[v - min] = true; + } + } + if (containsZero) { + used[-min] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/InternalTypeUtils.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/InternalTypeUtils.java new file mode 100644 index 0000000000000..85ca05522b784 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/InternalTypeUtils.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.type; + +/** + * Utilities for {@link InternalType}. + */ +public class InternalTypeUtils { + + /** + * Gets the arity of the type. + */ + public static int getArity(InternalType t) { + if (t instanceof RowType) { + return ((RowType) t).getArity(); + } else { + return 1; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java index 380ec4493d889..db89521bed51d 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java @@ -106,7 +106,8 @@ public boolean equals(Object o) { RowType that = (RowType) o; // RowType comparisons should not compare names and are compatible with the behavior of CompositeTypeInfo. - return Arrays.equals(getFieldTypes(), that.getFieldTypes()); + return Arrays.equals(getFieldTypes(), that.getFieldTypes()) && + Arrays.equals(getFieldNames(), that.getFieldNames()); } @Override diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java index ea90a77d77df9..69628dc86a7af 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java @@ -120,6 +120,9 @@ public static InternalType createInternalTypeFromTypeInfo(TypeInformation typeIn .toArray(InternalType[]::new), compositeType.getFieldNames() ); + } else if (typeInfo instanceof DecimalTypeInfo) { + DecimalTypeInfo decimalType = (DecimalTypeInfo) typeInfo; + return InternalTypes.createDecimalType(decimalType.precision(), decimalType.scale()); } else if (typeInfo instanceof PrimitiveArrayTypeInfo) { PrimitiveArrayTypeInfo arrayType = (PrimitiveArrayTypeInfo) typeInfo; return InternalTypes.createArrayType( diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java index b4e0b41a271fd..dbae3e87b3851 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java @@ -148,7 +148,7 @@ private static int fmix(int h1, int length) { return fmix(h1); } - private static int fmix(int h) { + public static int fmix(int h) { h ^= h >>> 16; h *= 0x85ebca6b; h ^= h >>> 13; @@ -156,4 +156,13 @@ private static int fmix(int h) { h ^= h >>> 16; return h; } + + public static long fmix(long h) { + h ^= (h >>> 33); + h *= 0xff51afd7ed558ccdL; + h ^= (h >>> 33); + h *= 0xc4ceb9fe1a85ec53L; + h ^= (h >>> 33); + return h; + } }