Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Interval;
import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral;
Expand Down Expand Up @@ -3046,10 +3047,14 @@ public List<Expression> withInList(PredicateContext ctx) {

@Override
public Literal visitDecimalLiteral(DecimalLiteralContext ctx) {
if (Config.enable_decimal_conversion) {
return new DecimalV3Literal(new BigDecimal(ctx.getText()));
} else {
return new DecimalLiteral(new BigDecimal(ctx.getText()));
try {
if (Config.enable_decimal_conversion) {
return new DecimalV3Literal(new BigDecimal(ctx.getText()));
} else {
return new DecimalLiteral(new BigDecimal(ctx.getText()));
}
} catch (Exception e) {
return new DoubleLiteral(Double.parseDouble(ctx.getText()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ private Statistics computeRepeat(Repeat<? extends Plan> repeat) {
.setNumNulls(stats.numNulls < 0 ? stats.numNulls : stats.numNulls * groupingSetNum)
.setDataSize(stats.dataSize < 0 ? stats.dataSize : stats.dataSize * groupingSetNum);
return Pair.of(kv.getKey(), columnStatisticBuilder.build());
}).collect(Collectors.toMap(Pair::key, Pair::value));
}).collect(Collectors.toMap(Pair::key, Pair::value, (item1, item2) -> item1));
return new Statistics(rowCount < 0 ? rowCount : rowCount * groupingSetNum, 1, columnStatisticMap);
}

Expand All @@ -808,7 +808,7 @@ private Statistics computeOneRowRelation(List<NamedExpression> projects) {
// TODO: compute the literal size
return Pair.of(project.toSlot(), statistic);
})
.collect(Collectors.toMap(Pair::key, Pair::value));
.collect(Collectors.toMap(Pair::key, Pair::value, (item1, item2) -> item1));
int rowCount = 1;
return new Statistics(rowCount, 1, columnStatsMap);
}
Expand All @@ -823,7 +823,7 @@ private Statistics computeEmptyRelation(EmptyRelation emptyRelation) {
.setAvgSizeByte(0);
return Pair.of(project.toSlot(), columnStat.build());
})
.collect(Collectors.toMap(Pair::key, Pair::value));
.collect(Collectors.toMap(Pair::key, Pair::value, (item1, item2) -> item1));
int rowCount = 0;
return new Statistics(rowCount, 1, columnStatsMap);
}
Expand Down Expand Up @@ -998,7 +998,7 @@ private Statistics computeWindow(Window windowOperator) {
}
}
return Pair.of(expr.toSlot(), colStatsBuilder.build());
}).collect(Collectors.toMap(Pair::key, Pair::value));
}).collect(Collectors.toMap(Pair::key, Pair::value, (item1, item2) -> item1));
columnStatisticMap.putAll(childColumnStats);
return new Statistics(childStats.getRowCount(), 1, columnStatisticMap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,7 @@ private static FunctionSignature defaultDateTimeV2PrecisionPromotion(
if (finalType == null) {
finalType = dateTimeV2Type;
} else {
finalType = DateTimeV2Type.getWiderDatetimeV2Type(finalType,
DateTimeV2Type.forType(arguments.get(i).getDataType()));
finalType = DateTimeV2Type.getWiderDatetimeV2Type(finalType, dateTimeV2Type);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

import java.text.NumberFormat;

/**
* Double literal
*/
Expand All @@ -51,16 +49,4 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
public LiteralExpr toLegacyLiteral() {
return new FloatLiteral(value, Type.DOUBLE);
}

@Override
public String toString() {
NumberFormat nf = NumberFormat.getInstance();
nf.setGroupingUsed(false);
return nf.format(value);
}

@Override
public String getStringValue() {
return toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.FloatType;

import java.text.NumberFormat;

/**
* float type literal
*/
Expand All @@ -50,11 +48,4 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.FloatLiteral((double) value, Type.FLOAT);
}

@Override
public String getStringValue() {
NumberFormat nf = NumberFormat.getInstance();
nf.setGroupingUsed(false);
return nf.format(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.IntegralType;

import com.google.common.collect.ImmutableList;

Expand Down Expand Up @@ -234,6 +235,13 @@ protected Expression uncheckedCastTo(DataType targetType) throws AnalysisExcepti
return Literal.of(true);
}
}
if (targetType instanceof IntegralType) {
// do trailing zeros to avoid number parse error when cast to integral type
BigDecimal bigDecimal = new BigDecimal(desc);
if (bigDecimal.stripTrailingZeros().scale() <= 0) {
desc = bigDecimal.stripTrailingZeros().toPlainString();
}
}
if (targetType.isTinyIntType()) {
return Literal.of(Byte.valueOf(desc));
} else if (targetType.isSmallIntType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.planner.DataSink;
import org.apache.doris.planner.OlapTableSink;
import org.apache.doris.proto.InternalService;
Expand Down Expand Up @@ -553,20 +554,22 @@ public static Plan normalizePlan(Plan plan, TableIf table) {
throw new AnalysisException("Column count doesn't match value count");
}
for (int i = 0; i < values.size(); i++) {
if (values.get(i) instanceof DefaultValueSlot) {
boolean hasDefaultValue = false;
for (Column column : columns) {
if (unboundTableSink.getColNames().get(i).equalsIgnoreCase(column.getName())) {
constantExprs.add(generateDefaultExpression(column));
hasDefaultValue = true;
}
}
if (!hasDefaultValue) {
throw new AnalysisException("Unknown column '"
+ unboundTableSink.getColNames().get(i) + "' in target table.");
Column sameNameColumn = null;
for (Column column : table.getBaseSchema(true)) {
if (unboundTableSink.getColNames().get(i).equalsIgnoreCase(column.getName())) {
sameNameColumn = column;
break;
}
}
if (sameNameColumn == null) {
throw new AnalysisException("Unknown column '"
+ unboundTableSink.getColNames().get(i) + "' in target table.");
}
if (values.get(i) instanceof DefaultValueSlot) {
constantExprs.add(generateDefaultExpression(sameNameColumn));
} else {
constantExprs.add(values.get(i));
DataType targetType = DataType.fromCatalogType(sameNameColumn.getType());
constantExprs.add((NamedExpression) castValue(values.get(i), targetType));
}
}
} else {
Expand All @@ -577,7 +580,8 @@ public static Plan normalizePlan(Plan plan, TableIf table) {
if (values.get(i) instanceof DefaultValueSlot) {
constantExprs.add(generateDefaultExpression(columns.get(i)));
} else {
constantExprs.add(values.get(i));
DataType targetType = DataType.fromCatalogType(columns.get(i).getType());
constantExprs.add((NamedExpression) castValue(values.get(i), targetType));
}
}
}
Expand All @@ -595,6 +599,14 @@ public static Plan normalizePlan(Plan plan, TableIf table) {
}
}

private static Expression castValue(Expression value, DataType targetType) {
if (value instanceof UnboundAlias) {
return value.withChildren(TypeCoercionUtils.castUnbound(((UnboundAlias) value).child(), targetType));
} else {
return TypeCoercionUtils.castUnbound(value, targetType);
}
}

/**
* get target table from names.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,14 @@ public static Expression castIfNotSameType(Expression input, DataType targetType
}
}

public static Expression castUnbound(Expression expression, DataType targetType) {
if (expression instanceof Literal) {
return TypeCoercionUtils.castIfNotSameType(expression, targetType);
} else {
return TypeCoercionUtils.unSafeCast(expression, targetType);
}
}

/**
* like castIfNotSameType does, but varchar or char type would be cast to target length exactly
*/
Expand Down Expand Up @@ -467,11 +475,6 @@ private static Expression unSafeCast(Expression input, DataType dataType) {
return promoted;
}
}
// adapt scale when from string to datetimev2 with float
if (type.isStringLikeType() && dataType.isDateTimeV2Type()) {
return recordTypeCoercionForSubQuery(input,
DateTimeV2Type.forTypeFromString(((Literal) input).getStringValue()));
}
}
return recordTypeCoercionForSubQuery(input, dataType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ void testSimplifyArithmetic() {
assertRewriteAfterTypeCoercion("IA", "IA");
assertRewriteAfterTypeCoercion("IA + 1", "IA + 1");
assertRewriteAfterTypeCoercion("IA + IB", "IA + IB");
assertRewriteAfterTypeCoercion("1 * 3 / IA", "(3 / cast(IA as DOUBLE))");
assertRewriteAfterTypeCoercion("1 * 3 / IA", "(3.0 / cast(IA as DOUBLE))");
assertRewriteAfterTypeCoercion("1 - IA", "1 - IA");
assertRewriteAfterTypeCoercion("1 + 1", "2");
assertRewriteAfterTypeCoercion("IA + 2 - 1", "IA + 1");
assertRewriteAfterTypeCoercion("IA + 2 - (1 - 1)", "IA + 2");
assertRewriteAfterTypeCoercion("IA + 2 - ((1 - IB) - (3 + IC))", "IA + IB + IC + 4");
assertRewriteAfterTypeCoercion("IA * IB + 2 - IC * 2", "(IA * IB) - (IC * 2) + 2");
assertRewriteAfterTypeCoercion("IA * IB", "IA * IB");
assertRewriteAfterTypeCoercion("IA * IB / 2 * 2", "cast((IA * IB) as DOUBLE) / 1");
assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as DOUBLE) / 4");
assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as DOUBLE) / 4");
assertRewriteAfterTypeCoercion("IA * (IB / 2) * 2)", "cast(IA as DOUBLE) * cast(IB as DOUBLE) / 1");
assertRewriteAfterTypeCoercion("IA * (IB / 2) * (IC + 1))", "cast(IA as DOUBLE) * cast(IB as DOUBLE) * cast((IC + 1) as DOUBLE) / 2");
assertRewriteAfterTypeCoercion("IA * IB / 2 / IC * 2 * ID / 4", "(((cast((IA * IB) as DOUBLE) / cast(IC as DOUBLE)) * cast(ID as DOUBLE)) / 4)");
assertRewriteAfterTypeCoercion("IA * IB / 2 * 2", "cast((IA * IB) as DOUBLE) / 1.0");
assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as DOUBLE) / 4.0");
assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as DOUBLE) / 4.0");
assertRewriteAfterTypeCoercion("IA * (IB / 2) * 2)", "cast(IA as DOUBLE) * cast(IB as DOUBLE) / 1.0");
assertRewriteAfterTypeCoercion("IA * (IB / 2) * (IC + 1))", "cast(IA as DOUBLE) * cast(IB as DOUBLE) * cast((IC + 1) as DOUBLE) / 2.0");
assertRewriteAfterTypeCoercion("IA * IB / 2 / IC * 2 * ID / 4", "(((cast((IA * IB) as DOUBLE) / cast(IC as DOUBLE)) * cast(ID as DOUBLE)) / 4.0)");
}

@Test
Expand Down Expand Up @@ -86,8 +86,8 @@ void testSimplifyArithmeticComparison() {
assertRewriteAfterTypeCoercion("IA + 1 > IB", "cast(IA as BIGINT) > (cast(IB as BIGINT) - 1)");
assertRewriteAfterTypeCoercion("IA + 1 > IB * IC", "cast(IA as BIGINT) > ((IB * IC) - 1)");
assertRewriteAfterTypeCoercion("IA * ID > IB * IC", "IA * ID > IB * IC");
assertRewriteAfterTypeCoercion("IA * ID / 2 > IB * IC", "cast((IA * ID) as DOUBLE) > cast((IB * IC) as DOUBLE) * 2");
assertRewriteAfterTypeCoercion("IA * ID / -2 > IB * IC", "cast((IB * IC) as DOUBLE) * -2 > cast((IA * ID) as DOUBLE)");
assertRewriteAfterTypeCoercion("IA * ID / 2 > IB * IC", "cast((IA * ID) as DOUBLE) > cast((IB * IC) as DOUBLE) * 2.0");
assertRewriteAfterTypeCoercion("IA * ID / -2 > IB * IC", "cast((IB * IC) as DOUBLE) * -2.0 > cast((IA * ID) as DOUBLE)");
assertRewriteAfterTypeCoercion("1 - IA > 1", "(cast(IA as BIGINT) < 0)");
assertRewriteAfterTypeCoercion("1 - IA + 1 * 3 - 5 > 1", "(cast(IA as BIGINT) < -2)");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,17 +389,17 @@ void testMapDateTimeV2ComputePrecision() {
new NullLiteral(),
new DateTimeV2Literal("2020-02-02 00:00:00.1234"));
signature = ComputeSignatureHelper.computePrecision(new FakeComputeSignature(), signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertEquals(DateTimeV2Type.of(6),
Assertions.assertInstanceOf(MapType.class, signature.getArgType(0));
Assertions.assertEquals(DateTimeV2Type.of(4),
((MapType) signature.getArgType(0)).getKeyType());
Assertions.assertEquals(DateTimeV2Type.of(6),
Assertions.assertEquals(DateTimeV2Type.of(4),
((MapType) signature.getArgType(0)).getValueType());
Assertions.assertTrue(signature.getArgType(1) instanceof MapType);
Assertions.assertEquals(DateTimeV2Type.of(6),
Assertions.assertInstanceOf(MapType.class, signature.getArgType(1));
Assertions.assertEquals(DateTimeV2Type.of(4),
((MapType) signature.getArgType(1)).getKeyType());
Assertions.assertEquals(DateTimeV2Type.of(6),
Assertions.assertEquals(DateTimeV2Type.of(4),
((MapType) signature.getArgType(1)).getValueType());
Assertions.assertEquals(DateTimeV2Type.of(6),
Assertions.assertEquals(DateTimeV2Type.of(4),
signature.getArgType(2));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ PhysicalResultSink
----------------PhysicalOlapScan[com_dd_library] apply RFs: RF0
------------PhysicalDistribute[DistributionSpecHash]
--------------hashAgg[LOCAL]
----------------filter((cast(experiment_id as DOUBLE) = 37))
----------------filter((cast(experiment_id as DOUBLE) = 37.0))
------------------PhysicalOlapScan[shunt_log_com_dd_library]

-- !2 --
Expand All @@ -25,7 +25,7 @@ PhysicalResultSink
--------------PhysicalOlapScan[com_dd_library] apply RFs: RF0
------------PhysicalDistribute[DistributionSpecHash]
--------------hashAgg[LOCAL]
----------------filter((cast(experiment_id as DOUBLE) = 73))
----------------filter((cast(experiment_id as DOUBLE) = 73.0))
------------------PhysicalOlapScan[shunt_log_com_dd_library]

-- !3 --
Expand All @@ -37,7 +37,7 @@ PhysicalResultSink
----------hashJoin[INNER_JOIN] hashCondition=((a.device_id = b.device_id)) otherCondition=() build RFs:RF0 device_id->[device_id]
------------PhysicalOlapScan[com_dd_library] apply RFs: RF0
------------PhysicalDistribute[DistributionSpecHash]
--------------filter((cast(experiment_id as DOUBLE) = 73))
--------------filter((cast(experiment_id as DOUBLE) = 73.0))
----------------PhysicalOlapScan[shunt_log_com_dd_library]

-- !4 --
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ PhysicalResultSink
--------------filter((a.event_id = 'ad_click'))
----------------PhysicalOlapScan[com_dd_library_one_side] apply RFs: RF0
------------PhysicalDistribute[DistributionSpecHash]
--------------filter((cast(experiment_id as DOUBLE) = 37))
--------------filter((cast(experiment_id as DOUBLE) = 37.0))
----------------PhysicalOlapScan[shunt_log_com_dd_library_one_side]

-- !2 --
Expand All @@ -23,7 +23,7 @@ PhysicalResultSink
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[com_dd_library_one_side] apply RFs: RF0
------------PhysicalDistribute[DistributionSpecHash]
--------------filter((cast(experiment_id as DOUBLE) = 73))
--------------filter((cast(experiment_id as DOUBLE) = 73.0))
----------------PhysicalOlapScan[shunt_log_com_dd_library_one_side]

-- !3 --
Expand All @@ -35,7 +35,7 @@ PhysicalResultSink
----------hashJoin[INNER_JOIN] hashCondition=((a.device_id = b.device_id)) otherCondition=() build RFs:RF0 device_id->[device_id]
------------PhysicalOlapScan[com_dd_library_one_side] apply RFs: RF0
------------PhysicalDistribute[DistributionSpecHash]
--------------filter((cast(experiment_id as DOUBLE) = 73))
--------------filter((cast(experiment_id as DOUBLE) = 73.0))
----------------PhysicalOlapScan[shunt_log_com_dd_library_one_side]

-- !4 --
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ PhysicalResultSink

-- !filter_aggregation_group_set --
PhysicalResultSink
--filter((cast(msg as DOUBLE) = 1))
--filter((cast(msg as DOUBLE) = 1.0))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalRepeat
Expand All @@ -252,7 +252,7 @@ PhysicalResultSink

-- !filter_aggregation_group_set --
PhysicalResultSink
--filter(((t1.id > 10) OR (cast(msg as DOUBLE) = 1)))
--filter(((t1.id > 10) OR (cast(msg as DOUBLE) = 1.0)))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalRepeat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalProject
------filter((if((mean = 0), 0, (stdev / mean)) > 1))
------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0))
--------hashAgg[GLOBAL]
----------PhysicalDistribute[DistributionSpecHash]
------------hashAgg[LOCAL]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ PhysicalResultSink
------------------------------------PhysicalOlapScan[store]
--------------------------PhysicalDistribute[DistributionSpecReplicated]
----------------------------PhysicalProject
------------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1) and hd_buy_potential IN ('1001-5000', '5001-10000'))
------------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('1001-5000', '5001-10000'))
--------------------------------PhysicalOlapScan[household_demographics]

Loading