diff --git a/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java b/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java index 602aabdf92a46c..ffe0d880e5f7a2 100644 --- a/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java +++ b/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java @@ -26,6 +26,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; +import java.util.ArrayList; import java.util.List; /** @@ -101,6 +102,7 @@ public boolean equals(Object obj) { CaseExpr expr = (CaseExpr) obj; return hasCaseExpr == expr.hasCaseExpr && hasElseExpr == expr.hasElseExpr; } + public boolean hasCaseExpr() { return hasCaseExpr; } @@ -251,4 +253,84 @@ public List getReturnExprs() { } return exprs; } + + // this method just compare literal value and not completely consistent with be,for two cases + // 1 not deal float + // 2 just compare literal value with same type. for a example sql 'select case when 123 then '1' else '2' end as col' + // for be will return '1', because be only regard 0 as false + // but for current LiteralExpr.compareLiteral, `123`' won't be regard as true + // the case which two values has different type left to be + public static Expr computeCaseExpr(CaseExpr expr) { + LiteralExpr caseExpr; + int startIndex = 0; + int endIndex = expr.getChildren().size(); + if (expr.hasCaseExpr()) { + // just deal literal here + // and avoid `float compute` in java,float should be dealt in be + Expr caseChildExpr = expr.getChild(0); + if (!caseChildExpr.isLiteral() + || caseChildExpr instanceof DecimalLiteral || caseChildExpr instanceof FloatLiteral) { + return expr; + } + caseExpr = (LiteralExpr) expr.getChild(0); + startIndex++; + } else { + caseExpr = new BoolLiteral(true); + } + + if (caseExpr instanceof NullLiteral) { + if (expr.hasElseExpr) { + return expr.getChild(expr.getChildren().size() - 1); + } else { + return new NullLiteral(); + } + } + + if (expr.hasElseExpr) { + endIndex--; + } + + // early return when the `when expr` can't be converted to constants + Expr startExpr = expr.getChild(startIndex); + if ((!startExpr.isLiteral() || startExpr instanceof DecimalLiteral || startExpr instanceof FloatLiteral) + || (!(startExpr instanceof NullLiteral) && !startExpr.getClass().toString().equals(caseExpr.getClass().toString()))) { + return expr; + } + + for (int i = startIndex; i < endIndex; i = i + 2) { + Expr currentWhenExpr = expr.getChild(i); + // skip null literal + if (currentWhenExpr instanceof NullLiteral) { + continue; + } + // stop convert in three cases + // 1 not literal + // 2 float + // 3 `case expr` and `when expr` don't have same type + if ((!currentWhenExpr.isLiteral() || currentWhenExpr instanceof DecimalLiteral || currentWhenExpr instanceof FloatLiteral) + || !currentWhenExpr.getClass().toString().equals(caseExpr.getClass().toString())) { + // remove the expr which has been evaluated + List exprLeft = new ArrayList<>(); + if (expr.hasCaseExpr()) { + exprLeft.add(caseExpr); + } + for (int j = i; j < expr.getChildren().size(); j++) { + exprLeft.add(expr.getChild(j)); + } + Expr retCaseExpr = expr.clone(); + retCaseExpr.getChildren().clear(); + retCaseExpr.addChildren(exprLeft); + return retCaseExpr; + } else if (caseExpr.compareLiteral((LiteralExpr) currentWhenExpr) == 0) { + return expr.getChild(i + 1); + } + } + + if (expr.hasElseExpr) { + return expr.getChild(expr.getChildren().size() - 1); + } else { + return new NullLiteral(); + } + } + } diff --git a/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java b/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java index 1c4298e9ed3e25..f697150400669e 100644 --- a/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java +++ b/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java @@ -19,6 +19,7 @@ import org.apache.doris.analysis.Analyzer; +import org.apache.doris.analysis.CaseExpr; import org.apache.doris.analysis.CastExpr; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.NullLiteral; @@ -48,6 +49,11 @@ public class FoldConstantsRule implements ExprRewriteRule { @Override public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException { + // evaluate `case when expr` when possible + if (expr instanceof CaseExpr) { + return CaseExpr.computeCaseExpr((CaseExpr) expr); + } + // Avoid calling Expr.isConstant() because that would lead to repeated traversals // of the Expr tree. Assumes the bottom-up application of this rule. Constant // children should have been folded at this point. diff --git a/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java b/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java index c192d183278b89..741d71eab2462c 100644 --- a/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java +++ b/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java @@ -17,6 +17,7 @@ package org.apache.doris.planner; +import org.apache.commons.lang3.StringUtils; import org.apache.doris.analysis.CreateDbStmt; import org.apache.doris.analysis.CreateTableStmt; import org.apache.doris.analysis.DropDbStmt; @@ -242,31 +243,31 @@ public static void beforeClass() throws Exception { "PROPERTIES (\n" + " \"replication_num\" = \"1\"\n" + ");"); - + createTable("CREATE TABLE test.`pushdown_test` (\n" + - " `k1` tinyint(4) NULL COMMENT \"\",\n" + - " `k2` smallint(6) NULL COMMENT \"\",\n" + - " `k3` int(11) NULL COMMENT \"\",\n" + - " `k4` bigint(20) NULL COMMENT \"\",\n" + - " `k5` decimal(9, 3) NULL COMMENT \"\",\n" + - " `k6` char(5) NULL COMMENT \"\",\n" + - " `k10` date NULL COMMENT \"\",\n" + - " `k11` datetime NULL COMMENT \"\",\n" + - " `k7` varchar(20) NULL COMMENT \"\",\n" + - " `k8` double MAX NULL COMMENT \"\",\n" + - " `k9` float SUM NULL COMMENT \"\"\n" + - ") ENGINE=OLAP\n" + - "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" + - "COMMENT \"OLAP\"\n" + - "PARTITION BY RANGE(`k1`)\n" + - "(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" + - "PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" + - "PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" + - "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" + - "PROPERTIES (\n" + - "\"replication_num\" = \"1\",\n" + - "\"in_memory\" = \"false\",\n" + - "\"storage_format\" = \"DEFAULT\"\n" + + " `k1` tinyint(4) NULL COMMENT \"\",\n" + + " `k2` smallint(6) NULL COMMENT \"\",\n" + + " `k3` int(11) NULL COMMENT \"\",\n" + + " `k4` bigint(20) NULL COMMENT \"\",\n" + + " `k5` decimal(9, 3) NULL COMMENT \"\",\n" + + " `k6` char(5) NULL COMMENT \"\",\n" + + " `k10` date NULL COMMENT \"\",\n" + + " `k11` datetime NULL COMMENT \"\",\n" + + " `k7` varchar(20) NULL COMMENT \"\",\n" + + " `k8` double MAX NULL COMMENT \"\",\n" + + " `k9` float SUM NULL COMMENT \"\"\n" + + ") ENGINE=OLAP\n" + + "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" + + "COMMENT \"OLAP\"\n" + + "PARTITION BY RANGE(`k1`)\n" + + "(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" + + "PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" + + "PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" + + "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" + + "PROPERTIES (\n" + + "\"replication_num\" = \"1\",\n" + + "\"in_memory\" = \"false\",\n" + + "\"storage_format\" = \"DEFAULT\"\n" + ");"); } @@ -656,18 +657,119 @@ public void testJoinPredicateTransitivity() throws Exception { Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1")); Assert.assertTrue(explainString.contains("PREDICATES: `join2`.`id` > 1")); } - + + @Test + public void testConvertCaseWhenToConstant() throws Exception { + // basic test + String caseWhenSql = "select " + + "case when date_format(now(),'%H%i') < 123 then 1 else 0 end as col " + + "from test.test1 " + + "where time = case when date_format(now(),'%H%i') < 123 then date_format(date_sub(now(),2),'%Y%m%d') else date_format(date_sub(now(),1),'%Y%m%d') end"; + Assert.assertTrue(!StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + caseWhenSql), "CASE WHEN")); + + // test 1: case when then + // 1.1 multi when in on `case when` and can be converted to constants + String sql11 = "select case when false then 2 when true then 3 else 0 end as col11;"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql11), "constant exprs: \n 3")); + + // 1.2 multi `when expr` in on `case when` ,`when expr` can not be converted to constants + String sql121 = "select case when false then 2 when substr(k7,2,1) then 3 else 0 end as col121 from test.baseall"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql121), + "OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 3 ELSE 0 END")); + + // 1.2.2 when expr which can not be converted to constants in the first + String sql122 = "select case when substr(k7,2,1) then 2 when false then 3 else 0 end as col122 from test.baseall"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql122), + "OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 2 WHEN FALSE THEN 3 ELSE 0 END")); + + // 1.2.3 test return `then expr` in the middle + String sql124 = "select case when false then 1 when true then 2 when false then 3 else 'other' end as col124"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql124), "constant exprs: \n '2'")); + + // 1.3 test return null + String sql3 = "select case when false then 2 end as col3"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql3), "constant exprs: \n NULL")); + + // 1.3.1 test return else expr + String sql131 = "select case when false then 2 when false then 3 else 4 end as col131"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql131), "constant exprs: \n 4")); + + // 1.4 nest `case when` and can be converted to constants + String sql14 = "select case when (case when true then true else false end) then 2 when false then 3 else 0 end as col"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql14), "constant exprs: \n 2")); + + // 1.5 nest `case when` and can not be converted to constants + String sql15 = "select case when case when substr(k7,2,1) then true else false end then 2 when false then 3 else 0 end as col from test.baseall"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql15), + "OUTPUT EXPRS:CASE WHEN CASE WHEN substr(`k7`, 2, 1) THEN TRUE ELSE FALSE END THEN 2 WHEN FALSE THEN 3 ELSE 0 END")); + + // 1.6 test when expr is null + String sql16 = "select case when null then 1 else 2 end as col16;"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql16), "constant exprs: \n 2")); + + // test 2: case xxx when then + // 2.1 test equal + String sql2 = "select case 1 when 1 then 'a' when 2 then 'b' else 'other' end as col2;"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql2), "constant exprs: \n 'a'")); + + // 2.1.2 test not equal + String sql212 = "select case 'a' when 1 then 'a' when 'a' then 'b' else 'other' end as col212;"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql212), "constant exprs: \n 'b'")); + + // 2.2 test return null + String sql22 = "select case 'a' when 1 then 'a' when 'b' then 'b' end as col22;"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql22), "constant exprs: \n NULL")); + + // 2.2.2 test return else + String sql222 = "select case 1 when 2 then 'a' when 3 then 'b' else 'other' end as col222;"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql222), "constant exprs: \n 'other'")); + + // 2.3 test can not convert to constant,middle when expr is not constant + String sql23 = "select case 'a' when 'b' then 'a' when substr(k7,2,1) then 2 when false then 3 else 0 end as col23 from test.baseall"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql23), + "OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '0' THEN '3' ELSE '0' END")); + + // 2.3.1 first when expr is not constant + String sql231 = "select case 'a' when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col231 from test.baseall"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql231), + "OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END")); + + // 2.3.2 case expr is not constant + String sql232 = "select case k1 when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col232 from test.baseall"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql232), + "OUTPUT EXPRS:CASE`k1` WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END")); + + // 3.1 test float,float in case expr + String sql31 = "select case cast(100 as float) when 1 then 'a' when 2 then 'b' else 'other' end as col31;"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql31), + "constant exprs: \n CASE100.0 WHEN 1.0 THEN 'a' WHEN 2.0 THEN 'b' ELSE 'other' END")); + + // 4.1 test null in case expr return else + String sql41 = "select case null when 1 then 'a' when 2 then 'b' else 'other' end as col41"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql41), "constant exprs: \n 'other'")); + + // 4.1.2 test null in case expr return null + String sql412 = "select case null when 1 then 'a' when 2 then 'b' end as col41"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql412), "constant exprs: \n NULL")); + + // 4.2.1 test null in when expr + String sql421 = "select case 'a' when null then 'a' else 'other' end as col421"; + Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql421), "constant exprs: \n 'other'")); + } + @Test public void testJoinPredicateTransitivityWithSubqueryInWhereClause() throws Exception { connectContext.setDatabase("default_cluster:test"); - String sql = "SELECT *\n" + + String sql = "SELECT *\n" + "FROM test.pushdown_test\n" + "WHERE 0 < (\n" + - " SELECT MAX(k9)\n" + + " SELECT MAX(k9)\n" + " FROM test.pushdown_test);"; String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql); Assert.assertTrue(explainString.contains("PLAN FRAGMENT")); Assert.assertTrue(explainString.contains("CROSS JOIN")); Assert.assertTrue(!explainString.contains("PREDICATES")); } + + }