diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/FlinkSqlOperatorTable.java new file mode 100644 index 0000000000000..19af644be0fc8 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/FlinkSqlOperatorTable.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlBinaryOperator; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; + +/** + * Operator table that contains only Flink-specific functions and operators. + */ +public class FlinkSqlOperatorTable extends ReflectiveSqlOperatorTable { + + /** + * The table of contains Flink-specific operators. + */ + private static FlinkSqlOperatorTable instance; + + /** + * Returns the Flink operator table, creating it if necessary. + */ + public static synchronized FlinkSqlOperatorTable instance() { + if (instance == null) { + // Creates and initializes the standard operator table. + // Uses two-phase construction, because we can't initialize the + // table until the constructor of the sub-class has completed. + instance = new FlinkSqlOperatorTable(); + instance.init(); + } + return instance; + } + + // ----------------------------------------------------------------------------- + + private static final SqlReturnTypeInference FLINK_QUOTIENT_NULLABLE = new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + RelDataType type1 = opBinding.getOperandType(0); + RelDataType type2 = opBinding.getOperandType(1); + if (SqlTypeUtil.isDecimal(type1) || SqlTypeUtil.isDecimal(type2)) { + return ReturnTypes.QUOTIENT_NULLABLE.inferReturnType(opBinding); + } else { + RelDataType doubleType = opBinding.getTypeFactory().createSqlType(SqlTypeName.DOUBLE); + if (type1.isNullable() || type2.isNullable()) { + return opBinding.getTypeFactory().createTypeWithNullability(doubleType, true); + } else { + return doubleType; + } + } + } + }; + + // ----------------------------------------------------------------------------- + // Flink specific built-in scalar SQL functions + // ----------------------------------------------------------------------------- + + /** + * Arithmetic division operator, '/'. + */ + public static final SqlBinaryOperator DIVIDE = new SqlBinaryOperator( + "/", + SqlKind.DIVIDE, + 60, + true, + FLINK_QUOTIENT_NULLABLE, + InferTypes.FIRST_KNOWN, + OperandTypes.DIVISION_OPERATOR); + + /** Function used to access a processing time attribute. */ + public static final SqlFunction PROCTIME = new ProctimeSqlFunction(); + + /** Function used to materialize a processing time attribute. */ + public static final SqlFunction PROCTIME_MATERIALIZE = new ProctimeMaterializeSqlFunction(); + + /** Function to access the timestamp of a StreamRecord. */ + public static final SqlFunction STREAMRECORD_TIMESTAMP = new StreamRecordTimestampSqlFunction(); + + // ----------------------------------------------------------------------------- + // Window SQL functions + // ----------------------------------------------------------------------------- + + // TODO: add window functions here + + // ----------------------------------------------------------------------------- + // operators extend from Calcite + // ----------------------------------------------------------------------------- + + // SET OPERATORS + public static final SqlOperator UNION = SqlStdOperatorTable.UNION; + public static final SqlOperator UNION_ALL = SqlStdOperatorTable.UNION_ALL; + public static final SqlOperator EXCEPT = SqlStdOperatorTable.EXCEPT; + public static final SqlOperator EXCEPT_ALL = SqlStdOperatorTable.EXCEPT_ALL; + public static final SqlOperator INTERSECT = SqlStdOperatorTable.INTERSECT; + public static final SqlOperator INTERSECT_ALL = SqlStdOperatorTable.INTERSECT_ALL; + + // BINARY OPERATORS + public static final SqlOperator AND = SqlStdOperatorTable.AND; + public static final SqlOperator AS = SqlStdOperatorTable.AS; + public static final SqlOperator CONCAT = SqlStdOperatorTable.CONCAT; + public static final SqlOperator DIVIDE_INTEGER = SqlStdOperatorTable.DIVIDE_INTEGER; + public static final SqlOperator DOT = SqlStdOperatorTable.DOT; + public static final SqlOperator EQUALS = SqlStdOperatorTable.EQUALS; + public static final SqlOperator GREATER_THAN = SqlStdOperatorTable.GREATER_THAN; + public static final SqlOperator IS_DISTINCT_FROM = SqlStdOperatorTable.IS_DISTINCT_FROM; + public static final SqlOperator IS_NOT_DISTINCT_FROM = SqlStdOperatorTable.IS_NOT_DISTINCT_FROM; + public static final SqlOperator GREATER_THAN_OR_EQUAL = SqlStdOperatorTable.GREATER_THAN_OR_EQUAL; + public static final SqlOperator LESS_THAN = SqlStdOperatorTable.LESS_THAN; + public static final SqlOperator LESS_THAN_OR_EQUAL = SqlStdOperatorTable.LESS_THAN_OR_EQUAL; + public static final SqlOperator MINUS = SqlStdOperatorTable.MINUS; + public static final SqlOperator MINUS_DATE = SqlStdOperatorTable.MINUS_DATE; + public static final SqlOperator MULTIPLY = SqlStdOperatorTable.MULTIPLY; + public static final SqlOperator NOT_EQUALS = SqlStdOperatorTable.NOT_EQUALS; + public static final SqlOperator OR = SqlStdOperatorTable.OR; + public static final SqlOperator PLUS = SqlStdOperatorTable.PLUS; + public static final SqlOperator DATETIME_PLUS = SqlStdOperatorTable.DATETIME_PLUS; + + // POSTFIX OPERATORS + public static final SqlOperator DESC = SqlStdOperatorTable.DESC; + public static final SqlOperator NULLS_FIRST = SqlStdOperatorTable.NULLS_FIRST; + public static final SqlOperator NULLS_LAST = SqlStdOperatorTable.NULLS_LAST; + public static final SqlOperator IS_NOT_NULL = SqlStdOperatorTable.IS_NOT_NULL; + public static final SqlOperator IS_NULL = SqlStdOperatorTable.IS_NULL; + public static final SqlOperator IS_NOT_TRUE = SqlStdOperatorTable.IS_NOT_TRUE; + public static final SqlOperator IS_TRUE = SqlStdOperatorTable.IS_TRUE; + public static final SqlOperator IS_NOT_FALSE = SqlStdOperatorTable.IS_NOT_FALSE; + public static final SqlOperator IS_FALSE = SqlStdOperatorTable.IS_FALSE; + public static final SqlOperator IS_NOT_UNKNOWN = SqlStdOperatorTable.IS_NOT_UNKNOWN; + public static final SqlOperator IS_UNKNOWN = SqlStdOperatorTable.IS_UNKNOWN; + + // PREFIX OPERATORS + public static final SqlOperator NOT = SqlStdOperatorTable.NOT; + public static final SqlOperator UNARY_MINUS = SqlStdOperatorTable.UNARY_MINUS; + public static final SqlOperator UNARY_PLUS = SqlStdOperatorTable.UNARY_PLUS; + + // GROUPING FUNCTIONS + public static final SqlFunction GROUP_ID = SqlStdOperatorTable.GROUP_ID; + public static final SqlFunction GROUPING = SqlStdOperatorTable.GROUPING; + public static final SqlFunction GROUPING_ID = SqlStdOperatorTable.GROUPING_ID; + + // AGGREGATE OPERATORS + public static final SqlAggFunction SUM = SqlStdOperatorTable.SUM; + public static final SqlAggFunction SUM0 = SqlStdOperatorTable.SUM0; + public static final SqlAggFunction COUNT = SqlStdOperatorTable.COUNT; + public static final SqlAggFunction APPROX_COUNT_DISTINCT = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; + public static final SqlAggFunction COLLECT = SqlStdOperatorTable.COLLECT; + public static final SqlAggFunction MIN = SqlStdOperatorTable.MIN; + public static final SqlAggFunction MAX = SqlStdOperatorTable.MAX; + public static final SqlAggFunction AVG = SqlStdOperatorTable.AVG; + public static final SqlAggFunction STDDEV = SqlStdOperatorTable.STDDEV; + public static final SqlAggFunction STDDEV_POP = SqlStdOperatorTable.STDDEV_POP; + public static final SqlAggFunction STDDEV_SAMP = SqlStdOperatorTable.STDDEV_SAMP; + public static final SqlAggFunction VARIANCE = SqlStdOperatorTable.VARIANCE; + public static final SqlAggFunction VAR_POP = SqlStdOperatorTable.VAR_POP; + public static final SqlAggFunction VAR_SAMP = SqlStdOperatorTable.VAR_SAMP; + + // ARRAY OPERATORS + public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR; + public static final SqlOperator ELEMENT = SqlStdOperatorTable.ELEMENT; + + // MAP OPERATORS + public static final SqlOperator MAP_VALUE_CONSTRUCTOR = SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR; + + // ARRAY MAP SHARED OPERATORS + public static final SqlOperator ITEM = SqlStdOperatorTable.ITEM; + public static final SqlOperator CARDINALITY = SqlStdOperatorTable.CARDINALITY; + + // SPECIAL OPERATORS + public static final SqlOperator MULTISET_VALUE = SqlStdOperatorTable.MULTISET_VALUE; + public static final SqlOperator ROW = SqlStdOperatorTable.ROW; + public static final SqlOperator OVERLAPS = SqlStdOperatorTable.OVERLAPS; + public static final SqlOperator LITERAL_CHAIN = SqlStdOperatorTable.LITERAL_CHAIN; + public static final SqlOperator BETWEEN = SqlStdOperatorTable.BETWEEN; + public static final SqlOperator SYMMETRIC_BETWEEN = SqlStdOperatorTable.SYMMETRIC_BETWEEN; + public static final SqlOperator NOT_BETWEEN = SqlStdOperatorTable.NOT_BETWEEN; + public static final SqlOperator SYMMETRIC_NOT_BETWEEN = SqlStdOperatorTable.SYMMETRIC_NOT_BETWEEN; + public static final SqlOperator NOT_LIKE = SqlStdOperatorTable.NOT_LIKE; + public static final SqlOperator LIKE = SqlStdOperatorTable.LIKE; + public static final SqlOperator NOT_SIMILAR_TO = SqlStdOperatorTable.NOT_SIMILAR_TO; + public static final SqlOperator SIMILAR_TO = SqlStdOperatorTable.SIMILAR_TO; + public static final SqlOperator CASE = SqlStdOperatorTable.CASE; + public static final SqlOperator REINTERPRET = SqlStdOperatorTable.REINTERPRET; + public static final SqlOperator EXTRACT = SqlStdOperatorTable.EXTRACT; + public static final SqlOperator IN = SqlStdOperatorTable.IN; + public static final SqlOperator NOT_IN = SqlStdOperatorTable.NOT_IN; + + // FUNCTIONS + public static final SqlFunction OVERLAY = SqlStdOperatorTable.OVERLAY; + public static final SqlFunction TRIM = SqlStdOperatorTable.TRIM; + public static final SqlFunction POSITION = SqlStdOperatorTable.POSITION; + public static final SqlFunction CHAR_LENGTH = SqlStdOperatorTable.CHAR_LENGTH; + public static final SqlFunction CHARACTER_LENGTH = SqlStdOperatorTable.CHARACTER_LENGTH; + public static final SqlFunction UPPER = SqlStdOperatorTable.UPPER; + public static final SqlFunction LOWER = SqlStdOperatorTable.LOWER; + public static final SqlFunction INITCAP = SqlStdOperatorTable.INITCAP; + public static final SqlFunction POWER = SqlStdOperatorTable.POWER; + public static final SqlFunction SQRT = SqlStdOperatorTable.SQRT; + public static final SqlFunction MOD = SqlStdOperatorTable.MOD; + public static final SqlFunction LN = SqlStdOperatorTable.LN; + public static final SqlFunction LOG10 = SqlStdOperatorTable.LOG10; + public static final SqlFunction ABS = SqlStdOperatorTable.ABS; + public static final SqlFunction EXP = SqlStdOperatorTable.EXP; + public static final SqlFunction NULLIF = SqlStdOperatorTable.NULLIF; + public static final SqlFunction COALESCE = SqlStdOperatorTable.COALESCE; + public static final SqlFunction FLOOR = SqlStdOperatorTable.FLOOR; + public static final SqlFunction CEIL = SqlStdOperatorTable.CEIL; + public static final SqlFunction LOCALTIME = SqlStdOperatorTable.LOCALTIME; + public static final SqlFunction LOCALTIMESTAMP = SqlStdOperatorTable.LOCALTIMESTAMP; + public static final SqlFunction CURRENT_TIME = SqlStdOperatorTable.CURRENT_TIME; + public static final SqlFunction CURRENT_TIMESTAMP = SqlStdOperatorTable.CURRENT_TIMESTAMP; + public static final SqlFunction CURRENT_DATE = SqlStdOperatorTable.CURRENT_DATE; + public static final SqlFunction CAST = SqlStdOperatorTable.CAST; + public static final SqlFunction QUARTER = SqlStdOperatorTable.QUARTER; + public static final SqlOperator SCALAR_QUERY = SqlStdOperatorTable.SCALAR_QUERY; + public static final SqlOperator EXISTS = SqlStdOperatorTable.EXISTS; + public static final SqlFunction SIN = SqlStdOperatorTable.SIN; + public static final SqlFunction COS = SqlStdOperatorTable.COS; + public static final SqlFunction TAN = SqlStdOperatorTable.TAN; + public static final SqlFunction COT = SqlStdOperatorTable.COT; + public static final SqlFunction ASIN = SqlStdOperatorTable.ASIN; + public static final SqlFunction ACOS = SqlStdOperatorTable.ACOS; + public static final SqlFunction ATAN = SqlStdOperatorTable.ATAN; + public static final SqlFunction ATAN2 = SqlStdOperatorTable.ATAN2; + public static final SqlFunction DEGREES = SqlStdOperatorTable.DEGREES; + public static final SqlFunction RADIANS = SqlStdOperatorTable.RADIANS; + public static final SqlFunction SIGN = SqlStdOperatorTable.SIGN; + public static final SqlFunction PI = SqlStdOperatorTable.PI; + public static final SqlFunction RAND = SqlStdOperatorTable.RAND; + public static final SqlFunction RAND_INTEGER = SqlStdOperatorTable.RAND_INTEGER; + public static final SqlFunction TIMESTAMP_ADD = SqlStdOperatorTable.TIMESTAMP_ADD; + public static final SqlFunction TIMESTAMP_DIFF = SqlStdOperatorTable.TIMESTAMP_DIFF; + + // MATCH_RECOGNIZE + public static final SqlFunction FIRST = SqlStdOperatorTable.FIRST; + public static final SqlFunction LAST = SqlStdOperatorTable.LAST; + public static final SqlFunction PREV = SqlStdOperatorTable.PREV; + public static final SqlFunction NEXT = SqlStdOperatorTable.NEXT; + public static final SqlFunction CLASSIFIER = SqlStdOperatorTable.CLASSIFIER; + public static final SqlOperator FINAL = SqlStdOperatorTable.FINAL; + public static final SqlOperator RUNNING = SqlStdOperatorTable.RUNNING; + + // OVER FUNCTIONS + public static final SqlAggFunction RANK = SqlStdOperatorTable.RANK; + public static final SqlAggFunction DENSE_RANK = SqlStdOperatorTable.DENSE_RANK; + public static final SqlAggFunction ROW_NUMBER = SqlStdOperatorTable.ROW_NUMBER; + public static final SqlAggFunction LEAD = SqlStdOperatorTable.LEAD; + public static final SqlAggFunction LAG = SqlStdOperatorTable.LAG; +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeMaterializeSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeMaterializeSqlFunction.java new file mode 100644 index 0000000000000..17b9ea3356c82 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeMaterializeSqlFunction.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlMonotonicity; + +/** + * Function that materializes a processing time attribute. + * After materialization the result can be used in regular arithmetical calculations. + */ +public class ProctimeMaterializeSqlFunction extends SqlFunction { + + public ProctimeMaterializeSqlFunction() { + super( + "PROCTIME_MATERIALIZE", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.TIMESTAMP), + InferTypes.RETURN_TYPE, + OperandTypes.family(SqlTypeFamily.TIMESTAMP), + SqlFunctionCategory.SYSTEM); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.FUNCTION; + } + + @Override + public SqlMonotonicity getMonotonicity(SqlOperatorBinding call) { + return SqlMonotonicity.INCREASING; + } + + @Override + public boolean isDeterministic() { + return false; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java new file mode 100644 index 0000000000000..45827deafcb9c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/ProctimeSqlFunction.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.flink.table.calcite.FlinkTypeFactory; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelProtoDataType; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; + +/** + * Function used to access a proctime attribute. + */ +public class ProctimeSqlFunction extends SqlFunction { + public ProctimeSqlFunction() { + super( + "PROCTIME", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(new ProctimeRelProtoDataType()), + null, + OperandTypes.NILADIC, + SqlFunctionCategory.TIMEDATE); + } + + private static class ProctimeRelProtoDataType implements RelProtoDataType { + @Override + public RelDataType apply(RelDataTypeFactory factory) { + return ((FlinkTypeFactory) factory).createRowtimeIndicatorType(); + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/StreamRecordTimestampSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/StreamRecordTimestampSqlFunction.java new file mode 100644 index 0000000000000..9a794257c4710 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/functions/sql/StreamRecordTimestampSqlFunction.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.sql; + +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Function to access the timestamp of a StreamRecord. + */ +public class StreamRecordTimestampSqlFunction extends SqlFunction { + + public StreamRecordTimestampSqlFunction() { + super( + "STREAMRECORD_TIMESTAMP", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.BIGINT), + InferTypes.RETURN_TYPE, + OperandTypes.family(SqlTypeFamily.NUMERIC), + SqlFunctionCategory.SYSTEM); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.FUNCTION; + } + + @Override + public boolean isDeterministic() { + return true; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala index 97cf376eed868..4411f6f234741 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala @@ -32,6 +32,7 @@ import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.{RowTypeInfo, _} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkRelBuilder, FlinkTypeFactory, FlinkTypeSystem} +import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable import org.apache.flink.table.plan.cost.FlinkCostFactory import org.apache.flink.table.plan.schema.RelTable import org.apache.flink.types.Row @@ -63,6 +64,7 @@ abstract class TableEnvironment(val config: TableConfig) { .costFactory(new FlinkCostFactory) .typeSystem(new FlinkTypeSystem) .sqlToRelConverterConfig(getSqlToRelConverterConfig) + .operatorTable(FlinkSqlOperatorTable.instance()) // TODO: introduce ExpressionReducer after codegen // set the executor to evaluate constant expressions // .executor(new ExpressionReducer(config)) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLocalRef.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLocalRef.scala new file mode 100644 index 0000000000000..62831dd629802 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkLocalRef.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.calcite + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rex.RexLocalRef +import org.apache.flink.table.`type`.InternalType + +/** + * Special reference which represent a local filed, such as aggregate buffers or constants. + * We are stored as class members, so the field can be referenced directly. + * We should use an unique name to locate the field. + * + * See [[org.apache.flink.table.codegen.ExprCodeGenerator.visitLocalRef()]] + */ +case class RexAggLocalVariable( + fieldTerm: String, + nullTerm: String, + dataType: RelDataType, + internalType: InternalType) + extends RexLocalRef(0, dataType) + +/** + * Special reference which represent a distinct key input filed, + * We use the name to locate the distinct key field. + * + * See [[org.apache.flink.table.codegen.ExprCodeGenerator.visitLocalRef()]] + */ +case class RexDistinctKeyVariable( + keyTerm: String, + dataType: RelDataType, + internalType: InternalType) + extends RexLocalRef(0, dataType) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala index 2b683119dced8..0d525c7359fb4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala @@ -18,15 +18,19 @@ package org.apache.flink.table.calcite -import org.apache.flink.table.`type`.{ArrayType, DecimalType, InternalType, InternalTypes, MapType, RowType} +import org.apache.flink.table.`type`._ import org.apache.flink.table.api.{TableException, TableSchema} -import org.apache.flink.table.plan.schema.{ArrayRelDataType, MapRelDataType, RowRelDataType, RowSchema, TimeIndicatorRelDataType} - +import org.apache.flink.table.plan.schema.{GenericRelDataType, _} +import org.apache.calcite.avatica.util.TimeUnit import org.apache.calcite.jdbc.JavaTypeFactoryImpl import org.apache.calcite.rel.`type`._ +import org.apache.calcite.sql.SqlIntervalQualifier import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.`type`.{BasicSqlType, SqlTypeName} +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.util.ConversionUtil +import java.nio.charset.Charset import java.util import scala.collection.JavaConverters._ @@ -64,12 +68,25 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp case InternalTypes.DATE => createSqlType(DATE) case InternalTypes.TIME => createSqlType(TIME) case InternalTypes.TIMESTAMP => createSqlType(TIMESTAMP) + case InternalTypes.PROCTIME_INDICATOR => createProctimeIndicatorType() + case InternalTypes.ROWTIME_INDICATOR => createRowtimeIndicatorType() + + // interval types + case InternalTypes.INTERVAL_MONTHS => + createSqlIntervalType( + new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)) + case InternalTypes.INTERVAL_MILLIS => + createSqlIntervalType( + new SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.SECOND, SqlParserPos.ZERO)) case InternalTypes.BINARY => createSqlType(VARBINARY) case InternalTypes.CHAR => throw new TableException("Character type is not supported.") + case decimal: DecimalType => + createSqlType(DECIMAL, decimal.precision(), decimal.scale()) + case rowType: RowType => new RowRelDataType(rowType, isNullable, this) case arrayType: ArrayType => new ArrayRelDataType(arrayType, @@ -81,6 +98,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp createTypeFromInternalType(mapType.getValueType, isNullable = true), isNullable) + case generic: GenericType[_] => + new GenericRelDataType(generic, isNullable, getTypeSystem) + case _@t => throw new TableException(s"Type is not supported: $t") } @@ -91,6 +111,19 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp createTypeWithNullability(relType, isNullable) } + /** + * Creates a indicator type for processing-time, but with similar properties as SQL timestamp. + */ + def createProctimeIndicatorType(): RelDataType = { + val originalType = createTypeFromInternalType(InternalTypes.TIMESTAMP, isNullable = false) + canonize( + new TimeIndicatorRelDataType( + getTypeSystem, + originalType.asInstanceOf[BasicSqlType], + isEventTime = false) + ) + } + /** * Creates a indicator type for event-time, but with similar properties as SQL timestamp. */ @@ -209,6 +242,28 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } } + + override def createArrayType(elementType: RelDataType, maxCardinality: Long): RelDataType = { + val arrayType = InternalTypes.createArrayType(FlinkTypeFactory.toInternalType(elementType)) + val relType = new ArrayRelDataType( + arrayType, + elementType, + isNullable = false) + canonize(relType) + } + + override def createMapType(keyType: RelDataType, valueType: RelDataType): RelDataType = { + val internalKeyType = FlinkTypeFactory.toInternalType(keyType) + val internalValueType = FlinkTypeFactory.toInternalType(valueType) + val internalMapType = InternalTypes.createMapType(internalKeyType, internalValueType) + val relType = new MapRelDataType( + internalMapType, + keyType, + valueType, + isNullable = false) + canonize(relType) + } + override def createSqlType(typeName: SqlTypeName): RelDataType = { if (typeName == DECIMAL) { // if we got here, the precision and scale are not specified, here we @@ -231,7 +286,22 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } // change nullability - val newType = super.createTypeWithNullability(relDataType, isNullable) + val newType = relDataType match { + case array: ArrayRelDataType => + new ArrayRelDataType(array.arrayType, array.getComponentType, isNullable) + + case map: MapRelDataType => + new MapRelDataType(map.mapType, map.keyType, map.valueType, isNullable) + + case generic: GenericRelDataType => + new GenericRelDataType(generic.genericType, isNullable, typeSystem) + + case timeIndicator: TimeIndicatorRelDataType => + timeIndicator + + case _ => + super.createTypeWithNullability(relDataType, isNullable) + } canonize(newType) } @@ -273,6 +343,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } } + override def getDefaultCharset: Charset = { + Charset.forName(ConversionUtil.NATIVE_UTF16_CHARSET_NAME) + } } object FlinkTypeFactory { @@ -291,22 +364,32 @@ object FlinkTypeFactory { case FLOAT => InternalTypes.FLOAT case DOUBLE => InternalTypes.DOUBLE case VARCHAR | CHAR => InternalTypes.STRING - case DECIMAL => throw new RuntimeException("Not support yet.") + case VARBINARY | BINARY => InternalTypes.BINARY + case DECIMAL => InternalTypes.createDecimalType(relDataType.getPrecision, relDataType.getScale) + + // time indicators + case TIMESTAMP if relDataType.isInstanceOf[TimeIndicatorRelDataType] => + val indicator = relDataType.asInstanceOf[TimeIndicatorRelDataType] + if (indicator.isEventTime) { + InternalTypes.ROWTIME_INDICATOR + } else { + InternalTypes.PROCTIME_INDICATOR + } // temporal types case DATE => InternalTypes.DATE case TIME => InternalTypes.TIME case TIMESTAMP => InternalTypes.TIMESTAMP - - case VARBINARY => InternalTypes.BINARY + case typeName if YEAR_INTERVAL_TYPES.contains(typeName) => InternalTypes.INTERVAL_MONTHS + case typeName if DAY_INTERVAL_TYPES.contains(typeName) => InternalTypes.INTERVAL_MILLIS case NULL => throw new TableException( "Type NULL is not supported. Null values must have a supported type.") // symbol for special flags e.g. TRIM's BOTH, LEADING, TRAILING - // are represented as integer - case SYMBOL => InternalTypes.INT + // are represented as Enum + case SYMBOL => InternalTypes.createGenericType(classOf[Enum[_]]) case ROW if relDataType.isInstanceOf[RowRelDataType] => val compositeRelDataType = relDataType.asInstanceOf[RowRelDataType] diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala index 99d8cab4f39aa..6b6e6d1952001 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/FlinkTypeSystem.scala @@ -20,19 +20,18 @@ package org.apache.flink.table.calcite import org.apache.calcite.rel.`type`.RelDataTypeSystemImpl import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.table.`type`.DecimalType /** * Custom type system for Flink. */ class FlinkTypeSystem extends RelDataTypeSystemImpl { - // we cannot use Int.MaxValue because of an overflow in Calcite's type inference logic - // half should be enough for all use cases - override def getMaxNumericScale: Int = Int.MaxValue / 2 + // set the maximum precision of a NUMERIC or DECIMAL type to DecimalType.MAX_PRECISION. + override def getMaxNumericPrecision: Int = DecimalType.MAX_PRECISION - // we cannot use Int.MaxValue because of an overflow in Calcite's type inference logic - // half should be enough for all use cases - override def getMaxNumericPrecision: Int = Int.MaxValue / 2 + // the max scale can't be greater than precision + override def getMaxNumericScale: Int = DecimalType.MAX_PRECISION override def getDefaultPrecision(typeName: SqlTypeName): Int = typeName match { @@ -48,4 +47,8 @@ class FlinkTypeSystem extends RelDataTypeSystemImpl { super.getDefaultPrecision(typeName) } + // when union a number of CHAR types of different lengths, we should cast to a VARCHAR + // this fixes the problem of CASE WHEN with different length string literals but get wrong + // result with additional space suffix + override def shouldConvertRaggedUnionTypesToVarying(): Boolean = true } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala index d07815fad5928..5d4d3f33896ff 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGenUtils.scala @@ -18,35 +18,54 @@ package org.apache.flink.table.codegen -import org.apache.flink.api.common.ExecutionConfig -import org.apache.flink.api.common.typeinfo.{AtomicType => AtomicTypeInfo} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.core.memory.MemorySegment import org.apache.flink.table.`type`._ -import org.apache.flink.table.calcite.FlinkPlannerImpl -import org.apache.flink.table.dataformat._ +import org.apache.flink.table.dataformat.{Decimal, _} import org.apache.flink.table.typeutils.TypeCheckUtils import java.lang.reflect.Method import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable - object CodeGenUtils { // ------------------------------- DEFAULT TERMS ------------------------------------------ val DEFAULT_TIMEZONE_TERM = "timeZone" + val DEFAULT_INPUT1_TERM = "in1" + + val DEFAULT_INPUT2_TERM = "in2" + + val DEFAULT_COLLECTOR_TERM = "c" + + val DEFAULT_OUT_RECORD_TERM = "out" + + val DEFAULT_OPERATOR_COLLECTOR_TERM = "output" + + val DEFAULT_OUT_RECORD_WRITER_TERM = "outWriter" + + val DEFAULT_CONTEXT_TERM = "ctx" + // -------------------------- CANONICAL CLASS NAMES --------------------------------------- val BINARY_ROW: String = className[BinaryRow] + val BINARY_ARRAY: String = className[BinaryArray] + val BINARY_GENERIC: String = className[BinaryGeneric[_]] + val BINARY_STRING: String = className[BinaryString] + + val BINARY_MAP: String = className[BinaryMap] + val BASE_ROW: String = className[BaseRow] + val GENERIC_ROW: String = className[GenericRow] + + val DECIMAL: String = className[Decimal] + val SEGMENT: String = className[MemorySegment] // ---------------------------------------------------------------------------------------- @@ -68,15 +87,6 @@ object CodeGenUtils { */ def className[T](implicit m: Manifest[T]): String = m.runtimeClass.getCanonicalName - def needCopyForType(t: InternalType): Boolean = t match { - case InternalTypes.STRING => true - case _: ArrayType => true - case _: MapType => true - case _: RowType => true - case _: GenericType[_] => true - case _ => false - } - // when casting we first need to unbox Primitives, for example, // float a = 1.0f; // byte b = (byte) a; @@ -120,7 +130,6 @@ object CodeGenUtils { case InternalTypes.BINARY => "byte[]" case _: DecimalType => className[Decimal] - // BINARY is also an ArrayType and uses BinaryArray internally too case _: ArrayType => className[BinaryArray] case _: MapType => className[BinaryMap] case _: RowType => className[BaseRow] @@ -195,6 +204,16 @@ object CodeGenUtils { throw new CodeGenException("Boolean expression type expected.") } + def requireTemporal(genExpr: GeneratedExpression): Unit = + if (!TypeCheckUtils.isTemporal(genExpr.resultType)) { + throw new CodeGenException("Temporal expression type expected.") + } + + def requireTimeInterval(genExpr: GeneratedExpression): Unit = + if (!TypeCheckUtils.isTimeInterval(genExpr.resultType)) { + throw new CodeGenException("Interval expression type expected.") + } + def requireArray(genExpr: GeneratedExpression): Unit = if (!TypeCheckUtils.isArray(genExpr.resultType)) { throw new CodeGenException("Array expression type expected.") @@ -210,227 +229,267 @@ object CodeGenUtils { throw new CodeGenException("Integer expression type expected.") } - // --------------------------- Generate Utils --------------------------------------- - - def generateOutputRecordStatement( - t: InternalType, - clazz: Class[_], - outRecordTerm: String, - outRecordWriterTerm: Option[String] = None): String = { - t match { - case rt: RowType if clazz == classOf[BinaryRow] => - val writerTerm = outRecordWriterTerm.getOrElse( - throw new CodeGenException("No writer is specified when writing BinaryRow record.") - ) - val binaryRowWriter = className[BinaryRowWriter] - val typeTerm = clazz.getCanonicalName - s""" - |final $typeTerm $outRecordTerm = new $typeTerm(${rt.getArity}); - |final $binaryRowWriter $writerTerm = new $binaryRowWriter($outRecordTerm); - |""".stripMargin.trim - case rt: RowType if classOf[ObjectArrayRow].isAssignableFrom(clazz) => - val typeTerm = clazz.getCanonicalName - s"final $typeTerm $outRecordTerm = new $typeTerm(${rt.getArity});" - case _: RowType if clazz == classOf[JoinedRow] => - val typeTerm = clazz.getCanonicalName - s"final $typeTerm $outRecordTerm = new $typeTerm();" - case _ => - val typeTerm = boxedTypeTermForType(t) - s"final $typeTerm $outRecordTerm = new $typeTerm();" - } - } + // -------------------------------------------------------------------------------- + // DataFormat Operations + // -------------------------------------------------------------------------------- + + // -------------------------- BaseRow Read Access ------------------------------- def baseRowFieldReadAccess( - ctx: CodeGeneratorContext, pos: Int, rowTerm: String, fieldType: InternalType) : String = - baseRowFieldReadAccess(ctx, pos.toString, rowTerm, fieldType) + ctx: CodeGeneratorContext, + index: Int, + rowTerm: String, + fieldType: InternalType) : String = + baseRowFieldReadAccess(ctx, index.toString, rowTerm, fieldType) def baseRowFieldReadAccess( - ctx: CodeGeneratorContext, pos: String, rowTerm: String, fieldType: InternalType) : String = + ctx: CodeGeneratorContext, + indexTerm: String, + rowTerm: String, + fieldType: InternalType) : String = fieldType match { - case InternalTypes.INT => s"$rowTerm.getInt($pos)" - case InternalTypes.LONG => s"$rowTerm.getLong($pos)" - case InternalTypes.SHORT => s"$rowTerm.getShort($pos)" - case InternalTypes.BYTE => s"$rowTerm.getByte($pos)" - case InternalTypes.FLOAT => s"$rowTerm.getFloat($pos)" - case InternalTypes.DOUBLE => s"$rowTerm.getDouble($pos)" - case InternalTypes.BOOLEAN => s"$rowTerm.getBoolean($pos)" - case InternalTypes.STRING => s"$rowTerm.getString($pos)" - case InternalTypes.BINARY => s"$rowTerm.getBinary($pos)" - case dt: DecimalType => s"$rowTerm.getDecimal($pos, ${dt.precision()}, ${dt.scale()})" - case InternalTypes.CHAR => s"$rowTerm.getChar($pos)" - case _: TimestampType => s"$rowTerm.getLong($pos)" - case _: DateType => s"$rowTerm.getInt($pos)" - case InternalTypes.TIME => s"$rowTerm.getInt($pos)" - case _: ArrayType => s"$rowTerm.getArray($pos)" - case _: MapType => s"$rowTerm.getMap($pos)" - case rt: RowType => s"$rowTerm.getRow($pos, ${rt.getArity})" - case _: GenericType[_] => s"$rowTerm.getGeneric($pos)" + // primitive types + case InternalTypes.BOOLEAN => s"$rowTerm.getBoolean($indexTerm)" + case InternalTypes.BYTE => s"$rowTerm.getByte($indexTerm)" + case InternalTypes.CHAR => s"$rowTerm.getChar($indexTerm)" + case InternalTypes.SHORT => s"$rowTerm.getShort($indexTerm)" + case InternalTypes.INT => s"$rowTerm.getInt($indexTerm)" + case InternalTypes.LONG => s"$rowTerm.getLong($indexTerm)" + case InternalTypes.FLOAT => s"$rowTerm.getFloat($indexTerm)" + case InternalTypes.DOUBLE => s"$rowTerm.getDouble($indexTerm)" + case InternalTypes.STRING => s"$rowTerm.getString($indexTerm)" + case InternalTypes.BINARY => s"$rowTerm.getBinary($indexTerm)" + case dt: DecimalType => s"$rowTerm.getDecimal($indexTerm, ${dt.precision()}, ${dt.scale()})" + + // temporal types + case _: DateType => s"$rowTerm.getInt($indexTerm)" + case InternalTypes.TIME => s"$rowTerm.getInt($indexTerm)" + case _: TimestampType => s"$rowTerm.getLong($indexTerm)" + + // complex types + case _: ArrayType => s"$rowTerm.getArray($indexTerm)" + case _: MapType => s"$rowTerm.getMap($indexTerm)" + case rt: RowType => s"$rowTerm.getRow($indexTerm, ${rt.getArity})" + + case _: GenericType[_] => s"$rowTerm.getGeneric($indexTerm)" } - /** - * Generates code for comparing two field. - */ - def genCompare( - ctx: CodeGeneratorContext, - t: InternalType, - nullsIsLast: Boolean, - c1: String, - c2: String): String = t match { - case InternalTypes.BOOLEAN => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" - case _: PrimitiveType | _: DateType | _: TimeType | _: TimestampType => - s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" - case InternalTypes.BINARY => - val sortUtil = classOf[org.apache.flink.table.runtime.sort.SortUtil].getCanonicalName - s"$sortUtil.compareBinary($c1, $c2)" - case at: ArrayType => - val compareFunc = newName("compareArray") - val compareCode = genArrayCompare( - ctx, - FlinkPlannerImpl.getNullDefaultOrder(true), at, "a", "b") - val funcCode: String = - s""" - public int $compareFunc($BINARY_ARRAY a, $BINARY_ARRAY b) { - $compareCode - return 0; - } - """ - ctx.addReusableMember(funcCode) - s"$compareFunc($c1, $c2)" - case rowType: RowType => - val orders = rowType.getFieldTypes.map(_ => true) - val comparisons = genRowCompare( - ctx, - rowType.getFieldTypes.indices.toArray, - rowType.getFieldTypes, - orders, - FlinkPlannerImpl.getNullDefaultOrders(orders), - "a", - "b") - val compareFunc = newName("compareRow") - val funcCode: String = - s""" - public int $compareFunc($BASE_ROW a, $BASE_ROW b) { - $comparisons - return 0; + // -------------------------- BaseRow Set Field ------------------------------- + + def baseRowSetField( + ctx: CodeGeneratorContext, + rowClass: Class[_ <: BaseRow], + rowTerm: String, + indexTerm: String, + fieldExpr: GeneratedExpression, + binaryRowWriterTerm: Option[String]): String = { + + val fieldType = fieldExpr.resultType + val fieldTerm = fieldExpr.resultTerm + + if (rowClass == classOf[BinaryRow]) { + binaryRowWriterTerm match { + case Some(writer) => + // use writer to set field + val writeField = binaryWriterWriteField(ctx, indexTerm, fieldTerm, writer, fieldType) + if (ctx.nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | ${binaryWriterWriteNull(indexTerm, writer, fieldType)}; + |} else { + | $writeField; + |} + """.stripMargin + } else { + s""" + |${fieldExpr.code} + |$writeField; + """.stripMargin } - """ - ctx.addReusableMember(funcCode) - s"$compareFunc($c1, $c2)" - case gt: GenericType[_] => - val ser = ctx.addReusableObject(gt.getSerializer, "serializer") - val comp = ctx.addReusableObject( - gt.getTypeInfo.asInstanceOf[AtomicTypeInfo[_]].createComparator(true, new ExecutionConfig), - "comparator") - s""" - |$comp.compare( - | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($c1, $ser), - | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($c2, $ser) - |) - """.stripMargin - case other if other.isInstanceOf[AtomicType] => s"$c1.compareTo($c2)" - } - /** - * Generates code for comparing array. - */ - def genArrayCompare( - ctx: CodeGeneratorContext, nullsIsLast: Boolean, t: ArrayType, a: String, b: String) - : String = { - val nullIsLastRet = if (nullsIsLast) 1 else -1 - val elementType = t.getElementType - val fieldA = newName("fieldA") - val isNullA = newName("isNullA") - val lengthA = newName("lengthA") - val fieldB = newName("fieldB") - val isNullB = newName("isNullB") - val lengthB = newName("lengthB") - val minLength = newName("minLength") - val i = newName("i") - val comp = newName("comp") - val typeTerm = primitiveTypeTermForType(elementType) - s""" - int $lengthA = a.numElements(); - int $lengthB = b.numElements(); - int $minLength = ($lengthA > $lengthB) ? $lengthB : $lengthA; - for (int $i = 0; $i < $minLength; $i++) { - boolean $isNullA = a.isNullAt($i); - boolean $isNullB = b.isNullAt($i); - if ($isNullA && $isNullB) { - // Continue to compare the next element - } else if ($isNullA) { - return $nullIsLastRet; - } else if ($isNullB) { - return ${-nullIsLastRet}; + case None => + // directly set field to BinaryRow, this depends on all the fields are fixed length + val writeField = binaryRowFieldSetAccess(indexTerm, rowTerm, fieldType, fieldTerm) + if (ctx.nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | ${binaryRowSetNull(indexTerm, rowTerm, fieldType)}; + |} else { + | $writeField; + |} + """.stripMargin } else { - $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, i, a, elementType)}; - $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, i, b, elementType)}; - int $comp = ${genCompare(ctx, elementType, nullsIsLast, fieldA, fieldB)}; - if ($comp != 0) { - return $comp; - } + s""" + |${fieldExpr.code} + |$writeField; + """.stripMargin } - } - - if ($lengthA < $lengthB) { - return -1; - } else if ($lengthA > $lengthB) { - return 1; - } - """ + } + } else if (rowClass == classOf[GenericRow] || rowClass == classOf[BoxedWrapperRow]) { + val writeField = if (rowClass == classOf[GenericRow]) { + s"$rowTerm.setField($indexTerm, $fieldTerm);" + } else { + boxedWrapperRowFieldSetAccess(rowTerm, indexTerm, fieldTerm, fieldType) + } + if (ctx.nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | $rowTerm.setNullAt($indexTerm); + |} else { + | $writeField; + |} + """.stripMargin + } else { + s""" + |${fieldExpr.code} + |$writeField; + """.stripMargin + } + } else { + throw new UnsupportedOperationException("Not support set field for " + rowClass) + } } - /** - * Generates code for comparing row keys. - */ - def genRowCompare( - ctx: CodeGeneratorContext, - keys: Array[Int], - keyTypes: Array[InternalType], - orders: Array[Boolean], - nullsIsLast: Array[Boolean], - row1: String, - row2: String): String = { + // -------------------------- BinaryRow Set Field ------------------------------- + + def binaryRowSetNull(index: Int, rowTerm: String, t: InternalType): String = + binaryRowSetNull(index.toString, rowTerm, t) + + def binaryRowSetNull(indexTerm: String, rowTerm: String, t: InternalType): String = t match { + case d: DecimalType if !Decimal.isCompact(d.precision()) => + s"$rowTerm.setDecimal($indexTerm, null, ${d.precision()}, ${d.scale()})" + case _ => s"$rowTerm.setNullAt($indexTerm)" + } - val compares = new mutable.ArrayBuffer[String] + def binaryRowFieldSetAccess( + index: Int, + binaryRowTerm: String, + fieldType: InternalType, + fieldValTerm: String): String = + binaryRowFieldSetAccess(index.toString, binaryRowTerm, fieldType, fieldValTerm) + + def binaryRowFieldSetAccess( + index: String, + binaryRowTerm: String, + fieldType: InternalType, + fieldValTerm: String): String = + fieldType match { + case InternalTypes.INT => s"$binaryRowTerm.setInt($index, $fieldValTerm)" + case InternalTypes.LONG => s"$binaryRowTerm.setLong($index, $fieldValTerm)" + case InternalTypes.SHORT => s"$binaryRowTerm.setShort($index, $fieldValTerm)" + case InternalTypes.BYTE => s"$binaryRowTerm.setByte($index, $fieldValTerm)" + case InternalTypes.FLOAT => s"$binaryRowTerm.setFloat($index, $fieldValTerm)" + case InternalTypes.DOUBLE => s"$binaryRowTerm.setDouble($index, $fieldValTerm)" + case InternalTypes.BOOLEAN => s"$binaryRowTerm.setBoolean($index, $fieldValTerm)" + case InternalTypes.CHAR => s"$binaryRowTerm.setChar($index, $fieldValTerm)" + case _: DateType => s"$binaryRowTerm.setInt($index, $fieldValTerm)" + case InternalTypes.TIME => s"$binaryRowTerm.setInt($index, $fieldValTerm)" + case _: TimestampType => s"$binaryRowTerm.setLong($index, $fieldValTerm)" + case d: DecimalType => + s"$binaryRowTerm.setDecimal($index, $fieldValTerm, ${d.precision()}, ${d.scale()})" + case _ => + throw new CodeGenException("Fail to find binary row field setter method of InternalType " + + fieldType + ".") + } - for (i <- keys.indices) { - val index = keys(i) + // -------------------------- BoxedWrapperRow Set Field ------------------------------- - val symbol = if (orders(i)) "" else "-" + def boxedWrapperRowFieldSetAccess( + rowTerm: String, + indexTerm: String, + fieldTerm: String, + fieldType: InternalType): String = + fieldType match { + case InternalTypes.INT => s"$rowTerm.setInt($indexTerm, $fieldTerm)" + case InternalTypes.LONG => s"$rowTerm.setLong($indexTerm, $fieldTerm)" + case InternalTypes.SHORT => s"$rowTerm.setShort($indexTerm, $fieldTerm)" + case InternalTypes.BYTE => s"$rowTerm.setByte($indexTerm, $fieldTerm)" + case InternalTypes.FLOAT => s"$rowTerm.setFloat($indexTerm, $fieldTerm)" + case InternalTypes.DOUBLE => s"$rowTerm.setDouble($indexTerm, $fieldTerm)" + case InternalTypes.BOOLEAN => s"$rowTerm.setBoolean($indexTerm, $fieldTerm)" + case InternalTypes.CHAR => s"$rowTerm.setChar($indexTerm, $fieldTerm)" + case _: DateType => s"$rowTerm.setInt($indexTerm, $fieldTerm)" + case InternalTypes.TIME => s"$rowTerm.setInt($indexTerm, $fieldTerm)" + case _: TimestampType => s"$rowTerm.setLong($indexTerm, $fieldTerm)" + case _ => s"$rowTerm.setNonPrimitiveValue($indexTerm, $fieldTerm)" + } - val nullIsLastRet = if (nullsIsLast(i)) 1 else -1 + // -------------------------- BinaryArray Set Access ------------------------------- + + def binaryArraySetNull( + index: Int, + arrayTerm: String, + elementType: InternalType): String = elementType match { + case InternalTypes.BOOLEAN => s"$arrayTerm.setNullBoolean($index)" + case InternalTypes.BYTE => s"$arrayTerm.setNullByte($index)" + case InternalTypes.CHAR => s"$arrayTerm.setNullChar($index)" + case InternalTypes.SHORT => s"$arrayTerm.setNullShort($index)" + case InternalTypes.INT => s"$arrayTerm.setNullInt($index)" + case InternalTypes.LONG => s"$arrayTerm.setNullLong($index)" + case InternalTypes.FLOAT => s"$arrayTerm.setNullFloat($index)" + case InternalTypes.DOUBLE => s"$arrayTerm.setNullDouble($index)" + case InternalTypes.TIME => s"$arrayTerm.setNullInt($index)" + case _: DateType => s"$arrayTerm.setNullInt($index)" + case _: TimestampType => s"$arrayTerm.setNullLong($index)" + case _ => s"$arrayTerm.setNullLong($index)" + } - val t = keyTypes(i) + // -------------------------- BinaryWriter Write ------------------------------- - val typeTerm = primitiveTypeTermForType(t) - val fieldA = newName("fieldA") - val isNullA = newName("isNullA") - val fieldB = newName("fieldB") - val isNullB = newName("isNullB") - val comp = newName("comp") + def binaryWriterWriteNull(index: Int, writerTerm: String, t: InternalType): String = + binaryWriterWriteNull(index.toString, writerTerm, t) - val code = - s""" - |boolean $isNullA = $row1.isNullAt($index); - |boolean $isNullB = $row2.isNullAt($index); - |if ($isNullA && $isNullB) { - | // Continue to compare the next element - |} else if ($isNullA) { - | return $nullIsLastRet; - |} else if ($isNullB) { - | return ${-nullIsLastRet}; - |} else { - | $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, index, row1, t)}; - | $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, index, row2, t)}; - | int $comp = ${genCompare(ctx, t, nullsIsLast(i), fieldA, fieldB)}; - | if ($comp != 0) { - | return $symbol$comp; - | } - |} - """.stripMargin - compares += code - } - compares.mkString + def binaryWriterWriteNull( + indexTerm: String, + writerTerm: String, + t: InternalType): String = t match { + case d: DecimalType if !Decimal.isCompact(d.precision()) => + s"$writerTerm.writeDecimal($indexTerm, null, ${d.precision()})" + case _ => s"$writerTerm.setNullAt($indexTerm)" } + + def binaryWriterWriteField( + ctx: CodeGeneratorContext, + index: Int, + fieldValTerm: String, + writerTerm: String, + fieldType: InternalType): String = + binaryWriterWriteField(ctx, index.toString, fieldValTerm, writerTerm, fieldType) + + def binaryWriterWriteField( + ctx: CodeGeneratorContext, + indexTerm: String, + fieldValTerm: String, + writerTerm: String, + fieldType: InternalType): String = + fieldType match { + case InternalTypes.INT => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" + case InternalTypes.LONG => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)" + case InternalTypes.SHORT => s"$writerTerm.writeShort($indexTerm, $fieldValTerm)" + case InternalTypes.BYTE => s"$writerTerm.writeByte($indexTerm, $fieldValTerm)" + case InternalTypes.FLOAT => s"$writerTerm.writeFloat($indexTerm, $fieldValTerm)" + case InternalTypes.DOUBLE => s"$writerTerm.writeDouble($indexTerm, $fieldValTerm)" + case InternalTypes.BOOLEAN => s"$writerTerm.writeBoolean($indexTerm, $fieldValTerm)" + case InternalTypes.BINARY => s"$writerTerm.writeBinary($indexTerm, $fieldValTerm)" + case InternalTypes.STRING => s"$writerTerm.writeString($indexTerm, $fieldValTerm)" + case d: DecimalType => + s"$writerTerm.writeDecimal($indexTerm, $fieldValTerm, ${d.precision()})" + case InternalTypes.CHAR => s"$writerTerm.writeChar($indexTerm, $fieldValTerm)" + case _: DateType => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" + case InternalTypes.TIME => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" + case _: TimestampType => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)" + + // complex types + case _: ArrayType => s"$writerTerm.writeArray($indexTerm, $fieldValTerm)" + case _: MapType => s"$writerTerm.writeMap($indexTerm, $fieldValTerm)" + case _: RowType => + val serializerTerm = ctx.addReusableTypeSerializer(fieldType) + s"$writerTerm.writeRow($indexTerm, $fieldValTerm, $serializerTerm)" + + case _: GenericType[_] => s"$writerTerm.writeGeneric($indexTerm, $fieldValTerm)" + } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala index 98f1eb8eef51d..f611aec28647f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/CodeGeneratorContext.scala @@ -18,15 +18,20 @@ package org.apache.flink.table.codegen -import org.apache.calcite.avatica.util.DateTimeUtils +import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.common.functions.{Function, RuntimeContext} -import org.apache.flink.table.`type`.{InternalType, InternalTypes, RowType} +import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.table.`type`.{InternalType, InternalTypes, RowType, TypeConverters} import org.apache.flink.table.api.TableConfig import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.GenerateUtils.generateRecordStatement import org.apache.flink.table.dataformat.GenericRow import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction} +import org.apache.flink.table.runtime.util.collections._ import org.apache.flink.util.InstantiationUtil +import org.apache.calcite.avatica.util.DateTimeUtils + import scala.collection.mutable /** @@ -92,12 +97,10 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { // string_constant -> reused_term private val reusableStringConstants: mutable.Map[String, String] = mutable.Map[String, String]() - // map of local variable statements. It will be placed in method if method code not excess - // max code length, otherwise will be placed in member area of the class. The statements - // are maintained for multiple methods, so that it's a map from method_name to variables. - // - // method_name -> local_variable_statements - private val reusableLocalVariableStatements = mutable.Map[String, mutable.LinkedHashSet[String]]() + // map of type serializer that will be added only once + // InternalType -> reused_term + private val reusableTypeSerializers: mutable.Map[InternalType, String] = + mutable.Map[InternalType, String]() /** * The current method name for [[reusableLocalVariableStatements]]. You can start a new @@ -105,6 +108,14 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { */ private var currentMethodNameForLocalVariables = "DEFAULT" + // map of local variable statements. It will be placed in method if method code not excess + // max code length, otherwise will be placed in member area of the class. The statements + // are maintained for multiple methods, so that it's a map from method_name to variables. + // + // method_name -> local_variable_statements + private val reusableLocalVariableStatements = mutable.Map[String, mutable.LinkedHashSet[String]]( + (currentMethodNameForLocalVariables, mutable.LinkedHashSet[String]())) + // --------------------------------------------------------------------------------- // Getter // --------------------------------------------------------------------------------- @@ -112,7 +123,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { def getReusableInputUnboxingExprs(inputTerm: String, index: Int): Option[GeneratedExpression] = reusableInputUnboxingExprs.get((inputTerm, index)) - def getNullCheck: Boolean = tableConfig.getNullCheck + def nullCheck: Boolean = tableConfig.getNullCheck // --------------------------------------------------------------------------------- // Local Variables for Code Split @@ -137,7 +148,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { * @param fieldTypeTerm the field type term * @return a new generated unique field name */ - def newReusableLocalVariable(fieldTypeTerm: String, fieldName: String): String = { + def addReusableLocalVariable(fieldTypeTerm: String, fieldName: String): String = { val fieldTerm = newName(fieldName) reusableLocalVariableStatements .getOrElse(currentMethodNameForLocalVariables, mutable.LinkedHashSet[String]()) @@ -154,7 +165,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { * left is field type term and right is field name * @return the new generated unique field names for each variable pairs */ - def newReusableLocalFields(fieldTypeAndNames: (String, String)*): Seq[String] = { + def addReusableLocalVariables(fieldTypeAndNames: (String, String)*): Seq[String] = { val fieldTerms = newNames(fieldTypeAndNames.map(_._2): _*) fieldTypeAndNames.map(_._1).zip(fieldTerms).foreach { case (fieldTypeTerm, fieldTerm) => reusableLocalVariableStatements @@ -296,6 +307,11 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { reusableMemberStatements.add(memberStatement) } + /** + * Adds a reusable init statement which will be placed in constructor. + */ + def addReusableInitStatement(s: String): Unit = reusableInitStatements.add(s) + /** * Adds a reusable per record statement */ @@ -338,7 +354,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { clazz: Class[_], outRecordTerm: String, outRecordWriterTerm: Option[String] = None): Unit = { - val statement = generateOutputRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm) + val statement = generateRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm) reusableMemberStatements.add(statement) } @@ -352,6 +368,42 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { rowTerm) } + /** + * Adds a reusable internal hash set to the member area of the generated class. + */ + def addReusableHashSet(elements: Seq[GeneratedExpression], elementType: InternalType): String = { + val fieldTerm = newName("set") + + val setTypeTerm = elementType match { + case InternalTypes.BYTE => className[ByteHashSet] + case InternalTypes.SHORT => className[ShortHashSet] + case InternalTypes.INT => className[IntHashSet] + case InternalTypes.LONG => className[LongHashSet] + case InternalTypes.FLOAT => className[FloatHashSet] + case InternalTypes.DOUBLE => className[DoubleHashSet] + case _ => className[ObjectHashSet[_]] + } + + addReusableMember( + s"final $setTypeTerm $fieldTerm = new $setTypeTerm(${elements.size})") + + elements.foreach { element => + val content = + s""" + |${element.code} + |if (${element.nullTerm}) { + | $fieldTerm.addNull(); + |} else { + | $fieldTerm.add(${element.resultTerm}); + |} + |""".stripMargin + reusableInitStatements.add(content) + } + reusableInitStatements.add(s"$fieldTerm.optimize();") + + fieldTerm + } + /** * Adds a reusable timestamp to the beginning of the SAM of the generated class. */ @@ -458,7 +510,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { |""".stripMargin val fieldInit = seedExpr match { - case Some(s) if getNullCheck => + case Some(s) if nullCheck => s""" |${s.code} |if (!${s.nullTerm}) { @@ -550,6 +602,28 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { fieldTerm } + /** + * Adds a reusable [[TypeSerializer]] to the member area of the generated class. + * + * @param t the internal type which used to generate internal type serializer + * @return member variable term + */ + def addReusableTypeSerializer(t: InternalType): String = { + // if type serializer has been used before, we can reuse the code that + // has already been generated + reusableTypeSerializers.get(t) match { + case Some(term) => term + + case None => + val term = newName("typeSerializer") + val ser = TypeConverters.createInternalTypeInfoFromInternalType(t) + .createSerializer(new ExecutionConfig) + addReusableObjectInternal(ser, term, ser.getClass.getCanonicalName) + reusableTypeSerializers(t) = term + term + } + } + /** * Adds a reusable static SLF4J Logger to the member area of the generated class. */ @@ -607,7 +681,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { val field = newName("str") val stmt = s""" - |private final $BINARY_STRING $field = $BINARY_STRING.fromString("$value");" + |private final $BINARY_STRING $field = $BINARY_STRING.fromString("$value"); """.stripMargin reusableMemberStatements.add(stmt) reusableStringConstants(value) = field @@ -686,3 +760,9 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { fieldTerm } } + +object CodeGeneratorContext { + def apply(config: TableConfig): CodeGeneratorContext = { + new CodeGeneratorContext(config) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ExprCodeGenerator.scala new file mode 100644 index 0000000000000..14a326c579139 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/ExprCodeGenerator.scala @@ -0,0 +1,721 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen + +import org.apache.calcite.rex._ +import org.apache.calcite.sql.SqlOperator +import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName} +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.table.`type`._ +import org.apache.flink.table.api.TableException +import org.apache.flink.table.calcite.{FlinkTypeFactory, RexAggLocalVariable, RexDistinctKeyVariable} +import org.apache.flink.table.codegen.CodeGenUtils.{requireTemporal, requireTimeInterval, _} +import org.apache.flink.table.codegen.GenerateUtils._ +import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} +import org.apache.flink.table.codegen.calls.ScalarOperatorGens._ +import org.apache.flink.table.dataformat._ +import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable._ +import org.apache.flink.table.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval} +import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TypeCheckUtils} + +import scala.collection.JavaConversions._ + +/** + * This code generator is mainly responsible for generating codes for a given calcite [[RexNode]]. + * It can also generate type conversion codes for the result converter. + */ +class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) + extends RexVisitor[GeneratedExpression] { + + // check if nullCheck is enabled when inputs can be null + if (nullableInput && !ctx.nullCheck) { + throw new CodeGenException("Null check must be enabled if entire rows can be null.") + } + + /** + * term of the [[ProcessFunction]]'s context, can be changed when needed + */ + var contextTerm = "ctx" + + /** + * information of the first input + */ + var input1Type: InternalType = _ + var input1Term: String = _ + var input1FieldMapping: Option[Array[Int]] = None + + /** + * information of the optional second input + */ + var input2Type: Option[InternalType] = None + var input2Term: Option[String] = None + var input2FieldMapping: Option[Array[Int]] = None + + /** + * Bind the input information, should be called before generating expression. + */ + def bindInput( + inputType: InternalType, + inputTerm: String = DEFAULT_INPUT1_TERM, + inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = { + input1Type = inputType + input1Term = inputTerm + input1FieldMapping = inputFieldMapping + this + } + + /** + * In some cases, the expression will have two inputs (e.g. join condition and udtf). We should + * bind second input information before use. + */ + def bindSecondInput( + inputType: InternalType, + inputTerm: String = DEFAULT_INPUT2_TERM, + inputFieldMapping: Option[Array[Int]] = None): ExprCodeGenerator = { + input2Type = Some(inputType) + input2Term = Some(inputTerm) + input2FieldMapping = inputFieldMapping + this + } + + private lazy val input1Mapping: Array[Int] = input1FieldMapping match { + case Some(mapping) => mapping + case _ => fieldIndices(input1Type) + } + + private lazy val input2Mapping: Array[Int] = input2FieldMapping match { + case Some(mapping) => mapping + case _ => input2Type match { + case Some(input) => fieldIndices(input) + case _ => Array[Int]() + } + } + + private def fieldIndices(t: InternalType): Array[Int] = t match { + case rt: RowType => (0 until rt.getArity).toArray + case _ => Array(0) + } + + /** + * Generates an expression from a RexNode. If objects or variables can be reused, they will be + * added to reusable code sections internally. + * + * @param rex Calcite row expression + * @return instance of GeneratedExpression + */ + def generateExpression(rex: RexNode): GeneratedExpression = { + rex.accept(this) + } + + /** + * Generates an expression that converts the first input (and second input) into the given type. + * If two inputs are converted, the second input is appended. If objects or variables can + * be reused, they will be added to reusable code sections internally. The evaluation result + * will be stored in the variable outRecordTerm. + * + * @param returnType conversion target type. Inputs and output must have the same arity. + * @param outRecordTerm the result term + * @param outRecordWriterTerm the result writer term + * @param reusedOutRow If objects or variables can be reused, they will be added to reusable + * code sections internally. + * @return instance of GeneratedExpression + */ + def generateConverterResultExpression( + returnType: RowType, + returnTypeClazz: Class[_ <: BaseRow], + outRecordTerm: String = DEFAULT_OUT_RECORD_TERM, + outRecordWriterTerm: String = DEFAULT_OUT_RECORD_WRITER_TERM, + reusedOutRow: Boolean = true, + fieldCopy: Boolean = false, + rowtimeExpression: Option[RexNode] = None) + : GeneratedExpression = { + val input1AccessExprs = input1Mapping.map { + case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER | + TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER if rowtimeExpression.isDefined => + // generate rowtime attribute from expression + generateExpression(rowtimeExpression.get) + case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER | + TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER => + throw new TableException("Rowtime extraction expression missing. Please report a bug.") + case TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER => + // attribute is proctime indicator. + // we use a null literal and generate a timestamp when we need it. + generateNullLiteral(InternalTypes.PROCTIME_INDICATOR, ctx.nullCheck) + case TimeIndicatorTypeInfo.PROCTIME_BATCH_MARKER => + // attribute is proctime field in a batch query. + // it is initialized with the current time. + generateCurrentTimestamp(ctx) + case idx => + // get type of result field + generateInputAccess( + ctx, + input1Type, + input1Term, + idx, + nullableInput, + fieldCopy) + } + + val input2AccessExprs = input2Type match { + case Some(ti) => + input2Mapping.map(idx => generateInputAccess( + ctx, + ti, + input2Term.get, + idx, + nullableInput, + ctx.nullCheck) + ).toSeq + case None => Seq() // add nothing + } + + generateResultExpression( + input1AccessExprs ++ input2AccessExprs, + returnType, + returnTypeClazz, + outRow = outRecordTerm, + outRowWriter = Some(outRecordWriterTerm), + reusedOutRow = reusedOutRow) + } + + /** + * Generates an expression from a sequence of other expressions. The evaluation result + * may be stored in the variable outRecordTerm. + * + * @param fieldExprs field expressions to be converted + * @param returnType conversion target type. Type must have the same arity than fieldExprs. + * @param outRow the result term + * @param outRowWriter the result writer term for BinaryRow. + * @param reusedOutRow If objects or variables can be reused, they will be added to reusable + * code sections internally. + * @param outRowAlreadyExists Don't need addReusableRecord if out row already exists. + * @return instance of GeneratedExpression + */ + def generateResultExpression( + fieldExprs: Seq[GeneratedExpression], + returnType: RowType, + returnTypeClazz: Class[_ <: BaseRow], + outRow: String = DEFAULT_OUT_RECORD_TERM, + outRowWriter: Option[String] = Some(DEFAULT_OUT_RECORD_WRITER_TERM), + reusedOutRow: Boolean = true, + outRowAlreadyExists: Boolean = false): GeneratedExpression = { + val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap + generateResultExpression(fieldExprs, fieldExprIdxToOutputRowPosMap, returnType, + returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists) + } + + /** + * Generates an expression from a sequence of other expressions. The evaluation result + * may be stored in the variable outRecordTerm. + * + * @param fieldExprs field expressions to be converted + * @param fieldExprIdxToOutputRowPosMap Mapping index of fieldExpr in `fieldExprs` + * to position of output row. + * @param returnType conversion target type. Type must have the same arity than fieldExprs. + * @param outRow the result term + * @param outRowWriter the result writer term for BinaryRow. + * @param reusedOutRow If objects or variables can be reused, they will be added to reusable + * code sections internally. + * @param outRowAlreadyExists Don't need addReusableRecord if out row already exists. + * @return instance of GeneratedExpression + */ + def generateResultExpression( + fieldExprs: Seq[GeneratedExpression], + fieldExprIdxToOutputRowPosMap: Map[Int, Int], + returnType: RowType, + returnTypeClazz: Class[_ <: BaseRow], + outRow: String, + outRowWriter: Option[String], + reusedOutRow: Boolean, + outRowAlreadyExists: Boolean) + : GeneratedExpression = { + // initial type check + if (returnType.getArity != fieldExprs.length) { + throw new CodeGenException( + s"Arity [${returnType.getArity}] of result type [$returnType] does not match " + + s"number [${fieldExprs.length}] of expressions [$fieldExprs].") + } + if (fieldExprIdxToOutputRowPosMap.size != fieldExprs.length) { + throw new CodeGenException( + s"Size [${returnType.getArity}] of fieldExprIdxToOutputRowPosMap does not match " + + s"number [${fieldExprs.length}] of expressions [$fieldExprs].") + } + // type check + fieldExprs.zipWithIndex foreach { + // timestamp type(Include TimeIndicator) and generic type can compatible with each other. + case (fieldExpr, i) + if fieldExpr.resultType.isInstanceOf[GenericType[_]] || + fieldExpr.resultType.isInstanceOf[TimestampType] => + if (returnType.getTypeAt(i).getClass != fieldExpr.resultType.getClass + && !returnType.getTypeAt(i).isInstanceOf[GenericType[_]]) { + throw new CodeGenException( + s"Incompatible types of expression and result type, Expression[$fieldExpr] type is " + + s"[${fieldExpr.resultType}], result type is [${returnType.getTypeAt(i)}]") + } + case (fieldExpr, i) if fieldExpr.resultType != returnType.getTypeAt(i) => + throw new CodeGenException( + s"Incompatible types of expression and result type. Expression[$fieldExpr] type is " + + s"[${fieldExpr.resultType}], result type is [${returnType.getTypeAt(i)}]") + case _ => // ok + } + + val setFieldsCode = fieldExprs.zipWithIndex.map { case (fieldExpr, index) => + val pos = fieldExprIdxToOutputRowPosMap.getOrElse(index, + throw new CodeGenException(s"Illegal field expr index: $index")) + baseRowSetField(ctx, returnTypeClazz, outRow, pos.toString, fieldExpr, outRowWriter) + }.mkString("\n") + + val outRowInitCode = if (!outRowAlreadyExists) { + val initCode = generateRecordStatement(returnType, returnTypeClazz, outRow, outRowWriter) + if (reusedOutRow) { + ctx.addReusableMember(initCode) + NO_CODE + } else { + initCode + } + } else { + NO_CODE + } + + val code = if (returnTypeClazz == classOf[BinaryRow] && outRowWriter.isDefined) { + val writer = outRowWriter.get + val resetWriter = if (ctx.nullCheck) s"$writer.reset();" else s"$writer.resetCursor();" + val completeWriter: String = s"$writer.complete();" + s""" + |$outRowInitCode + |$resetWriter + |$setFieldsCode + |$completeWriter + """.stripMargin + } else { + s""" + |$outRowInitCode + |$setFieldsCode + """.stripMargin + } + GeneratedExpression(outRow, NEVER_NULL, code, returnType) + } + + override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = { + // if inputRef index is within size of input1 we work with input1, input2 otherwise + val input = if (inputRef.getIndex < InternalTypeUtils.getArity(input1Type)) { + (input1Type, input1Term) + } else { + (input2Type.getOrElse(throw new CodeGenException("Invalid input access.")), + input2Term.getOrElse(throw new CodeGenException("Invalid input access."))) + } + + val index = if (input._2 == input1Term) { + inputRef.getIndex + } else { + inputRef.getIndex - InternalTypeUtils.getArity(input1Type) + } + + generateInputAccess(ctx, input._1, input._2, index, nullableInput, ctx.nullCheck) + } + + override def visitTableInputRef(rexTableInputRef: RexTableInputRef): GeneratedExpression = + visitInputRef(rexTableInputRef) + + override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = { + val refExpr = rexFieldAccess.getReferenceExpr.accept(this) + val index = rexFieldAccess.getField.getIndex + val fieldAccessExpr = generateFieldAccess( + ctx, + refExpr.resultType, + refExpr.resultTerm, + index) + + val resultTypeTerm = primitiveTypeTermForType(fieldAccessExpr.resultType) + val defaultValue = primitiveDefaultValue(fieldAccessExpr.resultType) + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "result"), + ("boolean", "isNull")) + + val resultCode = if (ctx.nullCheck) { + s""" + |${refExpr.code} + |if (${refExpr.nullTerm}) { + | $resultTerm = $defaultValue; + | $nullTerm = true; + |} + |else { + | ${fieldAccessExpr.code} + | $resultTerm = ${fieldAccessExpr.resultTerm}; + | $nullTerm = ${fieldAccessExpr.nullTerm}; + |} + |""".stripMargin + } else { + s""" + |${refExpr.code} + |${fieldAccessExpr.code} + |$resultTerm = ${fieldAccessExpr.resultTerm}; + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, resultCode, fieldAccessExpr.resultType) + } + + override def visitLiteral(literal: RexLiteral): GeneratedExpression = { + val resultType = FlinkTypeFactory.toInternalType(literal.getType) + val value = literal.getValue3 + generateLiteral(ctx, resultType, value) + } + + override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = { + GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1Type) + } + + override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = localRef match { + case local: RexAggLocalVariable => + GeneratedExpression(local.fieldTerm, local.nullTerm, NO_CODE, local.internalType) + case value: RexDistinctKeyVariable => + val inputExpr = ctx.getReusableInputUnboxingExprs(input1Term, 0) match { + case Some(expr) => expr + case None => + val pType = primitiveTypeTermForType(value.internalType) + val defaultValue = primitiveDefaultValue(value.internalType) + val resultTerm = newName("field") + val nullTerm = newName("isNull") + val code = + s""" + |$pType $resultTerm = $defaultValue; + |boolean $nullTerm = true; + |if ($input1Term != null) { + | $nullTerm = false; + | $resultTerm = ($pType) $input1Term; + |} + """.stripMargin + val expr = GeneratedExpression(resultTerm, nullTerm, code, value.internalType) + ctx.addReusableInputUnboxingExprs(input1Term, 0, expr) + expr + } + // hide the generated code as it will be executed only once + GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, NO_CODE, inputExpr.resultType) + case _ => throw new CodeGenException("Local variables are not supported yet.") + } + + override def visitRangeRef(rangeRef: RexRangeRef): GeneratedExpression = + throw new CodeGenException("Range references are not supported yet.") + + override def visitDynamicParam(dynamicParam: RexDynamicParam): GeneratedExpression = + throw new CodeGenException("Dynamic parameter references are not supported yet.") + + override def visitCall(call: RexCall): GeneratedExpression = { + + val resultType = FlinkTypeFactory.toInternalType(call.getType) + + // convert operands and help giving untyped NULL literals a type + val operands = call.getOperands.zipWithIndex.map { + + // this helps e.g. for AS(null) + // we might need to extend this logic in case some rules do not create typed NULLs + case (operandLiteral: RexLiteral, 0) if + operandLiteral.getType.getSqlTypeName == SqlTypeName.NULL && + call.getOperator.getReturnTypeInference == ReturnTypes.ARG0 => + generateNullLiteral(resultType, ctx.nullCheck) + + case (o@_, _) => o.accept(this) + } + + generateCallExpression(ctx, call.getOperator, operands, resultType) + } + + override def visitOver(over: RexOver): GeneratedExpression = + throw new CodeGenException("Aggregate functions over windows are not supported yet.") + + override def visitSubQuery(subQuery: RexSubQuery): GeneratedExpression = + throw new CodeGenException("Subqueries are not supported yet.") + + override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = + throw new CodeGenException("Pattern field references are not supported yet.") + + // ---------------------------------------------------------------------------------------- + + private def generateCallExpression( + ctx: CodeGeneratorContext, + operator: SqlOperator, + operands: Seq[GeneratedExpression], + resultType: InternalType): GeneratedExpression = { + operator match { + // arithmetic + case PLUS if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "+", resultType, left, right) + + case PLUS | DATETIME_PLUS if isTemporal(resultType) => + val left = operands.head + val right = operands(1) + requireTemporal(left) + requireTemporal(right) + generateTemporalPlusMinus(ctx, plus = true, resultType, left, right) + + case MINUS if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "-", resultType, left, right) + + case MINUS | MINUS_DATE if isTemporal(resultType) => + val left = operands.head + val right = operands(1) + requireTemporal(left) + requireTemporal(right) + generateTemporalPlusMinus(ctx, plus = false, resultType, left, right) + + case MULTIPLY if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "*", resultType, left, right) + + case MULTIPLY if isTimeInterval(resultType) => + val left = operands.head + val right = operands(1) + requireTimeInterval(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "*", resultType, left, right) + + case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "/", resultType, left, right) + + case MOD if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateBinaryArithmeticOperator(ctx, "%", resultType, left, right) + + case UNARY_MINUS if isNumeric(resultType) => + val operand = operands.head + requireNumeric(operand) + generateUnaryArithmeticOperator(ctx, "-", resultType, operand) + + case UNARY_MINUS if isTimeInterval(resultType) => + val operand = operands.head + requireTimeInterval(operand) + generateUnaryIntervalPlusMinus(ctx, plus = false, operand) + + case UNARY_PLUS if isNumeric(resultType) => + val operand = operands.head + requireNumeric(operand) + generateUnaryArithmeticOperator(ctx, "+", resultType, operand) + + case UNARY_PLUS if isTimeInterval(resultType) => + val operand = operands.head + requireTimeInterval(operand) + generateUnaryIntervalPlusMinus(ctx, plus = true, operand) + + // comparison + case EQUALS => + val left = operands.head + val right = operands(1) + generateEquals(ctx, left, right) + + case NOT_EQUALS => + val left = operands.head + val right = operands(1) + generateNotEquals(ctx, left, right) + + case GREATER_THAN => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, ">", left, right) + + case GREATER_THAN_OR_EQUAL => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, ">=", left, right) + + case LESS_THAN => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, "<", left, right) + + case LESS_THAN_OR_EQUAL => + val left = operands.head + val right = operands(1) + requireComparable(left) + requireComparable(right) + generateComparison(ctx, "<=", left, right) + + case IS_NULL => + val operand = operands.head + generateIsNull(ctx, operand) + + case IS_NOT_NULL => + val operand = operands.head + generateIsNotNull(ctx, operand) + + // logic + case AND => + operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) => + requireBoolean(left) + requireBoolean(right) + generateAnd(ctx, left, right) + } + + case OR => + operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) => + requireBoolean(left) + requireBoolean(right) + generateOr(ctx, left, right) + } + + case NOT => + val operand = operands.head + requireBoolean(operand) + generateNot(ctx, operand) + + case CASE => + generateIfElse(ctx, operands, resultType) + + case IS_TRUE => + val operand = operands.head + requireBoolean(operand) + generateIsTrue(operand) + + case IS_NOT_TRUE => + val operand = operands.head + requireBoolean(operand) + generateIsNotTrue(operand) + + case IS_FALSE => + val operand = operands.head + requireBoolean(operand) + generateIsFalse(operand) + + case IS_NOT_FALSE => + val operand = operands.head + requireBoolean(operand) + generateIsNotFalse(operand) + + case IN => + val left = operands.head + val right = operands.tail + generateIn(ctx, left, right) + + case NOT_IN => + val left = operands.head + val right = operands.tail + generateNot(ctx, generateIn(ctx, left, right)) + + // casting + case CAST => + val operand = operands.head + generateCast(ctx, operand, resultType) + + // Reinterpret + case REINTERPRET => + val operand = operands.head + generateReinterpret(ctx, operand, resultType) + + // as / renaming + case AS => + operands.head + + // rows + case ROW => + generateRow(ctx, resultType, operands) + + // arrays + case ARRAY_VALUE_CONSTRUCTOR => + generateArray(ctx, resultType, operands) + + // maps + case MAP_VALUE_CONSTRUCTOR => + generateMap(ctx, resultType, operands) + + case ITEM => + operands.head.resultType match { + case t: InternalType if TypeCheckUtils.isArray(t) => + val array = operands.head + val index = operands(1) + requireInteger(index) + generateArrayElementAt(ctx, array, index) + + case t: InternalType if TypeCheckUtils.isMap(t) => + val key = operands(1) + generateMapGet(ctx, operands.head, key) + + case _ => throw new CodeGenException("Expect an array or a map.") + } + + case CARDINALITY => + operands.head.resultType match { + case t: InternalType if TypeCheckUtils.isArray(t) => + val array = operands.head + generateArrayCardinality(ctx, array) + + case t: InternalType if TypeCheckUtils.isMap(t) => + val map = operands.head + generateMapCardinality(ctx, map) + + case _ => throw new CodeGenException("Expect an array or a map.") + } + + case ELEMENT => + val array = operands.head + requireArray(array) + generateArrayElement(ctx, array) + + case DOT => + generateDot(ctx, operands) + + case PROCTIME => + // attribute is proctime indicator. + // We use a null literal and generate a timestamp when we need it. + generateNullLiteral(InternalTypes.PROCTIME_INDICATOR, ctx.nullCheck) + + case PROCTIME_MATERIALIZE => + generateProctimeTimestamp(ctx, contextTerm) + + case STREAMRECORD_TIMESTAMP => + generateRowtimeAccess(ctx, contextTerm) + + // advanced scalar functions + case sqlOperator: SqlOperator => + val explainCall = s"$sqlOperator(${operands.map(_.resultType).mkString(", ")})" + // TODO: support BinaryStringCallGen and FunctionGenerator + throw new CodeGenException(s"Unsupported call: $explainCall \n" + + s"If you think this function should be supported, " + + s"you can create an issue and start a discussion for it.") + + // unknown or invalid + case call@_ => + val explainCall = s"$call(${operands.map(_.resultType).mkString(", ")})" + throw new CodeGenException(s"Unsupported call: $explainCall") + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/FunctionCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/FunctionCodeGenerator.scala new file mode 100644 index 0000000000000..f123b48a29aeb --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/FunctionCodeGenerator.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.codegen + +import org.apache.flink.api.common.functions._ +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.functions.async.{AsyncFunction, RichAsyncFunction} +import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.generated.GeneratedFunction + +/** + * A code generator for generating Flink [[org.apache.flink.api.common.functions.Function]]s. + * Including [[MapFunction]], [[FlatMapFunction]], [[FlatJoinFunction]], [[ProcessFunction]], and + * the corresponding rich version of the functions. + */ +object FunctionCodeGenerator { + + /** + * Generates a [[org.apache.flink.api.common.functions.Function]] that can be passed to Java + * compiler. + * + * @param ctx The context of the code generator + * @param name Class name of the Function. Must not be unique but has to be a valid Java class + * identifier. + * @param clazz Flink Function to be generated. + * @param bodyCode code contents of the SAM (Single Abstract Method). Inputs, collector, or + * output record can be accessed via the given term methods. + * @param returnType expected return type + * @param input1Type the first input type + * @param input1Term the first input term + * @param input2Type the second input type, optional. + * @param input2Term the second input term. + * @param collectorTerm the collector term + * @param contextTerm the context term + * @tparam F Flink Function to be generated. + * @return instance of GeneratedFunction + */ + def generateFunction[F <: Function]( + ctx: CodeGeneratorContext, + name: String, + clazz: Class[F], + bodyCode: String, + returnType: InternalType, + input1Type: InternalType, + input1Term: String = DEFAULT_INPUT1_TERM, + input2Type: Option[InternalType] = None, + input2Term: Option[String] = Some(DEFAULT_INPUT2_TERM), + collectorTerm: String = DEFAULT_COLLECTOR_TERM, + contextTerm: String = DEFAULT_CONTEXT_TERM) + : GeneratedFunction[F] = { + val funcName = newName(name) + val inputTypeTerm = boxedTypeTermForType(input1Type) + + // Janino does not support generics, that's why we need + // manual casting here + val samHeader = + // FlatMapFunction + if (clazz == classOf[FlatMapFunction[_, _]]) { + val baseClass = classOf[RichFlatMapFunction[_, _]] + (baseClass, + s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + // MapFunction + else if (clazz == classOf[MapFunction[_, _]]) { + val baseClass = classOf[RichMapFunction[_, _]] + (baseClass, + "Object map(Object _in1)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + // FlatJoinFunction + else if (clazz == classOf[FlatJoinFunction[_, _, _]]) { + val baseClass = classOf[RichFlatJoinFunction[_, _, _]] + val inputTypeTerm2 = boxedTypeTermForType(input2Type.getOrElse( + throw new CodeGenException("Input 2 for FlatJoinFunction should not be null"))) + (baseClass, + s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;", + s"$inputTypeTerm2 ${input2Term.get} = ($inputTypeTerm2) _in2;")) + } + + // ProcessFunction + else if (clazz == classOf[ProcessFunction[_, _]]) { + val baseClass = classOf[ProcessFunction[_, _]] + (baseClass, + s"void processElement(Object _in1, " + + s"org.apache.flink.streaming.api.functions.ProcessFunction.Context $contextTerm," + + s"org.apache.flink.util.Collector $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + // AsyncFunction + else if (clazz == classOf[AsyncFunction[_, _]]) { + val baseClass = classOf[RichAsyncFunction[_, _]] + (baseClass, + s"void asyncInvoke(Object _in1, " + + s"org.apache.flink.streaming.api.functions.async.ResultFuture $collectorTerm)", + List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) + } + + else { + // TODO more functions + throw new CodeGenException("Unsupported Function.") + } + + val funcCode = + j""" + public class $funcName + extends ${samHeader._1.getCanonicalName} { + + ${ctx.reuseMemberCode()} + + public $funcName(Object[] references) throws Exception { + ${ctx.reuseInitCode()} + } + + ${ctx.reuseConstructorCode(funcName)} + + @Override + public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception { + ${ctx.reuseOpenCode()} + } + + @Override + public ${samHeader._2} throws Exception { + ${samHeader._3.mkString("\n")} + ${ctx.reusePerRecordCode()} + ${ctx.reuseLocalVariableCode()} + ${ctx.reuseInputUnboxingCode()} + $bodyCode + } + + @Override + public void close() throws Exception { + ${ctx.reuseCloseCode()} + } + } + """.stripMargin + + new GeneratedFunction(funcName, funcCode, ctx.references.toArray) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GenerateUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GenerateUtils.scala new file mode 100644 index 0000000000000..28fe506dddeff --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GenerateUtils.scala @@ -0,0 +1,751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.{AtomicType => AtomicTypeInfo} +import org.apache.flink.table.`type`._ +import org.apache.flink.table.calcite.FlinkPlannerImpl +import org.apache.flink.table.codegen.CodeGenUtils._ +import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE} +import org.apache.flink.table.codegen.calls.CurrentTimePointCallGen +import org.apache.flink.table.dataformat._ +import org.apache.flink.table.typeutils.TypeCheckUtils.{isReference, isTemporal} + +import org.apache.calcite.avatica.util.ByteString +import org.apache.commons.lang3.StringEscapeUtils + +import java.math.{BigDecimal => JBigDecimal} + +import scala.collection.mutable + +/** + * Utilities to generate code for general purpose. + */ +object GenerateUtils { + + // ---------------------------------------------------------------------------------------- + // basic call generate utils + // ---------------------------------------------------------------------------------------- + + /** + * Generates a call with a single result statement. + */ + def generateCallIfArgsNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + operands: Seq[GeneratedExpression], + resultNullable: Boolean = false) + (call: Seq[String] => String): GeneratedExpression = { + generateCallWithStmtIfArgsNotNull(ctx, returnType, operands, resultNullable) { + args => ("", call(args)) + } + } + + /** + * Generates a call with auxiliary statements and result expression. + */ + def generateCallWithStmtIfArgsNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + operands: Seq[GeneratedExpression], + resultNullable: Boolean = false) + (call: Seq[String] => (String, String)): GeneratedExpression = { + val resultTypeTerm = primitiveTypeTermForType(returnType) + val nullTerm = ctx.addReusableLocalVariable("boolean", "isNull") + val resultTerm = ctx.addReusableLocalVariable(resultTypeTerm, "result") + val defaultValue = primitiveDefaultValue(returnType) + val isResultNullable = resultNullable || (isReference(returnType) && !isTemporal(returnType)) + val nullTermCode = if (ctx.nullCheck && isResultNullable) { + s"$nullTerm = ($resultTerm == null);" + } else { + "" + } + + val (stmt, result) = call(operands.map(_.resultTerm)) + + val resultCode = if (ctx.nullCheck && operands.nonEmpty) { + s""" + |${operands.map(_.code).mkString("\n")} + |$nullTerm = ${operands.map(_.nullTerm).mkString(" || ")}; + |$resultTerm = $defaultValue; + |if (!$nullTerm) { + | $stmt + | $resultTerm = $result; + | $nullTermCode + |} + |""".stripMargin + } else if (ctx.nullCheck && operands.isEmpty) { + s""" + |${operands.map(_.code).mkString("\n")} + |$nullTerm = false; + |$stmt + |$resultTerm = $result; + |$nullTermCode + |""".stripMargin + } else { + s""" + |$nullTerm = false; + |${operands.map(_.code).mkString("\n")} + |$stmt + |$resultTerm = $result; + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, resultCode, returnType) + } + + /** + * Generates a string result call with a single result statement. + * This will convert the String result to BinaryString. + */ + def generateStringResultCallIfArgsNotNull( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression]) + (call: Seq[String] => String): GeneratedExpression = { + generateCallIfArgsNotNull(ctx, InternalTypes.STRING, operands) { + args => s"$BINARY_STRING.fromString(${call(args)})" + } + } + + + /** + * Generates a string result call with auxiliary statements and result expression. + * This will convert the String result to BinaryString. + */ + def generateStringResultCallWithStmtIfArgsNotNull( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression]) + (call: Seq[String] => (String, String)): GeneratedExpression = { + generateCallWithStmtIfArgsNotNull(ctx, InternalTypes.STRING, operands) { + args => + val (stmt, result) = call(args) + (stmt, s"$BINARY_STRING.fromString($result)") + } + } + + // --------------------------- General Generate Utils ---------------------------------- + + /** + * Generates a record declaration statement. The record can be any type of BaseRow or + * other types. + * @param t the record type + * @param clazz the specified class of the type (only used when RowType) + * @param recordTerm the record term to be declared + * @param recordWriterTerm the record writer term (only used when BinaryRow type) + * @return the record declaration statement + */ + def generateRecordStatement( + t: InternalType, + clazz: Class[_], + recordTerm: String, + recordWriterTerm: Option[String] = None): String = { + t match { + case rt: RowType if clazz == classOf[BinaryRow] => + val writerTerm = recordWriterTerm.getOrElse( + throw new CodeGenException("No writer is specified when writing BinaryRow record.") + ) + val binaryRowWriter = className[BinaryRowWriter] + val typeTerm = clazz.getCanonicalName + s""" + |final $typeTerm $recordTerm = new $typeTerm(${rt.getArity}); + |final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm); + |""".stripMargin.trim + case rt: RowType if classOf[ObjectArrayRow].isAssignableFrom(clazz) => + val typeTerm = clazz.getCanonicalName + s"final $typeTerm $recordTerm = new $typeTerm(${rt.getArity});" + case _: RowType if clazz == classOf[JoinedRow] => + val typeTerm = clazz.getCanonicalName + s"final $typeTerm $recordTerm = new $typeTerm();" + case _ => + val typeTerm = boxedTypeTermForType(t) + s"final $typeTerm $recordTerm = new $typeTerm();" + } + } + + def generateNullLiteral( + resultType: InternalType, + nullCheck: Boolean): GeneratedExpression = { + val defaultValue = primitiveDefaultValue(resultType) + val resultTypeTerm = primitiveTypeTermForType(resultType) + if (nullCheck) { + GeneratedExpression( + s"(($resultTypeTerm) $defaultValue)", + ALWAYS_NULL, + NO_CODE, + resultType, + literalValue = Some(null)) // the literal is null + } else { + throw new CodeGenException("Null literals are not allowed if nullCheck is disabled.") + } + } + + def generateNonNullLiteral( + literalType: InternalType, + literalCode: String, + literalValue: Any): GeneratedExpression = { + val resultTypeTerm = primitiveTypeTermForType(literalType) + GeneratedExpression( + s"(($resultTypeTerm) $literalCode)", + NEVER_NULL, + NO_CODE, + literalType, + literalValue = Some(literalValue)) + } + + def generateLiteral( + ctx: CodeGeneratorContext, + literalType: InternalType, + literalValue: Any): GeneratedExpression = { + if (literalValue == null) { + return generateNullLiteral(literalType, ctx.nullCheck) + } + // non-null values + literalType match { + + case InternalTypes.BOOLEAN => + generateNonNullLiteral(literalType, literalValue.toString, literalValue) + + case InternalTypes.BYTE => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral(literalType, decimal.byteValue().toString, decimal.byteValue()) + + case InternalTypes.SHORT => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral(literalType, decimal.shortValue().toString, decimal.shortValue()) + + case InternalTypes.INT => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral(literalType, decimal.intValue().toString, decimal.intValue()) + + case InternalTypes.LONG => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + generateNonNullLiteral( + literalType, decimal.longValue().toString + "L", decimal.longValue()) + + case InternalTypes.FLOAT => + val floatValue = literalValue.asInstanceOf[JBigDecimal].floatValue() + floatValue match { + case Float.NaN => generateNonNullLiteral( + literalType, "java.lang.Float.NaN", Float.NaN) + case Float.NegativeInfinity => + generateNonNullLiteral( + literalType, + "java.lang.Float.NEGATIVE_INFINITY", + Float.NegativeInfinity) + case Float.PositiveInfinity => generateNonNullLiteral( + literalType, + "java.lang.Float.POSITIVE_INFINITY", + Float.PositiveInfinity) + case _ => generateNonNullLiteral( + literalType, floatValue.toString + "f", floatValue) + } + + case InternalTypes.DOUBLE => + val doubleValue = literalValue.asInstanceOf[JBigDecimal].doubleValue() + doubleValue match { + case Double.NaN => generateNonNullLiteral( + literalType, "java.lang.Double.NaN", Double.NaN) + case Double.NegativeInfinity => + generateNonNullLiteral( + literalType, + "java.lang.Double.NEGATIVE_INFINITY", + Double.NegativeInfinity) + case Double.PositiveInfinity => + generateNonNullLiteral( + literalType, + "java.lang.Double.POSITIVE_INFINITY", + Double.PositiveInfinity) + case _ => generateNonNullLiteral( + literalType, doubleValue.toString + "d", doubleValue) + } + case decimal: DecimalType => + val precision = decimal.precision() + val scale = decimal.scale() + val fieldTerm = newName("decimal") + val decimalClass = className[Decimal] + val fieldDecimal = + s""" + |$decimalClass $fieldTerm = + | $decimalClass.castFrom("${literalValue.toString}", $precision, $scale); + |""".stripMargin + ctx.addReusableMember(fieldDecimal) + val value = Decimal.fromBigDecimal( + literalValue.asInstanceOf[JBigDecimal], precision, scale) + generateNonNullLiteral(literalType, fieldTerm, value) + + case InternalTypes.STRING => + val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString) + val field = ctx.addReusableStringConstants(escapedValue) + generateNonNullLiteral(literalType, field, BinaryString.fromString(escapedValue)) + + case InternalTypes.BINARY => + val bytesVal = literalValue.asInstanceOf[ByteString].getBytes + val fieldTerm = ctx.addReusableObject( + bytesVal, "binary", bytesVal.getClass.getCanonicalName) + generateNonNullLiteral(literalType, fieldTerm, bytesVal) + + case InternalTypes.DATE => + generateNonNullLiteral(literalType, literalValue.toString, literalValue) + + case InternalTypes.TIME => + generateNonNullLiteral(literalType, literalValue.toString, literalValue) + + case InternalTypes.TIMESTAMP => + // Hack + // Currently, in RexLiteral/SqlLiteral(Calcite), TimestampString has no time zone. + // TimeString, DateString TimestampString are treated as UTC time/(unix time) + // when they are converted/formatted/validated + // Here, we adjust millis before Calcite solve TimeZone perfectly + val millis = literalValue.asInstanceOf[Long] + val adjustedValue = millis - ctx.tableConfig.getTimeZone.getOffset(millis) + generateNonNullLiteral(literalType, adjustedValue.toString + "L", adjustedValue) + + case InternalTypes.INTERVAL_MONTHS => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + if (decimal.isValidInt) { + generateNonNullLiteral(literalType, decimal.intValue().toString, decimal.intValue()) + } else { + throw new CodeGenException( + s"Decimal '$decimal' can not be converted to interval of months.") + } + + case InternalTypes.INTERVAL_MILLIS => + val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) + if (decimal.isValidLong) { + generateNonNullLiteral( + literalType, + decimal.longValue().toString + "L", + decimal.longValue()) + } else { + throw new CodeGenException( + s"Decimal '$decimal' can not be converted to interval of milliseconds.") + } + + // Symbol type for special flags e.g. TRIM's BOTH, LEADING, TRAILING + case symbol: GenericType[_] if symbol.getTypeClass.isAssignableFrom(classOf[Enum[_]]) => + generateSymbol(literalValue.asInstanceOf[Enum[_]]) + + case t@_ => + throw new CodeGenException(s"Type not supported: $t") + } + } + + def generateSymbol(enum: Enum[_]): GeneratedExpression = { + GeneratedExpression( + qualifyEnum(enum), + NEVER_NULL, + NO_CODE, + new GenericType(enum.getDeclaringClass), + literalValue = Some(enum)) + } + + /** + * Generates access to a non-null field that does not require unboxing logic. + * + * @param fieldType type of field + * @param fieldTerm expression term of field (already unboxed) + * @return internal unboxed field representation + */ + private[flink] def generateNonNullField( + fieldType: InternalType, + fieldTerm: String) + : GeneratedExpression = { + val resultTypeTerm = primitiveTypeTermForType(fieldType) + GeneratedExpression(s"(($resultTypeTerm) $fieldTerm)", NEVER_NULL, NO_CODE, fieldType) + } + + def generateProctimeTimestamp( + ctx: CodeGeneratorContext, + contextTerm: String): GeneratedExpression = { + val resultTerm = ctx.addReusableLocalVariable("long", "result") + val resultCode = + s""" + |$resultTerm = $contextTerm.timerService().currentProcessingTime(); + |""".stripMargin.trim + // the proctime has been materialized, so it's TIMESTAMP now, not PROCTIME_INDICATOR + GeneratedExpression(resultTerm, NEVER_NULL, resultCode, InternalTypes.TIMESTAMP) + } + + def generateCurrentTimestamp( + ctx: CodeGeneratorContext): GeneratedExpression = { + new CurrentTimePointCallGen(false).generate(ctx, Seq(), InternalTypes.TIMESTAMP) + } + + def generateRowtimeAccess( + ctx: CodeGeneratorContext, + contextTerm: String): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + ("Long", "result"), + ("boolean", "isNull")) + + val accessCode = + s""" + |$resultTerm = $contextTerm.timestamp(); + |if ($resultTerm == null) { + | throw new RuntimeException("Rowtime timestamp is null. Please make sure that a " + + | "proper TimestampAssigner is defined and the stream environment uses the EventTime " + + | "time characteristic."); + |} + |$nullTerm = false; + """.stripMargin.trim + + GeneratedExpression(resultTerm, nullTerm, accessCode, InternalTypes.ROWTIME_INDICATOR) + } + + /** + * Generates access to a field of the input. + * @param ctx code generator context which maintains various code statements. + * @param inputType input type + * @param inputTerm input term + * @param index the field index to access + * @param nullableInput whether the input is nullable + * @param deepCopy whether to copy the accessed field (usually needed when buffered) + */ + def generateInputAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int, + nullableInput: Boolean, + deepCopy: Boolean = false): GeneratedExpression = { + // if input has been used before, we can reuse the code that + // has already been generated + val inputExpr = ctx.getReusableInputUnboxingExprs(inputTerm, index) match { + // input access and unboxing has already been generated + case Some(expr) => expr + + // generate input access and unboxing if necessary + case None => + val expr = if (nullableInput) { + generateNullableInputFieldAccess(ctx, inputType, inputTerm, index, deepCopy) + } else { + generateFieldAccess(ctx, inputType, inputTerm, index, deepCopy) + } + + ctx.addReusableInputUnboxingExprs(inputTerm, index, expr) + expr + } + // hide the generated code as it will be executed only once + GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType) + } + + def generateNullableInputFieldAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int, + deepCopy: Boolean = false): GeneratedExpression = { + + val fieldType = inputType match { + case ct: RowType => ct.getFieldTypes()(index) + case _ => inputType + } + val resultTypeTerm = primitiveTypeTermForType(fieldType) + val defaultValue = primitiveDefaultValue(fieldType) + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "result"), + ("boolean", "isNull")) + + val fieldAccessExpr = generateFieldAccess( + ctx, inputType, inputTerm, index, deepCopy) + + val inputCheckCode = + s""" + |$resultTerm = $defaultValue; + |$nullTerm = true; + |if ($inputTerm != null) { + | ${fieldAccessExpr.code} + | $resultTerm = ${fieldAccessExpr.resultTerm}; + | $nullTerm = ${fieldAccessExpr.nullTerm}; + |} + |""".stripMargin.trim + + GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType) + } + + /** + * Converts the external boxed format to an internal mostly primitive field representation. + * Wrapper types can autoboxed to their corresponding primitive type (Integer -> int). + * + * @param ctx code generator context which maintains various code statements. + * @param fieldType type of field + * @param fieldTerm expression term of field to be unboxed + * @return internal unboxed field representation + */ + def generateInputFieldUnboxing( + ctx: CodeGeneratorContext, + fieldType: InternalType, + fieldTerm: String): GeneratedExpression = { + + val resultTypeTerm = primitiveTypeTermForType(fieldType) + val defaultValue = primitiveDefaultValue(fieldType) + + val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "result"), + ("boolean", "isNull")) + + val wrappedCode = if (ctx.nullCheck) { + s""" + |$nullTerm = $fieldTerm == null; + |$resultTerm = $defaultValue; + |if (!$nullTerm) { + | $resultTerm = $fieldTerm; + |} + |""".stripMargin.trim + } else { + s""" + |$resultTerm = $fieldTerm; + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType) + } + + /** + * Generates field access code expression. The different between this method and + * [[generateFieldAccess(ctx, inputType, inputTerm, index)]] is that this method + * accepts an additional `deepCopy` parameter. When deepCopy is set to true, the returned + * result will be copied. + * + * NOTE: Please set `deepCopy` to true when the result will be buffered. + */ + def generateFieldAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int, + deepCopy: Boolean): GeneratedExpression = { + val expr = generateFieldAccess(ctx, inputType, inputTerm, index) + if (deepCopy) { + expr.deepCopy(ctx) + } else { + expr + } + } + + def generateFieldAccess( + ctx: CodeGeneratorContext, + inputType: InternalType, + inputTerm: String, + index: Int): GeneratedExpression = + inputType match { + case ct: RowType => + val fieldType = ct.getFieldTypes()(index) + val resultTypeTerm = primitiveTypeTermForType(fieldType) + val defaultValue = primitiveDefaultValue(fieldType) + val readCode = baseRowFieldReadAccess(ctx, index.toString, inputTerm, fieldType) + val Seq(fieldTerm, nullTerm) = ctx.addReusableLocalVariables( + (resultTypeTerm, "field"), + ("boolean", "isNull")) + + val inputCode = if (ctx.nullCheck) { + s""" + |$nullTerm = $inputTerm.isNullAt($index); + |$fieldTerm = $defaultValue; + |if (!$nullTerm) { + | $fieldTerm = $readCode; + |} + """.stripMargin.trim + } else { + s""" + |$nullTerm = false; + |$fieldTerm = $readCode; + """.stripMargin + } + GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType) + + case _ => + val fieldTypeTerm = boxedTypeTermForType(inputType) + val inputCode = s"($fieldTypeTerm) $inputTerm" + generateInputFieldUnboxing(ctx, inputType, inputCode) + } + + + /** + * Generates code for comparing two field. + */ + def generateCompare( + ctx: CodeGeneratorContext, + t: InternalType, + nullsIsLast: Boolean, + leftTerm: String, + rightTerm: String): String = t match { + case InternalTypes.BOOLEAN => s"($leftTerm == $rightTerm ? 0 : ($leftTerm ? 1 : -1))" + case _: PrimitiveType | _: DateType | _: TimeType | _: TimestampType => + s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)" + case InternalTypes.BINARY => + val sortUtil = classOf[org.apache.flink.table.runtime.sort.SortUtil].getCanonicalName + s"$sortUtil.compareBinary($leftTerm, $rightTerm)" + case at: ArrayType => + val compareFunc = newName("compareArray") + val compareCode = generateArrayCompare( + ctx, + FlinkPlannerImpl.getNullDefaultOrder(true), at, "a", "b") + val funcCode: String = + s""" + public int $compareFunc($BINARY_ARRAY a, $BINARY_ARRAY b) { + $compareCode + return 0; + } + """ + ctx.addReusableMember(funcCode) + s"$compareFunc($leftTerm, $rightTerm)" + case rowType: RowType => + val orders = rowType.getFieldTypes.map(_ => true) + val comparisons = generateRowCompare( + ctx, + rowType.getFieldTypes.indices.toArray, + rowType.getFieldTypes, + orders, + FlinkPlannerImpl.getNullDefaultOrders(orders), + "a", + "b") + val compareFunc = newName("compareRow") + val funcCode: String = + s""" + public int $compareFunc($BASE_ROW a, $BASE_ROW b) { + $comparisons + return 0; + } + """ + ctx.addReusableMember(funcCode) + s"$compareFunc($leftTerm, $rightTerm)" + case gt: GenericType[_] => + val ser = ctx.addReusableObject(gt.getSerializer, "serializer") + val comp = ctx.addReusableObject( + gt.getTypeInfo.asInstanceOf[AtomicTypeInfo[_]].createComparator(true, new ExecutionConfig), + "comparator") + s""" + |$comp.compare( + | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($leftTerm, $ser), + | $BINARY_GENERIC.getJavaObjectFromBinaryGeneric($rightTerm, $ser) + |) + """.stripMargin + case other if other.isInstanceOf[AtomicType] => s"$leftTerm.compareTo($rightTerm)" + } + + /** + * Generates code for comparing array. + */ + def generateArrayCompare( + ctx: CodeGeneratorContext, + nullsIsLast: Boolean, + arrayType: ArrayType, + leftTerm: String, + rightTerm: String) + : String = { + val nullIsLastRet = if (nullsIsLast) 1 else -1 + val elementType = arrayType.getElementType + val fieldA = newName("fieldA") + val isNullA = newName("isNullA") + val lengthA = newName("lengthA") + val fieldB = newName("fieldB") + val isNullB = newName("isNullB") + val lengthB = newName("lengthB") + val minLength = newName("minLength") + val i = newName("i") + val comp = newName("comp") + val typeTerm = primitiveTypeTermForType(elementType) + s""" + int $lengthA = a.numElements(); + int $lengthB = b.numElements(); + int $minLength = ($lengthA > $lengthB) ? $lengthB : $lengthA; + for (int $i = 0; $i < $minLength; $i++) { + boolean $isNullA = a.isNullAt($i); + boolean $isNullB = b.isNullAt($i); + if ($isNullA && $isNullB) { + // Continue to compare the next element + } else if ($isNullA) { + return $nullIsLastRet; + } else if ($isNullB) { + return ${-nullIsLastRet}; + } else { + $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, i, leftTerm, elementType)}; + $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, i, rightTerm, elementType)}; + int $comp = ${generateCompare(ctx, elementType, nullsIsLast, fieldA, fieldB)}; + if ($comp != 0) { + return $comp; + } + } + } + + if ($lengthA < $lengthB) { + return -1; + } else if ($lengthA > $lengthB) { + return 1; + } + """ + } + + /** + * Generates code for comparing row keys. + */ + def generateRowCompare( + ctx: CodeGeneratorContext, + keys: Array[Int], + keyTypes: Array[InternalType], + orders: Array[Boolean], + nullsIsLast: Array[Boolean], + leftTerm: String, + rightTerm: String): String = { + + val compares = new mutable.ArrayBuffer[String] + + for (i <- keys.indices) { + val index = keys(i) + + val symbol = if (orders(i)) "" else "-" + + val nullIsLastRet = if (nullsIsLast(i)) 1 else -1 + + val t = keyTypes(i) + + val typeTerm = primitiveTypeTermForType(t) + val fieldA = newName("fieldA") + val isNullA = newName("isNullA") + val fieldB = newName("fieldB") + val isNullB = newName("isNullB") + val comp = newName("comp") + + val code = + s""" + |boolean $isNullA = $leftTerm.isNullAt($index); + |boolean $isNullB = $rightTerm.isNullAt($index); + |if ($isNullA && $isNullB) { + | // Continue to compare the next element + |} else if ($isNullA) { + | return $nullIsLastRet; + |} else if ($isNullB) { + | return ${-nullIsLastRet}; + |} else { + | $typeTerm $fieldA = ${baseRowFieldReadAccess(ctx, index, leftTerm, t)}; + | $typeTerm $fieldB = ${baseRowFieldReadAccess(ctx, index, rightTerm, t)}; + | int $comp = ${generateCompare(ctx, t, nullsIsLast(i), fieldA, fieldB)}; + | if ($comp != 0) { + | return $symbol$comp; + | } + |} + """.stripMargin + compares += code + } + compares.mkString + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala index 35a3999ad9dc5..f8285c4b5abdc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/GeneratedExpression.scala @@ -19,6 +19,8 @@ package org.apache.flink.table.codegen import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForType, newName} +import org.apache.flink.table.typeutils.TypeCheckUtils /** * Describes a generated expression. @@ -27,16 +29,51 @@ import org.apache.flink.table.`type`.InternalType * @param nullTerm boolean term that indicates if expression is null * @param code code necessary to produce resultTerm and nullTerm * @param resultType type of the resultTerm - * @param literal flag to indicate a constant expression do not reference input and can thus - * be used in the member area (e.g. as constructor parameter of a reusable - * instance) + * @param literalValue None if the expression is not literal. Otherwise it represent the + * original object of the literal. */ case class GeneratedExpression( resultTerm: String, nullTerm: String, code: String, resultType: InternalType, - literal: Boolean = false) + literalValue: Option[Any] = None) { + + /** + * Indicates a constant expression do not reference input and can thus be used + * in the member area (e.g. as constructor parameter of a reusable instance) + * + * @return true if the expression is literal + */ + def literal: Boolean = literalValue.isDefined + + /** + * Deep copy the generated expression. + * + * NOTE: Please use this method when the result will be buffered. + * This method makes sure a new object/data is created when the type is mutable. + */ + def deepCopy(ctx: CodeGeneratorContext): GeneratedExpression = { + // only copy when type is mutable + if (TypeCheckUtils.isMutable(resultType)) { + val newResultTerm = newName("field") + // if the type need copy, it must be a boxed type + val typeTerm = boxedTypeTermForType(resultType) + val serTerm = ctx.addReusableTypeSerializer(resultType) + val newCode = + s""" + |$code + |$typeTerm $newResultTerm = $resultTerm; + |if (!$nullTerm) { + | $newResultTerm = ($typeTerm) ($serTerm.copy($newResultTerm)); + |} + """.stripMargin + GeneratedExpression(newResultTerm, nullTerm, newCode, resultType, literalValue) + } else { + this + } + } +} object GeneratedExpression { val ALWAYS_NULL = "true" diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala index 2e3df9577bf0b..0f6b2d230c09d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SortCodeGenerator.scala @@ -383,7 +383,7 @@ class SortCodeGenerator( val baseClass = classOf[RecordComparator] val ctx = new CodeGeneratorContext(conf) - val compareCode = CodeGenUtils.genRowCompare( + val compareCode = GenerateUtils.generateRowCompare( ctx, keys, keyTypes, orders, nullsIsLast, "o1", "o2") val code = diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala new file mode 100644 index 0000000000000..ce181ab44b57f --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/BuiltInMethods.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.calcite.linq4j.tree.Types +import org.apache.flink.table.runtime.functions.DateTimeUtils + +import java.util.TimeZone + +object BuiltInMethods { + + val STRING_TO_TIMESTAMP = Types.lookupMethod( + classOf[DateTimeUtils], + "strToTimestamp", + classOf[String], classOf[TimeZone]) + + val UNIX_TIME_TO_STRING = Types.lookupMethod( + classOf[DateTimeUtils], + "timeToString", + classOf[Int]) + + val TIMESTAMP_TO_STRING = Types.lookupMethod( + classOf[DateTimeUtils], + "timestampToString", + classOf[Long], classOf[Int], classOf[TimeZone]) +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala new file mode 100644 index 0000000000000..a35bc5d798c9c --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CallGenerator.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression} + +/** + * Generator to generate a call expression. It is usually used when the generation + * depends on some other parameters or the generation is too complex to be a util method. + */ +trait CallGenerator { + + def generate( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + returnType: InternalType): GeneratedExpression + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ConstantCallGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ConstantCallGen.scala new file mode 100644 index 0000000000000..d51674a66236a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ConstantCallGen.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`.InternalType +import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.codegen.GenerateUtils.generateNonNullLiteral + +/** + * Generates a function call which returns a constant. + */ +class ConstantCallGen(constantCode: String, constantValue: Any) extends CallGenerator { + + override def generate( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + returnType: InternalType): GeneratedExpression = { + generateNonNullLiteral(returnType, constantCode, constantValue) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala new file mode 100644 index 0000000000000..6b81d9de92401 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/CurrentTimePointCallGen.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`.{InternalType, InternalTypes} +import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.codegen.GenerateUtils.generateNonNullField + +/** + * Generates function call to determine current time point (as date/time/timestamp) in + * local timezone or not. + */ +class CurrentTimePointCallGen(local: Boolean) extends CallGenerator { + + override def generate( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + returnType: InternalType): GeneratedExpression = returnType match { + case InternalTypes.TIME if local => + val time = ctx.addReusableLocalTime() + generateNonNullField(returnType, time) + + case InternalTypes.TIMESTAMP if local => + val timestamp = ctx.addReusableLocalTimestamp() + generateNonNullField(returnType, timestamp) + + case InternalTypes.DATE => + val date = ctx.addReusableDate() + generateNonNullField(returnType, date) + + case InternalTypes.TIME => + val time = ctx.addReusableTime() + generateNonNullField(returnType, time) + + case InternalTypes.TIMESTAMP => + val timestamp = ctx.addReusableTimestamp() + generateNonNullField(returnType, timestamp) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperatorGens.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperatorGens.scala new file mode 100644 index 0000000000000..b9b5ae07591d3 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperatorGens.scala @@ -0,0 +1,2017 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen.calls + +import org.apache.flink.table.`type`._ +import org.apache.flink.table.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull, binaryWriterWriteField, binaryWriterWriteNull, _} +import org.apache.flink.table.codegen.GenerateUtils._ +import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE, ALWAYS_NULL} +import org.apache.flink.table.codegen.{CodeGenException, CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.dataformat._ +import org.apache.flink.table.typeutils.TypeCheckUtils._ +import org.apache.flink.table.typeutils.{TypeCheckUtils, TypeCoercion} +import org.apache.flink.util.Preconditions.checkArgument + +import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY +import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange} +import org.apache.calcite.util.BuiltInMethod + +import java.lang.{StringBuilder => JStringBuilder} +import java.nio.charset.StandardCharsets + +/** + * Utilities to generate SQL scalar operators, e.g. arithmetic operator, + * compare operator, equal operator, etc. + */ +object ScalarOperatorGens { + + // ---------------------------------------------------------------------------------------- + // scalar operators generate utils + // ---------------------------------------------------------------------------------------- + + /** + * Generates a binary arithmetic operator, e.g. + - * / % + */ + def generateBinaryArithmeticOperator( + ctx: CodeGeneratorContext, + operator: String, + resultType: InternalType, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + + resultType match { + case dt: DecimalType => + return generateDecimalBinaryArithmeticOperator(ctx, operator, dt, left, right) + case _ => + } + + val leftCasting = operator match { + case "%" => + if (left.resultType == right.resultType) { + numericCasting(left.resultType, resultType) + } else { + val castedType = if (isDecimal(left.resultType)) { + InternalTypes.LONG + } else { + left.resultType + } + numericCasting(left.resultType, castedType) + } + case _ => numericCasting(left.resultType, resultType) + } + + val rightCasting = numericCasting(right.resultType, resultType) + val resultTypeTerm = primitiveTypeTermForType(resultType) + + generateOperatorIfNotNull(ctx, resultType, left, right) { + (leftTerm, rightTerm) => + s"($resultTypeTerm) (${leftCasting(leftTerm)} $operator ${rightCasting(rightTerm)})" + } + } + + /** + * Generates a binary arithmetic operator for Decimal, e.g. + - * / % + */ + private def generateDecimalBinaryArithmeticOperator( + ctx: CodeGeneratorContext, + operator: String, + resultType: DecimalType, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + + // do not cast a decimal operand to resultType, which may change its value. + // use it as is during calculation. + def castToDec(t: InternalType): String => String = t match { + case _: DecimalType => (operandTerm: String) => s"$operandTerm" + case _ => numericCasting(t, resultType) + } + val methods = Map( + "+" -> "add", + "-" -> "subtract", + "*" -> "multiply", + "/" -> "divide", + "%" -> "mod") + + generateOperatorIfNotNull(ctx, resultType, left, right) { + (leftTerm, rightTerm) => { + val method = methods(operator) + val leftCasted = castToDec(left.resultType)(leftTerm) + val rightCasted = castToDec(right.resultType)(rightTerm) + val precision = resultType.precision() + val scale = resultType.scale() + s"$DECIMAL.$method($leftCasted, $rightCasted, $precision, $scale)" + } + } + } + + /** + * Generates an unary arithmetic operator, e.g. -num + */ + def generateUnaryArithmeticOperator( + ctx: CodeGeneratorContext, + operator: String, + resultType: InternalType, + operand: GeneratedExpression) + : GeneratedExpression = { + generateUnaryOperatorIfNotNull(ctx, resultType, operand) { + operandTerm => + if (isDecimal(operand.resultType) && operator == "-") { + s"$operandTerm.negate()" + } else if (isDecimal(operand.resultType) && operator == "+") { + s"$operandTerm" + } else { + s"$operator($operandTerm)" + } + } + } + + + def generateTemporalPlusMinus( + ctx: CodeGeneratorContext, + plus: Boolean, + resultType: InternalType, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + + val op = if (plus) "+" else "-" + + (left.resultType, right.resultType) match { + // arithmetic of time point and time interval + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.INTERVAL_MONTHS) | + (InternalTypes.INTERVAL_MILLIS, InternalTypes.INTERVAL_MILLIS) => + generateBinaryArithmeticOperator(ctx, op, left.resultType, left, right) + + case (InternalTypes.DATE, InternalTypes.INTERVAL_MILLIS) => + generateOperatorIfNotNull(ctx, InternalTypes.DATE, left, right) { + (l, r) => s"$l $op ((int) ($r / ${MILLIS_PER_DAY}L))" + } + + case (InternalTypes.DATE, InternalTypes.INTERVAL_MONTHS) => + generateOperatorIfNotNull(ctx, InternalTypes.DATE, left, right) { + (l, r) => s"${qualifyMethod(BuiltInMethod.ADD_MONTHS.method)}($l, $op($r))" + } + + case (InternalTypes.TIME, InternalTypes.INTERVAL_MILLIS) => + generateOperatorIfNotNull(ctx, InternalTypes.TIME, left, right) { + (l, r) => s"$l $op ((int) ($r))" + } + + case (InternalTypes.TIMESTAMP, InternalTypes.INTERVAL_MILLIS) => + generateOperatorIfNotNull(ctx, InternalTypes.TIMESTAMP, left, right) { + (l, r) => s"$l $op $r" + } + + case (InternalTypes.TIMESTAMP, InternalTypes.INTERVAL_MONTHS) => + generateOperatorIfNotNull(ctx, InternalTypes.TIMESTAMP, left, right) { + (l, r) => s"${qualifyMethod(BuiltInMethod.ADD_MONTHS.method)}($l, $op($r))" + } + + // minus arithmetic of time points (i.e. for TIMESTAMPDIFF) + case (InternalTypes.TIMESTAMP | InternalTypes.TIME | InternalTypes.DATE, + InternalTypes.TIMESTAMP | InternalTypes.TIME | InternalTypes.DATE) if !plus => + resultType match { + case InternalTypes.INTERVAL_MONTHS => + generateOperatorIfNotNull(ctx, resultType, left, right) { + (ll, rr) => (left.resultType, right.resultType) match { + case (InternalTypes.TIMESTAMP, InternalTypes.DATE) => + s"${qualifyMethod(BuiltInMethod.SUBTRACT_MONTHS.method)}" + + s"($ll, $rr * ${MILLIS_PER_DAY}L)" + case (InternalTypes.DATE, InternalTypes.TIMESTAMP) => + s"${qualifyMethod(BuiltInMethod.SUBTRACT_MONTHS.method)}" + + s"($ll * ${MILLIS_PER_DAY}L, $rr)" + case _ => + s"${qualifyMethod(BuiltInMethod.SUBTRACT_MONTHS.method)}($ll, $rr)" + } + } + + case InternalTypes.INTERVAL_MILLIS => + generateOperatorIfNotNull(ctx, resultType, left, right) { + (ll, rr) => (left.resultType, right.resultType) match { + case (InternalTypes.TIMESTAMP, InternalTypes.TIMESTAMP) => + s"$ll $op $rr" + case (InternalTypes.DATE, InternalTypes.DATE) => + s"($ll * ${MILLIS_PER_DAY}L) $op ($rr * ${MILLIS_PER_DAY}L)" + case (InternalTypes.TIMESTAMP, InternalTypes.DATE) => + s"$ll $op ($rr * ${MILLIS_PER_DAY}L)" + case (InternalTypes.DATE, InternalTypes.TIMESTAMP) => + s"($ll * ${MILLIS_PER_DAY}L) $op $rr" + } + } + } + + case _ => + throw new CodeGenException("Unsupported temporal arithmetic.") + } + } + + def generateUnaryIntervalPlusMinus( + ctx: CodeGeneratorContext, + plus: Boolean, + operand: GeneratedExpression) + : GeneratedExpression = { + val operator = if (plus) "+" else "-" + generateUnaryArithmeticOperator(ctx, operator, operand.resultType, operand) + } + + // ---------------------------------------------------------------------------------------- + // scalar expression generate utils + // ---------------------------------------------------------------------------------------- + + /** + * Generates IN expression using a HashSet + */ + def generateIn( + ctx: CodeGeneratorContext, + needle: GeneratedExpression, + haystack: Seq[GeneratedExpression]) + : GeneratedExpression = { + + // add elements to hash set if they are constant + if (haystack.forall(_.literal)) { + + // determine common numeric type + val widerType = TypeCoercion.widerTypeOf( + needle.resultType, + haystack.head.resultType) + + // we need to normalize the values for the hash set + val castNumeric = widerType match { + case Some(t) => (value: GeneratedExpression) => + numericCasting(value.resultType, t)(value.resultTerm) + case None => (value: GeneratedExpression) => value.resultTerm + } + + val resultType = widerType match { + case Some(t) => t + case None => needle.resultType + } + + val elements = haystack.map { element => + element.copy( + castNumeric(element), // cast element to wider type + element.nullTerm, + element.code, + resultType) + } + val setTerm = ctx.addReusableHashSet(elements, resultType) + + val castedNeedle = needle.copy( + castNumeric(needle), // cast needle to wider type + needle.nullTerm, + needle.code, + resultType) + + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultTypeTerm = primitiveTypeTermForType(InternalTypes.BOOLEAN) + val defaultValue = primitiveDefaultValue(InternalTypes.BOOLEAN) + + val operatorCode = if (ctx.nullCheck) { + s""" + |${castedNeedle.code} + |$resultTypeTerm $resultTerm = $defaultValue; + |boolean $nullTerm = true; + |if (!${castedNeedle.nullTerm}) { + | $resultTerm = $setTerm.contains(${castedNeedle.resultTerm}); + | $nullTerm = !$resultTerm && $setTerm.containsNull(); + |} + |""".stripMargin.trim + } + else { + s""" + |${castedNeedle.code} + |$resultTypeTerm $resultTerm = $setTerm.contains(${castedNeedle.resultTerm}); + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, InternalTypes.BOOLEAN) + } else { + // we use a chain of ORs for a set that contains non-constant elements + haystack + .map(generateEquals(ctx, needle, _)) + .reduce((left, right) => + generateOr(ctx, left, right) + ) + } + } + + def generateEquals( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + if (left.resultType == InternalTypes.STRING && right.resultType == InternalTypes.STRING) { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + (leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" + } + } + // numeric types + else if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + generateComparison(ctx, "==", left, right) + } + // temporal types + else if (isTemporal(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "==", left, right) + } + // array types + else if (isArray(left.resultType) && left.resultType == right.resultType) { + generateArrayComparison(ctx, left, right) + } + // map types + else if (isMap(left.resultType) && left.resultType == right.resultType) { + generateMapComparison(ctx, left, right) + } + // comparable types of same type + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "==", left, right) + } + // support date/time/timestamp equalTo string. + // for performance, we cast literal string to literal time. + else if (isTimePoint(left.resultType) && right.resultType == InternalTypes.STRING) { + if (right.literal) { + generateEquals(ctx, left, generateCastStringLiteralToDateTime(ctx, right, left.resultType)) + } else { + generateEquals(ctx, left, generateCast(ctx, right, left.resultType)) + } + } + else if (isTimePoint(right.resultType) && left.resultType == InternalTypes.STRING) { + if (left.literal) { + generateEquals( + ctx, + generateCastStringLiteralToDateTime(ctx, left, right.resultType), + right) + } else { + generateEquals(ctx, generateCast(ctx, left, right.resultType), right) + } + } + // non comparable types + else { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + if (isReference(left)) { + (leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" + } + else if (isReference(right)) { + (leftTerm, rightTerm) => s"$rightTerm.equals($leftTerm)" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } + } + } + } + + def generateNotEquals( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + if (left.resultType == InternalTypes.STRING && right.resultType == InternalTypes.STRING) { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + (leftTerm, rightTerm) => s"!$leftTerm.equals($rightTerm)" + } + } + // numeric types + else if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + generateComparison(ctx, "!=", left, right) + } + // temporal types + else if (isTemporal(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "!=", left, right) + } + // array types + else if (isArray(left.resultType) && left.resultType == right.resultType) { + val equalsExpr = generateEquals(ctx, left, right) + GeneratedExpression( + s"(!${equalsExpr.resultTerm})", equalsExpr.nullTerm, equalsExpr.code, InternalTypes.BOOLEAN) + } + // map types + else if (isMap(left.resultType) && left.resultType == right.resultType) { + val equalsExpr = generateEquals(ctx, left, right) + GeneratedExpression( + s"(!${equalsExpr.resultTerm})", equalsExpr.nullTerm, equalsExpr.code, InternalTypes.BOOLEAN) + } + // comparable types + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + generateComparison(ctx, "!=", left, right) + } + // non-comparable types + else { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + if (isReference(left)) { + (leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))" + } + else if (isReference(right)) { + (leftTerm, rightTerm) => s"!($rightTerm.equals($leftTerm))" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } + } + } + } + + /** + * Generates comparison code for numeric types and comparable types of same type. + */ + def generateComparison( + ctx: CodeGeneratorContext, + operator: String, + left: GeneratedExpression, + right: GeneratedExpression) + : GeneratedExpression = { + generateOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, left, right) { + // either side is decimal + if (isDecimal(left.resultType) || isDecimal(right.resultType)) { + (leftTerm, rightTerm) => { + s"${className[Decimal]}.compare($leftTerm, $rightTerm) $operator 0" + } + } + // both sides are numeric + else if (isNumeric(left.resultType) && isNumeric(right.resultType)) { + (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + } + // both sides are temporal of same type + else if (isTemporal(left.resultType) && left.resultType == right.resultType) { + (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + } + // both sides are boolean + else if (isBoolean(left.resultType) && left.resultType == right.resultType) { + operator match { + case "==" | "!=" => (leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm" + case ">" | "<" | "<=" | ">=" => + (leftTerm, rightTerm) => + s"java.lang.Boolean.compare($leftTerm, $rightTerm) $operator 0" + case _ => throw new CodeGenException(s"Unsupported boolean comparison '$operator'.") + } + } + // both sides are binary type + else if (isBinary(left.resultType) && left.resultType == right.resultType) { + (leftTerm, rightTerm) => + s"java.util.Arrays.equals($leftTerm, $rightTerm)" + } + // both sides are same comparable type + else if (isComparable(left.resultType) && left.resultType == right.resultType) { + (leftTerm, rightTerm) => + s"(($leftTerm == null) ? (($rightTerm == null) ? 0 : -1) : (($rightTerm == null) ? " + + s"1 : ($leftTerm.compareTo($rightTerm)))) $operator 0" + } + else { + throw new CodeGenException(s"Incomparable types: ${left.resultType} and " + + s"${right.resultType}") + } + } + } + + def generateIsNull( + ctx: CodeGeneratorContext, + operand: GeneratedExpression): GeneratedExpression = { + if (ctx.nullCheck) { + GeneratedExpression(operand.nullTerm, NEVER_NULL, operand.code, InternalTypes.BOOLEAN) + } + else if (!ctx.nullCheck && isReference(operand)) { + val resultTerm = newName("isNull") + val operatorCode = + s""" + |${operand.code} + |boolean $resultTerm = ${operand.resultTerm} == null; + |""".stripMargin + GeneratedExpression(resultTerm, NEVER_NULL, operatorCode, InternalTypes.BOOLEAN) + } + else { + GeneratedExpression("false", NEVER_NULL, operand.code, InternalTypes.BOOLEAN) + } + } + + def generateIsNotNull( + ctx: CodeGeneratorContext, + operand: GeneratedExpression): GeneratedExpression = { + if (ctx.nullCheck) { + val resultTerm = newName("result") + val operatorCode = + s""" + |${operand.code} + |boolean $resultTerm = !${operand.nullTerm}; + |""".stripMargin.trim + GeneratedExpression(resultTerm, NEVER_NULL, operatorCode, InternalTypes.BOOLEAN) + } + else if (!ctx.nullCheck && isReference(operand)) { + val resultTerm = newName("result") + val operatorCode = + s""" + |${operand.code} + |boolean $resultTerm = ${operand.resultTerm} != null; + |""".stripMargin.trim + GeneratedExpression(resultTerm, NEVER_NULL, operatorCode, InternalTypes.BOOLEAN) + } + else { + GeneratedExpression("true", NEVER_NULL, operand.code, InternalTypes.BOOLEAN) + } + } + + def generateAnd( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + + val operatorCode = if (ctx.nullCheck) { + // Three-valued logic: + // no Unknown -> Two-valued logic + // True && Unknown -> Unknown + // False && Unknown -> False + // Unknown && True -> Unknown + // Unknown && False -> False + // Unknown && Unknown -> Unknown + s""" + |${left.code} + | + |boolean $resultTerm = false; + |boolean $nullTerm = false; + |if (!${left.nullTerm} && !${left.resultTerm}) { + | // left expr is false, skip right expr + |} else { + | ${right.code} + | + | if (!${left.nullTerm} && !${right.nullTerm}) { + | $resultTerm = ${left.resultTerm} && ${right.resultTerm}; + | $nullTerm = false; + | } + | else if (!${left.nullTerm} && ${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else if (!${left.nullTerm} && !${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = false; + | $nullTerm = false; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && ${right.resultTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && !${right.resultTerm}) { + | $resultTerm = false; + | $nullTerm = false; + | } + | else { + | $resultTerm = false; + | $nullTerm = true; + | } + |} + """.stripMargin.trim + } + else { + s""" + |${left.code} + |boolean $resultTerm = false; + |if (${left.resultTerm}) { + | ${right.code} + | $resultTerm = ${right.resultTerm}; + |} + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, InternalTypes.BOOLEAN) + } + + def generateOr( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + + val operatorCode = if (ctx.nullCheck) { + // Three-valued logic: + // no Unknown -> Two-valued logic + // True || Unknown -> True + // False || Unknown -> Unknown + // Unknown || True -> True + // Unknown || False -> Unknown + // Unknown || Unknown -> Unknown + s""" + |${left.code} + | + |boolean $resultTerm = true; + |boolean $nullTerm = false; + |if (!${left.nullTerm} && ${left.resultTerm}) { + | // left expr is true, skip right expr + |} else { + | ${right.code} + | + | if (!${left.nullTerm} && !${right.nullTerm}) { + | $resultTerm = ${left.resultTerm} || ${right.resultTerm}; + | $nullTerm = false; + | } + | else if (!${left.nullTerm} && ${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = true; + | $nullTerm = false; + | } + | else if (!${left.nullTerm} && !${left.resultTerm} && ${right.nullTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && ${right.resultTerm}) { + | $resultTerm = true; + | $nullTerm = false; + | } + | else if (${left.nullTerm} && !${right.nullTerm} && !${right.resultTerm}) { + | $resultTerm = false; + | $nullTerm = true; + | } + | else { + | $resultTerm = false; + | $nullTerm = true; + | } + |} + |""".stripMargin.trim + } + else { + s""" + |${left.code} + |boolean $resultTerm = true; + |if (!${left.resultTerm}) { + | ${right.code} + | $resultTerm = ${right.resultTerm}; + |} + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, InternalTypes.BOOLEAN) + } + + def generateNot( + ctx: CodeGeneratorContext, + operand: GeneratedExpression) + : GeneratedExpression = { + // Three-valued logic: + // no Unknown -> Two-valued logic + // Unknown -> Unknown + generateUnaryOperatorIfNotNull(ctx, InternalTypes.BOOLEAN, operand) { + operandTerm => s"!($operandTerm)" + } + } + + def generateIsTrue(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + operand.resultTerm, // unknown is always false by default + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateIsNotTrue(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + s"(!${operand.resultTerm})", // unknown is always false by default + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateIsFalse(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + s"(!${operand.resultTerm} && !${operand.nullTerm})", + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateIsNotFalse(operand: GeneratedExpression): GeneratedExpression = { + GeneratedExpression( + s"(${operand.resultTerm} || ${operand.nullTerm})", + GeneratedExpression.NEVER_NULL, + operand.code, + InternalTypes.BOOLEAN) + } + + def generateReinterpret( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + targetType: InternalType) + : GeneratedExpression = (operand.resultType, targetType) match { + + case (fromTp, toTp) if fromTp == toTp => + operand + + // internal reinterpretation of temporal types + // Date -> Integer + // Time -> Integer + // Timestamp -> Long + // Integer -> Date + // Integer -> Time + // Long -> Timestamp + // Integer -> Interval Months + // Long -> Interval Millis + // Interval Months -> Integer + // Interval Millis -> Long + // Date -> Long + // Time -> Long + // Interval Months -> Long + case (InternalTypes.DATE, InternalTypes.INT) | + (InternalTypes.TIME, InternalTypes.INT) | + (_: TimestampType, InternalTypes.LONG) | + (InternalTypes.INT, InternalTypes.DATE) | + (InternalTypes.INT, InternalTypes.TIME) | + (InternalTypes.LONG, _: TimestampType) | + (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) | + (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) | + (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) | + (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) | + (InternalTypes.DATE, InternalTypes.LONG) | + (InternalTypes.TIME, InternalTypes.LONG) | + (InternalTypes.INTERVAL_MONTHS, InternalTypes.LONG) => + internalExprCasting(operand, targetType) + + case (from, to) => + throw new CodeGenException(s"Unsupported reinterpret from '$from' to '$to'.") + } + + def generateCast( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + targetType: InternalType) + : GeneratedExpression = (operand.resultType, targetType) match { + + // special case: cast from TimeIndicatorTypeInfo to SqlTimeTypeInfo + case (InternalTypes.PROCTIME_INDICATOR, InternalTypes.TIMESTAMP) | + (InternalTypes.ROWTIME_INDICATOR, InternalTypes.TIMESTAMP) | + (InternalTypes.TIMESTAMP, InternalTypes.PROCTIME_INDICATOR) | + (InternalTypes.TIMESTAMP, InternalTypes.ROWTIME_INDICATOR) => + operand.copy(resultType = InternalTypes.TIMESTAMP) // just replace the DataType + + // identity casting + case (fromTp, toTp) if fromTp == toTp => + operand + + // Date/Time/Timestamp -> String + case (left, InternalTypes.STRING) if TypeCheckUtils.isTimePoint(left) => + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + operandTerm => + val zoneTerm = ctx.addReusableTimeZone() + s"${internalToStringCode(left, operandTerm.head, zoneTerm)}" + } + + // Interval Months -> String + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.STRING) => + val method = qualifyMethod(BuiltInMethod.INTERVAL_YEAR_MONTH_TO_STRING.method) + val timeUnitRange = qualifyEnum(TimeUnitRange.YEAR_TO_MONTH) + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s"$method(${terms.head}, $timeUnitRange)" + } + + // Interval Millis -> String + case (InternalTypes.INTERVAL_MILLIS, InternalTypes.STRING) => + val method = qualifyMethod(BuiltInMethod.INTERVAL_DAY_TIME_TO_STRING.method) + val timeUnitRange = qualifyEnum(TimeUnitRange.DAY_TO_SECOND) + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s"$method(${terms.head}, $timeUnitRange, 3)" // milli second precision + } + + // Array -> String + case (at: ArrayType, InternalTypes.STRING) => + generateCastArrayToString(ctx, operand, at) + + // Byte array -> String UTF-8 + case (InternalTypes.BINARY, InternalTypes.STRING) => + val charset = classOf[StandardCharsets].getCanonicalName + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s"(new String(${terms.head}.toByteArray(), $charset.UTF_8))" + } + + + // Map -> String + case (mt: MapType, InternalTypes.STRING) => + generateCastMapToString(ctx, operand, mt) + + // composite type -> String + case (brt: RowType, InternalTypes.STRING) => + generateCastBaseRowToString(ctx, operand, brt) + + case (g: GenericType[_], InternalTypes.STRING) => + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => + val converter = DataFormatConverters.getConverterForTypeInfo(g.getTypeInfo) + val converterTerm = ctx.addReusableObject(converter, "converter") + s""" "" + $converterTerm.toExternal(${terms.head})""" + } + + // * (not Date/Time/Timestamp) -> String + // TODO: GenericType with Date/Time/Timestamp -> String would call toString implicitly + case (_, InternalTypes.STRING) => + generateStringResultCallIfArgsNotNull(ctx, Seq(operand)) { + terms => s""" "" + ${terms.head}""" + } + + // * -> Character + case (_, InternalTypes.CHAR) => + throw new CodeGenException("Character type not supported.") + + // String -> Boolean + case (InternalTypes.STRING, InternalTypes.BOOLEAN) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => s"$operandTerm.toBooleanSQL()" + } + + // String -> NUMERIC TYPE (not Character) + case (InternalTypes.STRING, _) + if TypeCheckUtils.isNumeric(targetType) => + targetType match { + case dt: DecimalType => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm.toDecimal(${dt.precision}, ${dt.scale})" + } + case _ => + val methodName = targetType match { + case InternalTypes.BYTE => "toByte" + case InternalTypes.SHORT => "toShort" + case InternalTypes.INT => "toInt" + case InternalTypes.LONG => "toLong" + case InternalTypes.DOUBLE => "toDouble" + case InternalTypes.FLOAT => "toFloat" + case _ => null + } + assert(methodName != null, "Unexpected data type.") + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => s"($operandTerm.trim().$methodName())" + } + } + + // String -> Date + case (InternalTypes.STRING, InternalTypes.DATE) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => + s"${qualifyMethod(BuiltInMethod.STRING_TO_DATE.method)}($operandTerm.toString())" + } + + // String -> Time + case (InternalTypes.STRING, InternalTypes.TIME) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => + s"${qualifyMethod(BuiltInMethod.STRING_TO_TIME.method)}($operandTerm.toString())" + } + + // String -> Timestamp + case (InternalTypes.STRING, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull( + ctx, + targetType, + operand, + resultNullable = true) { + operandTerm => + val zoneTerm = ctx.addReusableTimeZone() + s"""${qualifyMethod(BuiltInMethods.STRING_TO_TIMESTAMP)}($operandTerm.toString(), + | $zoneTerm)""".stripMargin + } + + // String -> binary + case (InternalTypes.STRING, InternalTypes.BINARY) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm.getBytes()" + } + + // Note: SQL2003 $6.12 - casting is not allowed between boolean and numeric types. + // Calcite does not allow it either. + + // Boolean -> BigDecimal + case (InternalTypes.BOOLEAN, dt: DecimalType) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castFrom($operandTerm, ${dt.precision}, ${dt.scale})" + } + + // Boolean -> NUMERIC TYPE + case (InternalTypes.BOOLEAN, _) if TypeCheckUtils.isNumeric(targetType) => + val targetTypeTerm = primitiveTypeTermForType(targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"($targetTypeTerm) ($operandTerm ? 1 : 0)" + } + + // BigDecimal -> Boolean + case (_: DecimalType, InternalTypes.BOOLEAN) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castToBoolean($operandTerm)" + } + + // BigDecimal -> Timestamp + case (_: DecimalType, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castToTimestamp($operandTerm)" + } + + // NUMERIC TYPE -> Boolean + case (left, InternalTypes.BOOLEAN) if isNumeric(left) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm != 0" + } + + // between NUMERIC TYPE | Decimal + case (left, right) if isNumeric(left) && isNumeric(right) => + val operandCasting = numericCasting(operand.resultType, targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"${operandCasting(operandTerm)}" + } + + // Date -> Timestamp + case (InternalTypes.DATE, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => + s"$operandTerm * ${classOf[DateTimeUtils].getCanonicalName}.MILLIS_PER_DAY" + } + + // Timestamp -> Date + case (InternalTypes.TIMESTAMP, InternalTypes.DATE) => + val targetTypeTerm = primitiveTypeTermForType(targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => + s"($targetTypeTerm) ($operandTerm / " + + s"${classOf[DateTimeUtils].getCanonicalName}.MILLIS_PER_DAY)" + } + + // Time -> Timestamp + case (InternalTypes.TIME, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$operandTerm" + } + + // Timestamp -> Time + case (InternalTypes.TIMESTAMP, InternalTypes.TIME) => + val targetTypeTerm = primitiveTypeTermForType(targetType) + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => + s"($targetTypeTerm) ($operandTerm % " + + s"${classOf[DateTimeUtils].getCanonicalName}.MILLIS_PER_DAY)" + } + + // Timestamp -> Decimal + case (InternalTypes.TIMESTAMP, dt: DecimalType) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"$DECIMAL.castFrom" + + s"(((double) ($operandTerm / 1000.0)), ${dt.precision}, ${dt.scale})" + } + + // Tinyint -> Timestamp + // Smallint -> Timestamp + // Int -> Timestamp + // Bigint -> Timestamp + case (InternalTypes.BYTE, InternalTypes.TIMESTAMP) | + (InternalTypes.SHORT,InternalTypes.TIMESTAMP) | + (InternalTypes.INT, InternalTypes.TIMESTAMP) | + (InternalTypes.LONG, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"(((long) $operandTerm) * 1000)" + } + + // Float -> Timestamp + // Double -> Timestamp + case (InternalTypes.FLOAT, InternalTypes.TIMESTAMP) | + (InternalTypes.DOUBLE, InternalTypes.TIMESTAMP) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((long) ($operandTerm * 1000))" + } + + // Timestamp -> Tinyint + case (InternalTypes.TIMESTAMP, InternalTypes.BYTE) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((byte) ($operandTerm / 1000))" + } + + // Timestamp -> Smallint + case (InternalTypes.TIMESTAMP, InternalTypes.SHORT) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((short) ($operandTerm / 1000))" + } + + // Timestamp -> Int + case (InternalTypes.TIMESTAMP, InternalTypes.INT) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((int) ($operandTerm / 1000))" + } + + // Timestamp -> BigInt + case (InternalTypes.TIMESTAMP, InternalTypes.LONG) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((long) ($operandTerm / 1000))" + } + + // Timestamp -> Float + case (InternalTypes.TIMESTAMP, InternalTypes.FLOAT) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((float) ($operandTerm / 1000.0))" + } + + // Timestamp -> Double + case (InternalTypes.TIMESTAMP, InternalTypes.DOUBLE) => + generateUnaryOperatorIfNotNull(ctx, targetType, operand) { + operandTerm => s"((double) ($operandTerm / 1000.0))" + } + + // internal temporal casting + // Date -> Integer + // Time -> Integer + // Integer -> Date + // Integer -> Time + // Integer -> Interval Months + // Long -> Interval Millis + // Interval Months -> Integer + // Interval Millis -> Long + case (InternalTypes.DATE, InternalTypes.INT) | + (InternalTypes.TIME, InternalTypes.INT) | + (InternalTypes.INT, InternalTypes.DATE) | + (InternalTypes.INT, InternalTypes.TIME) | + (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) | + (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) | + (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) | + (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) => + internalExprCasting(operand, targetType) + + // internal reinterpretation of temporal types + // Date, Time, Interval Months -> Long + case (InternalTypes.DATE, InternalTypes.LONG) + | (InternalTypes.TIME, InternalTypes.LONG) + | (InternalTypes.INTERVAL_MONTHS, InternalTypes.LONG) => + internalExprCasting(operand, targetType) + + case (from, to) => + throw new CodeGenException(s"Unsupported cast from '$from' to '$to'.") + } + + def generateIfElse( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + resultType: InternalType, + i: Int = 0) + : GeneratedExpression = { + // else part + if (i == operands.size - 1) { + generateCast(ctx, operands(i), resultType) + } + else { + // check that the condition is boolean + // we do not check for null instead we use the default value + // thus null is false + requireBoolean(operands(i)) + val condition = operands(i) + val trueAction = generateCast(ctx, operands(i + 1), resultType) + val falseAction = generateIfElse(ctx, operands, resultType, i + 2) + + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultTypeTerm = primitiveTypeTermForType(resultType) + + val operatorCode = if (ctx.nullCheck) { + s""" + |${condition.code} + |$resultTypeTerm $resultTerm; + |boolean $nullTerm; + |if (${condition.resultTerm}) { + | ${trueAction.code} + | $resultTerm = ${trueAction.resultTerm}; + | $nullTerm = ${trueAction.nullTerm}; + |} + |else { + | ${falseAction.code} + | $resultTerm = ${falseAction.resultTerm}; + | $nullTerm = ${falseAction.nullTerm}; + |} + |""".stripMargin.trim + } + else { + s""" + |${condition.code} + |$resultTypeTerm $resultTerm; + |if (${condition.resultTerm}) { + | ${trueAction.code} + | $resultTerm = ${trueAction.resultTerm}; + |} + |else { + | ${falseAction.code} + | $resultTerm = ${falseAction.resultTerm}; + |} + |""".stripMargin.trim + } + + GeneratedExpression(resultTerm, nullTerm, operatorCode, resultType) + } + } + + def generateDot( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression]): GeneratedExpression = { + + // due to https://issues.apache.org/jira/browse/CALCITE-2162, expression such as + // "array[1].a.b" won't work now. + if (operands.size > 2) { + throw new CodeGenException( + "A DOT operator with more than 2 operands is not supported yet.") + } + + checkArgument(operands(1).literal) + checkArgument(operands(1).resultType == InternalTypes.STRING) + checkArgument(operands.head.resultType.isInstanceOf[RowType]) + + val fieldName = operands(1).literalValue.get.toString + val fieldIdx = operands + .head + .resultType + .asInstanceOf[RowType] + .getFieldIndex(fieldName) + + val access = generateFieldAccess( + ctx, + operands.head.resultType, + operands.head.resultTerm, + fieldIdx) + + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultTypeTerm = primitiveTypeTermForType(access.resultType) + val defaultValue = primitiveDefaultValue(access.resultType) + + val resultCode = if (ctx.nullCheck) { + s""" + |${operands.map(_.code).mkString("\n")} + |$resultTypeTerm $resultTerm; + |boolean $nullTerm; + |if (${operands.map(_.nullTerm).mkString(" || ")}) { + | $resultTerm = $defaultValue; + | $nullTerm = true; + |} + |else { + | ${access.code} + | $resultTerm = ${access.resultTerm}; + | $nullTerm = ${access.nullTerm}; + |} + |""".stripMargin + } else { + s""" + |${operands.map(_.code).mkString("\n")} + |${access.code} + |$resultTypeTerm $resultTerm = ${access.resultTerm}; + |""".stripMargin + } + + + GeneratedExpression( + resultTerm, + nullTerm, + resultCode, + access.resultType + ) + } + + // ---------------------------------------------------------------------------------------- + // value construction and accessing generate utils + // ---------------------------------------------------------------------------------------- + + def generateRow( + ctx: CodeGeneratorContext, + resultType: InternalType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + checkArgument(resultType.isInstanceOf[RowType]) + val rowType = resultType.asInstanceOf[RowType] + val fieldTypes = rowType.getFieldTypes + val isLiteral = elements.forall(e => e.literal) + val isPrimitive = fieldTypes.forall(f => f.isInstanceOf[PrimitiveType]) + + if (isLiteral) { + // generate literal row + generateLiteralRow(ctx, rowType, elements) + } else { + if (isPrimitive) { + // generate primitive row + val mapped = elements.zipWithIndex.map { case (element, idx) => + if (element.literal) { + element + } else { + val tpe = fieldTypes(idx) + val resultTerm = primitiveDefaultValue(tpe) + GeneratedExpression(resultTerm, ALWAYS_NULL, NO_CODE, tpe, Some(null)) + } + } + val row = generateLiteralRow(ctx, rowType, mapped) + val code = elements.zipWithIndex.map { case (element, idx) => + val tpe = fieldTypes(idx) + if (element.literal) { + "" + } else if(ctx.nullCheck) { + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryRowSetNull(idx, row.resultTerm, tpe)}; + |} else { + | ${binaryRowFieldSetAccess(idx, row.resultTerm, tpe, element.resultTerm)}; + |} + """.stripMargin + } else { + s""" + |${element.code} + |${binaryRowFieldSetAccess(idx, row.resultTerm, tpe, element.resultTerm)}; + """.stripMargin + } + }.mkString("\n") + GeneratedExpression(row.resultTerm, NEVER_NULL, code, rowType) + } else { + // generate general row + generateNonLiteralRow(ctx, rowType, elements) + } + } + } + + private def generateLiteralRow( + ctx: CodeGeneratorContext, + rowType: RowType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + checkArgument(elements.forall(e => e.literal)) + val expr = generateNonLiteralRow(ctx, rowType, elements) + ctx.addReusableInitStatement(expr.code) + GeneratedExpression(expr.resultTerm, GeneratedExpression.NEVER_NULL, NO_CODE, rowType) + } + + private def generateNonLiteralRow( + ctx: CodeGeneratorContext, + rowType: RowType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + val rowTerm = newName("row") + val writerTerm = newName("writer") + val writerCls = className[BinaryRowWriter] + + val writeCode = elements.zipWithIndex.map { + case (element, idx) => + val tpe = rowType.getTypeAt(idx) + if (ctx.nullCheck) { + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryWriterWriteNull(idx, writerTerm, tpe)}; + |} else { + | ${binaryWriterWriteField(ctx, idx, element.resultTerm, writerTerm, tpe)}; + |} + """.stripMargin + } else { + s""" + |${element.code} + |${binaryWriterWriteField(ctx, idx, element.resultTerm, writerTerm, tpe)}; + """.stripMargin + } + }.mkString("\n") + + val code = + s""" + |$writerTerm.reset(); + |$writeCode + |$writerTerm.complete(); + """.stripMargin + + ctx.addReusableMember(s"$BINARY_ROW $rowTerm = new $BINARY_ROW(${rowType.getArity});") + ctx.addReusableMember(s"$writerCls $writerTerm = new $writerCls($rowTerm);") + GeneratedExpression(rowTerm, GeneratedExpression.NEVER_NULL, code, rowType) + } + + def generateArray( + ctx: CodeGeneratorContext, + resultType: InternalType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + checkArgument(resultType.isInstanceOf[ArrayType]) + val arrayType = resultType.asInstanceOf[ArrayType] + val elementType = arrayType.getElementType + val isLiteral = elements.forall(e => e.literal) + val isPrimitive = elementType.isInstanceOf[PrimitiveType] + + if (isLiteral) { + // generate literal array + generateLiteralArray(ctx, arrayType, elements) + } else { + if (isPrimitive) { + // generate primitive array + val mapped = elements.map { element => + if (element.literal) { + element + } else { + val resultTerm = primitiveDefaultValue(elementType) + GeneratedExpression(resultTerm, ALWAYS_NULL, NO_CODE, elementType, Some(null)) + } + } + val array = generateLiteralArray(ctx, arrayType, mapped) + val code = generatePrimitiveArrayUpdateCode(ctx, array.resultTerm, elementType, elements) + GeneratedExpression(array.resultTerm, GeneratedExpression.NEVER_NULL, code, arrayType) + } else { + // generate general array + generateNonLiteralArray(ctx, arrayType, elements) + } + } + } + + private def generatePrimitiveArrayUpdateCode( + ctx: CodeGeneratorContext, + arrayTerm: String, + elementType: InternalType, + elements: Seq[GeneratedExpression]): String = { + elements.zipWithIndex.map { case (element, idx) => + if (element.literal) { + "" + } else if (ctx.nullCheck) { + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryArraySetNull(idx, arrayTerm, elementType)}; + |} else { + | ${binaryRowFieldSetAccess( + idx, arrayTerm, elementType, element.resultTerm)}; + |} + """.stripMargin + } else { + s""" + |${element.code} + |${binaryRowFieldSetAccess( + idx, arrayTerm, elementType, element.resultTerm)}; + """.stripMargin + } + }.mkString("\n") + } + + private def generateLiteralArray( + ctx: CodeGeneratorContext, + arrayType: ArrayType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + checkArgument(elements.forall(e => e.literal)) + val expr = generateNonLiteralArray(ctx, arrayType, elements) + ctx.addReusableInitStatement(expr.code) + GeneratedExpression(expr.resultTerm, GeneratedExpression.NEVER_NULL, NO_CODE, arrayType) + } + + private def generateNonLiteralArray( + ctx: CodeGeneratorContext, + arrayType: ArrayType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + val elementType = arrayType.getElementType + val arrayTerm = newName("array") + val writerTerm = newName("writer") + val writerCls = className[BinaryArrayWriter] + val elementSize = BinaryArray.calculateFixLengthPartSize(elementType) + + val writeCode = elements.zipWithIndex.map { + case (element, idx) => + s""" + |${element.code} + |if (${element.nullTerm}) { + | ${binaryArraySetNull(idx, writerTerm, elementType)}; + |} else { + | ${binaryWriterWriteField(ctx, idx, element.resultTerm, writerTerm, elementType)}; + |} + """.stripMargin + }.mkString("\n") + + val code = + s""" + |$writerTerm.reset(); + |$writeCode + |$writerTerm.complete(); + """.stripMargin + + val memberStmt = + s""" + |$BINARY_ARRAY $arrayTerm = new $BINARY_ARRAY(); + |$writerCls $writerTerm = new $writerCls($arrayTerm, ${elements.length}, $elementSize); + """.stripMargin + + ctx.addReusableMember(memberStmt) + GeneratedExpression(arrayTerm, GeneratedExpression.NEVER_NULL, code, arrayType) + } + + def generateArrayElementAt( + ctx: CodeGeneratorContext, + array: GeneratedExpression, + index: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val componentInfo = array.resultType.asInstanceOf[ArrayType].getElementType + val resultTypeTerm = primitiveTypeTermForType(componentInfo) + val defaultTerm = primitiveDefaultValue(componentInfo) + + val idxStr = s"${index.resultTerm} - 1" + val arrayIsNull = s"${array.resultTerm}.isNullAt($idxStr)" + val arrayGet = + baseRowFieldReadAccess(ctx, idxStr, array.resultTerm, componentInfo) + + val arrayAccessCode = + s""" + |${array.code} + |${index.code} + |boolean $nullTerm = ${array.nullTerm} || ${index.nullTerm} || $arrayIsNull; + |$resultTypeTerm $resultTerm = $nullTerm ? $defaultTerm : $arrayGet; + |""".stripMargin + + GeneratedExpression(resultTerm, nullTerm, arrayAccessCode, componentInfo) + } + + def generateArrayElement( + ctx: CodeGeneratorContext, + array: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val resultType = array.resultType.asInstanceOf[ArrayType].getElementType + val resultTypeTerm = primitiveTypeTermForType(resultType) + val defaultValue = primitiveDefaultValue(resultType) + + val arrayLengthCode = s"${array.nullTerm} ? 0 : ${array.resultTerm}.numElements()" + + val arrayGet = baseRowFieldReadAccess(ctx, 0, array.resultTerm, resultType) + val arrayAccessCode = + s""" + |${array.code} + |boolean $nullTerm; + |$resultTypeTerm $resultTerm; + |switch ($arrayLengthCode) { + | case 0: + | $nullTerm = true; + | $resultTerm = $defaultValue; + | break; + | case 1: + | $nullTerm = ${array.resultTerm}.isNullAt(0); + | $resultTerm = $nullTerm ? $defaultValue : $arrayGet; + | break; + | default: + | throw new RuntimeException("Array has more than one element."); + |} + |""".stripMargin + + GeneratedExpression(resultTerm, nullTerm, arrayAccessCode, resultType) + } + + def generateArrayCardinality( + ctx: CodeGeneratorContext, + array: GeneratedExpression) + : GeneratedExpression = { + generateUnaryOperatorIfNotNull(ctx, InternalTypes.INT, array) { + _ => s"${array.resultTerm}.numElements()" + } + } + + def generateMap( + ctx: CodeGeneratorContext, + resultType: InternalType, + elements: Seq[GeneratedExpression]): GeneratedExpression = { + + checkArgument(resultType.isInstanceOf[MapType]) + val mapType = resultType.asInstanceOf[MapType] + val mapTerm = newName("map") + + // prepare map key array + val keyElements = elements.grouped(2).map { case Seq(key, _) => key }.toSeq + val keyType = mapType.getKeyType + val keyExpr = generateArray(ctx, InternalTypes.createArrayType(keyType), keyElements) + val isKeyFixLength = keyType.isInstanceOf[PrimitiveType] + + // prepare map value array + val valueElements = elements.grouped(2).map { case Seq(_, value) => value }.toSeq + val valueType = mapType.getValueType + val valueExpr = generateArray(ctx, InternalTypes.createArrayType(valueType), valueElements) + val isValueFixLength = valueType.isInstanceOf[PrimitiveType] + + // construct binary map + ctx.addReusableMember(s"$BINARY_MAP $mapTerm = null;") + + val code = if (isKeyFixLength && isValueFixLength) { + // the key and value are fixed length, initialize and reuse the map in constructor + val init = s"$mapTerm = $BINARY_MAP.valueOf(${keyExpr.resultTerm}, ${valueExpr.resultTerm});" + ctx.addReusableInitStatement(init) + // there are some non-literal primitive fields need to update + val keyArrayTerm = newName("keyArray") + val valueArrayTerm = newName("valueArray") + val keyUpdate = generatePrimitiveArrayUpdateCode( + ctx, keyArrayTerm, keyType, keyElements) + val valueUpdate = generatePrimitiveArrayUpdateCode( + ctx, valueArrayTerm, valueType, valueElements) + s""" + |$BINARY_ARRAY $keyArrayTerm = $mapTerm.keyArray(); + |$keyUpdate + |$BINARY_ARRAY $valueArrayTerm = $mapTerm.valueArray(); + |$valueUpdate + """.stripMargin + } else { + // the key or value is not fixed length, re-create the map on every update + s""" + |${keyExpr.code} + |${valueExpr.code} + |$mapTerm = $BINARY_MAP.valueOf(${keyExpr.resultTerm}, ${valueExpr.resultTerm}); + """.stripMargin + } + GeneratedExpression(mapTerm, NEVER_NULL, code, resultType) + } + + def generateMapGet( + ctx: CodeGeneratorContext, + map: GeneratedExpression, + key: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames("result", "isNull") + val tmpKey = newName("key") + val length = newName("length") + val keys = newName("keys") + val values = newName("values") + val index = newName("index") + val found = newName("found") + + val mapType = map.resultType.asInstanceOf[MapType] + val keyType = mapType.getKeyType + val valueType = mapType.getValueType + + // use primitive for key as key is not null + val keyTypeTerm = primitiveTypeTermForType(keyType) + val valueTypeTerm = primitiveTypeTermForType(valueType) + val valueDefault = primitiveDefaultValue(valueType) + + val mapTerm = map.resultTerm + + val equal = generateEquals(ctx, key, GeneratedExpression(tmpKey, NEVER_NULL, NO_CODE, keyType)) + val code = + s""" + |final int $length = $mapTerm.numElements(); + |final $BINARY_ARRAY $keys = $mapTerm.keyArray(); + |final $BINARY_ARRAY $values = $mapTerm.valueArray(); + | + |int $index = 0; + |boolean $found = false; + |if (${key.nullTerm}) { + | while ($index < $length && !$found) { + | if ($keys.isNullAt($index)) { + | $found = true; + | } else { + | $index++; + | } + | } + |} else { + | while ($index < $length && !$found) { + | final $keyTypeTerm $tmpKey = ${baseRowFieldReadAccess(ctx, index, keys, keyType)}; + | ${equal.code} + | if (${equal.resultTerm}) { + | $found = true; + | } else { + | $index++; + | } + | } + |} + | + |if (!$found || $values.isNullAt($index)) { + | $nullTerm = true; + |} else { + | $resultTerm = ${baseRowFieldReadAccess(ctx, index, values, valueType)}; + |} + """.stripMargin + + val accessCode = + s""" + |${map.code} + |${key.code} + |boolean $nullTerm = (${map.nullTerm} || ${key.nullTerm}); + |$valueTypeTerm $resultTerm = $valueDefault; + |if (!$nullTerm) { + | $code + |} + """.stripMargin + + GeneratedExpression(resultTerm, nullTerm, accessCode, valueType) + } + + def generateMapCardinality( + ctx: CodeGeneratorContext, + map: GeneratedExpression): GeneratedExpression = { + generateUnaryOperatorIfNotNull(ctx, InternalTypes.INT, map) { + _ => s"${map.resultTerm}.numElements()" + } + } + + // ---------------------------------------------------------------------------------------- + // private generate utils + // ---------------------------------------------------------------------------------------- + + private def generateCastStringLiteralToDateTime( + ctx: CodeGeneratorContext, + stringLiteral: GeneratedExpression, + expectType: InternalType): GeneratedExpression = { + checkArgument(stringLiteral.literal) + val rightTerm = stringLiteral.resultTerm + val typeTerm = primitiveTypeTermForType(expectType) + val defaultTerm = primitiveDefaultValue(expectType) + val term = newName("stringToTime") + val zoneTerm = ctx.addReusableTimeZone() + val code = stringToInternalCode(expectType, rightTerm, zoneTerm) + val stmt = s"$typeTerm $term = ${stringLiteral.nullTerm} ? $defaultTerm : $code;" + ctx.addReusableMember(stmt) + stringLiteral.copy(resultType = expectType, resultTerm = term) + } + + private def generateCastArrayToString( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + at: ArrayType): GeneratedExpression = + generateStringResultCallWithStmtIfArgsNotNull(ctx, Seq(operand)) { + terms => + val builderCls = classOf[JStringBuilder].getCanonicalName + val builderTerm = newName("builder") + ctx.addReusableMember(s"""$builderCls $builderTerm = new $builderCls();""") + + val arrayTerm = terms.head + + val indexTerm = newName("i") + val numTerm = newName("num") + + val elementType = at.getElementType + val elementCls = primitiveTypeTermForType(elementType) + val elementTerm = newName("element") + val elementNullTerm = newName("isNull") + val elementCode = + s""" + |$elementCls $elementTerm = ${primitiveDefaultValue(elementType)}; + |boolean $elementNullTerm = $arrayTerm.isNullAt($indexTerm); + |if (!$elementNullTerm) { + | $elementTerm = ($elementCls) ${ + baseRowFieldReadAccess(ctx, indexTerm, arrayTerm, elementType)}; + |} + """.stripMargin + val elementExpr = GeneratedExpression( + elementTerm, elementNullTerm, elementCode, elementType) + val castExpr = generateCast(ctx, elementExpr, InternalTypes.STRING) + + val stmt = + s""" + |$builderTerm.setLength(0); + |$builderTerm.append("["); + |int $numTerm = $arrayTerm.numElements(); + |for (int $indexTerm = 0; $indexTerm < $numTerm; $indexTerm++) { + | if ($indexTerm != 0) { + | $builderTerm.append(", "); + | } + | + | ${castExpr.code} + | if (${castExpr.nullTerm}) { + | $builderTerm.append("null"); + | } else { + | $builderTerm.append(${castExpr.resultTerm}); + | } + |} + |$builderTerm.append("]"); + """.stripMargin + (stmt, s"$builderTerm.toString()") + } + + private def generateCastMapToString( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + mt: MapType): GeneratedExpression = + generateStringResultCallWithStmtIfArgsNotNull(ctx, Seq(operand)) { + terms => + val resultTerm = newName("toStringResult") + + val builderCls = classOf[JStringBuilder].getCanonicalName + val builderTerm = newName("builder") + ctx.addReusableMember(s"$builderCls $builderTerm = new $builderCls();") + + val binaryMapTerm = terms.head + val arrayCls = classOf[BinaryArray].getCanonicalName + val keyArrayTerm = newName("keyArray") + val valueArrayTerm = newName("valueArray") + + val indexTerm = newName("i") + val numTerm = newName("num") + + val keyType = mt.getKeyType + val keyCls = primitiveTypeTermForType(keyType) + val keyTerm = newName("key") + val keyNullTerm = newName("isNull") + val keyCode = + s""" + |$keyCls $keyTerm = ${primitiveDefaultValue(keyType)}; + |boolean $keyNullTerm = $keyArrayTerm.isNullAt($indexTerm); + |if (!$keyNullTerm) { + | $keyTerm = ($keyCls) ${ + baseRowFieldReadAccess(ctx, indexTerm, keyArrayTerm, keyType)}; + |} + """.stripMargin + val keyExpr = GeneratedExpression(keyTerm, keyNullTerm, keyCode, keyType) + val keyCastExpr = generateCast(ctx, keyExpr, InternalTypes.STRING) + + val valueType = mt.getValueType + val valueCls = primitiveTypeTermForType(valueType) + val valueTerm = newName("value") + val valueNullTerm = newName("isNull") + val valueCode = + s""" + |$valueCls $valueTerm = ${primitiveDefaultValue(valueType)}; + |boolean $valueNullTerm = $valueArrayTerm.isNullAt($indexTerm); + |if (!$valueNullTerm) { + | $valueTerm = ($valueCls) ${ + baseRowFieldReadAccess(ctx, indexTerm, valueArrayTerm, valueType)}; + |} + """.stripMargin + val valueExpr = GeneratedExpression(valueTerm, valueNullTerm, valueCode, valueType) + val valueCastExpr = generateCast(ctx, valueExpr, InternalTypes.STRING) + + val stmt = + s""" + |String $resultTerm; + |$arrayCls $keyArrayTerm = $binaryMapTerm.keyArray(); + |$arrayCls $valueArrayTerm = $binaryMapTerm.valueArray(); + | + |$builderTerm.setLength(0); + |$builderTerm.append("{"); + | + |int $numTerm = $binaryMapTerm.numElements(); + |for (int $indexTerm = 0; $indexTerm < $numTerm; $indexTerm++) { + | if ($indexTerm != 0) { + | $builderTerm.append(", "); + | } + | + | ${keyCastExpr.code} + | if (${keyCastExpr.nullTerm}) { + | $builderTerm.append("null"); + | } else { + | $builderTerm.append(${keyCastExpr.resultTerm}); + | } + | $builderTerm.append("="); + | + | ${valueCastExpr.code} + | if (${valueCastExpr.nullTerm}) { + | $builderTerm.append("null"); + | } else { + | $builderTerm.append(${valueCastExpr.resultTerm}); + | } + |} + |$builderTerm.append("}"); + | + |$resultTerm = $builderTerm.toString(); + """.stripMargin + (stmt, resultTerm) + } + + private def generateCastBaseRowToString( + ctx: CodeGeneratorContext, + operand: GeneratedExpression, + brt: RowType): GeneratedExpression = + generateStringResultCallWithStmtIfArgsNotNull(ctx, Seq(operand)) { + terms => + val builderCls = classOf[JStringBuilder].getCanonicalName + val builderTerm = newName("builder") + ctx.addReusableMember(s"""$builderCls $builderTerm = new $builderCls();""") + + val rowTerm = terms.head + + val appendCode = brt.getFieldTypes.zipWithIndex.map { + case (elementType, idx) => + val elementCls = primitiveTypeTermForType(elementType) + val elementTerm = newName("element") + val elementExpr = GeneratedExpression( + elementTerm, s"$rowTerm.isNullAt($idx)", + s"$elementCls $elementTerm = ($elementCls) ${baseRowFieldReadAccess( + ctx, idx, rowTerm, elementType)};", elementType) + val castExpr = generateCast(ctx, elementExpr, InternalTypes.STRING) + s""" + |${if (idx != 0) s"""$builderTerm.append(",");""" else ""} + |${castExpr.code} + |if (${castExpr.nullTerm}) { + | $builderTerm.append("null"); + |} else { + | $builderTerm.append(${castExpr.resultTerm}); + |} + """.stripMargin + }.mkString("\n") + + val stmt = + s""" + |$builderTerm.setLength(0); + |$builderTerm.append("("); + |$appendCode + |$builderTerm.append(")"); + """.stripMargin + (stmt, s"$builderTerm.toString()") + } + + private def generateArrayComparison( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = + generateCallWithStmtIfArgsNotNull(ctx, InternalTypes.BOOLEAN, Seq(left, right)) { + args => + val leftTerm = args.head + val rightTerm = args(1) + + val resultTerm = newName("compareResult") + val binaryArrayCls = classOf[BinaryArray].getCanonicalName + + val elementType = left.resultType.asInstanceOf[ArrayType].getElementType + val elementCls = primitiveTypeTermForType(elementType) + val elementDefault = primitiveDefaultValue(elementType) + + val leftElementTerm = newName("leftElement") + val leftElementNullTerm = newName("leftElementIsNull") + val leftElementExpr = + GeneratedExpression(leftElementTerm, leftElementNullTerm, "", elementType) + + val rightElementTerm = newName("rightElement") + val rightElementNullTerm = newName("rightElementIsNull") + val rightElementExpr = + GeneratedExpression(rightElementTerm, rightElementNullTerm, "", elementType) + + val indexTerm = newName("index") + val elementEqualsExpr = generateEquals(ctx, leftElementExpr, rightElementExpr) + + val stmt = + s""" + |boolean $resultTerm; + |if ($leftTerm instanceof $binaryArrayCls && $rightTerm instanceof $binaryArrayCls) { + | $resultTerm = $leftTerm.equals($rightTerm); + |} else { + | if ($leftTerm.numElements() == $rightTerm.numElements()) { + | $resultTerm = true; + | for (int $indexTerm = 0; $indexTerm < $leftTerm.numElements(); $indexTerm++) { + | $elementCls $leftElementTerm = $elementDefault; + | boolean $leftElementNullTerm = $leftTerm.isNullAt($indexTerm); + | if (!$leftElementNullTerm) { + | $leftElementTerm = + | ${baseRowFieldReadAccess(ctx, indexTerm, leftTerm, elementType)}; + | } + | + | $elementCls $rightElementTerm = $elementDefault; + | boolean $rightElementNullTerm = $rightTerm.isNullAt($indexTerm); + | if (!$rightElementNullTerm) { + | $rightElementTerm = + | ${baseRowFieldReadAccess(ctx, indexTerm, rightTerm, elementType)}; + | } + | + | ${elementEqualsExpr.code} + | if (!${elementEqualsExpr.resultTerm}) { + | $resultTerm = false; + | break; + | } + | } + | } else { + | $resultTerm = false; + | } + |} + """.stripMargin + (stmt, resultTerm) + } + + private def generateMapComparison( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression): GeneratedExpression = + generateCallWithStmtIfArgsNotNull(ctx, InternalTypes.BOOLEAN, Seq(left, right)) { + args => + val leftTerm = args.head + val rightTerm = args(1) + val resultTerm = newName("compareResult") + val stmt = s"boolean $resultTerm = $leftTerm.equals($rightTerm);" + (stmt, resultTerm) + } + + // ------------------------------------------------------------------------------------------ + + private def generateUnaryOperatorIfNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + operand: GeneratedExpression, + resultNullable: Boolean = false) + (expr: String => String): GeneratedExpression = { + generateCallIfArgsNotNull(ctx, returnType, Seq(operand), resultNullable) { + args => expr(args.head) + } + } + + private def generateOperatorIfNotNull( + ctx: CodeGeneratorContext, + returnType: InternalType, + left: GeneratedExpression, + right: GeneratedExpression, + resultNullable: Boolean = false) + (expr: (String, String) => String) + : GeneratedExpression = { + generateCallIfArgsNotNull(ctx, returnType, Seq(left, right), resultNullable) { + args => expr(args.head, args(1)) + } + } + + // ---------------------------------------------------------------------------------------------- + + private def internalExprCasting( + expr: GeneratedExpression, + targetType: InternalType) + : GeneratedExpression = { + expr.copy(resultType = targetType) + } + + private def numericCasting( + operandType: InternalType, + resultType: InternalType): String => String = { + + val resultTypeTerm = primitiveTypeTermForType(resultType) + + def decToPrimMethod(targetType: InternalType): String = targetType match { + case InternalTypes.BYTE => "castToByte" + case InternalTypes.SHORT => "castToShort" + case InternalTypes.INT => "castToInt" + case InternalTypes.LONG => "castToLong" + case InternalTypes.FLOAT => "castToFloat" + case InternalTypes.DOUBLE => "castToDouble" + case InternalTypes.BOOLEAN => "castToBoolean" + case _ => throw new CodeGenException(s"Unsupported decimal casting type: '$targetType'") + } + + // no casting necessary + if (operandType == resultType) { + operandTerm => s"$operandTerm" + } + // decimal to decimal, may have different precision/scale + else if (isDecimal(resultType) && isDecimal(operandType)) { + val dt = resultType.asInstanceOf[DecimalType] + operandTerm => + s"$DECIMAL.castToDecimal($operandTerm, ${dt.precision()}, ${dt.scale()})" + } + // non_decimal_numeric to decimal + else if (isDecimal(resultType) && isNumeric(operandType)) { + val dt = resultType.asInstanceOf[DecimalType] + operandTerm => + s"$DECIMAL.castFrom($operandTerm, ${dt.precision()}, ${dt.scale()})" + } + // decimal to non_decimal_numeric + else if (isNumeric(resultType) && isDecimal(operandType) ) { + operandTerm => + s"$DECIMAL.${decToPrimMethod(resultType)}($operandTerm)" + } + // numeric to numeric + // TODO: Create a wrapper layer that handles type conversion between numeric. + else if (isNumeric(operandType) && isNumeric(resultType)) { + val resultTypeValue = resultTypeTerm + "Value()" + val boxedTypeTerm = boxedTypeTermForType(operandType) + operandTerm => + s"(new $boxedTypeTerm($operandTerm)).$resultTypeValue" + } + // result type is time interval and operand type is integer + else if (isTimeInterval(resultType) && isInteger(operandType)){ + operandTerm => s"(($resultTypeTerm) $operandTerm)" + } + else { + throw new CodeGenException(s"Unsupported casting from $operandType to $resultType.") + } + } + + private def stringToInternalCode( + targetType: InternalType, + operandTerm: String, + zoneTerm: String): String = + targetType match { + case InternalTypes.DATE => + s"${qualifyMethod(BuiltInMethod.STRING_TO_DATE.method)}($operandTerm.toString())" + case InternalTypes.TIME => + s"${qualifyMethod(BuiltInMethod.STRING_TO_TIME.method)}($operandTerm.toString())" + case InternalTypes.TIMESTAMP => + s"""${qualifyMethod(BuiltInMethods.STRING_TO_TIMESTAMP)}($operandTerm.toString(), + | $zoneTerm)""".stripMargin + case _ => throw new UnsupportedOperationException + } + + private def internalToStringCode( + fromType: InternalType, + operandTerm: String, + zoneTerm: String): String = + fromType match { + case InternalTypes.DATE => + s"${qualifyMethod(BuiltInMethod.UNIX_DATE_TO_STRING.method)}($operandTerm)" + case InternalTypes.TIME => + s"${qualifyMethod(BuiltInMethods.UNIX_TIME_TO_STRING)}($operandTerm)" + case _: TimestampType => // including rowtime indicator + s"${qualifyMethod(BuiltInMethods.TIMESTAMP_TO_STRING)}($operandTerm, 3, $zoneTerm)" + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/GenericRelDataType.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/GenericRelDataType.scala new file mode 100644 index 0000000000000..b35967a7429ed --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/GenericRelDataType.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.schema + +import org.apache.calcite.rel.`type`.RelDataTypeSystem +import org.apache.calcite.sql.`type`.{BasicSqlType, SqlTypeName} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.`type`.GenericType + +/** + * Generic type for encapsulating Flink's [[TypeInformation]]. + * + * @param genericType InternalType to encapsulate + * @param nullable flag if type can be nullable + * @param typeSystem Flink's type system + */ +class GenericRelDataType( + val genericType: GenericType[_], + val nullable: Boolean, + typeSystem: RelDataTypeSystem) + extends BasicSqlType( + typeSystem, + SqlTypeName.ANY) { + + isNullable = nullable + + override def toString = s"ANY($genericType)" + + def canEqual(other: Any): Boolean = other.isInstanceOf[GenericRelDataType] + + override def equals(other: Any): Boolean = other match { + case that: GenericRelDataType => + super.equals(that) && + (that canEqual this) && + genericType == that.genericType && + nullable == that.nullable + case _ => false + } + + override def hashCode(): Int = { + genericType.hashCode() + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala index 41676a84ac220..fa8c1faae622d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCheckUtils.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.typeutils import org.apache.flink.table.`type`._ +import org.apache.flink.table.codegen.GeneratedExpression object TypeCheckUtils { @@ -29,6 +30,26 @@ object TypeCheckUtils { case _ => false } + def isTemporal(dataType: InternalType): Boolean = + isTimePoint(dataType) || isTimeInterval(dataType) + + def isTimePoint(dataType: InternalType): Boolean = dataType match { + case InternalTypes.INTERVAL_MILLIS | InternalTypes.INTERVAL_MONTHS => false + case _: TimeType | _: DateType | _: TimestampType => true + case _ => false + } + + def isRowTime(dataType: InternalType): Boolean = + dataType == InternalTypes.ROWTIME_INDICATOR + + def isProcTime(dataType: InternalType): Boolean = + dataType == InternalTypes.PROCTIME_INDICATOR + + def isTimeInterval(dataType: InternalType): Boolean = dataType match { + case InternalTypes.INTERVAL_MILLIS | InternalTypes.INTERVAL_MONTHS => true + case _ => false + } + def isString(dataType: InternalType): Boolean = dataType == InternalTypes.STRING def isBinary(dataType: InternalType): Boolean = dataType == InternalTypes.BINARY @@ -51,4 +72,25 @@ object TypeCheckUtils { !dataType.isInstanceOf[RowType] && !isArray(dataType) + def isMutable(dataType: InternalType): Boolean = dataType match { + // the internal representation of String is BinaryString which is mutable + case InternalTypes.STRING => true + case _: ArrayType | _: MapType | _: RowType | _: GenericType[_] => true + case _ => false + } + + def isReference(t: InternalType): Boolean = t match { + case InternalTypes.INT + | InternalTypes.LONG + | InternalTypes.SHORT + | InternalTypes.BYTE + | InternalTypes.FLOAT + | InternalTypes.DOUBLE + | InternalTypes.BOOLEAN + | InternalTypes.CHAR => false + case _ => true + } + + def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCoercion.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCoercion.scala new file mode 100644 index 0000000000000..e559dd2a4bbee --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/typeutils/TypeCoercion.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils + +import org.apache.flink.table.`type`.{DecimalType, InternalType, InternalTypes, TimestampType} +import org.apache.flink.table.typeutils.TypeCheckUtils._ + +/** + * Utilities for type conversions. + */ +object TypeCoercion { + + val numericWideningPrecedence: IndexedSeq[InternalType] = + IndexedSeq( + InternalTypes.BYTE, + InternalTypes.SHORT, + InternalTypes.INT, + InternalTypes.LONG, + InternalTypes.FLOAT, + InternalTypes.DOUBLE) + + def widerTypeOf(tp1: InternalType, tp2: InternalType): Option[InternalType] = { + (tp1, tp2) match { + case (ti1, ti2) if ti1 == ti2 => Some(ti1) + + case (_, InternalTypes.STRING) => Some(InternalTypes.STRING) + case (InternalTypes.STRING, _) => Some(InternalTypes.STRING) + + case (_, dt: DecimalType) => Some(dt) + case (dt: DecimalType, _) => Some(dt) + + case (a, b) if isTimePoint(a) && isTimeInterval(b) => Some(a) + case (a, b) if isTimeInterval(a) && isTimePoint(b) => Some(b) + + case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => + val higherIndex = numericWideningPrecedence.lastIndexWhere(t => t == tp1 || t == tp2) + Some(numericWideningPrecedence(higherIndex)) + + case _ => None + } + } + + /** + * Test if we can do cast safely without lose of type. + */ + def canSafelyCast(from: InternalType, to: InternalType): Boolean = (from, to) match { + case (_, InternalTypes.STRING) => true + + case (a, _: DecimalType) if isNumeric(a) => true + + case tuple if tuple.productIterator.forall(numericWideningPrecedence.contains) => + if (numericWideningPrecedence.indexOf(from) < numericWideningPrecedence.indexOf(to)) { + true + } else { + false + } + + case _ => false + } + + /** + * All the supported cast types in flink-table. + * + * Note: No distinction between explicit and implicit conversions + * Note: This is a subset of SqlTypeAssignmentRule + * Note: This may lose type during the cast. + */ + def canCast(from: InternalType, to: InternalType): Boolean = (from, to) match { + case (fromTp, toTp) if fromTp == toTp => true + + case (_, InternalTypes.STRING) => true + + case (_, InternalTypes.CHAR) => false // Character type not supported. + + case (InternalTypes.STRING, b) if isNumeric(b) => true + case (InternalTypes.STRING, InternalTypes.BOOLEAN) => true + case (InternalTypes.STRING, _: DecimalType) => true + case (InternalTypes.STRING, InternalTypes.DATE) => true + case (InternalTypes.STRING, InternalTypes.TIME) => true + case (InternalTypes.STRING, _: TimestampType) => true + + case (InternalTypes.BOOLEAN, b) if isNumeric(b) => true + case (InternalTypes.BOOLEAN, _: DecimalType) => true + case (a, InternalTypes.BOOLEAN) if isNumeric(a) => true + case (_: DecimalType, InternalTypes.BOOLEAN) => true + + case (a, b) if isNumeric(a) && isNumeric(b) => true + case (a, _: DecimalType) if isNumeric(a) => true + case (_: DecimalType, b) if isNumeric(b) => true + case (_: DecimalType, _: DecimalType) => true + case (InternalTypes.INT, InternalTypes.DATE) => true + case (InternalTypes.INT, InternalTypes.TIME) => true + case (InternalTypes.BYTE, _: TimestampType) => true + case (InternalTypes.SHORT, _: TimestampType) => true + case (InternalTypes.INT, _: TimestampType) => true + case (InternalTypes.LONG, _: TimestampType) => true + case (InternalTypes.DOUBLE, _: TimestampType) => true + case (InternalTypes.FLOAT, _: TimestampType) => true + case (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) => true + case (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) => true + + case (InternalTypes.DATE, InternalTypes.TIME) => false + case (InternalTypes.TIME, InternalTypes.DATE) => false + case (a, b) if isTimePoint(a) && isTimePoint(b) => true + case (InternalTypes.DATE, InternalTypes.INT) => true + case (InternalTypes.TIME, InternalTypes.INT) => true + case (_: TimestampType, InternalTypes.BYTE) => true + case (_: TimestampType, InternalTypes.INT) => true + case (_: TimestampType, InternalTypes.SHORT) => true + case (_: TimestampType, InternalTypes.LONG) => true + case (_: TimestampType, InternalTypes.DOUBLE) => true + case (_: TimestampType, InternalTypes.FLOAT) => true + + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) => true + case (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) => true + + case _ => false + } + + /** + * All the supported reinterpret types in flink-table. + */ + def canReinterpret(from: InternalType, to: InternalType): Boolean = (from, to) match { + case (fromTp, toTp) if fromTp == toTp => true + + case (InternalTypes.DATE, InternalTypes.INT) => true + case (InternalTypes.TIME, InternalTypes.INT) => true + case (_: TimestampType, InternalTypes.LONG) => true + case (InternalTypes.INT, InternalTypes.DATE) => true + case (InternalTypes.INT, InternalTypes.TIME) => true + case (InternalTypes.LONG, _: TimestampType) => true + case (InternalTypes.INT, InternalTypes.INTERVAL_MONTHS) => true + case (InternalTypes.LONG, InternalTypes.INTERVAL_MILLIS) => true + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.INT) => true + case (InternalTypes.INTERVAL_MILLIS, InternalTypes.LONG) => true + + case (InternalTypes.DATE, InternalTypes.LONG) => true + case (InternalTypes.TIME, InternalTypes.LONG) => true + case (InternalTypes.INTERVAL_MONTHS, InternalTypes.LONG) => true + + case _ => false + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ArrayTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ArrayTypeTest.scala new file mode 100644 index 0000000000000..8e8e7994291e3 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ArrayTypeTest.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.ArrayTypeTestBase +import org.junit.Test + +class ArrayTypeTest extends ArrayTypeTestBase { + + @Test + def testArrayLiterals(): Unit = { + // primitive literals + testSqlApi( + "ARRAY[1, 2, 3]", + "[1, 2, 3]") + + testSqlApi( + "ARRAY[TRUE, TRUE, TRUE]", + "[true, true, true]") + + testSqlApi( + "ARRAY[ARRAY[ARRAY[1], ARRAY[1]]]", + "[[[1], [1]]]") + + testSqlApi( + "ARRAY[1 + 1, 3 * 3]", + "[2, 9]") + + testSqlApi( + "ARRAY[NULLIF(1,1), 1]", + "[null, 1]") + + testSqlApi( + "ARRAY[ARRAY[NULLIF(1,1), 1]]", + "[[null, 1]]") + + testSqlApi( + "ARRAY[DATE '1985-04-11', DATE '2018-07-26']", + "[1985-04-11, 2018-07-26]") + + testSqlApi( + "ARRAY[TIME '14:15:16', TIME '17:18:19']", + "[14:15:16, 17:18:19]") + + testSqlApi( + "ARRAY[TIMESTAMP '1985-04-11 14:15:16', TIMESTAMP '2018-07-26 17:18:19']", + "[1985-04-11 14:15:16.000, 2018-07-26 17:18:19.000]") + + testSqlApi( + "ARRAY[CAST(2.0002 AS DECIMAL(10,4)), CAST(2.0003 AS DECIMAL(10,4))]", + "[2.0002, 2.0003]") + + testSqlApi( + "ARRAY[ARRAY[TRUE]]", + "[[true]]") + + testSqlApi( + "ARRAY[ARRAY[1, 2, 3], ARRAY[3, 2, 1]]", + "[[1, 2, 3], [3, 2, 1]]") + + // implicit type cast only works on SQL APIs. + testSqlApi( + "ARRAY[CAST(1 AS DOUBLE), CAST(2 AS FLOAT)]", + "[1.0, 2.0]") + } + + @Test + def testArrayField(): Unit = { + testSqlApi( + "ARRAY[f0, f1]", + "[null, 42]") + + testSqlApi( + "ARRAY[f0, f1]", + "[null, 42]") + + testSqlApi( + "f2", + "[1, 2, 3]") + + testSqlApi( + "f3", + "[1984-03-12, 1984-02-10]") + + testSqlApi( + "f5", + "[[1, 2, 3], null]") + + testSqlApi( + "f6", + "[1, null, null, 4]") + + testSqlApi( + "f2", + "[1, 2, 3]") + + testSqlApi( + "f2[1]", + "1") + + testSqlApi( + "f3[1]", + "1984-03-12") + + testSqlApi( + "f3[2]", + "1984-02-10") + + testSqlApi( + "f5[1][2]", + "2") + + testSqlApi( + "f5[2][2]", + "null") + + testSqlApi( + "f4[2][2]", + "null") + + testSqlApi( + "f11[1]", + "1") + } + + @Test + def testArrayOperations(): Unit = { + // cardinality + testSqlApi( + "CARDINALITY(f2)", + "3") + + testSqlApi( + "CARDINALITY(f4)", + "null") + + testSqlApi( + "CARDINALITY(f11)", + "1") + + // element + testSqlApi( + "ELEMENT(f9)", + "1") + + testSqlApi( + "ELEMENT(f8)", + "4.0") + + testSqlApi( + "ELEMENT(f10)", + "null") + + testSqlApi( + "ELEMENT(f4)", + "null") + + testSqlApi( + "ELEMENT(f11)", + "1") + + // comparison + testSqlApi( + "f2 = f5[1]", + "true") + + testSqlApi( + "f6 = ARRAY[1, 2, 3]", + "false") + + testSqlApi( + "f2 <> f5[1]", + "false") + + testSqlApi( + "f2 = f7", + "false") + + testSqlApi( + "f2 <> f7", + "true") + + testSqlApi( + "f11 = f11", + "true") + + testSqlApi( + "f11 = f9", + "true") + + testSqlApi( + "f11 <> f11", + "false") + + testSqlApi( + "f11 <> f9", + "false") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala new file mode 100644 index 0000000000000..c77132b78fbbf --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/CompositeAccessTest.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.CompositeTypeTestBase +import org.junit.Test + +class CompositeAccessTest extends CompositeTypeTestBase { + + @Test + def testGetField(): Unit = { + + // single field by string key + testSqlApi( + "testTable.f0.intField", + "42") + testSqlApi("f0.intField", "42") + + testSqlApi("testTable.f0.stringField", "Bob") + testSqlApi("f0.stringField", "Bob") + + testSqlApi("testTable.f0.booleanField", "true") + testSqlApi("f0.booleanField", "true") + + // nested single field + testSqlApi( + "testTable.f1.objectField.intField", + "25") + testSqlApi("f1.objectField.intField", "25") + + testSqlApi("testTable.f1.objectField.stringField", "Timo") + testSqlApi("f1.objectField.stringField", "Timo") + + testSqlApi("testTable.f1.objectField.booleanField", "false") + testSqlApi("f1.objectField.booleanField", "false") + + testSqlApi( + "testTable.f2._1", + "a") + testSqlApi("f2._1", "a") + + testSqlApi("testTable.f3.f1", "b") + testSqlApi("f3.f1", "b") + + testSqlApi("testTable.f4.myString", "Hello") + testSqlApi("f4.myString", "Hello") + + testSqlApi("testTable.f5", "13") + testSqlApi("f5", "13") + + testSqlApi( + "testTable.f7._1", + "true") + + // composite field return type + testSqlApi("testTable.f6", "MyCaseClass2(null)") + testSqlApi("f6", "MyCaseClass2(null)") + + // MyCaseClass is converted to BaseRow + // so the result of "toString" does'nt contain MyCaseClass prefix + testSqlApi( + "testTable.f1.objectField", + "(25,Timo,false)") + testSqlApi("f1.objectField", "(25,Timo,false)") + + testSqlApi( + "testTable.f0", + "(42,Bob,true)") + testSqlApi("f0", "(42,Bob,true)") + + // flattening (test base only returns first column) + testSqlApi( + "testTable.f1.objectField.*", + "25") + testSqlApi("f1.objectField.*", "25") + + testSqlApi( + "testTable.f0.*", + "42") + testSqlApi("f0.*", "42") + + // array of composites + testSqlApi( + "f8[1]._1", + "true" + ) + + testSqlApi( + "f8[1]._2", + "23" + ) + + testSqlApi( + "f9[2]._1", + "null" + ) + + testSqlApi( + "f10[1].stringField", + "Bob" + ) + + testSqlApi( + "f11[1].myString", + "Hello" + ) + + testSqlApi( + "f11[2]", + "null" + ) + + testSqlApi( + "f12[1].arrayField[1].stringField", + "Alice" + ) + + testSqlApi( + "f13[1].objectField.stringField", + "Bob" + ) + } +} + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala new file mode 100644 index 0000000000000..06608fc2a5593 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.expressions.utils.ExpressionTestBase +import org.apache.flink.table.typeutils.DecimalTypeInfo +import org.apache.flink.types.Row +import org.junit.Test + +class DecimalTypeTest extends ExpressionTestBase { + + @Test + def testDecimalLiterals(): Unit = { + // implicit double + testSqlApi( + "11.2", + "11.2") + + // implicit double + testSqlApi( + "0.7623533651719233", + "0.7623533651719233") + + // explicit decimal (with precision of 19) + testSqlApi( + "1234567891234567891", + "1234567891234567891") + } + + @Test + def testDecimalBorders(): Unit = { + testSqlApi( + Double.MaxValue.toString, + Double.MaxValue.toString) + + testSqlApi( + Double.MinValue.toString, + Double.MinValue.toString) + + testSqlApi( + s"CAST(${Double.MinValue} AS FLOAT)", + Float.NegativeInfinity.toString) + + testSqlApi( + s"CAST(${Byte.MinValue} AS TINYINT)", + Byte.MinValue.toString) + + testSqlApi( + s"CAST(${Byte.MinValue} AS TINYINT) - CAST(1 AS TINYINT)", + Byte.MaxValue.toString) + + testSqlApi( + s"CAST(${Short.MinValue} AS SMALLINT)", + Short.MinValue.toString) + + testSqlApi( + s"CAST(${Int.MinValue} AS INT) - 1", + Int.MaxValue.toString) + + testSqlApi( + s"CAST(${Long.MinValue} AS BIGINT)", + Long.MinValue.toString) + } + + @Test + def testDecimalCasting(): Unit = { + // from String + testSqlApi( + "CAST('123456789123456789123456789' AS DECIMAL(27, 0))", + "123456789123456789123456789") + + // from double + testSqlApi( + "CAST(f3 AS DECIMAL)", + "4") + + testSqlApi( + "CAST(f3 AS DECIMAL(10,2))", + "4.20" + ) + + // to double + testSqlApi( + "CAST(f0 AS DOUBLE)", + "1.2345678912345679E8") + + // to int + testSqlApi( + "CAST(f4 AS INT)", + "123456789") + + // to long + testSqlApi( + "CAST(f4 AS BIGINT)", + "123456789") + } + + @Test + def testDecimalArithmetic(): Unit = { + + // note: calcite type inference: + // Decimal+ExactNumeric => Decimal + // Decimal+Double => Double. + + // implicit cast to decimal + testSqlApi( + "f1 + 12", + "123456789123456789123456801") + + // implicit cast to decimal + testSqlApi( + "12 + f1", + "123456789123456789123456801") + + testSqlApi( + "f1 + 12.3", + "123456789123456789123456801.3" + ) + + testSqlApi( + "12.3 + f1", + "123456789123456789123456801.3") + + testSqlApi( + "f1 + f1", + "246913578246913578246913578") + + testSqlApi( + "f1 - f1", + "0") + + testSqlApi( + "f1 / f1", + "1.00000000") + + testSqlApi( + "MOD(f1, f1)", + "0") + + testSqlApi( + "-f0", + "-123456789.123456789123456789") + } + + @Test + def testDecimalComparison(): Unit = { + testSqlApi( + "f1 < 12", + "false") + + testSqlApi( + "f1 > 12", + "true") + + testSqlApi( + "f1 = 12", + "false") + + testSqlApi( + "f5 = 0", + "true") + + testSqlApi( + "f1 = CAST('123456789123456789123456789' AS DECIMAL(30, 0))", + "true") + + testSqlApi( + "f1 <> CAST('123456789123456789123456789' AS DECIMAL(30, 0))", + "false") + + testSqlApi( + "f4 < f0", + "true") + + // TODO add all tests if FLINK-4070 is fixed + testSqlApi( + "12 < f1", + "true") + } + + // ---------------------------------------------------------------------------------------------- + + override def testData: Row = { + val testData = new Row(6) + testData.setField(0, Decimal.castFrom("123456789.123456789123456789", 30, 18)) + testData.setField(1, Decimal.castFrom("123456789123456789123456789", 30, 0)) + testData.setField(2, 42) + testData.setField(3, 4.2) + testData.setField(4, Decimal.castFrom("123456789", 10, 0)) + testData.setField(5, Decimal.castFrom("0.000", 10, 3)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + DecimalTypeInfo.of(30, 18), + DecimalTypeInfo.of(30, 0), + Types.INT, + Types.DOUBLE, + DecimalTypeInfo.of(10, 0), + DecimalTypeInfo.of(10, 3)) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/LiteralTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/LiteralTest.scala new file mode 100644 index 0000000000000..f1c4c25527552 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/LiteralTest.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.api.common.typeinfo.{TypeInformation, Types} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.expressions.utils.ExpressionTestBase +import org.apache.flink.types.Row +import org.junit.{Ignore, Test} + +class LiteralTest extends ExpressionTestBase { + + @Test + def testFieldWithBooleanPrefix(): Unit = { + + testSqlApi( + "trUeX", + "trUeX_value" + ) + + testSqlApi( + "FALSE_A", + "FALSE_A_value" + ) + + testSqlApi( + "FALSE_AB", + "FALSE_AB_value" + ) + + testSqlApi( + "trUe", + "true" + ) + + testSqlApi( + "FALSE", + "false" + ) + } + + @Ignore("TODO: FLINK-11898") + @Test + def testNonAsciiLiteral(): Unit = { + testSqlApi( + "f4 LIKE '%测试%'", + "true") + + testSqlApi( + "'Абвгде' || '谢谢'", + "Абвгде谢谢") + } + + @Ignore("TODO: FLINK-11898") + @Test + def testDoubleQuote(): Unit = { + val hello = "\"\"" + testSqlApi( + s"concat('a', '$hello')", + s"a and $hello") + } + + @Test + def testStringLiterals(): Unit = { + + // these tests use Java/Scala escaping for non-quoting unicode characters + + testSqlApi( + "'>\n<'", + ">\n<") + + testSqlApi( + "'>\u263A<'", + ">\u263A<") + + testSqlApi( + "'>\u263A<'", + ">\u263A<") + + testSqlApi( + "'>\\<'", + ">\\<") + + testSqlApi( + "'>''<'", + ">'<") + + testSqlApi( + "' '", + " ") + + testSqlApi( + "''", + "") + + testSqlApi( + "'>foo([\\w]+)<'", + ">foo([\\w]+)<") + + testSqlApi( + "'>\\''\n<'", + ">\\'\n<") + + testSqlApi( + "'It''s me.'", + "It's me.") + + // these test use SQL for describing unicode characters + + testSqlApi( + "U&'>\\263A<'", // default escape backslash + ">\u263A<") + + testSqlApi( + "U&'>#263A<' UESCAPE '#'", // custom escape '#' + ">\u263A<") + + testSqlApi( + """'>\\<'""", + ">\\\\<") + } + + override def testData: Row = { + val testData = new Row(4) + testData.setField(0, "trUeX_value") + testData.setField(1, "FALSE_A_value") + testData.setField(2, "FALSE_AB_value") + testData.setField(3, "这是个测试字符串") + testData + } + + override def typeInfo : RowTypeInfo = { + new RowTypeInfo( + Array( + Types.STRING, + Types.STRING, + Types.STRING, + Types.STRING + ).asInstanceOf[Array[TypeInformation[_]]], + Array("trUeX", "FALSE_A", "FALSE_AB", "f4") + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MapTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MapTypeTest.scala new file mode 100644 index 0000000000000..a8d2729dde778 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MapTypeTest.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.MapTypeTestBase +import org.junit.Test + +class MapTypeTest extends MapTypeTestBase { + + @Test + def testItem(): Unit = { + testSqlApi("f0['map is null']", "null") + testSqlApi("f1['map is empty']", "null") + testSqlApi("f2['b']", "13") + testSqlApi("f3[1]", "null") + testSqlApi("f3[12]", "a") + } + + @Test + def testMapLiteral(): Unit = { + // primitive literals + testSqlApi( + "MAP[1, 1]", + "{1=1}") + + testSqlApi( + "map[TRUE, TRUE]", + "{true=true}") + + testSqlApi( + "MAP[MAP[1, 2], MAP[3, 4]]", + "{{1=2}={3=4}}") + + testSqlApi( + "map[1 + 2, 3 * 3, 3 - 6, 4 - 2]", + "{3=9, -3=2}") + + testSqlApi( + "map[1, NULLIF(1,1)]", + "{1=null}") + + // explicit conversion + testSqlApi( + "MAP[1, CAST(2 AS BIGINT), 3, CAST(4 AS BIGINT)]", + "{1=2, 3=4}") + + testSqlApi( + "MAP[DATE '1985-04-11', TIME '14:15:16', DATE '2018-07-26', TIME '17:18:19']", + "{1985-04-11=14:15:16, 2018-07-26=17:18:19}") + + testSqlApi( + "MAP[TIME '14:15:16', TIMESTAMP '1985-04-11 14:15:16', " + + "TIME '17:18:19', TIMESTAMP '2018-07-26 17:18:19']", + "{14:15:16=1985-04-11 14:15:16.000, 17:18:19=2018-07-26 17:18:19.000}") + + testSqlApi( + "MAP[CAST(2.0002 AS DECIMAL(5, 4)), CAST(2.0003 AS DECIMAL(5, 4))]", + "{2.0002=2.0003}") + + // implicit type cast only works on SQL API + testSqlApi( + "MAP['k1', CAST(1 AS DOUBLE), 'k2', CAST(2 AS FLOAT)]", + "{k1=1.0, k2=2.0}") + } + + @Test + def testMapField(): Unit = { + testSqlApi( + "MAP[f4, f5]", + "{foo=12}") + + testSqlApi( + "MAP[f4, f1]", + "{foo={}}") + + testSqlApi( + "MAP[f2, f3]", + "{{a=12, b=13}={12=a, 13=b}}") + + testSqlApi( + "MAP[f1['a'], f5]", + "{null=12}") + + testSqlApi( + "f1", + "{}") + + testSqlApi( + "f2", + "{a=12, b=13}") + + testSqlApi( + "f2['a']", + "12") + + testSqlApi( + "f3[12]", + "a") + + testSqlApi( + "MAP[f4, f3]['foo'][13]", + "b") + } + + @Test + def testMapOperations(): Unit = { + + // comparison + testSqlApi( + "f1 = f2", + "false") + + testSqlApi( + "f3 = f7", + "true") + + testSqlApi( + "f5 = f2['a']", + "true") + + testSqlApi( + "f8 = f9", + "true") + + testSqlApi( + "f10 = f11", + "true") + + testSqlApi( + "f8 <> f9", + "false") + + testSqlApi( + "f10 <> f11", + "false") + + testSqlApi( + "f0['map is null']", + "null") + + testSqlApi( + "f1['map is empty']", + "null") + + testSqlApi( + "f2['b']", + "13") + + testSqlApi( + "f3[1]", + "null") + + testSqlApi( + "f3[12]", + "a") + + testSqlApi( + "CARDINALITY(f3)", + "2") + + testSqlApi( + "f2['a'] IS NOT NULL", + "true") + + testSqlApi( + "f2['a'] IS NULL", + "false") + + testSqlApi( + "f2['c'] IS NOT NULL", + "false") + + testSqlApi( + "f2['c'] IS NULL", + "true") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MathFunctionsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MathFunctionsTest.scala new file mode 100644 index 0000000000000..ae603ab2b1000 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/MathFunctionsTest.scala @@ -0,0 +1,694 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.ScalarTypesTestBase +import org.junit.{Ignore, Test} + +@Ignore("TODO: [FLINK-11898] the math functions will be supported in the future") +class MathFunctionsTest extends ScalarTypesTestBase { + + // ---------------------------------------------------------------------------------------------- + // Math functions + // ---------------------------------------------------------------------------------------------- + + @Test + def testMod(): Unit = { + testSqlApi( + "MOD(f4, f7)", + "2") + + testSqlApi( + "MOD(f4, 3)", + "2") + + testSqlApi( + "MOD(44, 3)", + "2") + } + + @Test + def testExp(): Unit = { + testSqlApi( + "EXP(f2)", + math.exp(42.toByte).toString) + + testSqlApi( + "EXP(f3)", + math.exp(43.toShort).toString) + + testSqlApi( + "EXP(f4)", + math.exp(44.toLong).toString) + + testSqlApi( + "EXP(f5)", + math.exp(4.5.toFloat).toString) + + testSqlApi( + "EXP(f6)", + math.exp(4.6).toString) + + testSqlApi( + "EXP(f7)", + math.exp(3).toString) + + testSqlApi( + "EXP(3)", + math.exp(3).toString) + } + + @Test + def testLog10(): Unit = { + testSqlApi( + "LOG10(f2)", + math.log10(42.toByte).toString) + + testSqlApi( + "LOG10(f3)", + math.log10(43.toShort).toString) + + testSqlApi( + "LOG10(f4)", + math.log10(44.toLong).toString) + + testSqlApi( + "LOG10(f5)", + math.log10(4.5.toFloat).toString) + + testSqlApi( + "LOG10(f6)", + math.log10(4.6).toString) + + testSqlApi( + "LOG10(f32)", + math.log10(-1).toString) + + testSqlApi( + "LOG10(f27)", + math.log10(0).toString) + } + + @Test + def testPower(): Unit = { + // f7: int , f4: long, f6: double + testSqlApi( + "POWER(f2, f7)", + math.pow(42.toByte, 3).toString) + + testSqlApi( + "POWER(f3, f6)", + math.pow(43.toShort, 4.6D).toString) + + testSqlApi( + "POWER(f4, f5)", + math.pow(44.toLong, 4.5.toFloat).toString) + + testSqlApi( + "POWER(f4, f5)", + math.pow(44.toLong, 4.5.toFloat).toString) + + // f5: float + testSqlApi( + "power(f5, f5)", + math.pow(4.5F, 4.5F).toString) + + testSqlApi( + "power(f5, f6)", + math.pow(4.5F, 4.6D).toString) + + testSqlApi( + "power(f5, f7)", + math.pow(4.5F, 3).toString) + + testSqlApi( + "power(f5, f4)", + math.pow(4.5F, 44L).toString) + + // f22: bigDecimal + // TODO delete casting in SQL when CALCITE-1467 is fixed + testSqlApi( + "power(CAST(f22 AS DOUBLE), f5)", + math.pow(2, 4.5F).toString) + + testSqlApi( + "power(CAST(f22 AS DOUBLE), f6)", + math.pow(2, 4.6D).toString) + + testSqlApi( + "power(CAST(f22 AS DOUBLE), f7)", + math.pow(2, 3).toString) + + testSqlApi( + "power(CAST(f22 AS DOUBLE), f4)", + math.pow(2, 44L).toString) + + testSqlApi( + "power(f6, f22)", + math.pow(4.6D, 2).toString) + } + + @Test + def testSqrt(): Unit = { + testSqlApi( + "SQRT(f6)", + math.sqrt(4.6D).toString) + + testSqlApi( + "SQRT(f7)", + math.sqrt(3).toString) + + testSqlApi( + "SQRT(f4)", + math.sqrt(44L).toString) + + testSqlApi( + "SQRT(CAST(f22 AS DOUBLE))", + math.sqrt(2.0).toString) + + testSqlApi( + "SQRT(f5)", + math.pow(4.5F, 0.5).toString) + + testSqlApi( + "SQRT(25)", + "5.0") + + testSqlApi( + "POWER(CAST(2.2 AS DOUBLE), CAST(0.5 AS DOUBLE))", // TODO fix FLINK-4621 + math.sqrt(2.2).toString) + } + + @Test + def testLn(): Unit = { + testSqlApi( + "LN(f2)", + math.log(42.toByte).toString) + + testSqlApi( + "LN(f3)", + math.log(43.toShort).toString) + + testSqlApi( + "LN(f4)", + math.log(44.toLong).toString) + + testSqlApi( + "LN(f5)", + math.log(4.5.toFloat).toString) + + testSqlApi( + "LN(f6)", + math.log(4.6).toString) + + testSqlApi( + "LN(f32)", + math.log(-1).toString) + + testSqlApi( + "LN(f27)", + math.log(0).toString) + } + + @Test + def testAbs(): Unit = { + testSqlApi( + "ABS(f2)", + "42") + + testSqlApi( + "ABS(f3)", + "43") + + testSqlApi( + "ABS(f4)", + "44") + + testSqlApi( + "ABS(f5)", + "4.5") + + testSqlApi( + "ABS(f6)", + "4.6") + + testSqlApi( + "ABS(f9)", + "42") + + testSqlApi( + "ABS(f10)", + "43") + + testSqlApi( + "ABS(f11)", + "44") + + testSqlApi( + "ABS(f12)", + "4.5") + + testSqlApi( + "ABS(f13)", + "4.6") + + testSqlApi( + "ABS(f15)", + "1231.1231231321321321111") + } + + @Test + def testArithmeticFloorCeil(): Unit = { + testSqlApi( + "FLOOR(f5)", + "4.0") + + testSqlApi( + "CEIL(f5)", + "5.0") + + testSqlApi( + "FLOOR(f3)", + "43") + + testSqlApi( + "CEIL(f3)", + "43") + + testSqlApi( + "FLOOR(f15)", + "-1232") + + testSqlApi( + "CEIL(f15)", + "-1231") + } + + @Test + def testSin(): Unit = { + testSqlApi( + "SIN(f2)", + math.sin(42.toByte).toString) + + testSqlApi( + "SIN(f3)", + math.sin(43.toShort).toString) + + testSqlApi( + "SIN(f4)", + math.sin(44.toLong).toString) + + testSqlApi( + "SIN(f5)", + math.sin(4.5.toFloat).toString) + + testSqlApi( + "SIN(f6)", + math.sin(4.6).toString) + + testSqlApi( + "SIN(f15)", + math.sin(-1231.1231231321321321111).toString) + } + + @Test + def testCos(): Unit = { + testSqlApi( + "COS(f2)", + math.cos(42.toByte).toString) + + testSqlApi( + "COS(f3)", + math.cos(43.toShort).toString) + + testSqlApi( + "COS(f4)", + math.cos(44.toLong).toString) + + testSqlApi( + "COS(f5)", + math.cos(4.5.toFloat).toString) + + testSqlApi( + "COS(f6)", + math.cos(4.6).toString) + + testSqlApi( + "COS(f15)", + math.cos(-1231.1231231321321321111).toString) + } + + @Test + def testTan(): Unit = { + testSqlApi( + "TAN(f2)", + math.tan(42.toByte).toString) + + testSqlApi( + "TAN(f3)", + math.tan(43.toShort).toString) + + testSqlApi( + "TAN(f4)", + math.tan(44.toLong).toString) + + testSqlApi( + "TAN(f5)", + math.tan(4.5.toFloat).toString) + + testSqlApi( + "TAN(f6)", + math.tan(4.6).toString) + + testSqlApi( + "TAN(f15)", + math.tan(-1231.1231231321321321111).toString) + } + + @Test + def testCot(): Unit = { + testSqlApi( + "COT(f2)", + (1.0d / math.tan(42.toByte)).toString) + + testSqlApi( + "COT(f3)", + (1.0d / math.tan(43.toShort)).toString) + + testSqlApi( + "COT(f4)", + (1.0d / math.tan(44.toLong)).toString) + + testSqlApi( + "COT(f5)", + (1.0d / math.tan(4.5.toFloat)).toString) + + testSqlApi( + "COT(f6)", + (1.0d / math.tan(4.6)).toString) + + testSqlApi( + "COT(f15)", + (1.0d / math.tan(-1231.1231231321321321111)).toString) + } + + @Test + def testAsin(): Unit = { + testSqlApi( + "ASIN(f25)", + math.asin(0.42.toByte).toString) + + testSqlApi( + "ASIN(f26)", + math.asin(0.toShort).toString) + + testSqlApi( + "ASIN(f27)", + math.asin(0.toLong).toString) + + testSqlApi( + "ASIN(f28)", + math.asin(0.45.toFloat).toString) + + testSqlApi( + "ASIN(f29)", + math.asin(0.46).toString) + + testSqlApi( + "ASIN(f30)", + math.asin(1).toString) + + testSqlApi( + "ASIN(f31)", + math.asin(-0.1231231321321321111).toString) + } + + @Test + def testAcos(): Unit = { + testSqlApi( + "ACOS(f25)", + math.acos(0.42.toByte).toString) + + testSqlApi( + "ACOS(f26)", + math.acos(0.toShort).toString) + + testSqlApi( + "ACOS(f27)", + math.acos(0.toLong).toString) + + testSqlApi( + "ACOS(f28)", + math.acos(0.45.toFloat).toString) + + testSqlApi( + "ACOS(f29)", + math.acos(0.46).toString) + + testSqlApi( + "ACOS(f30)", + math.acos(1).toString) + + testSqlApi( + "ACOS(f31)", + math.acos(-0.1231231321321321111).toString) + } + + @Test + def testAtan(): Unit = { + testSqlApi( + "ATAN(f25)", + math.atan(0.42.toByte).toString) + + testSqlApi( + "ATAN(f26)", + math.atan(0.toShort).toString) + + testSqlApi( + "ATAN(f27)", + math.atan(0.toLong).toString) + + testSqlApi( + "ATAN(f28)", + math.atan(0.45.toFloat).toString) + + testSqlApi( + "ATAN(f29)", + math.atan(0.46).toString) + + testSqlApi( + "ATAN(f30)", + math.atan(1).toString) + + testSqlApi( + "ATAN(f31)", + math.atan(-0.1231231321321321111).toString) + } + + @Test + def testDegrees(): Unit = { + testSqlApi( + "DEGREES(f2)", + math.toDegrees(42.toByte).toString) + + testSqlApi( + "DEGREES(f3)", + math.toDegrees(43.toShort).toString) + + testSqlApi( + "DEGREES(f4)", + math.toDegrees(44.toLong).toString) + + testSqlApi( + "DEGREES(f5)", + math.toDegrees(4.5.toFloat).toString) + + testSqlApi( + "DEGREES(f6)", + math.toDegrees(4.6).toString) + + testSqlApi( + "DEGREES(f15)", + math.toDegrees(-1231.1231231321321321111).toString) + } + + @Test + def testRadians(): Unit = { + testSqlApi( + "RADIANS(f2)", + math.toRadians(42.toByte).toString) + + testSqlApi( + "RADIANS(f3)", + math.toRadians(43.toShort).toString) + + testSqlApi( + "RADIANS(f4)", + math.toRadians(44.toLong).toString) + + testSqlApi( + "RADIANS(f5)", + math.toRadians(4.5.toFloat).toString) + + testSqlApi( + "RADIANS(f6)", + math.toRadians(4.6).toString) + + testSqlApi( + "RADIANS(f15)", + math.toRadians(-1231.1231231321321321111).toString) + } + + @Test + def testSign(): Unit = { + testSqlApi( + "SIGN(f4)", + 1.toString) + + testSqlApi( + "SIGN(f6)", + 1.0.toString) + + testSqlApi( + "SIGN(f15)", + "-1.0000000000000000000") // calcite: SIGN(Decimal(p,s)) => Decimal(p,s) + } + + @Test + def testRound(): Unit = { + testSqlApi( + "ROUND(f29, f30)", + 0.5.toString) + + testSqlApi( + "ROUND(f31, f7)", + "-0.123") + + testSqlApi( + "ROUND(f4, f32)", + 40.toString) + } + + @Test + def testPi(): Unit = { + testSqlApi( + "pi()", + math.Pi.toString) + } + + @Test + def testRandAndRandInteger(): Unit = { + val random1 = new java.util.Random(1) + testSqlApi( + "RAND(1)", + random1.nextDouble().toString) + + val random2 = new java.util.Random(3) + testSqlApi( + "RAND(f7)", + random2.nextDouble().toString) + + val random3 = new java.util.Random(1) + testSqlApi( + "RAND_INTEGER(1, 10)", + random3.nextInt(10).toString) + + val random4 = new java.util.Random(3) + testSqlApi( + "RAND_INTEGER(f7, CAST(f4 AS INT))", + random4.nextInt(44).toString) + } + + @Test + def testE(): Unit = { + testSqlApi( + "E()", + math.E.toString) + + testSqlApi( + "e()", + math.E.toString) + } + + @Test + def testLog(): Unit = { + testSqlApi( + "LOG(f6)", + "1.5260563034950492" + ) + + testSqlApi( + "LOG(f6-f6 + 10, f6-f6+100)", + "2.0" + ) + + testSqlApi( + "LOG(f6+20)", + "3.202746442938317" + ) + + testSqlApi( + "LOG(10)", + "2.302585092994046" + ) + + testSqlApi( + "LOG(10, 100)", + "2.0" + ) + + testSqlApi( + "log(f32, f32)", + (math.log(-1)/math.log(-1)).toString) + + testSqlApi( + "log(f27, f32)", + (math.log(0)/math.log(0)).toString) + } + + @Test + def testLog2(): Unit = { + testSqlApi( + "log2(f2)", + (math.log(42.toByte)/math.log(2.toByte)).toString) + + testSqlApi( + "log2(f3)", + (math.log(43.toShort)/math.log(2.toShort)).toString) + + testSqlApi( + "log2(f4)", + (math.log(44.toLong)/math.log(2.toLong)).toString) + + testSqlApi( + "log2(f5)", + (math.log(4.5.toFloat)/math.log(2.toFloat)).toString) + + testSqlApi( + "log2(f6)", + (math.log(4.6)/math.log(2)).toString) + + testSqlApi( + "log2(f32)", + (math.log(-1)/math.log(2)).toString) + + testSqlApi( + "log2(f27)", + (math.log(0)/math.log(2)).toString) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/RowTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/RowTypeTest.scala new file mode 100644 index 0000000000000..45b748277f126 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/RowTypeTest.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.RowTypeTestBase +import org.junit.Test + +class RowTypeTest extends RowTypeTestBase { + + @Test + def testRowLiteral(): Unit = { + + // primitive literal + testSqlApi( + "ROW(1, 'foo', true)", + "(1,foo,true)") + + // special literal + testSqlApi( + "ROW(DATE '1985-04-11', TIME '14:15:16', TIMESTAMP '1985-04-11 14:15:16', " + + "CAST(0.1 AS DECIMAL(2, 1)), ARRAY[1, 2, 3], MAP['foo', 'bar'], row(1, true))", + "(1985-04-11,14:15:16,1985-04-11 14:15:16.000,0.1,[1, 2, 3],{foo=bar},(1,true))") + + testSqlApi( + "ROW(1 + 1, 2 * 3, NULLIF(1, 1))", + "(2,6,null)" + ) + + testSqlApi("(1, 'foo', true)", "(1,foo,true)") + } + + @Test + def testRowField(): Unit = { + testSqlApi( + "(f0, f1)", + "(null,1)" + ) + + testSqlApi( + "f2", + "(2,foo,true)" + ) + + testSqlApi( + "(f2, f5)", + "((2,foo,true),(foo,null))" + ) + + testSqlApi( + "f4", + "(1984-03-12,0.00000000,[1, 2, 3])" + ) + + testSqlApi( + "(f1, 'foo',true)", + "(1,foo,true)" + ) + } + + @Test + def testRowOperations(): Unit = { + testSqlApi( + "f5.f0", + "foo" + ) + + testSqlApi( + "f3.f1.f2", + "true" + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala new file mode 100644 index 0000000000000..cd4d875c4b0e7 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/ScalarOperatorsTest.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.table.expressions.utils.ScalarOperatorsTestBase +import org.junit.Test + +class ScalarOperatorsTest extends ScalarOperatorsTestBase { + + @Test + def testIn(): Unit = { + testSqlApi( + "f2 IN (1, 2, 42)", + "true" + ) + + testSqlApi( + "CAST(f0 AS DECIMAL) IN (42.0, 2.00, 3.01, 1.000000)", // SQL would downcast otherwise + "true" + ) + + testSqlApi( + "f10 IN ('This is a test String.', 'String', 'Hello world', 'Comment#1')", + "true" + ) + + testSqlApi( + "f14 IN ('This is a test String.', 'String', 'Hello world')", + "null" + ) + + testSqlApi( + "f15 IN (DATE '1996-11-10')", + "true" + ) + + testSqlApi( + "f15 IN (DATE '1996-11-10', DATE '1996-11-11')", + "true" + ) + + testSqlApi( + "f7 IN (f16, f17)", + "true" + ) + } + + @Test + def testOtherExpressions(): Unit = { + + // nested field null type + testSqlApi("CASE WHEN f13.f1 IS NULL THEN 'a' ELSE 'b' END", "a") + testSqlApi("CASE WHEN f13.f1 IS NOT NULL THEN 'a' ELSE 'b' END", "b") + testSqlApi("f13 IS NULL", "false") + testSqlApi("f13 IS NOT NULL", "true") + testSqlApi("f13.f0 IS NULL", "false") + testSqlApi("f13.f0 IS NOT NULL", "true") + testSqlApi("f13.f1 IS NULL", "true") + testSqlApi("f13.f1 IS NOT NULL", "false") + + // boolean literals + testSqlApi( + "true", + "true") + + testSqlApi( + "fAlse", + "false") + + testSqlApi( + "tRuE", + "true") + + // null + testSqlApi("CAST(NULL AS INT)", "null") + testSqlApi( + "CAST(NULL AS VARCHAR) = ''", + "null") + + // case when + testSqlApi("CASE 11 WHEN 1 THEN 'a' ELSE 'b' END", "b") + testSqlApi("CASE 2 WHEN 1 THEN 'a' ELSE 'b' END", "b") + testSqlApi( + "CASE 1 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "1 or 2") + testSqlApi( + "CASE 2 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "1 or 2") + testSqlApi( + "CASE 3 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "3") + testSqlApi( + "CASE 4 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 " + + "THEN '3' ELSE 'none of the above' END", + "none of the above") + testSqlApi("CASE WHEN 'a'='a' THEN 1 END", "1") + testSqlApi("CASE 2 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "bcd") + testSqlApi("CASE 1 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "a") + testSqlApi("CASE 1 WHEN 1 THEN cast('a' as varchar(1)) WHEN 2 THEN " + + "cast('bcd' as varchar(3)) END", "a") + testSqlApi("CASE f2 WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END", "11") + testSqlApi("CASE f7 WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END", "null") + testSqlApi("CASE 42 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END", "null") + testSqlApi("CASE 1 WHEN 1 THEN true WHEN 2 THEN false ELSE NULL END", "true") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala new file mode 100644 index 0000000000000..08213f45f5f19 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/SqlExpressionTest.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.expressions.utils.ExpressionTestBase +import org.apache.flink.types.Row +import org.junit.{Ignore, Test} + +/** + * Tests all SQL expressions that are currently supported according to the documentation. + * This tests should be kept in sync with the documentation to reduce confusion due to the + * large amount of SQL functions. + * + * The tests do not test every parameter combination of a function. + * They are rather a function existence test and simple functional test. + * + * The tests are split up and ordered like the sections in the documentation. + */ +@Ignore("TODO: [FLINK-11898] the scalar functions will be supported in the future") +class SqlExpressionTest extends ExpressionTestBase { + + @Test + def testComparisonFunctions(): Unit = { + testSqlApi("1 = 1", "true") + testSqlApi("1 <> 1", "false") + testSqlApi("5 > 2", "true") + testSqlApi("2 >= 2", "true") + testSqlApi("5 < 2", "false") + testSqlApi("2 <= 2", "true") + testSqlApi("1 IS NULL", "false") + testSqlApi("1 IS NOT NULL", "true") + testSqlApi("NULLIF(1,1) IS DISTINCT FROM NULLIF(1,1)", "false") + testSqlApi("NULLIF(1,1) IS NOT DISTINCT FROM NULLIF(1,1)", "true") + testSqlApi("NULLIF(1,1) IS NOT DISTINCT FROM NULLIF(1,1)", "true") + testSqlApi("12 BETWEEN 11 AND 13", "true") + testSqlApi("12 BETWEEN ASYMMETRIC 13 AND 11", "false") + testSqlApi("12 BETWEEN SYMMETRIC 13 AND 11", "true") + testSqlApi("12 NOT BETWEEN 11 AND 13", "false") + testSqlApi("12 NOT BETWEEN ASYMMETRIC 13 AND 11", "true") + testSqlApi("12 NOT BETWEEN SYMMETRIC 13 AND 11", "false") + testSqlApi("'TEST' LIKE '%EST'", "true") + //testSqlApi("'%EST' LIKE '.%EST' ESCAPE '.'", "true") // TODO + testSqlApi("'TEST' NOT LIKE '%EST'", "false") + //testSqlApi("'%EST' NOT LIKE '.%EST' ESCAPE '.'", "false") // TODO + testSqlApi("'TEST' SIMILAR TO '.EST'", "true") + //testSqlApi("'TEST' SIMILAR TO ':.EST' ESCAPE ':'", "true") // TODO + testSqlApi("'TEST' NOT SIMILAR TO '.EST'", "false") + //testSqlApi("'TEST' NOT SIMILAR TO ':.EST' ESCAPE ':'", "false") // TODO + testSqlApi("'TEST' IN ('west', 'TEST', 'rest')", "true") + testSqlApi("'TEST' IN ('west', 'rest')", "false") + testSqlApi("'TEST' NOT IN ('west', 'TEST', 'rest')", "false") + testSqlApi("'TEST' NOT IN ('west', 'rest')", "true") + + // sub-query functions are not listed here + } + + @Test + def testLogicalFunctions(): Unit = { + testSqlApi("TRUE OR FALSE", "true") + testSqlApi("TRUE AND FALSE", "false") + testSqlApi("NOT TRUE", "false") + testSqlApi("TRUE IS FALSE", "false") + testSqlApi("TRUE IS NOT FALSE", "true") + testSqlApi("TRUE IS TRUE", "true") + testSqlApi("TRUE IS NOT TRUE", "false") + testSqlApi("NULLIF(TRUE,TRUE) IS UNKNOWN", "true") + testSqlApi("NULLIF(TRUE,TRUE) IS NOT UNKNOWN", "false") + } + + @Test + def testArithmeticFunctions(): Unit = { + testSqlApi("+5", "5") + testSqlApi("-5", "-5") + testSqlApi("5+5", "10") + testSqlApi("5-5", "0") + testSqlApi("5*5", "25") + testSqlApi("5/5", "1.0") + testSqlApi("POWER(5, 5)", "3125.0") + testSqlApi("ABS(-5)", "5") + testSqlApi("MOD(-26, 5)", "-1") + testSqlApi("SQRT(4)", "2.0") + testSqlApi("LN(1)", "0.0") + testSqlApi("LOG10(1)", "0.0") + testSqlApi("EXP(0)", "1.0") + testSqlApi("CEIL(2.5)", "3") + testSqlApi("CEILING(2.5)", "3") + testSqlApi("FLOOR(2.5)", "2") + testSqlApi("SIN(2.5)", "0.5984721441039564") + testSqlApi("SINH(2.5)", "6.0502044810397875") + testSqlApi("COS(2.5)", "-0.8011436155469337") + testSqlApi("TAN(2.5)", "-0.7470222972386603") + testSqlApi("COT(2.5)", "-1.3386481283041514") + testSqlApi("ASIN(0.5)", "0.5235987755982989") + testSqlApi("ACOS(0.5)", "1.0471975511965979") + testSqlApi("ATAN(0.5)", "0.4636476090008061") + testSqlApi("ATAN2(0.5, 0.5)", "0.7853981633974483") + testSqlApi("COSH(2.5)", "6.132289479663686") + testSqlApi("TANH(2.5)", "0.9866142981514303") + testSqlApi("DEGREES(0.5)", "28.64788975654116") + testSqlApi("RADIANS(0.5)", "0.008726646259971648") + testSqlApi("SIGN(-1.1)", "-1.0") // calcite: SIGN(Decimal(p,s)) => Decimal(p,s) + testSqlApi("ROUND(-12.345, 2)", "-12.35") + testSqlApi("PI()", "3.141592653589793") + testSqlApi("E()", "2.718281828459045") + } + + @Test + def testDivideFunctions(): Unit = { + + //slash + + // Decimal(2,1) / Decimal(2,1) => Decimal(8,6) + testSqlApi("1.0/8.0", "0.125000") + testSqlApi("2.0/3.0", "0.666667") + + // Integer => Decimal(10, 0) + // Decimal(10,0) / Decimal(2,1) => Decimal(17,6) + testSqlApi("-2/3.0", "-0.666667") + + // Decimal(2,1) / Decimal(10,0) => Decimal(23,12) + testSqlApi("2.0/(-3)", "-0.666666666667") + testSqlApi("-7.9/2", "-3.950000000000") + + //div function + testSqlApi("div(7, 2)", "3") + testSqlApi("div(7.9, 2.009)", "3") + testSqlApi("div(7, -2.009)", "-3") + testSqlApi("div(-7.9, 2)", "-3") + } + + @Test + def testStringFunctions(): Unit = { + testSqlApi("'test' || 'string'", "teststring") + testSqlApi("CHAR_LENGTH('string')", "6") + testSqlApi("CHARACTER_LENGTH('string')", "6") + testSqlApi("UPPER('string')", "STRING") + testSqlApi("LOWER('STRING')", "string") + testSqlApi("POSITION('STR' IN 'STRING')", "1") + testSqlApi("TRIM(BOTH ' STRING ')", "STRING") + testSqlApi("TRIM(LEADING 'x' FROM 'xxxxSTRINGxxxx')", "STRINGxxxx") + testSqlApi("TRIM(TRAILING 'x' FROM 'xxxxSTRINGxxxx')", "xxxxSTRING") + testSqlApi( + "OVERLAY('This is a old string' PLACING 'new' FROM 11 FOR 3)", + "This is a new string") + testSqlApi("SUBSTRING('hello world', 2)", "ello world") + testSqlApi("SUBSTRING('hello world', 2, 3)", "ell") + testSqlApi("SUBSTRING('hello world', 2, 300)", "ello world") + testSqlApi("SUBSTR('hello world', 2, 3)", "ell") + testSqlApi("SUBSTR('hello world', 2)", "ello world") + testSqlApi("SUBSTR('hello world', 2, 300)", "ello world") + testSqlApi("SUBSTR('hello world', 0, 3)", "hel") + testSqlApi("INITCAP('hello world')", "Hello World") + testSqlApi("REGEXP_REPLACE('foobar', 'oo|ar', '')", "fb") + testSqlApi("REGEXP_EXTRACT('foothebar', 'foo(.*?)(bar)', 2)", "bar") + testSqlApi( + "REPEAT('This is a test String.', 2)", + "This is a test String.This is a test String.") + testSqlApi("REPLACE('hello world', 'world', 'flink')", "hello flink") + } + + @Test + def testConditionalFunctions(): Unit = { + testSqlApi("CASE 2 WHEN 1, 2 THEN 2 ELSE 3 END", "2") + testSqlApi("CASE WHEN 1 = 2 THEN 2 WHEN 1 = 1 THEN 3 ELSE 3 END", "3") + testSqlApi("NULLIF(1, 1)", "null") + testSqlApi("COALESCE(NULL, 5)", "5") + testSqlApi("COALESCE(keyvalue('', ';', ':', 'isB2C'), '5')", "5") + testSqlApi("COALESCE(json_value('xx', '$x'), '5')", "5") + } + + @Test + def testTypeConversionFunctions(): Unit = { + testSqlApi("CAST(2 AS DOUBLE)", "2.0") + } + + @Test + def testValueConstructorFunctions(): Unit = { + testSqlApi("ROW('hello world', 12)", "hello world,12") + testSqlApi("('hello world', 12)", "hello world,12") + testSqlApi("('foo', ('bar', 12))", "foo,bar,12") + testSqlApi("ARRAY[TRUE, FALSE][2]", "false") + testSqlApi("ARRAY[TRUE, TRUE]", "[true, true]") + testSqlApi("MAP['k1', 'v1', 'k2', 'v2']['k2']", "v2") + testSqlApi("MAP['k1', CAST(true AS VARCHAR(256)), 'k2', 'foo']['k1']", "true") + } + + @Test + def testDateTimeFunctions(): Unit = { + testSqlApi("DATE '1990-10-14'", "1990-10-14") + testSqlApi("TIME '12:12:12'", "12:12:12") + testSqlApi("TIMESTAMP '1990-10-14 12:12:12.123'", "1990-10-14 12:12:12.123") + testSqlApi("INTERVAL '10 00:00:00.004' DAY TO SECOND", "+10 00:00:00.004") + testSqlApi("INTERVAL '10 00:12' DAY TO MINUTE", "+10 00:12:00.000") + testSqlApi("INTERVAL '2-10' YEAR TO MONTH", "+2-10") + testSqlApi("EXTRACT(DAY FROM DATE '1990-12-01')", "1") + testSqlApi("EXTRACT(DAY FROM INTERVAL '19 12:10:10.123' DAY TO SECOND(3))", "19") + testSqlApi("FLOOR(TIME '12:44:31' TO MINUTE)", "12:44:00") + testSqlApi("CEIL(TIME '12:44:31' TO MINUTE)", "12:45:00") + testSqlApi("QUARTER(DATE '2016-04-12')", "2") + testSqlApi( + "(TIME '2:55:00', INTERVAL '1' HOUR) OVERLAPS (TIME '3:30:00', INTERVAL '2' HOUR)", + "true") + } + + @Test + def testArrayFunctions(): Unit = { + testSqlApi("CARDINALITY(ARRAY[TRUE, TRUE, FALSE])", "3") + testSqlApi("ELEMENT(ARRAY['HELLO WORLD'])", "HELLO WORLD") + } + + @Test + def testHashFunctions(): Unit = { + testSqlApi("MD5('')", "d41d8cd98f00b204e9800998ecf8427e") + testSqlApi("MD5('test')", "098f6bcd4621d373cade4e832627b4f6") + + testSqlApi("SHA1('')", "da39a3ee5e6b4b0d3255bfef95601890afd80709") + testSqlApi("SHA1('test')", "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3") + + testSqlApi("SHA224('')", "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f") + testSqlApi("SHA2('', 224)", "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f") + + testSqlApi("SHA224('test')", "90a3ed9e32b2aaf4c61c410eb925426119e1a9dc53d4286ade99a809") + testSqlApi("SHA2('test', 224)", "90a3ed9e32b2aaf4c61c410eb925426119e1a9dc53d4286ade99a809") + + testSqlApi("SHA256('')", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + testSqlApi("SHA2('', 256)", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + + testSqlApi("SHA256('test')", "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08") + testSqlApi("SHA2('test', 256)", + "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08") + + testSqlApi("SHA384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc" + + "7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b") + testSqlApi("SHA2('', 384)", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0" + + "cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b") + + testSqlApi("SHA384('test')", "768412320f7b0aa5812fce428dc4706b3cae50e02a64caa16a782249bfe8efc" + + "4b7ef1ccb126255d196047dfedf17a0a9") + testSqlApi("SHA2('test', 384)", "768412320f7b0aa5812fce428dc4706b3cae50e02a64caa16a782249bfe8" + + "efc4b7ef1ccb126255d196047dfedf17a0a9") + + testSqlApi("SHA512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d" + + "0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") + testSqlApi("SHA2('',512)", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce4" + + "7d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") + + testSqlApi("SHA512('test')", "ee26b0dd4af7e749aa1a8ee3c10ae9923f618980772e473f8819a5d4940e0db" + + "27ac185f8a0e1d5f84f88bc887fd67b143732c304cc5fa9ad8e6f57f50028a8ff") + testSqlApi("SHA2('test',512)", "ee26b0dd4af7e749aa1a8ee3c10ae9923f618980772e473f8819a5d4940e0" + + "db27ac185f8a0e1d5f84f88bc887fd67b143732c304cc5fa9ad8e6f57f50028a8ff") + + testSqlApi("MD5(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA1(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA224(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA256(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA384(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA512(CAST(NULL AS VARCHAR))", "null") + testSqlApi("SHA2(CAST(NULL AS VARCHAR), 256)", "null") + } + + @Test + def testNullableCases(): Unit = { + testSqlApi( + "BITAND(cast(NUll as bigInt), cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "BITNOT(cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "BITOR(cast(NUll as bigInt), cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "BITXOR(cast(NUll as bigInt), cast(NUll as bigInt))", + nullable + ) + + testSqlApi( + "TO_BASE64(FROM_BASE64(cast(NUll as varchar)))", + nullable + ) + + testSqlApi( + "FROM_BASE64(cast(NUll as varchar))", + nullable + ) + } + + override def testData: Row = new Row(0) + + override def typeInfo: RowTypeInfo = new RowTypeInfo() +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ArrayTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ArrayTypeTestBase.scala new file mode 100644 index 0000000000000..eb41325490eef --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ArrayTypeTestBase.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, PrimitiveArrayTypeInfo, Types} +import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.table.util.DateTimeTestUtil._ +import org.apache.flink.types.Row + +abstract class ArrayTypeTestBase extends ExpressionTestBase { + + case class MyCaseClass(string: String, int: Int) + + override def testData: Row = { + val testData = new Row(12) + testData.setField(0, null) + testData.setField(1, 42) + testData.setField(2, Array(1, 2, 3)) + testData.setField(3, Array(UTCDate("1984-03-12"), UTCDate("1984-02-10"))) + testData.setField(4, null) + testData.setField(5, Array(Array(1, 2, 3), null)) + testData.setField(6, Array[Integer](1, null, null, 4)) + testData.setField(7, Array(1, 2, 3, 4)) + testData.setField(8, Array(4.0)) + testData.setField(9, Array[Integer](1)) + testData.setField(10, Array[Integer]()) + testData.setField(11, Array[Integer](1)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.INT, + Types.INT, + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + ObjectArrayTypeInfo.getInfoFor(Types.SQL_DATE), + ObjectArrayTypeInfo.getInfoFor(ObjectArrayTypeInfo.getInfoFor(Types.INT)), + ObjectArrayTypeInfo.getInfoFor(PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO), + ObjectArrayTypeInfo.getInfoFor(Types.INT), + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, + ObjectArrayTypeInfo.getInfoFor(Types.INT), + ObjectArrayTypeInfo.getInfoFor(Types.INT), + BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala new file mode 100644 index 0000000000000..cc90dc0ce1c77 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/CompositeTypeTestBase.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo, TypeExtractor} +import org.apache.flink.api.scala.createTypeInformation +import org.apache.flink.table.expressions.utils.CompositeTypeTestBase.{MyCaseClass, MyCaseClass2, MyCaseClass3, MyPojo} +import org.apache.flink.types.Row + +class CompositeTypeTestBase extends ExpressionTestBase { + + override def testData: Row = { + val testData = new Row(14) + testData.setField(0, MyCaseClass(42, "Bob", booleanField = true)) + testData.setField(1, MyCaseClass2(MyCaseClass(25, "Timo", booleanField = false))) + testData.setField(2, ("a", "b")) + testData.setField(3, new org.apache.flink.api.java.tuple.Tuple2[String, String]("a", "b")) + testData.setField(4, new MyPojo()) + testData.setField(5, 13) + testData.setField(6, MyCaseClass2(null)) + testData.setField(7, Tuple1(true)) + testData.setField(8, Array(Tuple2(true, 23), Tuple2(false, 12))) + testData.setField(9, Array(Tuple1(true), null)) + testData.setField(10, Array(MyCaseClass(42, "Bob", booleanField = true))) + testData.setField(11, Array(new MyPojo(), null)) + testData.setField(12, Array(MyCaseClass3(Array(MyCaseClass(42, "Alice", booleanField = true))))) + testData.setField(13, Array(MyCaseClass2(MyCaseClass(42, "Bob", booleanField = true)))) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + createTypeInformation[MyCaseClass], + createTypeInformation[MyCaseClass2], + createTypeInformation[(String, String)], + new TupleTypeInfo(Types.STRING, Types.STRING), + TypeExtractor.createTypeInfo(classOf[MyPojo]), + Types.INT, + TypeExtractor.createTypeInfo(classOf[MyCaseClass2]), + createTypeInformation[Tuple1[Boolean]], + createTypeInformation[Array[Tuple2[Boolean, Int]]], + createTypeInformation[Array[Tuple1[Boolean]]], + createTypeInformation[Array[MyCaseClass]], + createTypeInformation[Array[MyPojo]], + createTypeInformation[Array[MyCaseClass3]], + createTypeInformation[Array[MyCaseClass2]] + ) + } +} + +object CompositeTypeTestBase { + case class MyCaseClass(intField: Int, stringField: String, booleanField: Boolean) + + case class MyCaseClass2(objectField: MyCaseClass) + + case class MyCaseClass3(arrayField: Array[MyCaseClass]) + + class MyPojo { + private var myInt: Int = 0 + private var myString: String = "Hello" + + def getMyInt: Int = myInt + + def setMyInt(value: Int): Unit = { + myInt = value + } + + def getMyString: String = myString + + def setMyString(value: String): Unit = { + myString = myString + } + } +} + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala new file mode 100644 index 0000000000000..3b02567720fd6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.TaskInfo +import org.apache.flink.api.common.functions.util.RuntimeUDFContext +import org.apache.flink.api.common.functions.{MapFunction, RichFunction, RichMapFunction} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.table.`type`.{InternalTypes, RowType, TypeConverters} +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.api.java.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkPlannerImpl +import org.apache.flink.table.codegen.{CodeGeneratorContext, ExprCodeGenerator, FunctionCodeGenerator} +import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, DataFormatConverters} +import org.apache.flink.types.Row +import org.junit.Assert.{assertEquals, fail} +import org.junit.{After, Before} +import org.apache.calcite.plan.Convention +import org.apache.calcite.plan.hep.{HepPlanner, HepProgramBuilder} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.logical.{LogicalCalc, LogicalTableScan} +import org.apache.calcite.rel.rules._ +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR +import org.apache.calcite.tools.Programs +import org.apache.flink.table.plan.optimize.program.{FlinkHepProgram, FlinkOptimizeContext} + +import java.util.Collections + +import scala.collection.mutable + +abstract class ExpressionTestBase { + + val config = new TableConfig() + + // (originalExpr, optimizedExpr, expectedResult) + private val testExprs = mutable.ArrayBuffer[(String, RexNode, String)]() + private val env = StreamExecutionEnvironment.createLocalEnvironment(4) + private val tEnv = StreamTableEnvironment.create(env, config) + private val relBuilder = tEnv.getRelBuilder + private val planner = new FlinkPlannerImpl( + tEnv.getFrameworkConfig, + tEnv.getPlanner, + tEnv.getTypeFactory) + + + // setup test utils + private val tableName = "testTable" + protected val nullable = "null" + protected val notNullable = "not null" + + @Before + def prepare(): Unit = { + val ds = env.fromCollection(Collections.emptyList[Row](), typeInfo) + tEnv.registerDataStream(tableName, ds) + + // prepare RelBuilder + relBuilder.scan(tableName) + + // reset test exprs + testExprs.clear() + } + + @After + def evaluateExprs(): Unit = { + val ctx = CodeGeneratorContext(config) + val inputType = TypeConverters.createInternalTypeFromTypeInfo(typeInfo) + val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false).bindInput(inputType) + + // cast expressions to String + val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._2, VARCHAR)) + + // generate code + val resultType = new RowType(Seq.fill(testExprs.size)(InternalTypes.STRING): _*) + + val exprs = stringTestExprs.map(exprGenerator.generateExpression) + val genExpr = exprGenerator.generateResultExpression(exprs, resultType, classOf[BinaryRow]) + + val bodyCode = + s""" + |${genExpr.code} + |return ${genExpr.resultTerm}; + """.stripMargin + + val genFunc = FunctionCodeGenerator.generateFunction[MapFunction[BaseRow, BinaryRow]]( + ctx, + "TestFunction", + classOf[MapFunction[BaseRow, BinaryRow]], + bodyCode, + resultType, + inputType) + + val mapper = genFunc.newInstance(getClass.getClassLoader) + + val isRichFunction = mapper.isInstanceOf[RichFunction] + + // call setRuntimeContext method and open method for RichFunction + if (isRichFunction) { + val richMapper = mapper.asInstanceOf[RichMapFunction[_, _]] + val t = new RuntimeUDFContext( + new TaskInfo("ExpressionTest", 1, 0, 1, 1), + null, + env.getConfig, + Collections.emptyMap(), + Collections.emptyMap(), + null) + richMapper.setRuntimeContext(t) + richMapper.open(new Configuration()) + } + + val converter = DataFormatConverters + .getConverterForTypeInfo(typeInfo) + .asInstanceOf[DataFormatConverters.DataFormatConverter[BaseRow, Row]] + val testRow = converter.toInternal(testData) + val result = mapper.map(testRow) + + // call close method for RichFunction + if (isRichFunction) { + mapper.asInstanceOf[RichMapFunction[_, _]].close() + } + + // compare + testExprs + .zipWithIndex + .foreach { + case ((originalExpr, optimizedExpr, expected), index) => + + // adapt string result + val actual = if(!result.asInstanceOf[BinaryRow].isNullAt(index)) { + result.asInstanceOf[BinaryRow].getString(index).toString + } else { + null + } + + assertEquals( + s"Wrong result for: [$originalExpr] optimized to: [$optimizedExpr]", + expected, + if (actual == null) "null" else actual) + } + + } + + private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = { + // create RelNode from SQL expression + val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName") + val validated = planner.validate(parsed) + val converted = planner.rel(validated).rel + + val builder = new HepProgramBuilder() + builder.addRuleInstance(ProjectToCalcRule.INSTANCE) + val hep = new HepPlanner(builder.build()) + hep.setRoot(converted) + val optimized = hep.findBestExp() + + // throw exception if plan contains more than a calc + if (!optimized.getInput(0).isInstanceOf[LogicalTableScan]) { + fail("Expression is converted into more than a Calc operation. Use a different test method.") + } + + testExprs += ((sqlExpr, extractRexNode(optimized), expected)) + } + + private def extractRexNode(node: RelNode): RexNode = { + val calcProgram = node + .asInstanceOf[LogicalCalc] + .getProgram + calcProgram.expandLocalRef(calcProgram.getProjectList.get(0)) + } + + def testSqlApi( + sqlExpr: String, + expected: String) + : Unit = { + addSqlTestExpr(sqlExpr, expected) + } + + def testData: Row + + def typeInfo: RowTypeInfo + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/MapTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/MapTypeTestBase.scala new file mode 100644 index 0000000000000..57a3e749a36d5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/MapTypeTestBase.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import com.google.common.collect.ImmutableMap +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, Types} +import org.apache.flink.api.java.typeutils.{MapTypeInfo, RowTypeInfo} +import org.apache.flink.types.Row + +import java.util.{HashMap => JHashMap} + +abstract class MapTypeTestBase extends ExpressionTestBase { + + override def testData: Row = { + val map1 = new JHashMap[String, Int]() + map1.put("a", 12) + map1.put("b", 13) + val map2 = new JHashMap[Int, String]() + map2.put(12, "a") + map2.put(13, "b") + val map3 = new JHashMap[Long, Int]() + map3.put(10L, 1) + map3.put(20L, 2) + val map4 = new JHashMap[Int, Array[Int]]() + map4.put(1, Array(10, 100)) + map4.put(2, Array(20, 200)) + val testData = new Row(12) + testData.setField(0, null) + testData.setField(1, new JHashMap[String, Int]()) + testData.setField(2, map1) + testData.setField(3, map2) + testData.setField(4, "foo") + testData.setField(5, 12) + testData.setField(6, Array(1.2, 1.3)) + testData.setField(7, ImmutableMap.of(12, "a", 13, "b")) + testData.setField(8, map3) + testData.setField(9, map3) + testData.setField(10, map4) + testData.setField(11, map4) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + new MapTypeInfo(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), + Types.STRING, + Types.INT, + Types.PRIMITIVE_ARRAY(Types.DOUBLE), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, Types.PRIMITIVE_ARRAY(Types.INT)), + new MapTypeInfo(BasicTypeInfo.INT_TYPE_INFO, Types.PRIMITIVE_ARRAY(Types.INT)) + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/RowTypeTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/RowTypeTestBase.scala new file mode 100644 index 0000000000000..ecee5ed1ab8b6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/RowTypeTestBase.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.{ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.typeutils.DecimalTypeInfo +import org.apache.flink.table.util.DateTimeTestUtil.UTCDate +import org.apache.flink.types.Row + +abstract class RowTypeTestBase extends ExpressionTestBase { + + override def testData: Row = { + val row = new Row(3) + row.setField(0, 2) + row.setField(1, "foo") + row.setField(2, true) + val nestedRow = new Row(2) + nestedRow.setField(0, 3) + nestedRow.setField(1, row) + val specialTypeRow = new Row(3) + specialTypeRow.setField(0, UTCDate("1984-03-12")) + specialTypeRow.setField(1, Decimal.castFrom("0.00000000", 9, 8)) + specialTypeRow.setField(2, Array[java.lang.Integer](1, 2, 3)) + val testData = new Row(7) + testData.setField(0, null) + testData.setField(1, 1) + testData.setField(2, row) + testData.setField(3, nestedRow) + testData.setField(4, specialTypeRow) + testData.setField(5, Row.of("foo", null)) + testData.setField(6, Row.of(null, null)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.STRING, + Types.INT, + Types.ROW(Types.INT, Types.STRING, Types.BOOLEAN), + Types.ROW(Types.INT, Types.ROW(Types.INT, Types.STRING, Types.BOOLEAN)), + Types.ROW( + Types.SQL_DATE, + DecimalTypeInfo.of(9, 8), + ObjectArrayTypeInfo.getInfoFor(Types.INT)), + Types.ROW(Types.STRING, Types.BOOLEAN), + Types.ROW(Types.STRING, Types.STRING) + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarOperatorsTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarOperatorsTestBase.scala new file mode 100644 index 0000000000000..503dadcdad1c3 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarOperatorsTestBase.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.Types +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.typeutils.DecimalTypeInfo +import org.apache.flink.table.util.DateTimeTestUtil._ +import org.apache.flink.types.Row + +abstract class ScalarOperatorsTestBase extends ExpressionTestBase { + + override def testData: Row = { + val testData = new Row(18) + testData.setField(0, 1: Byte) + testData.setField(1, 1: Short) + testData.setField(2, 1) + testData.setField(3, 1L) + testData.setField(4, 1.0f) + testData.setField(5, 1.0d) + testData.setField(6, true) + testData.setField(7, 0.0d) + testData.setField(8, 5) + testData.setField(9, 10) + testData.setField(10, "String") + testData.setField(11, false) + testData.setField(12, null) + testData.setField(13, Row.of("foo", null)) + testData.setField(14, null) + testData.setField(15, UTCDate("1996-11-10")) + testData.setField(16, Decimal.castFrom("0.00000000", 19, 8)) + testData.setField(17, Decimal.castFrom("10.0", 19, 1)) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.BYTE, + Types.SHORT, + Types.INT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.BOOLEAN, + Types.DOUBLE, + Types.INT, + Types.INT, + Types.STRING, + Types.BOOLEAN, + Types.BOOLEAN, + Types.ROW(Types.STRING, Types.STRING), + Types.STRING, + Types.SQL_DATE, + DecimalTypeInfo.of(19, 8), + DecimalTypeInfo.of(19, 1) + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala new file mode 100644 index 0000000000000..4c8a3d1fff7eb --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/utils/ScalarTypesTestBase.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.utils + +import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, Types} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.dataformat.Decimal +import org.apache.flink.table.typeutils.{DecimalTypeInfo, TimeIntervalTypeInfo} +import org.apache.flink.table.util.DateTimeTestUtil._ +import org.apache.flink.types.Row + +abstract class ScalarTypesTestBase extends ExpressionTestBase { + + override def testData: Row = { + val testData = new Row(46) + testData.setField(0, "This is a test String.") + testData.setField(1, true) + testData.setField(2, 42.toByte) + testData.setField(3, 43.toShort) + testData.setField(4, 44.toLong) + testData.setField(5, 4.5.toFloat) + testData.setField(6, 4.6) + testData.setField(7, 3) + testData.setField(8, " This is a test String. ") + testData.setField(9, -42.toByte) + testData.setField(10, -43.toShort) + testData.setField(11, -44.toLong) + testData.setField(12, -4.5.toFloat) + testData.setField(13, -4.6) + testData.setField(14, -3) + testData.setField(15, Decimal.castFrom("-1231.1231231321321321111", 38, 19)) + testData.setField(16, UTCDate("1996-11-10")) + testData.setField(17, UTCTime("06:55:44")) + testData.setField(18, UTCTimestamp("1996-11-10 06:55:44.333")) + testData.setField(19, 1467012213000L) // +16979 07:23:33.000 + testData.setField(20, 25) // +2-01 + testData.setField(21, null) + testData.setField(22, Decimal.castFrom("2", 38, 19)) + testData.setField(23, "%This is a test String.") + testData.setField(24, "*_This is a test String.") + testData.setField(25, 0.42.toByte) + testData.setField(26, 0.toShort) + testData.setField(27, 0.toLong) + testData.setField(28, 0.45.toFloat) + testData.setField(29, 0.46) + testData.setField(30, 1) + testData.setField(31, Decimal.castFrom("-0.1231231321321321111", 38, 0)) + testData.setField(32, -1) + testData.setField(33, null) + testData.setField(34, Decimal.castFrom("1514356320000", 38, 19)) + testData.setField(35, "a") + testData.setField(36, "b") + testData.setField(37, Array[Byte](1, 2, 3, 4)) + testData.setField(38, "AQIDBA==") + testData.setField(39, "1世3") + testData.setField(40, null) + testData.setField(41, null) + testData.setField(42, 256.toLong) + testData.setField(43, -1.toLong) + testData.setField(44, 256) + testData.setField(45, UTCTimestamp("1996-11-10 06:55:44.333").toString) + testData + } + + override def typeInfo: RowTypeInfo = { + new RowTypeInfo( + Types.STRING, + Types.BOOLEAN, + Types.BYTE, + Types.SHORT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.BYTE, + Types.SHORT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.INT, + DecimalTypeInfo.of(38, 19), + Types.SQL_DATE, + Types.SQL_TIME, + Types.SQL_TIMESTAMP, + TimeIntervalTypeInfo.INTERVAL_MILLIS, + TimeIntervalTypeInfo.INTERVAL_MONTHS, + Types.BOOLEAN, + DecimalTypeInfo.of(38, 19), + Types.STRING, + Types.STRING, + Types.BYTE, + Types.SHORT, + Types.LONG, + Types.FLOAT, + Types.DOUBLE, + Types.INT, + DecimalTypeInfo.of(38, 19), + Types.INT, + Types.STRING, + DecimalTypeInfo.of(19, 0), + Types.STRING, + Types.STRING, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + Types.STRING, + Types.STRING, + Types.STRING, + DecimalTypeInfo.of(38, 19), + Types.LONG, + Types.LONG, + Types.INT, + Types.STRING) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ArrayTypeValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ArrayTypeValidationTest.scala new file mode 100644 index 0000000000000..951952532c4fa --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ArrayTypeValidationTest.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.utils.ArrayTypeTestBase +import org.junit.Test + +class ArrayTypeValidationTest extends ArrayTypeTestBase { + + @Test(expected = classOf[ValidationException]) + def testImplicitTypeCastArraySql(): Unit = { + testSqlApi("ARRAY['string', 12]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testEmptyArraySql(): Unit = { + testSqlApi("ARRAY[]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testNullArraySql(): Unit = { + testSqlApi("ARRAY[NULL]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testDifferentTypesArraySql(): Unit = { + testSqlApi("ARRAY[1, TRUE]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testElementNonArraySql(): Unit = { + testSqlApi( + "ELEMENT(f0)", + "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testCardinalityOnNonArraySql(): Unit = { + testSqlApi("CARDINALITY(f0)", "FAIL") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/CompositeAccessValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/CompositeAccessValidationTest.scala new file mode 100644 index 0000000000000..cc0b171fa0e63 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/CompositeAccessValidationTest.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.utils.CompositeTypeTestBase +import org.junit.Test + +class CompositeAccessValidationTest extends CompositeTypeTestBase { + + @Test(expected = classOf[ValidationException]) + def testWrongSqlFieldFull(): Unit = { + testSqlApi("testTable.f5.test", "13") + } + + @Test(expected = classOf[ValidationException]) + def testWrongSqlField(): Unit = { + testSqlApi("f5.test", "13") + } +} + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/MapTypeValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/MapTypeValidationTest.scala new file mode 100644 index 0000000000000..7e960549b488e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/MapTypeValidationTest.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.utils.MapTypeTestBase +import org.junit.Test + +class MapTypeValidationTest extends MapTypeTestBase { + + @Test(expected = classOf[ValidationException]) + def testWrongKeyType(): Unit = { + testSqlApi("f2[12]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testUnsupportedComparisonType(): Unit = { + testSqlApi("f6 <> f2", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testEmptyMap(): Unit = { + testSqlApi("MAP[]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testUnsupportedMapImplicitTypeCastSql(): Unit = { + testSqlApi("MAP['k1', 'string', 'k2', 12]", "FAIL") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/RowTypeValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/RowTypeValidationTest.scala new file mode 100644 index 0000000000000..1a7123c09b5e6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/RowTypeValidationTest.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.flink.table.api.{SqlParserException, ValidationException} +import org.apache.flink.table.expressions.utils.RowTypeTestBase +import org.junit.Test + +class RowTypeValidationTest extends RowTypeTestBase { + + @Test(expected = classOf[SqlParserException]) + def testEmptyRowType(): Unit = { + testSqlApi("Row()", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testNullRowType(): Unit = { + testSqlApi("Row(NULL)", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testSqlRowIllegalAccess(): Unit = { + testSqlApi("f5.f2", "FAIL") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala new file mode 100644 index 0000000000000..d2a0997a3565d --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/expressions/validation/ScalarFunctionsValidationTest.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions.validation + +import org.apache.calcite.avatica.util.TimeUnit +import org.apache.flink.table.api.{SqlParserException, ValidationException} +import org.apache.flink.table.expressions.utils.ScalarTypesTestBase +import org.junit.{Ignore, Test} + +class ScalarFunctionsValidationTest extends ScalarTypesTestBase { + + // ---------------------------------------------------------------------------------------------- + // Math functions + // ---------------------------------------------------------------------------------------------- + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[IllegalArgumentException]) + def testInvalidLog1(): Unit = { + // invalid arithmetic argument + testSqlApi( + "LOG(1, 100)", + "FAIL" + ) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[IllegalArgumentException]) + def testInvalidLog2(): Unit ={ + // invalid arithmetic argument + testSqlApi( + "LOG(-1)", + "FAIL" + ) + } + + @Test(expected = classOf[ValidationException]) + def testInvalidBin1(): Unit = { + testSqlApi("BIN(f12)", "101010") // float type + } + + @Test(expected = classOf[ValidationException]) + def testInvalidBin2(): Unit = { + testSqlApi("BIN(f15)", "101010") // BigDecimal type + } + + @Test(expected = classOf[ValidationException]) + def testInvalidBin3(): Unit = { + testSqlApi("BIN(f16)", "101010") // Date type + } + + + // ---------------------------------------------------------------------------------------------- + // Temporal functions + // ---------------------------------------------------------------------------------------------- + + @Test(expected = classOf[SqlParserException]) + def testTimestampAddWithWrongTimestampInterval(): Unit ={ + testSqlApi("TIMESTAMPADD(XXX, 1, timestamp '2016-02-24'))", "2016-06-16") + } + + @Test(expected = classOf[SqlParserException]) + def testTimestampAddWithWrongTimestampFormat(): Unit ={ + testSqlApi("TIMESTAMPADD(YEAR, 1, timestamp '2016-02-24'))", "2016-06-16") + } + + @Test(expected = classOf[ValidationException]) + def testTimestampAddWithWrongQuantity(): Unit ={ + testSqlApi("TIMESTAMPADD(YEAR, 1.0, timestamp '2016-02-24 12:42:25')", "2016-06-16") + } + + // ---------------------------------------------------------------------------------------------- + // Sub-query functions + // ---------------------------------------------------------------------------------------------- + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testDOWWithTimeWhichIsUnsupported(): Unit = { + testSqlApi("EXTRACT(DOW FROM TIME '12:42:25')", "0") + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testDOYWithTimeWhichIsUnsupported(): Unit = { + testSqlApi("EXTRACT(DOY FROM TIME '12:42:25')", "0") + } + + private def testExtractFromTimeZeroResult(unit: TimeUnit): Unit = { + testSqlApi("EXTRACT(" + unit + " FROM TIME '00:00:00')", "0") + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testMillenniumWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.MILLENNIUM) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testCenturyWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.CENTURY) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testYearWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.YEAR) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testMonthWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.MONTH) + } + + @Ignore("TODO: FLINK-11898") + @Test(expected = classOf[ValidationException]) + def testDayWithTime(): Unit = { + testExtractFromTimeZeroResult(TimeUnit.DAY) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/DateTimeTestUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/DateTimeTestUtil.scala new file mode 100644 index 0000000000000..05b3f0ecff3ba --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/DateTimeTestUtil.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.util + +import org.apache.calcite.avatica.util.DateTimeUtils + +import java.sql.{Date, Time, Timestamp} + +object DateTimeTestUtil { + + def UTCDate(s: String): Date = { + new Date(DateTimeUtils.dateStringToUnixDate(s) * DateTimeUtils.MILLIS_PER_DAY) + } + + def UTCTime(s: String): Time = { + new Time(DateTimeUtils.timeStringToUnixDate(s).longValue()) + } + + def UTCTimestamp(s: String): Timestamp = { + new Timestamp(DateTimeUtils.timestampStringToUnixDate(s)) + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java index e3f371b2cb2c6..af71d56500175 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/DataFormatConverters.java @@ -46,6 +46,7 @@ import org.apache.flink.table.typeutils.DecimalTypeInfo; import org.apache.flink.types.Row; +import java.io.Serializable; import java.lang.reflect.Array; import java.math.BigDecimal; import java.sql.Date; @@ -111,6 +112,7 @@ public class DataFormatConverters { * lost its specific Java format. Only TypeInformation retains all its * Java format information. */ + @SuppressWarnings("unchecked") public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeInfo) { DataFormatConverter converter = TYPE_INFO_TO_CONVERTER.get(typeInfo); if (converter != null) { @@ -158,7 +160,9 @@ public static DataFormatConverter getConverterForTypeInfo(TypeInformation typeIn * @param Internal data format. * @param External data format. */ - public abstract static class DataFormatConverter { + public abstract static class DataFormatConverter implements Serializable { + + private static final long serialVersionUID = 1L; /** * Converts a external(Java) data format to its internal equivalent while automatically handling nulls. @@ -203,6 +207,8 @@ public final External toExternal(BaseRow row, int column) { */ public abstract static class IdentityConverter extends DataFormatConverter { + private static final long serialVersionUID = 6146619729108124872L; + @Override T toInternalImpl(T value) { return value; @@ -219,6 +225,8 @@ T toExternalImpl(T value) { */ public static class BooleanConverter extends IdentityConverter { + private static final long serialVersionUID = 3618373319753553272L; + public static final BooleanConverter INSTANCE = new BooleanConverter(); private BooleanConverter() {} @@ -234,6 +242,8 @@ Boolean toExternalImpl(BaseRow row, int column) { */ public static class ByteConverter extends IdentityConverter { + private static final long serialVersionUID = 1880134895918999433L; + public static final ByteConverter INSTANCE = new ByteConverter(); private ByteConverter() {} @@ -249,6 +259,8 @@ Byte toExternalImpl(BaseRow row, int column) { */ public static class ShortConverter extends IdentityConverter { + private static final long serialVersionUID = 8055034507232206636L; + public static final ShortConverter INSTANCE = new ShortConverter(); private ShortConverter() {} @@ -264,6 +276,8 @@ Short toExternalImpl(BaseRow row, int column) { */ public static class IntConverter extends IdentityConverter { + private static final long serialVersionUID = -7749307898273403416L; + public static final IntConverter INSTANCE = new IntConverter(); private IntConverter() {} @@ -279,6 +293,8 @@ Integer toExternalImpl(BaseRow row, int column) { */ public static class LongConverter extends IdentityConverter { + private static final long serialVersionUID = 7373868336730797650L; + public static final LongConverter INSTANCE = new LongConverter(); private LongConverter() {} @@ -294,6 +310,8 @@ Long toExternalImpl(BaseRow row, int column) { */ public static class FloatConverter extends IdentityConverter { + private static final long serialVersionUID = -1119035126939832966L; + public static final FloatConverter INSTANCE = new FloatConverter(); private FloatConverter() {} @@ -309,6 +327,8 @@ Float toExternalImpl(BaseRow row, int column) { */ public static class DoubleConverter extends IdentityConverter { + private static final long serialVersionUID = 2801171640313215040L; + public static final DoubleConverter INSTANCE = new DoubleConverter(); private DoubleConverter() {} @@ -324,6 +344,8 @@ Double toExternalImpl(BaseRow row, int column) { */ public static class CharConverter extends IdentityConverter { + private static final long serialVersionUID = -7631466361315237011L; + public static final CharConverter INSTANCE = new CharConverter(); private CharConverter() {} @@ -339,6 +361,8 @@ Character toExternalImpl(BaseRow row, int column) { */ public static class BinaryStringConverter extends IdentityConverter { + private static final long serialVersionUID = 5565684451615599206L; + public static final BinaryStringConverter INSTANCE = new BinaryStringConverter(); private BinaryStringConverter() {} @@ -354,6 +378,8 @@ BinaryString toExternalImpl(BaseRow row, int column) { */ public static class BinaryArrayConverter extends IdentityConverter { + private static final long serialVersionUID = -7790350668043604641L; + public static final BinaryArrayConverter INSTANCE = new BinaryArrayConverter(); private BinaryArrayConverter() {} @@ -369,6 +395,8 @@ BinaryArray toExternalImpl(BaseRow row, int column) { */ public static class BinaryMapConverter extends IdentityConverter { + private static final long serialVersionUID = -9114231688474126815L; + public static final BinaryMapConverter INSTANCE = new BinaryMapConverter(); private BinaryMapConverter() {} @@ -384,6 +412,8 @@ BinaryMap toExternalImpl(BaseRow row, int column) { */ public static class DecimalConverter extends IdentityConverter { + private static final long serialVersionUID = 3825744951173809617L; + private final int precision; private final int scale; @@ -403,6 +433,8 @@ Decimal toExternalImpl(BaseRow row, int column) { */ public static class BinaryGenericConverter extends IdentityConverter { + private static final long serialVersionUID = 1436229503920584273L; + public static final BinaryGenericConverter INSTANCE = new BinaryGenericConverter(); private BinaryGenericConverter() {} @@ -418,6 +450,8 @@ BinaryGeneric toExternalImpl(BaseRow row, int column) { */ public static class StringConverter extends DataFormatConverter { + private static final long serialVersionUID = 4713165079099282774L; + public static final StringConverter INSTANCE = new StringConverter(); private StringConverter() {} @@ -443,6 +477,8 @@ String toExternalImpl(BaseRow row, int column) { */ public static class BigDecimalConverter extends DataFormatConverter { + private static final long serialVersionUID = -6586239704060565834L; + private final int precision; private final int scale; @@ -472,6 +508,8 @@ BigDecimal toExternalImpl(BaseRow row, int column) { */ public static class GenericConverter extends DataFormatConverter, T> { + private static final long serialVersionUID = -3611718364918053384L; + private final TypeSerializer serializer; public GenericConverter(TypeSerializer serializer) { @@ -499,6 +537,8 @@ T toExternalImpl(BaseRow row, int column) { */ public static class DateConverter extends DataFormatConverter { + private static final long serialVersionUID = 1343457113582411650L; + public static final DateConverter INSTANCE = new DateConverter(); private DateConverter() {} @@ -524,6 +564,8 @@ Date toExternalImpl(BaseRow row, int column) { */ public static class TimeConverter extends DataFormatConverter { + private static final long serialVersionUID = -8061475784916442483L; + public static final TimeConverter INSTANCE = new TimeConverter(); private TimeConverter() {} @@ -549,6 +591,8 @@ Time toExternalImpl(BaseRow row, int column) { */ public static class TimestampConverter extends DataFormatConverter { + private static final long serialVersionUID = -779956524906131757L; + public static final TimestampConverter INSTANCE = new TimestampConverter(); private TimestampConverter() {} @@ -574,6 +618,8 @@ Timestamp toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveIntArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = 1780941126232395638L; + public static final PrimitiveIntArrayConverter INSTANCE = new PrimitiveIntArrayConverter(); private PrimitiveIntArrayConverter() {} @@ -599,6 +645,8 @@ int[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveBooleanArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -4037693692440282141L; + public static final PrimitiveBooleanArrayConverter INSTANCE = new PrimitiveBooleanArrayConverter(); private PrimitiveBooleanArrayConverter() {} @@ -624,6 +672,8 @@ boolean[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveByteArrayConverter extends IdentityConverter { + private static final long serialVersionUID = -2007960927801689921L; + public static final PrimitiveByteArrayConverter INSTANCE = new PrimitiveByteArrayConverter(); private PrimitiveByteArrayConverter() {} @@ -639,6 +689,8 @@ byte[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveShortArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -1343184089311186834L; + public static final PrimitiveShortArrayConverter INSTANCE = new PrimitiveShortArrayConverter(); private PrimitiveShortArrayConverter() {} @@ -664,6 +716,8 @@ short[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveLongArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = 4061982985342526078L; + public static final PrimitiveLongArrayConverter INSTANCE = new PrimitiveLongArrayConverter(); private PrimitiveLongArrayConverter() {} @@ -689,6 +743,8 @@ long[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveFloatArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -3237695040861141459L; + public static final PrimitiveFloatArrayConverter INSTANCE = new PrimitiveFloatArrayConverter(); private PrimitiveFloatArrayConverter() {} @@ -714,6 +770,8 @@ float[] toExternalImpl(BaseRow row, int column) { */ public static class PrimitiveDoubleArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = 6333670535356315691L; + public static final PrimitiveDoubleArrayConverter INSTANCE = new PrimitiveDoubleArrayConverter(); private PrimitiveDoubleArrayConverter() {} @@ -739,6 +797,8 @@ BinaryArray toInternalImpl(double[] value) { */ public static class PrimitiveCharArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -5438377988505771316L; + public static final PrimitiveCharArrayConverter INSTANCE = new PrimitiveCharArrayConverter(); private PrimitiveCharArrayConverter() {} @@ -764,6 +824,8 @@ char[] toExternalImpl(BaseRow row, int column) { */ public static class ObjectArrayConverter extends DataFormatConverter { + private static final long serialVersionUID = -7434682160639380078L; + private final Class arrayClass; private final InternalType elementType; private final DataFormatConverter elementConverter; @@ -825,6 +887,8 @@ private static T[] binaryArrayToJavaArray(BinaryArray value, InternalType e */ public static class MapConverter extends DataFormatConverter { + private static final long serialVersionUID = -916429669828309919L; + private final InternalType keyType; private final InternalType valueType; @@ -898,6 +962,8 @@ Map toExternalImpl(BaseRow row, int column) { */ public abstract static class AbstractBaseRowConverter extends DataFormatConverter { + private static final long serialVersionUID = 4365740929854771618L; + protected final DataFormatConverter[] converters; public AbstractBaseRowConverter(CompositeType t) { @@ -918,6 +984,8 @@ E toExternalImpl(BaseRow row, int column) { */ public static class BaseRowConverter extends IdentityConverter { + private static final long serialVersionUID = -4470307402371540680L; + public static final BaseRowConverter INSTANCE = new BaseRowConverter(); private BaseRowConverter() {} @@ -933,6 +1001,8 @@ BaseRow toExternalImpl(BaseRow row, int column) { */ public static class PojoConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = 6821541780176167135L; + private final PojoTypeInfo t; private final PojoField[] fields; @@ -979,6 +1049,8 @@ T toExternalImpl(BaseRow value) { */ public static class RowConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = -56553502075225785L; + private final RowTypeInfo t; public RowConverter(RowTypeInfo t) { @@ -1010,6 +1082,8 @@ Row toExternalImpl(BaseRow value) { */ public static class TupleConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = 2794892691010934194L; + private final TupleTypeInfo t; public TupleConverter(TupleTypeInfo t) { @@ -1046,6 +1120,8 @@ Tuple toExternalImpl(BaseRow value) { */ public static class CaseClassConverter extends AbstractBaseRowConverter { + private static final long serialVersionUID = -966598627968372952L; + private final TupleTypeInfoBase t; private final TupleSerializerBase serializer; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java index aeaa41a0e1306..f6aa4bbff4bf8 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/Decimal.java @@ -339,6 +339,30 @@ public static long castToIntegral(Decimal dec) { return bd.longValue(); } + public static long castToLong(Decimal dec) { + return castToIntegral(dec); + } + + public static int castToInt(Decimal dec) { + return (int) castToIntegral(dec); + } + + public static short castToShort(Decimal dec) { + return (short) castToIntegral(dec); + } + + public static byte castToByte(Decimal dec) { + return (byte) castToIntegral(dec); + } + + public static float castToFloat(Decimal dec) { + return (float) dec.doubleValue(); + } + + public static double castToDouble(Decimal dec) { + return dec.doubleValue(); + } + public static Decimal castToDecimal(Decimal dec, int precision, int scale) { return fromBigDecimal(dec.toBigDecimal(), precision, scale); } @@ -383,6 +407,10 @@ public static Decimal sign(Decimal b0) { } } + public static int compare(Decimal b1, Decimal b2){ + return b1.compareTo(b2); + } + public static int compare(Decimal b1, long n2) { if (!b1.isCompact()) { return b1.decimalVal.compareTo(BigDecimal.valueOf(n2)); @@ -400,6 +428,18 @@ public static int compare(Decimal b1, long n2) { } } + public static int compare(Decimal b1, double n2) { + return Double.compare(b1.doubleValue(), n2); + } + + public static int compare(long n1, Decimal b2) { + return -compare(b2, n1); + } + + public static int compare(double n1, Decimal b2) { + return -compare(b2, n1); + } + /** * SQL ROUND operator applied to BigDecimal values. */ diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java index ba7478f6d6d2f..b50e0f1a0c5ac 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/LazyBinaryFormat.java @@ -57,6 +57,7 @@ public LazyBinaryFormat(T javaObject) { } public T getJavaObject() { + // TODO: ensure deserialize ? return javaObject; } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java index c86f821b16823..d796706015cd2 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/DateTimeUtils.java @@ -17,7 +17,15 @@ package org.apache.flink.table.runtime.functions; +import org.apache.flink.api.java.tuple.Tuple2; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.sql.Timestamp; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; import java.util.TimeZone; /** @@ -27,7 +35,16 @@ */ public class DateTimeUtils { - public static final TimeZone LOCAL_TZ = TimeZone.getDefault(); + private static final Logger LOG = LoggerFactory.getLogger(DateTimeUtils.class); + + private static final TimeZone LOCAL_TZ = TimeZone.getDefault(); + + private static final String[] DEFAULT_DATETIME_FORMATS = new String[]{ + "yyyy-MM-dd HH:mm:ss", + "yyyy-MM-dd HH:mm:ss.S", + "yyyy-MM-dd HH:mm:ss.SS", + "yyyy-MM-dd HH:mm:ss.SSS" + }; /** * The number of milliseconds in a day. @@ -35,7 +52,21 @@ public class DateTimeUtils { *

This is the modulo 'mask' used when converting * TIMESTAMP values to DATE and TIME values. */ - public static final long MILLIS_PER_DAY = 86400000L; // = 24 * 60 * 60 * 1000 + private static final long MILLIS_PER_DAY = 86400000L; // = 24 * 60 * 60 * 1000 + + /** + * A ThreadLocal cache map for SimpleDateFormat, because SimpleDateFormat is not thread-safe. + * (format, timezone) => formatter + */ + private static final ThreadLocalCache, SimpleDateFormat> FORMATTER_CACHE = + new ThreadLocalCache, SimpleDateFormat>(64) { + @Override + public SimpleDateFormat getNewInstance(Tuple2 key) { + SimpleDateFormat sdf = new SimpleDateFormat(key.f0); + sdf.setTimeZone(key.f1); + return sdf; + } + }; /** Converts the Java type used for UDF parameters of SQL TIME type * ({@link java.sql.Time}) to internal representation (int). @@ -93,4 +124,99 @@ public static java.sql.Timestamp internalToTimestamp(long v) { return new java.sql.Timestamp(v - LOCAL_TZ.getOffset(v)); } + /** + * Parse date time string to timestamp based on the given time zone and + * "yyyy-MM-dd HH:mm:ss" format. Returns null if parsing failed. + * + * @param dateText the date time string + * @param tz the time zone + */ + public static Long strToTimestamp(String dateText, TimeZone tz) { + return strToTimestamp(dateText, DEFAULT_DATETIME_FORMATS[0], tz); + } + + /** + * Parse date time string to timestamp based on the given time zone and format. + * Returns null if parsing failed. + * + * @param dateText the date time string + * @param format date time string format + * @param tz the time zone + */ + public static Long strToTimestamp(String dateText, String format, TimeZone tz) { + SimpleDateFormat formatter = FORMATTER_CACHE.get(Tuple2.of(format, tz)); + try { + return formatter.parse(dateText).getTime(); + } catch (ParseException e) { + return null; + } + } + + /** + * Format a timestamp as specific. + * @param ts the timestamp to format. + * @param format the string formatter. + * @param tz the time zone + */ + public static String dateFormat(long ts, String format, TimeZone tz) { + SimpleDateFormat formatter = FORMATTER_CACHE.get(Tuple2.of(format, tz)); + Date dateTime = new Date(ts); + return formatter.format(dateTime); + } + + /** + * Convert a timestamp to string. + * @param ts the timestamp to convert. + * @param precision the milli second precision to preserve + * @param tz the time zone + */ + public static String timestampToString(long ts, int precision, TimeZone tz) { + int p = (precision <= 3 && precision >= 0) ? precision : 3; + String format = DEFAULT_DATETIME_FORMATS[p]; + return dateFormat(ts, format, tz); + } + + /** Helper for CAST({time} AS VARCHAR(n)). */ + public static String timeToString(int time) { + final StringBuilder buf = new StringBuilder(8); + timeToString(buf, time, 0); // set milli second precision to 0 + return buf.toString(); + } + + private static void timeToString(StringBuilder buf, int time, int precision) { + while (time < 0) { + time += MILLIS_PER_DAY; + } + int h = time / 3600000; + int time2 = time % 3600000; + int m = time2 / 60000; + int time3 = time2 % 60000; + int s = time3 / 1000; + int ms = time3 % 1000; + int2(buf, h); + buf.append(':'); + int2(buf, m); + buf.append(':'); + int2(buf, s); + if (precision > 0) { + buf.append('.'); + while (precision > 0) { + buf.append((char) ('0' + (ms / 100))); + ms = ms % 100; + ms = ms * 10; + + // keep consistent with Timestamp.toString() + if (ms == 0) { + break; + } + + --precision; + } + } + } + + private static void int2(StringBuilder buf, int i) { + buf.append((char) ('0' + (i / 10) % 10)); + buf.append((char) ('0' + i % 10)); + } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ThreadLocalCache.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ThreadLocalCache.java new file mode 100644 index 0000000000000..7425c0270b2a4 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/functions/ThreadLocalCache.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Provides a ThreadLocal cache with a maximum cache size per thread. + * Values must not be null. + */ +public abstract class ThreadLocalCache { + + private final ThreadLocal> cache = new ThreadLocal<>(); + private final int maxSizePerThread; + + protected ThreadLocalCache(int maxSizePerThread) { + this.maxSizePerThread = maxSizePerThread; + } + + public V get(K key) { + BoundedMap map = cache.get(); + if (map == null) { + map = new BoundedMap<>(maxSizePerThread); + cache.set(map); + } + V value = map.get(key); + if (value == null) { + value = getNewInstance(key); + map.put(key, value); + } + return value; + } + + public abstract V getNewInstance(K key); + + private static class BoundedMap extends LinkedHashMap { + + private static final long serialVersionUID = -211630219014422361L; + + private final int maxSize; + + private BoundedMap(int maxSize) { + this.maxSize = maxSize; + } + + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return this.size() > maxSize; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java new file mode 100644 index 0000000000000..bc84943ff1cb4 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Byte hash set. + */ +public class ByteHashSet extends OptimizableHashSet { + + private byte[] key; + + private byte min = Byte.MAX_VALUE; + private byte max = Byte.MIN_VALUE; + + public ByteHashSet(final int expected, final float f) { + super(expected, f); + this.key = new byte[this.n + 1]; + } + + public ByteHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public ByteHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final byte k) { + if (k == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + byte[] key = this.key; + int pos; + byte curr; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) != 0) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final byte k) { + if (isDense) { + return k >= min && k <= max && used[k - min]; + } else { + if (k == 0) { + return this.containsZero; + } else { + byte[] key = this.key; + byte curr; + int pos; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) == 0) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + byte[] key = this.key; + int mask = newN - 1; + byte[] newKey = new byte[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while(key[i] == 0); + + if (newKey[pos = MurmurHashUtil.fmix(key[i]) & mask] != 0) { + while (newKey[pos = pos + 1 & mask] != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + int range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[max - min + 1]; + for (byte v : key) { + if (v != 0) { + used[v - min] = true; + } + } + if (containsZero) { + used[-min] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/DoubleHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/DoubleHashSet.java new file mode 100644 index 0000000000000..0f1c056dd0649 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/DoubleHashSet.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Double hash set. + */ +public class DoubleHashSet extends OptimizableHashSet { + + private double[] key; + + public DoubleHashSet(final int expected, final float f) { + super(expected, f); + this.key = new double[this.n + 1]; + } + + public DoubleHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public DoubleHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + /** + * See {@link Double#equals(Object)}. + */ + public boolean add(final double k) { + long longKey = Double.doubleToLongBits(k); + if (longKey == 0L) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + double[] key = this.key; + int pos; + long curr; + if ((curr = Double.doubleToLongBits(key[pos = (int) MurmurHashUtil.fmix(longKey) & this.mask])) != 0L) { + if (curr == longKey) { + return false; + } + + while ((curr = Double.doubleToLongBits(key[pos = pos + 1 & this.mask])) != 0L) { + if (curr == longKey) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + return true; + } + + /** + * See {@link Double#equals(Object)}. + */ + public boolean contains(final double k) { + long longKey = Double.doubleToLongBits(k); + if (longKey == 0L) { + return this.containsZero; + } else { + double[] key = this.key; + long curr; + int pos; + if ((curr = Double.doubleToLongBits(key[pos = (int) MurmurHashUtil.fmix(longKey) & this.mask])) == 0L) { + return false; + } else if (longKey == curr) { + return true; + } else { + while ((curr = Double.doubleToLongBits(key[pos = pos + 1 & this.mask])) != 0L) { + if (longKey == curr) { + return true; + } + } + + return false; + } + } + } + + private void rehash(final int newN) { + double[] key = this.key; + int mask = newN - 1; + double[] newKey = new double[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (Double.doubleToLongBits(key[i]) == 0L); + + if (Double.doubleToLongBits(newKey[pos = + (int) MurmurHashUtil.fmix(Double.doubleToLongBits(key[i])) & mask]) != 0L) { + while (Double.doubleToLongBits(newKey[pos = pos + 1 & mask]) != 0L) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/FloatHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/FloatHashSet.java new file mode 100644 index 0000000000000..8d7e26e4c6aed --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/FloatHashSet.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Float hash set. + */ +public class FloatHashSet extends OptimizableHashSet { + + private float[] key; + + public FloatHashSet(final int expected, final float f) { + super(expected, f); + this.key = new float[this.n + 1]; + } + + public FloatHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public FloatHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + /** + * See {@link Float#equals(Object)}. + */ + public boolean add(final float k) { + int intKey = Float.floatToIntBits(k); + if (intKey == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + float[] key = this.key; + int pos; + int curr; + if ((curr = Float.floatToIntBits(key[pos = MurmurHashUtil.fmix(intKey) & this.mask])) != 0) { + if (curr == intKey) { + return false; + } + + while ((curr = Float.floatToIntBits(key[pos = pos + 1 & this.mask])) != 0) { + if (curr == intKey) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + return true; + } + + /** + * See {@link Float#equals(Object)}. + */ + public boolean contains(final float k) { + int intKey = Float.floatToIntBits(k); + if (intKey == 0) { + return this.containsZero; + } else { + float[] key = this.key; + int curr; + int pos; + if ((curr = Float.floatToIntBits(key[pos = MurmurHashUtil.fmix(intKey) & this.mask])) == 0) { + return false; + } else if (intKey == curr) { + return true; + } else { + while ((curr = Float.floatToIntBits(key[pos = pos + 1 & this.mask])) != 0) { + if (intKey == curr) { + return true; + } + } + + return false; + } + } + } + + private void rehash(final int newN) { + float[] key = this.key; + int mask = newN - 1; + float[] newKey = new float[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (Float.floatToIntBits(key[i]) == 0); + + if (Float.floatToIntBits(newKey[pos = + MurmurHashUtil.fmix(Float.floatToIntBits(key[i])) & mask]) != 0) { + while (Float.floatToIntBits(newKey[pos = pos + 1 & mask]) != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/IntHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/IntHashSet.java new file mode 100644 index 0000000000000..1159917cdba5b --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/IntHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Int hash set. + */ +public class IntHashSet extends OptimizableHashSet { + + private int[] key; + + private int min = Integer.MAX_VALUE; + private int max = Integer.MIN_VALUE; + + public IntHashSet(final int expected, final float f) { + super(expected, f); + this.key = new int[this.n + 1]; + } + + public IntHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public IntHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final int k) { + if (k == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + int[] key = this.key; + int pos; + int curr; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) != 0) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final int k) { + if (isDense) { + return k >= min && k <= max && used[k - min]; + } else { + if (k == 0) { + return this.containsZero; + } else { + int[] key = this.key; + int curr; + int pos; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) == 0) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + int[] key = this.key; + int mask = newN - 1; + int[] newKey = new int[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (key[i] == 0); + + if (newKey[pos = MurmurHashUtil.fmix(key[i]) & mask] != 0) { + while (newKey[pos = pos + 1 & mask] != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + int range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[max - min + 1]; + for (int v : key) { + if (v != 0) { + used[v - min] = true; + } + } + if (containsZero) { + used[-min] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/LongHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/LongHashSet.java new file mode 100644 index 0000000000000..37a8cd0244963 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/LongHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Long hash set. + */ +public class LongHashSet extends OptimizableHashSet { + + private long[] key; + + private long min = Long.MAX_VALUE; + private long max = Long.MIN_VALUE; + + public LongHashSet(final int expected, final float f) { + super(expected, f); + this.key = new long[this.n + 1]; + } + + public LongHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public LongHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final long k) { + if (k == 0L) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + long[] key = this.key; + int pos; + long curr; + if ((curr = key[pos = (int) MurmurHashUtil.fmix(k) & this.mask]) != 0L) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0L) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final long k) { + if (isDense) { + return k >= min && k <= max && used[(int) (k - min)]; + } else { + if (k == 0L) { + return this.containsZero; + } else { + long[] key = this.key; + long curr; + int pos; + if ((curr = key[pos = (int) MurmurHashUtil.fmix(k) & this.mask]) == 0L) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0L) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + long[] key = this.key; + int mask = newN - 1; + long[] newKey = new long[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (key[i] == 0L); + + if (newKey[pos = (int) MurmurHashUtil.fmix(key[i]) & mask] != 0L) { + while (newKey[pos = pos + 1 & mask] != 0L) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + long range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[(int) (max - min + 1)]; + for (long v : key) { + if (v != 0) { + used[(int) (v - min)] = true; + } + } + if (containsZero) { + used[(int) (-min)] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ObjectHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ObjectHashSet.java new file mode 100644 index 0000000000000..2a85b22cf0cf3 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ObjectHashSet.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import java.util.HashSet; + +/** + * Wrap {@link HashSet} with hashSet interface. + */ +public class ObjectHashSet extends OptimizableHashSet { + + private HashSet set; + + public ObjectHashSet(final int expected, final float f) { + super(expected, f); + this.set = new HashSet<>(expected, f); + } + + public ObjectHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public ObjectHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(T t) { + return set.add(t); + } + + public boolean contains(final T t) { + return set.contains(t); + } + + @Override + public void optimize() { + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/OptimizableHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/OptimizableHashSet.java new file mode 100644 index 0000000000000..66b55232e9527 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/OptimizableHashSet.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * A type-specific hash set with with a fast, small-footprint implementation. + * Refer to the implementation of fastutil. + * + *

Instances of this class use a hash table to represent a set. The table is + * enlarged as needed by doubling its size when new entries are created. + * + *

The difference with fastutil is that if the range of the maximum and minimum values is + * small or the data is dense, a Dense array will be used to greatly improve the access speed. + */ +public abstract class OptimizableHashSet { + + /** The initial default size of a hash table. */ + public static final int DEFAULT_INITIAL_SIZE = 16; + + /** The default load factor of a hash table. */ + public static final float DEFAULT_LOAD_FACTOR = 0.75f; + + /** + * Decide whether to convert to dense mode if it does not require more memory or + * could fit within L1 cache. + */ + public static final int DENSE_THRESHOLD = 8192; + + /** The acceptable load factor. */ + protected final float f; + + /** The mask for wrapping a position counter. */ + protected int mask; + + /** The current table size. */ + protected int n; + + /** Threshold after which we rehash. It must be the table size times {@link #f}. */ + protected int maxFill; + + /** Is this set has a null key. */ + protected boolean containsNull; + + /** Is this set has a zero key. */ + protected boolean containsZero; + + /** Number of entries in the set. */ + protected int size; + + /** Is now dense mode. */ + protected boolean isDense = false; + + /** Used array for dense mode. */ + protected boolean[] used; + + public OptimizableHashSet(final int expected, final float f) { + checkArgument(f > 0 && f <= 1); + checkArgument(expected >= 0); + this.f = f; + this.n = OptimizableHashSet.arraySize(expected, f); + this.mask = this.n - 1; + this.maxFill = OptimizableHashSet.maxFill(this.n, f); + } + + /** + * Add a null key. + */ + public void addNull() { + this.containsNull = true; + } + + /** + * Is there a null key. + */ + public boolean containsNull() { + return containsNull; + } + + protected int realSize() { + return this.containsZero ? this.size - 1 : this.size; + } + + /** + * Decide whether to convert to dense mode. + */ + public abstract void optimize(); + + /** + * Return the least power of two greater than or equal to the specified value. + * + *

Note that this function will return 1 when the argument is 0. + * + * @param x a long integer smaller than or equal to 262. + * @return the least power of two greater than or equal to the specified value. + */ + public static long nextPowerOfTwo(long x) { + if (x == 0L) { + return 1L; + } else { + --x; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + return (x | x >> 32) + 1L; + } + } + + /** + * Returns the maximum number of entries that can be filled before rehashing. + * + * @param n the size of the backing array. + * @param f the load factor. + * @return the maximum number of entries before rehashing. + */ + public static int maxFill(int n, float f) { + return Math.min((int) Math.ceil((double) ((float) n * f)), n - 1); + } + + /** + * Returns the least power of two smaller than or equal to 230 and larger than + * or equal to Math.ceil( expected / f ). + * + * @param expected the expected number of elements in a hash table. + * @param f the load factor. + * @return the minimum possible size for a backing array. + * @throws IllegalArgumentException if the necessary size is larger than 230. + */ + public static int arraySize(int expected, float f) { + long s = Math.max(2L, nextPowerOfTwo((long) Math.ceil((double) ((float) expected / f)))); + if (s > (Integer.MAX_VALUE / 2 + 1)) { + throw new IllegalArgumentException( + "Too large (" + expected + " expected elements with load factor " + f + ")"); + } else { + return (int) s; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ShortHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ShortHashSet.java new file mode 100644 index 0000000000000..54e42f9c22340 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ShortHashSet.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.util.collections; + +import org.apache.flink.table.util.MurmurHashUtil; + +/** + * Short hash set. + */ +public class ShortHashSet extends OptimizableHashSet { + + private short[] key; + + private short min = Short.MAX_VALUE; + private short max = Short.MIN_VALUE; + + public ShortHashSet(final int expected, final float f) { + super(expected, f); + this.key = new short[this.n + 1]; + } + + public ShortHashSet(final int expected) { + this(expected, DEFAULT_LOAD_FACTOR); + } + + public ShortHashSet() { + this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + } + + public boolean add(final short k) { + if (k == 0) { + if (this.containsZero) { + return false; + } + + this.containsZero = true; + } else { + short[] key = this.key; + int pos; + short curr; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) != 0) { + if (curr == k) { + return false; + } + + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (curr == k) { + return false; + } + } + } + + key[pos] = k; + } + + if (this.size++ >= this.maxFill) { + this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); + } + + if (k < min) { + min = k; + } + if (k > max) { + max = k; + } + return true; + } + + public boolean contains(final short k) { + if (isDense) { + return k >= min && k <= max && used[k - min]; + } else { + if (k == 0) { + return this.containsZero; + } else { + short[] key = this.key; + short curr; + int pos; + if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) == 0) { + return false; + } else if (k == curr) { + return true; + } else { + while ((curr = key[pos = pos + 1 & this.mask]) != 0) { + if (k == curr) { + return true; + } + } + + return false; + } + } + } + } + + private void rehash(final int newN) { + short[] key = this.key; + int mask = newN - 1; + short[] newKey = new short[newN + 1]; + int i = this.n; + + int pos; + for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { + do { + --i; + } while (key[i] == 0); + + if (newKey[pos = MurmurHashUtil.fmix(key[i]) & mask] != 0) { + while (newKey[pos = pos + 1 & mask] != 0) {} + } + } + + this.n = newN; + this.mask = mask; + this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); + this.key = newKey; + } + + @Override + public void optimize() { + int range = max - min; + if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { + this.used = new boolean[max - min + 1]; + for (short v : key) { + if (v != 0) { + used[v - min] = true; + } + } + if (containsZero) { + used[-min] = true; + } + isDense = true; + key = null; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/InternalTypeUtils.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/InternalTypeUtils.java new file mode 100644 index 0000000000000..85ca05522b784 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/InternalTypeUtils.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.type; + +/** + * Utilities for {@link InternalType}. + */ +public class InternalTypeUtils { + + /** + * Gets the arity of the type. + */ + public static int getArity(InternalType t) { + if (t instanceof RowType) { + return ((RowType) t).getArity(); + } else { + return 1; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java index 380ec4493d889..db89521bed51d 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/RowType.java @@ -106,7 +106,8 @@ public boolean equals(Object o) { RowType that = (RowType) o; // RowType comparisons should not compare names and are compatible with the behavior of CompositeTypeInfo. - return Arrays.equals(getFieldTypes(), that.getFieldTypes()); + return Arrays.equals(getFieldTypes(), that.getFieldTypes()) && + Arrays.equals(getFieldNames(), that.getFieldNames()); } @Override diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java index ea90a77d77df9..69628dc86a7af 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/type/TypeConverters.java @@ -120,6 +120,9 @@ public static InternalType createInternalTypeFromTypeInfo(TypeInformation typeIn .toArray(InternalType[]::new), compositeType.getFieldNames() ); + } else if (typeInfo instanceof DecimalTypeInfo) { + DecimalTypeInfo decimalType = (DecimalTypeInfo) typeInfo; + return InternalTypes.createDecimalType(decimalType.precision(), decimalType.scale()); } else if (typeInfo instanceof PrimitiveArrayTypeInfo) { PrimitiveArrayTypeInfo arrayType = (PrimitiveArrayTypeInfo) typeInfo; return InternalTypes.createArrayType( diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java index b4e0b41a271fd..dbae3e87b3851 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/util/MurmurHashUtil.java @@ -148,7 +148,7 @@ private static int fmix(int h1, int length) { return fmix(h1); } - private static int fmix(int h) { + public static int fmix(int h) { h ^= h >>> 16; h *= 0x85ebca6b; h ^= h >>> 13; @@ -156,4 +156,13 @@ private static int fmix(int h) { h ^= h >>> 16; return h; } + + public static long fmix(long h) { + h ^= (h >>> 33); + h *= 0xff51afd7ed558ccdL; + h ^= (h >>> 33); + h *= 0xc4ceb9fe1a85ec53L; + h ^= (h >>> 33); + return h; + } }