Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions fe/src/main/java/org/apache/doris/analysis/CaseExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.List;

/**
Expand Down Expand Up @@ -101,6 +102,7 @@ public boolean equals(Object obj) {
CaseExpr expr = (CaseExpr) obj;
return hasCaseExpr == expr.hasCaseExpr && hasElseExpr == expr.hasElseExpr;
}

public boolean hasCaseExpr() {
return hasCaseExpr;
}
Expand Down Expand Up @@ -251,4 +253,84 @@ public List<Expr> getReturnExprs() {
}
return exprs;
}

// this method just compare literal value and not completely consistent with be,for two cases
// 1 not deal float
// 2 just compare literal value with same type. for a example sql 'select case when 123 then '1' else '2' end as col'
// for be will return '1', because be only regard 0 as false
// but for current LiteralExpr.compareLiteral, `123`' won't be regard as true
// the case which two values has different type left to be
public static Expr computeCaseExpr(CaseExpr expr) {
LiteralExpr caseExpr;
int startIndex = 0;
int endIndex = expr.getChildren().size();
if (expr.hasCaseExpr()) {
// just deal literal here
// and avoid `float compute` in java,float should be dealt in be
Expr caseChildExpr = expr.getChild(0);
if (!caseChildExpr.isLiteral()
|| caseChildExpr instanceof DecimalLiteral || caseChildExpr instanceof FloatLiteral) {
return expr;
}
caseExpr = (LiteralExpr) expr.getChild(0);
startIndex++;
} else {
caseExpr = new BoolLiteral(true);
}

if (caseExpr instanceof NullLiteral) {
if (expr.hasElseExpr) {
return expr.getChild(expr.getChildren().size() - 1);
} else {
return new NullLiteral();
}
}

if (expr.hasElseExpr) {
endIndex--;
}

// early return when the `when expr` can't be converted to constants
Expr startExpr = expr.getChild(startIndex);
if ((!startExpr.isLiteral() || startExpr instanceof DecimalLiteral || startExpr instanceof FloatLiteral)
|| (!(startExpr instanceof NullLiteral) && !startExpr.getClass().toString().equals(caseExpr.getClass().toString()))) {
return expr;
}

for (int i = startIndex; i < endIndex; i = i + 2) {
Expr currentWhenExpr = expr.getChild(i);
// skip null literal
if (currentWhenExpr instanceof NullLiteral) {
continue;
}
// stop convert in three cases
// 1 not literal
// 2 float
// 3 `case expr` and `when expr` don't have same type
if ((!currentWhenExpr.isLiteral() || currentWhenExpr instanceof DecimalLiteral || currentWhenExpr instanceof FloatLiteral)
|| !currentWhenExpr.getClass().toString().equals(caseExpr.getClass().toString())) {
// remove the expr which has been evaluated
List<Expr> exprLeft = new ArrayList<>();
if (expr.hasCaseExpr()) {
exprLeft.add(caseExpr);
}
for (int j = i; j < expr.getChildren().size(); j++) {
exprLeft.add(expr.getChild(j));
}
Expr retCaseExpr = expr.clone();
retCaseExpr.getChildren().clear();
retCaseExpr.addChildren(exprLeft);
return retCaseExpr;
} else if (caseExpr.compareLiteral((LiteralExpr) currentWhenExpr) == 0) {
return expr.getChild(i + 1);
}
}

if (expr.hasElseExpr) {
return expr.getChild(expr.getChildren().size() - 1);
} else {
return new NullLiteral();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.CaseExpr;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.NullLiteral;
Expand Down Expand Up @@ -48,6 +49,11 @@ public class FoldConstantsRule implements ExprRewriteRule {

@Override
public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException {
// evaluate `case when expr` when possible
if (expr instanceof CaseExpr) {
return CaseExpr.computeCaseExpr((CaseExpr) expr);
}

// Avoid calling Expr.isConstant() because that would lead to repeated traversals
// of the Expr tree. Assumes the bottom-up application of this rule. Constant
// children should have been folded at this point.
Expand Down
156 changes: 129 additions & 27 deletions fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.planner;

import org.apache.commons.lang3.StringUtils;
import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateTableStmt;
import org.apache.doris.analysis.DropDbStmt;
Expand Down Expand Up @@ -242,31 +243,31 @@ public static void beforeClass() throws Exception {
"PROPERTIES (\n" +
" \"replication_num\" = \"1\"\n" +
");");

createTable("CREATE TABLE test.`pushdown_test` (\n" +
" `k1` tinyint(4) NULL COMMENT \"\",\n" +
" `k2` smallint(6) NULL COMMENT \"\",\n" +
" `k3` int(11) NULL COMMENT \"\",\n" +
" `k4` bigint(20) NULL COMMENT \"\",\n" +
" `k5` decimal(9, 3) NULL COMMENT \"\",\n" +
" `k6` char(5) NULL COMMENT \"\",\n" +
" `k10` date NULL COMMENT \"\",\n" +
" `k11` datetime NULL COMMENT \"\",\n" +
" `k7` varchar(20) NULL COMMENT \"\",\n" +
" `k8` double MAX NULL COMMENT \"\",\n" +
" `k9` float SUM NULL COMMENT \"\"\n" +
") ENGINE=OLAP\n" +
"AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" +
"COMMENT \"OLAP\"\n" +
"PARTITION BY RANGE(`k1`)\n" +
"(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" +
"PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" +
"PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" +
"DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\",\n" +
"\"in_memory\" = \"false\",\n" +
"\"storage_format\" = \"DEFAULT\"\n" +
" `k1` tinyint(4) NULL COMMENT \"\",\n" +
" `k2` smallint(6) NULL COMMENT \"\",\n" +
" `k3` int(11) NULL COMMENT \"\",\n" +
" `k4` bigint(20) NULL COMMENT \"\",\n" +
" `k5` decimal(9, 3) NULL COMMENT \"\",\n" +
" `k6` char(5) NULL COMMENT \"\",\n" +
" `k10` date NULL COMMENT \"\",\n" +
" `k11` datetime NULL COMMENT \"\",\n" +
" `k7` varchar(20) NULL COMMENT \"\",\n" +
" `k8` double MAX NULL COMMENT \"\",\n" +
" `k9` float SUM NULL COMMENT \"\"\n" +
") ENGINE=OLAP\n" +
"AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" +
"COMMENT \"OLAP\"\n" +
"PARTITION BY RANGE(`k1`)\n" +
"(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" +
"PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" +
"PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" +
"DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\",\n" +
"\"in_memory\" = \"false\",\n" +
"\"storage_format\" = \"DEFAULT\"\n" +
");");
}

Expand Down Expand Up @@ -656,18 +657,119 @@ public void testJoinPredicateTransitivity() throws Exception {
Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1"));
Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1"));
}


@Test
public void testConvertCaseWhenToConstant() throws Exception {
// basic test
String caseWhenSql = "select "
+ "case when date_format(now(),'%H%i') < 123 then 1 else 0 end as col "
+ "from test.test1 "
+ "where time = case when date_format(now(),'%H%i') < 123 then date_format(date_sub(now(),2),'%Y%m%d') else date_format(date_sub(now(),1),'%Y%m%d') end";
Assert.assertTrue(!StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + caseWhenSql), "CASE WHEN"));

// test 1: case when then
// 1.1 multi when in on `case when` and can be converted to constants
String sql11 = "select case when false then 2 when true then 3 else 0 end as col11;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql11), "constant exprs: \n 3"));

// 1.2 multi `when expr` in on `case when` ,`when expr` can not be converted to constants
String sql121 = "select case when false then 2 when substr(k7,2,1) then 3 else 0 end as col121 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql121),
"OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 3 ELSE 0 END"));

// 1.2.2 when expr which can not be converted to constants in the first
String sql122 = "select case when substr(k7,2,1) then 2 when false then 3 else 0 end as col122 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql122),
"OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 2 WHEN FALSE THEN 3 ELSE 0 END"));

// 1.2.3 test return `then expr` in the middle
String sql124 = "select case when false then 1 when true then 2 when false then 3 else 'other' end as col124";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql124), "constant exprs: \n '2'"));

// 1.3 test return null
String sql3 = "select case when false then 2 end as col3";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql3), "constant exprs: \n NULL"));

// 1.3.1 test return else expr
String sql131 = "select case when false then 2 when false then 3 else 4 end as col131";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql131), "constant exprs: \n 4"));

// 1.4 nest `case when` and can be converted to constants
String sql14 = "select case when (case when true then true else false end) then 2 when false then 3 else 0 end as col";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql14), "constant exprs: \n 2"));

// 1.5 nest `case when` and can not be converted to constants
String sql15 = "select case when case when substr(k7,2,1) then true else false end then 2 when false then 3 else 0 end as col from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql15),
"OUTPUT EXPRS:CASE WHEN CASE WHEN substr(`k7`, 2, 1) THEN TRUE ELSE FALSE END THEN 2 WHEN FALSE THEN 3 ELSE 0 END"));

// 1.6 test when expr is null
String sql16 = "select case when null then 1 else 2 end as col16;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql16), "constant exprs: \n 2"));

// test 2: case xxx when then
// 2.1 test equal
String sql2 = "select case 1 when 1 then 'a' when 2 then 'b' else 'other' end as col2;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql2), "constant exprs: \n 'a'"));

// 2.1.2 test not equal
String sql212 = "select case 'a' when 1 then 'a' when 'a' then 'b' else 'other' end as col212;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql212), "constant exprs: \n 'b'"));

// 2.2 test return null
String sql22 = "select case 'a' when 1 then 'a' when 'b' then 'b' end as col22;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql22), "constant exprs: \n NULL"));

// 2.2.2 test return else
String sql222 = "select case 1 when 2 then 'a' when 3 then 'b' else 'other' end as col222;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql222), "constant exprs: \n 'other'"));

// 2.3 test can not convert to constant,middle when expr is not constant
String sql23 = "select case 'a' when 'b' then 'a' when substr(k7,2,1) then 2 when false then 3 else 0 end as col23 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql23),
"OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '0' THEN '3' ELSE '0' END"));

// 2.3.1 first when expr is not constant
String sql231 = "select case 'a' when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col231 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql231),
"OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));

// 2.3.2 case expr is not constant
String sql232 = "select case k1 when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col232 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql232),
"OUTPUT EXPRS:CASE`k1` WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));

// 3.1 test float,float in case expr
String sql31 = "select case cast(100 as float) when 1 then 'a' when 2 then 'b' else 'other' end as col31;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql31),
"constant exprs: \n CASE100.0 WHEN 1.0 THEN 'a' WHEN 2.0 THEN 'b' ELSE 'other' END"));

// 4.1 test null in case expr return else
String sql41 = "select case null when 1 then 'a' when 2 then 'b' else 'other' end as col41";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql41), "constant exprs: \n 'other'"));

// 4.1.2 test null in case expr return null
String sql412 = "select case null when 1 then 'a' when 2 then 'b' end as col41";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql412), "constant exprs: \n NULL"));

// 4.2.1 test null in when expr
String sql421 = "select case 'a' when null then 'a' else 'other' end as col421";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql421), "constant exprs: \n 'other'"));
}

@Test
public void testJoinPredicateTransitivityWithSubqueryInWhereClause() throws Exception {
connectContext.setDatabase("default_cluster:test");
String sql = "SELECT *\n" +
String sql = "SELECT *\n" +
"FROM test.pushdown_test\n" +
"WHERE 0 < (\n" +
" SELECT MAX(k9)\n" +
" SELECT MAX(k9)\n" +
" FROM test.pushdown_test);";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql);
Assert.assertTrue(explainString.contains("PLAN FRAGMENT"));
Assert.assertTrue(explainString.contains("CROSS JOIN"));
Assert.assertTrue(!explainString.contains("PREDICATES"));
}


}