diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java index 8a41e91d35f2ca..b4daa214ee557f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java @@ -235,25 +235,25 @@ public void analyzeImpl(Analyzer analyzer) throws AnalysisException { // Add casts to case expr to compatible type. if (hasCaseExpr) { // Cast case expr. - if (children.get(0).type != whenType) { + if (!children.get(0).getType().equals(whenType)) { castChild(whenType, 0); } // Add casts to when exprs to compatible type. for (int i = loopStart; i < loopEnd; i += 2) { - if (children.get(i).type != whenType) { + if (!children.get(i).getType().equals(whenType)) { castChild(whenType, i); } } } // Cast then exprs to compatible type. for (int i = loopStart + 1; i < children.size(); i += 2) { - if (children.get(i).type != returnType) { + if (!children.get(i).getType().equals(returnType)) { castChild(returnType, i); } } // Cast else expr to compatible type. if (hasElseExpr) { - if (children.get(children.size() - 1).type != returnType) { + if (!children.get(children.size() - 1).getType().equals(returnType)) { castChild(returnType, children.size() - 1); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/CaseExprTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/CaseExprTest.java index c768be50304a97..72bf214c361b4f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/CaseExprTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/CaseExprTest.java @@ -17,12 +17,16 @@ package org.apache.doris.analysis; +import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.catalog.ScalarType; + import com.clearspring.analytics.util.Lists; import mockit.Expectations; import mockit.Injectable; import org.junit.Assert; import org.junit.Test; +import java.lang.reflect.Constructor; import java.util.List; public class CaseExprTest { @@ -65,4 +69,44 @@ public void testIsNullable(@Injectable Expr caseExpr, }; Assert.assertTrue(caseExpr3.isNullable()); } + + @Test + public void testTypeCast( + @Injectable Expr caseExpr, + @Injectable Expr whenExpr, + @Injectable Expr thenExpr, + @Injectable Expr elseExpr) throws Exception { + CaseWhenClause caseWhenClause = new CaseWhenClause(whenExpr, thenExpr); + List caseWhenClauseList = Lists.newArrayList(); + caseWhenClauseList.add(caseWhenClause); + CaseExpr expr = new CaseExpr(caseExpr, caseWhenClauseList, elseExpr); + Class clazz = Class.forName("org.apache.doris.catalog.ScalarType"); + Constructor constructor = clazz.getDeclaredConstructor(PrimitiveType.class); + constructor.setAccessible(true); + ScalarType intType = (ScalarType) constructor.newInstance(PrimitiveType.INT); + ScalarType tinyIntType = (ScalarType) constructor.newInstance(PrimitiveType.TINYINT); + new Expectations() { + { + caseExpr.getType(); + result = intType; + + whenExpr.getType(); + result = ScalarType.createType(PrimitiveType.INT); + whenExpr.castTo(intType); + times = 0; + + thenExpr.getType(); + result = tinyIntType; + thenExpr.castTo(tinyIntType); + times = 0; + + elseExpr.getType(); + result = ScalarType.createType(PrimitiveType.TINYINT); + elseExpr.castTo(tinyIntType); + times = 0; + } + }; + Analyzer analyzer = new Analyzer(null, null); + expr.analyzeImpl(analyzer); + } } diff --git a/regression-test/data/correctness/test_case_when.out b/regression-test/data/correctness/test_case_when.out index d4db120b762f22..fba725b132d4a4 100644 --- a/regression-test/data/correctness/test_case_when.out +++ b/regression-test/data/correctness/test_case_when.out @@ -3,3 +3,9 @@ 2 01 -1 3 00 1 +-- !select_agg -- +1 90020004 +2 45010002 +3 90020004 +4 90020004 + diff --git a/regression-test/suites/correctness/test_case_when.groovy b/regression-test/suites/correctness/test_case_when.groovy index f5afe3026fe7e3..bdade75ead3a9a 100644 --- a/regression-test/suites/correctness/test_case_when.groovy +++ b/regression-test/suites/correctness/test_case_when.groovy @@ -64,4 +64,16 @@ suite("test_case_when") { and channel_id = '00' group by hour_time, station_type; """ -} \ No newline at end of file + + qt_select_agg """ + SELECT + hour_time, + sum((CASE WHEN TRUE THEN merchant_id ELSE 0 END)) mid + FROM + dws_scan_qrcode_user_ts + GROUP BY + hour_time + ORDER BY + hour_time; + """ +}