diff --git a/docs/en/administrator-guide/variables.md b/docs/en/administrator-guide/variables.md index e2d1f23d282e0d..1379a2b062e40d 100644 --- a/docs/en/administrator-guide/variables.md +++ b/docs/en/administrator-guide/variables.md @@ -388,3 +388,7 @@ Note that the comment must start with /*+ and can only follow the SELECT. In a sort query, when an upper level node receives the ordered data of the lower level node, it will sort the corresponding data on the exchange node to ensure that the final data is ordered. However, when a single thread merges multiple channels of data, if the amount of data is too large, it will lead to a single point of exchange node merge bottleneck. Doris optimizes this part if there are too many data nodes in the lower layer. Exchange node will start multithreading for parallel merging to speed up the sorting process. This parameter is false by default, which means that exchange node does not adopt parallel merge sort to reduce the extra CPU and memory consumption. + +* `extract_wide_range_expr` + + Used to control whether turn on the 'Wide Common Factors' rule. The value has two: true or false. On by default. diff --git a/docs/zh-CN/administrator-guide/variables.md b/docs/zh-CN/administrator-guide/variables.md index bfcfc99b214bc0..441cb6fd6ed76c 100644 --- a/docs/zh-CN/administrator-guide/variables.md +++ b/docs/zh-CN/administrator-guide/variables.md @@ -383,3 +383,7 @@ SELECT /*+ SET_VAR(query_timeout = 1, enable_partition_cache=true) */ sleep(3); 在一个排序的查询之中,一个上层节点接收下层节点有序数据时,会在exchange node上进行对应的排序来保证最终的数据是有序的。但是单线程进行多路数据归并时,如果数据量过大,会导致exchange node的单点的归并瓶颈。 Doris在这部分进行了优化处理,如果下层的数据节点过多。exchange node会启动多线程进行并行归并来加速排序过程。该参数默认为False,即表示 exchange node 不采取并行的归并排序,来减少额外的CPU和内存消耗。 + +* `extract_wide_range_expr` + + 用于控制是否开启 「宽泛公因式提取」的优化。取值有两种:true 和 false 。默认情况下开启。 diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java index 8fc26a7d4dac77..36347686dec637 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java @@ -38,6 +38,7 @@ import org.apache.doris.rewrite.BetweenToCompoundRule; import org.apache.doris.rewrite.ExprRewriteRule; import org.apache.doris.rewrite.ExprRewriter; +import org.apache.doris.rewrite.ExtractCommonFactorsRule; import org.apache.doris.rewrite.FoldConstantsRule; import org.apache.doris.rewrite.RewriteFromUnixTimeRule; import org.apache.doris.rewrite.NormalizeBinaryPredicatesRule; @@ -259,7 +260,9 @@ public GlobalState(Catalog catalog, ConnectContext context) { rules.add(FoldConstantsRule.INSTANCE); rules.add(RewriteFromUnixTimeRule.INSTANCE); rules.add(SimplifyInvalidDateBinaryPredicatesDateRule.INSTANCE); - exprRewriter_ = new ExprRewriter(rules); + List onceRules = Lists.newArrayList(); + onceRules.add(ExtractCommonFactorsRule.INSTANCE); + exprRewriter_ = new ExprRewriter(rules, onceRules); // init mv rewriter List mvRewriteRules = Lists.newArrayList(); mvRewriteRules.add(ToBitmapToSlotRefRule.INSTANCE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java index 600a825dfb3635..f399334401aebb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java @@ -33,6 +33,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; +import com.google.common.collect.Range; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -465,7 +466,28 @@ public boolean slotIsLeft() { Preconditions.checkState(slotIsleft != null); return slotIsleft; } - + + public Range convertToRange() { + Preconditions.checkState(getChild(0) instanceof SlotRef); + Preconditions.checkState(getChild(1) instanceof LiteralExpr); + LiteralExpr literalExpr = (LiteralExpr) getChild(1); + switch (op) { + case EQ: + return Range.singleton(literalExpr); + case GE: + return Range.atLeast(literalExpr); + case GT: + return Range.greaterThan(literalExpr); + case LE: + return Range.atMost(literalExpr); + case LT: + return Range.lessThan(literalExpr); + case NE: + default: + return null; + } + } + // public static enum Operator2 { // EQ("=", FunctionOperator.EQ, FunctionOperator.FILTER_EQ), // NE("!=", FunctionOperator.NE, FunctionOperator.FILTER_NE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java index 0257e6b6b2b595..e298a0cfca970c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java @@ -284,7 +284,7 @@ private Expr castTo(LiteralExpr value) throws AnalysisException { return new IntLiteral(value.getLongValue(), type); } else if (type.isLargeIntType()) { return new LargeIntLiteral(value.getStringValue()); - } else if (type.isDecimal()) { + } else if (type.isDecimal() || type.isDecimalV2()) { return new DecimalLiteral(value.getStringValue()); } else if (type.isFloatingPointType()) { return new FloatLiteral(value.getDoubleValue(), type); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java index 66d4b0531aa611..91ce58603b0e50 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java @@ -157,7 +157,17 @@ public int compareLiteral(LiteralExpr expr) { if (expr instanceof NullLiteral) { return 1; } - return this.value.compareTo(((DecimalLiteral) expr).value); + if (expr instanceof DecimalLiteral) { + return this.value.compareTo(((DecimalLiteral) expr).value); + } else { + try { + DecimalLiteral decimalLiteral = new DecimalLiteral(expr.getStringValue()); + return this.compareLiteral(decimalLiteral); + } catch (AnalysisException e) { + throw new ClassCastException("Those two values cannot be compared: " + value + + " and " + expr.toSqlImpl()); + } + } } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java index b8515cb4825080..59d85ace26a6a3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java @@ -17,16 +17,16 @@ package org.apache.doris.analysis; +import org.apache.doris.common.io.Writable; + +import com.google.common.collect.Lists; + import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.List; import java.util.Objects; -import org.apache.doris.common.io.Writable; - -import com.google.common.collect.Lists; - /** * Return value of the grammar production that parses function * parameters. These parameters can be for scalar or aggregate functions. @@ -111,8 +111,10 @@ public static FunctionParams read(DataInput in) throws IOException { @Override public int hashCode() { int result = 31 * Boolean.hashCode(isStar) + Boolean.hashCode(isDistinct); - for (Expr expr : exprs) { - result = 31 * result + Objects.hashCode(expr); + if (exprs != null) { + for (Expr expr : exprs) { + result = 31 * result + Objects.hashCode(expr); + } } return result; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/InPredicate.java index c54ff49a2e2825..e684aa27bb617d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/InPredicate.java @@ -105,7 +105,7 @@ public int getInElementNum() { } @Override - public Expr clone() { + public InPredicate clone() { return new InPredicate(this); } @@ -212,17 +212,37 @@ public void analyzeImpl(Analyzer analyzer) throws AnalysisException { Reference slotRefRef = new Reference(); Reference idxRef = new Reference(); - if (isSingleColumnPredicate(slotRefRef, - idxRef) && idxRef.getRef() == 0 && slotRefRef.getRef().getNumDistinctValues() > 0) { - selectivity = - (double) (getChildren().size() - 1) / (double) slotRefRef.getRef() - .getNumDistinctValues(); + if (isSingleColumnPredicate(slotRefRef, idxRef) + && idxRef.getRef() == 0 && slotRefRef.getRef().getNumDistinctValues() > 0) { + selectivity = (double) (getChildren().size() - 1) / (double) slotRefRef.getRef() + .getNumDistinctValues(); selectivity = Math.max(0.0, Math.min(1.0, selectivity)); } else { selectivity = Expr.DEFAULT_SELECTIVITY; } } + public InPredicate union(InPredicate inPredicate) { + Preconditions.checkState(inPredicate.isLiteralChildren()); + Preconditions.checkState(this.isLiteralChildren()); + Preconditions.checkState(getChild(0).equals(inPredicate.getChild(0))); + List unionChildren = new ArrayList<>(getListChildren()); + unionChildren.removeAll(inPredicate.getListChildren()); + unionChildren.addAll(inPredicate.getListChildren()); + InPredicate union = new InPredicate(getChild(0), unionChildren, isNotIn); + return union; + } + + public InPredicate intersection(InPredicate inPredicate) { + Preconditions.checkState(inPredicate.isLiteralChildren()); + Preconditions.checkState(this.isLiteralChildren()); + Preconditions.checkState(getChild(0).equals(inPredicate.getChild(0))); + List intersectChildren = new ArrayList<>(getListChildren()); + intersectChildren.retainAll(inPredicate.getListChildren()); + InPredicate intersection = new InPredicate(getChild(0), intersectChildren, isNotIn); + return intersection; + } + @Override protected void toThrift(TExprNode msg) { // Can't serialize a predicate with a subquery diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java index 5bed750709f060..370c4d405606b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java @@ -32,7 +32,7 @@ import java.io.IOException; import java.nio.ByteBuffer; -public abstract class LiteralExpr extends Expr { +public abstract class LiteralExpr extends Expr implements Comparable { private static final Logger LOG = LogManager.getLogger(LiteralExpr.class); public LiteralExpr() { @@ -128,6 +128,11 @@ public Object getRealValue() { // must handle MaxLiteral. public abstract int compareLiteral(LiteralExpr expr); + @Override + public int compareTo(LiteralExpr literalExpr) { + return compareLiteral(literalExpr); + } + // Returns the string representation of the literal's value. Used when passing // literal values to the metastore rather than to Palo backends. This is similar to // the toSql() method, but does not perform any formatting of the string values. Neither diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java index 8f617da2acbb0a..d67ad8bf1511f0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectStmt.java @@ -40,6 +40,7 @@ import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.qe.ConnectContext; import org.apache.doris.rewrite.ExprRewriter; + import com.google.common.base.Preconditions; import com.google.common.base.Predicates; import com.google.common.base.Strings; @@ -56,7 +57,6 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -480,8 +480,7 @@ public void analyze(Analyzer analyzer) throws AnalysisException, UserException { "cannot combine SELECT DISTINCT with analytic functions"); } } - // do this before whereClause.analyze , some expr is not analyzed, this may cause some - // function not work as expected such as equals; + whereClauseRewrite(); if (whereClause != null) { if (checkGroupingFn(whereClause)) { @@ -579,162 +578,6 @@ private void whereClauseRewrite() { whereClause = new BoolLiteral(true); } } - Expr deDuplicatedWhere = deduplicateOrs(whereClause); - if (deDuplicatedWhere != null) { - whereClause = deDuplicatedWhere; - } - } - - /** - * this function only process (a and b and c) or (d and e and f) like clause, - * this function will extract this to [[a, b, c], [d, e, f]] - */ - private List> extractDuplicateOrs(CompoundPredicate expr) { - List> orExprs = new ArrayList<>(); - for (Expr child : expr.getChildren()) { - if (child instanceof CompoundPredicate) { - CompoundPredicate childCp = (CompoundPredicate) child; - if (childCp.getOp() == CompoundPredicate.Operator.OR) { - orExprs.addAll(extractDuplicateOrs(childCp)); - continue; - } else if (childCp.getOp() == CompoundPredicate.Operator.AND) { - orExprs.add(flatAndExpr(child)); - continue; - } - } - orExprs.add(Arrays.asList(child)); - } - return orExprs; - } - - /** - * This function attempts to apply the inverse OR distributive law: - * ((A AND B) OR (A AND C)) => (A AND (B OR C)) - * That is, locate OR clauses in which every subclause contains an - * identical term, and pull out the duplicated terms. - */ - private Expr deduplicateOrs(Expr expr) { - if (expr == null) { - return null; - } else if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.OR) { - Expr rewritedExpr = processDuplicateOrs(extractDuplicateOrs((CompoundPredicate) expr)); - if (rewritedExpr != null) { - return rewritedExpr; - } - } else { - for (int i = 0; i < expr.getChildren().size(); i++) { - Expr rewritedExpr = deduplicateOrs(expr.getChild(i)); - if (rewritedExpr != null) { - expr.setChild(i, rewritedExpr); - } - } - } - return expr; - } - - /** - * try to flat and , a and b and c => [a, b, c] - */ - private List flatAndExpr(Expr expr) { - List andExprs = new ArrayList<>(); - if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.AND) { - andExprs.addAll(flatAndExpr(expr.getChild(0))); - andExprs.addAll(flatAndExpr(expr.getChild(1))); - } else { - andExprs.add(expr); - } - return andExprs; - } - - /** - * the input is a list of list, the inner list is and connected exprs, the outer list is or connected - * for example clause (a and b and c) or (a and e and f) after extractDuplicateOrs will be [[a, b, c], [a, e, f]] - * this is the input of this function, first step is deduplicate [[a, b, c], [a, e, f]] => [[a], [b, c], [e, f]] - * then rebuild the expr to a and ((b and c) or (e and f)) - */ - private Expr processDuplicateOrs(List> exprs) { - if (exprs.size() < 2) { - return null; - } - // 1. remove duplicated elements [[a,a], [a, b], [a,b]] => [[a], [a,b]] - Set> set = new LinkedHashSet<>(); - for (List ex : exprs) { - Set es = new LinkedHashSet<>(); - es.addAll(ex); - set.add(es); - } - List> clearExprs = new ArrayList<>(); - for (Set es : set) { - List el = new ArrayList<>(); - el.addAll(es); - clearExprs.add(el); - } - if (clearExprs.size() == 1) { - return makeCompound(clearExprs.get(0), CompoundPredicate.Operator.AND); - } - // 2. find duplicate cross the clause - List cloneExprs = new ArrayList<>(clearExprs.get(0)); - for (int i = 1; i < clearExprs.size(); ++i) { - cloneExprs.retainAll(clearExprs.get(i)); - } - List temp = new ArrayList<>(); - if (CollectionUtils.isNotEmpty(cloneExprs)) { - temp.add(makeCompound(cloneExprs, CompoundPredicate.Operator.AND)); - } - - Expr result; - boolean isReturnCommonFactorExpr = false; - for (List exprList : clearExprs) { - exprList.removeAll(cloneExprs); - if (exprList.size() == 0) { - // For example, the sql is "where (a = 1) or (a = 1 and B = 2)" - // if "(a = 1)" is extracted as a common factor expression, then the first expression "(a = 1)" has no expression - // other than a common factor expression, and the second expression "(a = 1 and B = 2)" has an expression of "(B = 2)" - // - // In this case, the common factor expression ("a = 1") can be directly used to replace the whole CompoundOrPredicate. - // In Fact, the common factor expression is actually the parent set of expression "(a = 1)" and expression "(a = 1 and B = 2)" - // - // exprList.size() == 0 means one child of CompoundOrPredicate has no expression other than a common factor expression. - isReturnCommonFactorExpr = true; - break; - } - temp.add(makeCompound(exprList, CompoundPredicate.Operator.AND)); - } - if (isReturnCommonFactorExpr) { - result = temp.get(0); - } else { - // rebuild CompoundPredicate if found duplicate predicate will build (predicate) and (.. or ..) predicate in - // step 1: will build (.. or ..) - if (CollectionUtils.isNotEmpty(cloneExprs)) { - result = new CompoundPredicate(CompoundPredicate.Operator.AND, temp.get(0), - makeCompound(temp.subList(1, temp.size()), CompoundPredicate.Operator.OR)); - result.setPrintSqlInParens(true); - } else { - result = makeCompound(temp, CompoundPredicate.Operator.OR); - } - } - if (LOG.isDebugEnabled()) { - LOG.debug("equal ors: " + result.toSql()); - } - return result; - } - - /** - * Rebuild CompoundPredicate, [a, e, f] AND => a and e and f - */ - private Expr makeCompound(List exprs, CompoundPredicate.Operator op) { - if (CollectionUtils.isEmpty(exprs)) { - return null; - } - if (exprs.size() == 1) { - return exprs.get(0); - } - CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1)); - for (int i = 2; i < exprs.size(); ++i) { - result = new CompoundPredicate(op, result.clone(), exprs.get(i)); - } - result.setPrintSqlInParens(true); - return result; } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 73fc865c126199..3146ae31459f6e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -120,6 +120,7 @@ public class SessionVariable implements Serializable, Writable { public static final String DELETE_WITHOUT_PARTITION = "delete_without_partition"; + public static final String EXTRACT_WIDE_RANGE_EXPR = "extract_wide_range_expr"; public static final long DEFAULT_INSERT_VISIBLE_TIMEOUT_MS = 10_000; public static final long MIN_INSERT_VISIBLE_TIMEOUT_MS = 1000; // If user set a very small value, use this value instead. @@ -307,6 +308,9 @@ public class SessionVariable implements Serializable, Writable { @VariableMgr.VarAttr(name = DELETE_WITHOUT_PARTITION, needForward = true) public boolean deleteWithoutPartition = false; + @VariableMgr.VarAttr(name = EXTRACT_WIDE_RANGE_EXPR, needForward = true) + public boolean extractWideRangeExpr = true; + public long getMaxExecMemByte() { return maxExecMemByte; } @@ -621,6 +625,10 @@ public boolean isDeleteWithoutPartition() { return deleteWithoutPartition; } + public boolean isExtractWideRangeExpr() { + return extractWideRangeExpr; + } + // Serialize to thrift object // used for rest api public TQueryOptions toThrift() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExprRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExprRewriter.java index d5207ff5960beb..1217f41a33e274 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExprRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExprRewriter.java @@ -17,13 +17,14 @@ package org.apache.doris.rewrite; -import java.util.List; - import org.apache.doris.analysis.Analyzer; import org.apache.doris.analysis.Expr; import org.apache.doris.common.AnalysisException; + import com.google.common.collect.Lists; +import java.util.List; + /** * Helper class that drives the transformation of Exprs according to a given list of * ExprRewriteRules. The rules are applied as follows: @@ -32,15 +33,30 @@ * - the rule list is applied repeatedly until no rule has made any changes * - the rules are applied in the order they appear in the rule list * Keeps track of how many transformations were applied. + * + * There are two types of Rewriter, the first is Repeat Rewriter, + * and the other is Once Rewriter. + * The Repeat Rewriter framework will call Rule repeatedly + * until the entire expression does not change. + * The Once Rewriter framework will only call Rule once. + * According to different Rule strategies, + * Doris match different Rewriter framework execution. */ public class ExprRewriter { private int numChanges_ = 0; private final List rules_; + // Once-only Rules + private List onceRules_ = Lists.newArrayList(); public ExprRewriter(List rules) { rules_ = rules; } + public ExprRewriter(List rules, List onceRules) { + rules_ = rules; + onceRules_ = onceRules; + } + public ExprRewriter(ExprRewriteRule rule) { rules_ = Lists.newArrayList(rule); } @@ -55,6 +71,18 @@ public Expr rewrite(Expr expr, Analyzer analyzer) throws AnalysisException { rewrittenExpr = applyRuleRepeatedly(rewrittenExpr, rule, analyzer); } } while (oldNumChanges != numChanges_); + + for (ExprRewriteRule rule: onceRules_) { + rewrittenExpr = applyRuleOnce(rewrittenExpr, rule, analyzer); + } + return rewrittenExpr; + } + + private Expr applyRuleOnce(Expr expr, ExprRewriteRule rule, Analyzer analyzer) throws AnalysisException { + Expr rewrittenExpr = rule.apply(expr, analyzer); + if (rewrittenExpr != expr) { + numChanges_++; + } return rewrittenExpr; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java new file mode 100644 index 00000000000000..bdf4ed1f887b07 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java @@ -0,0 +1,449 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.rewrite; + +import org.apache.doris.analysis.Analyzer; +import org.apache.doris.analysis.BinaryPredicate; +import org.apache.doris.analysis.CompoundPredicate; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.InPredicate; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.analysis.SlotRef; +import org.apache.doris.common.AnalysisException; + +import com.google.common.base.Preconditions; +import com.google.common.collect.BoundType; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import com.google.common.collect.TreeRangeSet; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This rule extracts common predicate from multiple disjunctions when it is applied + * recursively bottom-up to a tree of CompoundPredicates. + * There are two common predicate that will be extracted as following: + * 1. Common Factors: (a and b) or (a and c) -> a and (b or c) + * 2. Wide common factors: (1 (1 [[a, b], [a, e, f]] + * 2. Extract common factors: + * @code commonFactorList: [a] + * @code clearExprs: [[b], [e, f]] + * 3. Extract wide common factors: + * @code wideCommonExpr: b' + * @code commonFactorList: [a, b'] + * 4. Construct new expr: + * @return: a and b' and (b or (e and f)) + */ + private Expr extractCommonFactors(List> exprs, Analyzer analyzer) { + if (exprs.size() < 2) { + return null; + } + // 1. remove duplicated elements [[a,a], [a, b], [a,b]] => [[a], [a,b]] + Set> set = new LinkedHashSet<>(); + for (List ex : exprs) { + Set es = new LinkedHashSet<>(); + es.addAll(ex); + set.add(es); + } + List> clearExprs = new ArrayList<>(); + for (Set es : set) { + List el = new ArrayList<>(); + el.addAll(es); + clearExprs.add(el); + } + if (clearExprs.size() == 1) { + return makeCompound(clearExprs.get(0), CompoundPredicate.Operator.AND); + } + + // 2. find duplicate cross the clause + List commonFactorList = new ArrayList<>(clearExprs.get(0)); + for (int i = 1; i < clearExprs.size(); ++i) { + commonFactorList.retainAll(clearExprs.get(i)); + } + boolean isReturnCommonFactorExpr = false; + for (List exprList : clearExprs) { + exprList.removeAll(commonFactorList); + if (exprList.size() == 0) { + // For example, the sql is "where (a = 1) or (a = 1 and B = 2)" + // if "(a = 1)" is extracted as a common factor expression, then the first expression "(a = 1)" has no expression + // other than a common factor expression, and the second expression "(a = 1 and B = 2)" has an expression of "(B = 2)" + // + // In this case, the common factor expression ("a = 1") can be directly used to replace the whole CompoundOrPredicate. + // In Fact, the common factor expression is actually the parent set of expression "(a = 1)" and expression "(a = 1 and B = 2)" + // + // exprList.size() == 0 means one child of CompoundOrPredicate has no expression other than a common factor expression. + isReturnCommonFactorExpr = true; + break; + } + } + if (isReturnCommonFactorExpr) { + Preconditions.checkState(!commonFactorList.isEmpty()); + Expr result = makeCompound(commonFactorList, CompoundPredicate.Operator.AND); + if (LOG.isDebugEnabled()) { + LOG.debug("equal ors: " + result.toSql()); + } + return result; + } + + // 3. find merge cross the clause + if (analyzer.getContext().getSessionVariable().isExtractWideRangeExpr()) { + Expr wideCommonExpr = findWideRangeExpr(clearExprs); + if (wideCommonExpr != null) { + commonFactorList.add(wideCommonExpr); + } + } + + // 4. construct new expr + // rebuild CompoundPredicate if found duplicate predicate will build (predicate) and (.. or ..) predicate in + // step 1: will build (.. or ..) + List remainingOrClause = Lists.newArrayList(); + for (List clearExpr : clearExprs) { + Preconditions.checkState(!clearExpr.isEmpty()); + remainingOrClause.add(makeCompound(clearExpr, CompoundPredicate.Operator.AND)); + } + Expr result = null; + if (CollectionUtils.isNotEmpty(commonFactorList)) { + result = new CompoundPredicate(CompoundPredicate.Operator.AND, + makeCompound(commonFactorList, CompoundPredicate.Operator.AND), + makeCompound(remainingOrClause, CompoundPredicate.Operator.OR)); + result.setPrintSqlInParens(true); + } else { + result = makeCompound(remainingOrClause, CompoundPredicate.Operator.OR); + } + if (LOG.isDebugEnabled()) { + LOG.debug("equal ors: " + result.toSql()); + } + return result; + } + + /** + * The wide range of expr must satisfy two conditions as following: + * 1. the expr is a constant filter for single column in single table. + * 2. the single column of expr must appear in all clauses. + * The expr extracted here is a wider range of expressions, similar to a pre-filtering. + * But pre-filtering does not necessarily mean that it must have a positive impact on performance. + */ + private Expr findWideRangeExpr(List> exprList) { + // 1. construct map + List>> columnNameToRangeList = Lists.newArrayList(); + List> columnNameToInList = Lists.newArrayList(); + OUT_CONJUNCTS: + for (List conjuncts : exprList) { + Map> columnNameToRange = Maps.newHashMap(); + Map columnNameToInPredicate = Maps.newHashMap(); + for (Expr predicate : conjuncts) { + if (!singleColumnPredicate(predicate)) { + continue; + } + SlotRef columnName = (SlotRef) predicate.getChild(0); + if (predicate instanceof BinaryPredicate) { + Range predicateRange = ((BinaryPredicate)predicate).convertToRange(); + if (predicateRange == null){ + continue; + } + Range range = columnNameToRange.get(columnName); + if (range == null) { + range = predicateRange; + } else { + try { + range = range.intersection(predicateRange); + } catch (IllegalArgumentException | ClassCastException e) { + // (a >1 and a < 0) ignore this OR clause + LOG.debug("The range without intersection", e); + continue OUT_CONJUNCTS; + } + } + columnNameToRange.put(columnName, range); + } else if (predicate instanceof InPredicate) { + InPredicate inPredicate = (InPredicate) predicate; + InPredicate intersectInPredicate = columnNameToInPredicate.get(columnName); + if (intersectInPredicate == null) { + intersectInPredicate = new InPredicate(inPredicate.getChild(0), inPredicate.getListChildren(), + inPredicate.isNotIn()); + } else { + intersectInPredicate = intersectInPredicate.intersection((InPredicate) predicate); + } + columnNameToInPredicate.put(columnName, intersectInPredicate); + } + } + columnNameToRangeList.add(columnNameToRange); + columnNameToInList.add(columnNameToInPredicate); + } + + // 2. merge clause + Map> resultRangeMap = Maps.newHashMap(); + for (Map.Entry> entry: columnNameToRangeList.get(0).entrySet()) { + RangeSet rangeSet = TreeRangeSet.create(); + rangeSet.add(entry.getValue()); + resultRangeMap.put(entry.getKey(), rangeSet); + } + for (int i = 1; i < columnNameToRangeList.size(); i++) { + Map> columnNameToRange = columnNameToRangeList.get(i); + resultRangeMap = mergeTwoClauseRange(resultRangeMap, columnNameToRange); + if (resultRangeMap.isEmpty()) { + break; + } + } + Map resultInMap = columnNameToInList.get(0); + for (int i = 1; i < columnNameToRangeList.size(); i++) { + Map columnNameToIn = columnNameToInList.get(i); + resultInMap = mergeTwoClauseIn(resultInMap, columnNameToIn); + if (resultInMap.isEmpty()) { + break; + } + } + + // 3. construct wide range expr + List wideRangeExprList = Lists.newArrayList(); + for (Map.Entry> entry : resultRangeMap.entrySet()) { + Expr wideRangeExpr = rangeSetToCompoundPredicate(entry.getKey(), entry.getValue()); + if (wideRangeExpr != null) { + wideRangeExprList.add(wideRangeExpr); + } + } + wideRangeExprList.addAll(resultInMap.values()); + return makeCompound(wideRangeExprList, CompoundPredicate.Operator.AND); + } + + /** + * An expression that meets the following three conditions will return true: + * 1. single column from single table + * 2. in or binary predicate + * 3. one child of predicate is constant + */ + private boolean singleColumnPredicate(Expr expr) { + List slotRefs = Lists.newArrayList(); + expr.collect(SlotRef.class, slotRefs); + if (slotRefs.size() != 1) { + return false; + } + if (expr instanceof InPredicate) { + InPredicate inPredicate = (InPredicate) expr; + if (!inPredicate.isLiteralChildren()) { + return false; + } + if (inPredicate.isNotIn()) { + return false; + } + if (inPredicate.getChild(0) instanceof SlotRef) { + return true; + } + return false; + } else if (expr instanceof BinaryPredicate) { + BinaryPredicate binaryPredicate = (BinaryPredicate) expr; + if (binaryPredicate.getChild(0) instanceof SlotRef + && binaryPredicate.getChild(1) instanceof LiteralExpr) { + return true; + } + return false; + } else { + return false; + } + } + + /** + * RangeSet1: 1> mergeTwoClauseRange(Map> clause1, + Map> clause2) { + Map> result = Maps.newHashMap(); + for (Map.Entry> clause1Entry: clause1.entrySet()) { + SlotRef columnName = clause1Entry.getKey(); + Range clause2Value = clause2.get(columnName); + if (clause2Value == null) { + continue; + } + try { + clause1Entry.getValue().add(clause2Value); + } catch (ClassCastException e) { + // ignore a >1.0 or a mergeTwoClauseIn(Map clause1, + Map clause2) { + Map result = Maps.newHashMap(); + for (Map.Entry clause1Entry: clause1.entrySet()) { + SlotRef columnName = clause1Entry.getKey(); + InPredicate clause2Value = clause2.get(columnName); + if (clause2Value == null) { + continue; + } + InPredicate union = clause1Entry.getValue().union(clause2Value); + result.put(columnName, union); + } + return result; + } + + /** + * this function only process (a and b and c) or (d and e and f) like clause, + * this function will format this to [[a, b, c], [d, e, f]] + */ + private List> exprFormatting(CompoundPredicate expr) { + List> orExprs = new ArrayList<>(); + for (Expr child : expr.getChildren()) { + if (child instanceof CompoundPredicate) { + CompoundPredicate childCp = (CompoundPredicate) child; + if (childCp.getOp() == CompoundPredicate.Operator.OR) { + orExprs.addAll(exprFormatting(childCp)); + continue; + } else if (childCp.getOp() == CompoundPredicate.Operator.AND) { + orExprs.add(flatAndExpr(child)); + continue; + } + } + orExprs.add(Arrays.asList(child)); + } + return orExprs; + } + + /** + * try to flat and , a and b and c => [a, b, c] + */ + private List flatAndExpr(Expr expr) { + List andExprs = new ArrayList<>(); + if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.AND) { + andExprs.addAll(flatAndExpr(expr.getChild(0))); + andExprs.addAll(flatAndExpr(expr.getChild(1))); + } else { + andExprs.add(expr); + } + return andExprs; + } + + /** + * Rebuild CompoundPredicate, [a, e, f] AND => a and e and f + */ + private Expr makeCompound(List exprs, CompoundPredicate.Operator op) { + if (CollectionUtils.isEmpty(exprs)) { + return null; + } + if (exprs.size() == 1) { + return exprs.get(0); + } + CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1)); + for (int i = 2; i < exprs.size(); ++i) { + result = new CompoundPredicate(op, result.clone(), exprs.get(i)); + } + result.setPrintSqlInParens(true); + return result; + } + + /** + * Convert RangeSet to Compound Predicate + * @param slotRef: + * @param rangeSet: {(1,3), (6,7)} + * @return: (k1>1 and k1<3) or (k1>6 and k1<7) + */ + public Expr rangeSetToCompoundPredicate(SlotRef slotRef, RangeSet rangeSet) { + List compoundList = Lists.newArrayList(); + for (Range range : rangeSet.asRanges()) { + LiteralExpr lowerBound = null; + LiteralExpr upperBound = null; + if (range.hasLowerBound()) { + lowerBound = range.lowerEndpoint(); + } + if (range.hasUpperBound()) { + upperBound = range.upperEndpoint(); + } + if (lowerBound == null && upperBound == null) { + continue; + } + if (lowerBound != null && upperBound != null && lowerBound.equals(upperBound)) { + compoundList.add(new BinaryPredicate(BinaryPredicate.Operator.EQ, slotRef, lowerBound)); + continue; + } + List binaryPredicateList = Lists.newArrayList(); + if (lowerBound != null) { + if (range.lowerBoundType() == BoundType.OPEN) { + binaryPredicateList.add(new BinaryPredicate(BinaryPredicate.Operator.GT, slotRef, lowerBound)); + } else { + binaryPredicateList.add(new BinaryPredicate(BinaryPredicate.Operator.GE, slotRef, lowerBound)); + } + } + if (upperBound !=null) { + if (range.upperBoundType() == BoundType.OPEN) { + binaryPredicateList.add(new BinaryPredicate(BinaryPredicate.Operator.LT, slotRef, upperBound)); + } else { + binaryPredicateList.add(new BinaryPredicate(BinaryPredicate.Operator.LE, slotRef, upperBound)); + } + } + compoundList.add(makeCompound(binaryPredicateList, CompoundPredicate.Operator.AND)); + } + return makeCompound(compoundList, CompoundPredicate.Operator.OR); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java index 0744b4501b0119..aa863a2003c5c7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java @@ -21,14 +21,20 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.jmockit.Deencapsulation; + +import com.google.common.collect.BoundType; import com.google.common.collect.Lists; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import com.google.common.collect.TreeRangeSet; -import org.junit.Assert; -import org.junit.Test; +import java.util.List; import mockit.Expectations; import mockit.Injectable; import mockit.Mocked; +import org.junit.Assert; +import org.junit.Test; public class BinaryPredicateTest { @@ -105,4 +111,16 @@ public void testWrongOperand(@Injectable Expr child0, @Injectable Expr child1) { } catch (AnalysisException e) { } } + + @Test + public void testConvertToRange() { + SlotRef slotRef = new SlotRef(new TableName("db1", "tb1"), "k1"); + LiteralExpr literalExpr = new IntLiteral(1); + BinaryPredicate binaryPredicate = new BinaryPredicate(BinaryPredicate.Operator.LE, slotRef, literalExpr); + Range range = binaryPredicate.convertToRange(); + Assert.assertEquals(literalExpr, range.upperEndpoint()); + Assert.assertEquals(BoundType.CLOSED, range.upperBoundType()); + Assert.assertFalse(range.hasLowerBound()); + } + } \ No newline at end of file diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/InPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/InPredicateTest.java new file mode 100644 index 00000000000000..7ce6f9b737fd46 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/InPredicateTest.java @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.analysis; + +import org.apache.doris.common.AnalysisException; + +import java.util.List; + +import com.clearspring.analytics.util.Lists; +import org.junit.Assert; +import org.junit.Test; + +public class InPredicateTest { + + /* + InPredicate1: k1 in (1,2) + InPredicate2: k1 in (2,3) + Intersection: k1 in (2) + */ + @Test + public void testIntersection() throws AnalysisException { + SlotRef slotRef1 = new SlotRef(new TableName("db1", "tb1"), "k1"); + LiteralExpr literalChild1 = new IntLiteral(1); + LiteralExpr literalChild2 = new IntLiteral(2); + List literalChildren1 = Lists.newArrayList(); + literalChildren1.add(literalChild1); + literalChildren1.add(literalChild2); + InPredicate inPredicate1 = new InPredicate(slotRef1, literalChildren1, false); + + SlotRef slotRef2 = new SlotRef(new TableName("db1", "tb1"), "k1"); + LiteralExpr literalChild3 = new LargeIntLiteral("2"); + LiteralExpr literalChild4 = new LargeIntLiteral("3"); + List literalChildren2 = Lists.newArrayList(); + literalChildren2.add(literalChild3); + literalChildren2.add(literalChild4); + InPredicate inPredicate2 = new InPredicate(slotRef2, literalChildren2, false); + + // check result + InPredicate intersection = inPredicate1.intersection(inPredicate2); + Assert.assertEquals(slotRef1, intersection.getChild(0)); + Assert.assertTrue(intersection.isLiteralChildren()); + Assert.assertEquals(1, intersection.getListChildren().size()); + Assert.assertEquals(literalChild2, intersection.getChild(1)); + + // keep origin predicate + Assert.assertTrue(inPredicate1.isLiteralChildren()); + Assert.assertEquals(2, inPredicate1.getListChildren().size()); + Assert.assertTrue(inPredicate1.contains(literalChild1)); + Assert.assertTrue(inPredicate1.contains(literalChild2)); + Assert.assertTrue(inPredicate2.isLiteralChildren()); + Assert.assertEquals(2, inPredicate2.getListChildren().size()); + Assert.assertTrue(inPredicate2.contains(literalChild3)); + Assert.assertTrue(inPredicate2.contains(literalChild4)); + } + + /* + InPredicate1: k1 in (1,2) + InPredicate2: k1 in (1) + Union: k1 in (1,2) + */ + @Test + public void testUnion() throws AnalysisException { + SlotRef slotRef1 = new SlotRef(new TableName("db1", "tb1"), "k1"); + LiteralExpr literalChild1 = new IntLiteral(1); + LiteralExpr literalChild2 = new IntLiteral(2); + List literalChildren1 = Lists.newArrayList(); + literalChildren1.add(literalChild1); + literalChildren1.add(literalChild2); + InPredicate inPredicate1 = new InPredicate(slotRef1, literalChildren1, false); + + SlotRef slotRef2 = new SlotRef(new TableName("db1", "tb1"), "k1"); + LiteralExpr literalChild3 = new LargeIntLiteral("1"); + List literalChildren2 = Lists.newArrayList(); + literalChildren2.add(literalChild3); + InPredicate inPredicate2 = new InPredicate(slotRef2, literalChildren2, false); + + // check result + InPredicate union = inPredicate1.union(inPredicate2); + Assert.assertEquals(slotRef1, union.getChild(0)); + Assert.assertTrue(union.isLiteralChildren()); + Assert.assertEquals(2, union.getListChildren().size()); + Assert.assertTrue(union.getListChildren().contains(literalChild1)); + Assert.assertTrue(union.getListChildren().contains(literalChild2)); + + // keep origin predicate + Assert.assertTrue(inPredicate1.isLiteralChildren()); + Assert.assertEquals(2, inPredicate1.getListChildren().size()); + Assert.assertTrue(inPredicate1.contains(literalChild1)); + Assert.assertTrue(inPredicate1.contains(literalChild2)); + Assert.assertTrue(inPredicate2.isLiteralChildren()); + Assert.assertEquals(1, inPredicate2.getListChildren().size()); + Assert.assertTrue(inPredicate2.contains(literalChild3)); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/RangeCompareTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/RangeCompareTest.java new file mode 100644 index 00000000000000..fd76a119f31cd7 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/RangeCompareTest.java @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the Licenser at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.analysis; + +import org.apache.doris.catalog.Type; +import org.apache.doris.common.AnalysisException; + +import com.google.common.collect.BoundType; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import com.google.common.collect.TreeRangeSet; + +import java.util.Set; + +import org.junit.Assert; +import org.junit.Test; + +public class RangeCompareTest { + + /* + Range1: a>=1 + Range2: a<2 + Intersection: 1<=a<2 + */ + @Test + public void testIntersection() { + LiteralExpr lowerBoundOfRange1 = new IntLiteral(1); + Range range1 = Range.atLeast(lowerBoundOfRange1); + LiteralExpr upperBoundOfRange2 = new IntLiteral(2); + Range range2 = Range.lessThan(upperBoundOfRange2); + Range intersectionRange = range1.intersection(range2); + Assert.assertTrue(intersectionRange.hasLowerBound()); + Assert.assertEquals(lowerBoundOfRange1, intersectionRange.lowerEndpoint()); + Assert.assertEquals(BoundType.CLOSED, intersectionRange.lowerBoundType()); + Assert.assertTrue(intersectionRange.hasUpperBound()); + Assert.assertEquals(upperBoundOfRange2, intersectionRange.upperEndpoint()); + Assert.assertEquals(BoundType.OPEN, intersectionRange.upperBoundType()); + } + + /* + Range1: a>1 + Range2: a<0 + Intersection: null + */ + @Test + public void testWithoutIntersection() { + LiteralExpr lowerBoundOfRange1 = new IntLiteral(1); + Range range1 = Range.greaterThan(lowerBoundOfRange1); + LiteralExpr upperBoundOfRange2 = new IntLiteral(1); + Range range2 = Range.lessThan(upperBoundOfRange2); + try { + Range intersectionRange = range1.intersection(range2); + Assert.fail(); + } catch (IllegalArgumentException e) { + System.out.println(e); + } + } + + // Range1: a>=1 + // Range2: a<0.1 + // Intersection: null + @Test + public void testIntersectionInvalidRange() throws AnalysisException { + LiteralExpr lowerBoundOfRange1 = new IntLiteral(1); + Range range1 = Range.atLeast(lowerBoundOfRange1); + LiteralExpr upperBoundOfRange2 = new DecimalLiteral("0.1"); + Range range2 = Range.lessThan(upperBoundOfRange2); + try { + Range intersectionRange = range1.intersection(range2); + Assert.fail(); + } catch (IllegalArgumentException e) { + System.out.println(e); + } + } + + // Range1: a>=3.0 + // Range2: a<6 + // Intersection: 3.0<=a<=6 + @Test + public void testIntersectionWithDifferentType() throws AnalysisException { + LiteralExpr lowerBoundOfRange1 = new DecimalLiteral("3.0"); + Range range1 = Range.atLeast(lowerBoundOfRange1); + LiteralExpr upperBoundOfRange2 = new IntLiteral(6); + Range range2 = Range.lessThan(upperBoundOfRange2); + try { + Range intersectionRange = range1.intersection(range2); + Assert.assertEquals(lowerBoundOfRange1, intersectionRange.lowerEndpoint()); + Assert.assertEquals(upperBoundOfRange2, intersectionRange.upperEndpoint()); + } catch (ClassCastException e) { + Assert.fail(e.getMessage()); + } + } + + // Range1: a>=3.0 + // Range2: a range1 = Range.atLeast(lowerBoundOfRange1); + LiteralExpr upperBoundOfRange2 = new BoolLiteral(true); + Range range2 = Range.lessThan(upperBoundOfRange2); + try { + range1.intersection(range2); + Assert.fail(); + } catch (IllegalArgumentException e) { + System.out.println(e); + } + } + + /* + Range1: 15 + Intersection: 5 range1 = Range.range(lowerBoundOfRange1, BoundType.OPEN, upperBoundOfRange1, BoundType.OPEN); + LiteralExpr lowerBoundOfRange2 = new IntLiteral(5); + Range range2 = Range.greaterThan(lowerBoundOfRange2); + Range intersection = range1.intersection(range2); + Assert.assertEquals(lowerBoundOfRange2, intersection.lowerEndpoint()); + Assert.assertEquals(BoundType.OPEN, intersection.lowerBoundType()); + Assert.assertEquals(upperBoundOfRange1, intersection.upperEndpoint()); + Assert.assertEquals(BoundType.OPEN, intersection.upperBoundType()); + } + + /* + Range1: a>=1 + Range2: a<0 + Merge Range Set: a >=1, a <0 + */ + @Test + public void testMergeRangeWithoutIntersection() { + LiteralExpr lowerBoundOfRange1 = new IntLiteral(1); + Range range1 = Range.atLeast(lowerBoundOfRange1); + LiteralExpr upperBoundOfRange2 = new IntLiteral(0); + Range range2 = Range.lessThan(upperBoundOfRange2); + RangeSet rangeSet = TreeRangeSet.create(); + rangeSet.add(range1); + rangeSet.add(range2); + Set> rangeList = rangeSet.asRanges(); + Assert.assertEquals(2, rangeList.size()); + Assert.assertTrue(rangeList.contains(range1)); + Assert.assertTrue(rangeList.contains(range2)); + } + + /* + Range1: 1<=a<=10 + Range2: 10<=a<=20 + Merge Range Set: 1<=a<=20 + */ + @Test + public void testMergeRangeWithIntersection1() { + LiteralExpr lowerBoundOfRange1 = new IntLiteral(1); + LiteralExpr upperBoundOfRange1 = new IntLiteral(10); + Range range1 = Range.range(lowerBoundOfRange1, BoundType.CLOSED, upperBoundOfRange1, BoundType.CLOSED); + LiteralExpr lowerBoundOfRange2 = new IntLiteral(10); + LiteralExpr upperBoundOfRange2 = new IntLiteral(20); + Range range2 = Range.range(lowerBoundOfRange2, BoundType.CLOSED, upperBoundOfRange2, BoundType.CLOSED); + RangeSet rangeSet = TreeRangeSet.create(); + rangeSet.add(range1); + rangeSet.add(range2); + Set> rangeList = rangeSet.asRanges(); + Assert.assertEquals(1, rangeList.size()); + Range intersection = rangeList.iterator().next(); + Assert.assertEquals(lowerBoundOfRange1, intersection.lowerEndpoint()); + Assert.assertEquals(upperBoundOfRange2, intersection.upperEndpoint()); + Assert.assertEquals(BoundType.CLOSED, intersection.lowerBoundType()); + Assert.assertEquals(BoundType.CLOSED, intersection.upperBoundType()); + } + + /* + Range1: 1<=a<10 + Range2: 5<=a<=20 + Merge Range Set: 1<=a<=20 + */ + @Test + public void testMergeRangeWithIntersection2() { + LiteralExpr lowerBoundOfRange1 = new IntLiteral(1); + LiteralExpr upperBoundOfRange1 = new IntLiteral(10); + Range range1 = Range.range(lowerBoundOfRange1, BoundType.CLOSED, upperBoundOfRange1, BoundType.OPEN); + LiteralExpr lowerBoundOfRange2 = new IntLiteral(5); + LiteralExpr upperBoundOfRange2 = new IntLiteral(20); + Range range2 = Range.range(lowerBoundOfRange2, BoundType.CLOSED, upperBoundOfRange2, BoundType.CLOSED); + RangeSet rangeSet = TreeRangeSet.create(); + rangeSet.add(range1); + rangeSet.add(range2); + Set> rangeList = rangeSet.asRanges(); + Assert.assertEquals(1, rangeList.size()); + Range intersection = rangeList.iterator().next(); + Assert.assertEquals(lowerBoundOfRange1, intersection.lowerEndpoint()); + Assert.assertEquals(upperBoundOfRange2, intersection.upperEndpoint()); + Assert.assertEquals(BoundType.CLOSED, intersection.lowerBoundType()); + Assert.assertEquals(BoundType.CLOSED, intersection.upperBoundType()); + } + + /* + Range1: a<=3.0 + Range2: a>2021-01-01 + Merge: ClassCastException + */ + @Test + public void testMergeRangeWithDifferentType() throws AnalysisException { + LiteralExpr lowerBoundOfRange1 = new DecimalLiteral("3.0"); + Range range1 = Range.lessThan(lowerBoundOfRange1); + LiteralExpr upperBoundOfRange2 = new DateLiteral("2021-01-01", Type.DATE); + Range range2 = Range.atLeast(upperBoundOfRange2); + RangeSet rangeSet = TreeRangeSet.create(); + rangeSet.add(range1); + try { + rangeSet.add(range2); + Assert.fail(); + } catch (ClassCastException e) { + System.out.println(e); + } + } + +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java index 20bbb1acc373aa..17b107cd413f89 100755 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java @@ -29,19 +29,18 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; - import java.util.List; import java.util.Set; import java.util.UUID; import mockit.Mock; import mockit.MockUp; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; public class SelectStmtTest { private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/"; @@ -270,16 +269,22 @@ public void testDeduplicateOrs() throws Exception { " );"; SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql, ctx); stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); - String rewritedFragment1 = "((`t1`.`k2` = `t4`.`k2` AND `t3`.`k3` = `t1`.`k3`) " + + String rewritedFragment1 = "((`t1`.`k2` = `t4`.`k2` AND `t3`.`k3` = `t1`.`k3` " + + "AND ((`t3`.`k1` = 'D' OR `t3`.`k1` = 'S' OR `t3`.`k1` = 'W') " + + "AND (`t4`.`k3` = '2 yr Degree' OR `t4`.`k3` = 'Advanced Degree' OR `t4`.`k3` = 'Secondary') " + + "AND (`t4`.`k4` = 1 OR `t4`.`k4` = 3))) " + "AND ((`t3`.`k1` = 'D' AND `t4`.`k3` = '2 yr Degree' " + "AND `t1`.`k4` >= 100.00 AND `t1`.`k4` <= 150.00 AND `t4`.`k4` = 3) " + "OR (`t3`.`k1` = 'S' AND `t4`.`k3` = 'Secondary' AND `t1`.`k4` >= 50.00 " + "AND `t1`.`k4` <= 100.00 AND `t4`.`k4` = 1) OR (`t3`.`k1` = 'W' AND `t4`.`k3` = 'Advanced Degree' " + "AND `t1`.`k4` >= 150.00 AND `t1`.`k4` <= 200.00 AND `t4`.`k4` = 1)))"; - String rewritedFragment2 = "((`t1`.`k1` = `t5`.`k1` AND `t5`.`k2` = 'United States') " + + String rewritedFragment2 = "((`t1`.`k1` = `t5`.`k1` AND `t5`.`k2` = 'United States' " + + "AND ((`t1`.`k4` >= 50 AND `t1`.`k4` <= 300) " + + "AND `t5`.`k3` IN ('CO', 'IL', 'MN', 'OH', 'MT', 'NM', 'TX', 'MO', 'MI'))) " + "AND ((`t5`.`k3` IN ('CO', 'IL', 'MN') AND `t1`.`k4` >= 100 AND `t1`.`k4` <= 200) " + "OR (`t5`.`k3` IN ('OH', 'MT', 'NM') AND `t1`.`k4` >= 150 AND `t1`.`k4` <= 300) OR (`t5`.`k3` IN " + "('TX', 'MO', 'MI') AND `t1`.`k4` >= 50 AND `t1`.`k4` <= 250)))"; + System.out.println(stmt.toSql()); Assert.assertTrue(stmt.toSql().contains(rewritedFragment1)); Assert.assertTrue(stmt.toSql().contains(rewritedFragment2)); diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java new file mode 100644 index 00000000000000..e1e7838c6ade20 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +package org.apache.doris.rewrite; + +import org.apache.doris.common.FeConstants; +import org.apache.doris.utframe.DorisAssert; +import org.apache.doris.utframe.UtFrameUtils; + +import org.apache.commons.lang3.StringUtils; + +import java.util.UUID; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class ExtractCommonFactorsRuleFunctionTest { + private static String baseDir = "fe"; + private static String runningDir = baseDir + "/mocked/ExtractCommonFactorsRuleFunctionTest/" + + UUID.randomUUID().toString() + "/"; + private static DorisAssert dorisAssert; + private static final String DB_NAME = "db1"; + private static final String TABLE_NAME_1 = "tb1"; + private static final String TABLE_NAME_2 = "tb2"; + + @BeforeClass + public static void beforeClass() throws Exception { + FeConstants.default_scheduler_interval_millisecond = 10; + FeConstants.runningUnitTest = true; + UtFrameUtils.createMinDorisCluster(runningDir); + dorisAssert = new DorisAssert(); + dorisAssert.withDatabase(DB_NAME).useDatabase(DB_NAME); + String createTableSQL = "create table " + DB_NAME + "." + TABLE_NAME_1 + + " (k1 int, k2 int) " + + "distributed by hash(k1) buckets 3 properties('replication_num' = '1');"; + dorisAssert.withTable(createTableSQL); + createTableSQL = "create table " + DB_NAME + "." + TABLE_NAME_2 + + " (k1 int, k2 int) " + + "distributed by hash(k1) buckets 3 properties('replication_num' = '1');"; + dorisAssert.withTable(createTableSQL); + } + + @AfterClass + public static void afterClass() throws Exception { + UtFrameUtils.cleanDorisFeDir(baseDir); + } + + @Test + public void testWithoutRewritten() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 =1) or (tb2.k2=1)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testCommonFactors() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1=tb2.k1 and tb1.k2 =1) or (tb1.k1=tb2.k1 and tb2.k2=1)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("HASH JOIN")); + Assert.assertEquals(1, StringUtils.countMatches(planString, "`tb1`.`k1` = `tb2`.`k1`")); + } + + @Test + public void testWideCommonFactorsWithEqualPredicate() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1=1 and tb2.k1=1) or (tb1.k1 =2 and tb2.k1=2)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("(`tb1`.`k1` = 1 OR `tb1`.`k1` = 2)")); + Assert.assertTrue(planString.contains("(`tb2`.`k1` = 1 OR `tb2`.`k1` = 2)")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWithoutWideCommonFactorsWhenInfinityRangePredicate() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1>1 and tb2.k1=1) or (tb1.k1 <2 and tb2.k2=2)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertFalse(planString.contains("(`tb1`.`k1` > 1 OR `tb1`.`k1` < 2)")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWideCommonFactorsWithMergeRangePredicate() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 between 1 and 3 and tb2.k1=1) or (tb1.k1 <2 and tb2.k2=2)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("`tb1`.`k1` <= 3")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWideCommonFactorsWithIntersectRangePredicate() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 >1 and tb1.k1 <3 and tb1.k1 <5 and tb2.k1=1) " + + "or (tb1.k1 <2 and tb2.k2=2)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("`tb1`.`k1` < 5")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWideCommonFactorsWithDuplicateRangePredicate() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 >1 and tb1.k1 >1 and tb1.k1 <5 and tb2.k1=1) " + + "or (tb1.k1 <2 and tb2.k2=2)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("`tb1`.`k1` < 5")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWideCommonFactorsWithInPredicate() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 in (1) and tb2.k1 in(1)) " + + "or (tb1.k1 in(2) and tb2.k1 in(2))"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("`tb1`.`k1` IN (1, 2)")); + Assert.assertTrue(planString.contains("`tb2`.`k1` IN (1, 2)")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWideCommonFactorsWithDuplicateInPredicate() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 in (1,2) and tb2.k1 in(1,2)) " + + "or (tb1.k1 in(3) and tb2.k1 in(2))"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("`tb1`.`k1` IN (1, 2, 3)")); + Assert.assertTrue(planString.contains("`tb2`.`k1` IN (1, 2)")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWideCommonFactorsWithRangeAndIn() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 between 1 and 3 and tb2.k1 in(1,2)) " + + "or (tb1.k1 between 2 and 4 and tb2.k1 in(3))"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("`tb1`.`k1` >= 1")); + Assert.assertTrue(planString.contains("`tb1`.`k1` <= 4")); + Assert.assertTrue(planString.contains("`tb2`.`k1` IN (1, 2, 3)")); + Assert.assertTrue(planString.contains("CROSS JOIN")); + } + + @Test + public void testWideCommonFactorsAndCommonFactors() throws Exception { + String query = "select * from tb1, tb2 where (tb1.k1 between 1 and 3 and tb1.k1=tb2.k1) " + + "or (tb1.k1=tb2.k1 and tb1.k1 between 2 and 4)"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("`tb1`.`k1` >= 1")); + Assert.assertTrue(planString.contains("`tb1`.`k1` <= 4")); + Assert.assertTrue(planString.contains("`tb1`.`k1` = `tb2`.`k1`")); + Assert.assertTrue(planString.contains("HASH JOIN")); + } + + // TPC-H Q19 + @Test + public void testComplexQuery() throws Exception { + String createTableSQL = "CREATE TABLE `lineitem` (\n" + + " `l_orderkey` int(11) NOT NULL COMMENT \"\",\n" + + " `l_partkey` int(11) NOT NULL COMMENT \"\",\n" + + " `l_suppkey` int(11) NOT NULL COMMENT \"\",\n" + + " `l_linenumber` int(11) NOT NULL COMMENT \"\",\n" + + " `l_quantity` decimal(15, 2) NOT NULL COMMENT \"\",\n" + + " `l_extendedprice` decimal(15, 2) NOT NULL COMMENT \"\",\n" + + " `l_discount` decimal(15, 2) NOT NULL COMMENT \"\",\n" + + " `l_tax` decimal(15, 2) NOT NULL COMMENT \"\",\n" + + " `l_returnflag` char(1) NOT NULL COMMENT \"\",\n" + + " `l_linestatus` char(1) NOT NULL COMMENT \"\",\n" + + " `l_shipdate` date NOT NULL COMMENT \"\",\n" + + " `l_commitdate` date NOT NULL COMMENT \"\",\n" + + " `l_receiptdate` date NOT NULL COMMENT \"\",\n" + + " `l_shipinstruct` char(25) NOT NULL COMMENT \"\",\n" + + " `l_shipmode` char(10) NOT NULL COMMENT \"\",\n" + + " `l_comment` varchar(44) NOT NULL COMMENT \"\"\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(`l_orderkey`)\n" + + "COMMENT \"OLAP\"\n" + + "DISTRIBUTED BY HASH(`l_orderkey`) BUCKETS 2\n" + + "PROPERTIES (\n" + + "\"replication_num\" = \"1\",\n" + + "\"in_memory\" = \"false\",\n" + + "\"storage_format\" = \"V2\"\n" + + ");"; + dorisAssert.withTable(createTableSQL); + createTableSQL = "CREATE TABLE `part` (\n" + + " `p_partkey` int(11) NOT NULL COMMENT \"\",\n" + + " `p_name` varchar(55) NOT NULL COMMENT \"\",\n" + + " `p_mfgr` char(25) NOT NULL COMMENT \"\",\n" + + " `p_brand` char(10) NOT NULL COMMENT \"\",\n" + + " `p_type` varchar(25) NOT NULL COMMENT \"\",\n" + + " `p_size` int(11) NOT NULL COMMENT \"\",\n" + + " `p_container` char(10) NOT NULL COMMENT \"\",\n" + + " `p_retailprice` decimal(15, 2) NOT NULL COMMENT \"\",\n" + + " `p_comment` varchar(23) NOT NULL COMMENT \"\"\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(`p_partkey`)\n" + + "COMMENT \"OLAP\"\n" + + "DISTRIBUTED BY HASH(`p_partkey`) BUCKETS 2\n" + + "PROPERTIES (\n" + + "\"replication_num\" = \"1\",\n" + + "\"in_memory\" = \"false\",\n" + + "\"storage_format\" = \"V2\"\n" + + ");"; + dorisAssert.withTable(createTableSQL); + String query = "select sum(l_extendedprice* (1 - l_discount)) as revenue " + + "from lineitem, part " + + "where ( p_partkey = l_partkey and p_brand = 'Brand#11' " + + "and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') " + + "and l_quantity >= 9 and l_quantity <= 9 + 10 " + + "and p_size between 1 and 5 and l_shipmode in ('AIR', 'AIR REG') " + + "and l_shipinstruct = 'DELIVER IN PERSON' ) " + + "or ( p_partkey = l_partkey and p_brand = 'Brand#21' " + + "and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') " + + "and l_quantity >= 20 and l_quantity <= 20 + 10 " + + "and p_size between 1 and 10 and l_shipmode in ('AIR', 'AIR REG') " + + "and l_shipinstruct = 'DELIVER IN PERSON' ) " + + "or ( p_partkey = l_partkey and p_brand = 'Brand#32' " + + "and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') " + + "and l_quantity >= 26 and l_quantity <= 26 + 10 " + + "and p_size between 1 and 15 and l_shipmode in ('AIR', 'AIR REG') " + + "and l_shipinstruct = 'DELIVER IN PERSON' )"; + String planString = dorisAssert.query(query).explainQuery(); + Assert.assertTrue(planString.contains("HASH JOIN")); + Assert.assertTrue(planString.contains("`l_partkey` = `p_partkey`")); + Assert.assertTrue(planString.contains("`l_shipmode` IN ('AIR', 'AIR REG')")); + Assert.assertTrue(planString.contains("`l_shipinstruct` = 'DELIVER IN PERSON'")); + Assert.assertTrue(planString.contains("((`l_quantity` >= 9 AND `l_quantity` <= 19) " + + "OR (`l_quantity` >= 20 AND `l_quantity` <= 36))")); + Assert.assertTrue(planString.contains("`p_size` >= 1")); + Assert.assertTrue(planString.contains("(`p_brand` = 'Brand#11' OR `p_brand` = 'Brand#21' OR `p_brand` = 'Brand#32')")); + Assert.assertTrue(planString.contains("`p_size` <= 15")); + Assert.assertTrue(planString.contains("`p_container` IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG', 'MED BAG', " + + "'MED BOX', 'MED PKG', 'MED PACK', 'LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')")); + } + +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleTest.java new file mode 100644 index 00000000000000..d53705666193af --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleTest.java @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +package org.apache.doris.rewrite; + +import org.apache.doris.analysis.CompoundPredicate; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.InPredicate; +import org.apache.doris.analysis.IntLiteral; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.analysis.SlotRef; +import org.apache.doris.analysis.TableName; +import org.apache.doris.common.jmockit.Deencapsulation; + +import com.google.common.collect.BoundType; +import com.google.common.collect.Maps; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import com.google.common.collect.TreeRangeSet; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.clearspring.analytics.util.Lists; +import org.junit.Assert; +import org.junit.Test; + +public class ExtractCommonFactorsRuleTest { + + // Input: k1 in (k2, 1) + // Result: false + @Test + public void testSingleColumnPredicateInColumn() { + SlotRef child0 = new SlotRef(new TableName("db1", "tb1"), "k1"); + SlotRef inColumn = new SlotRef(new TableName("db1", "tb1"), "k2"); + IntLiteral intLiteral = new IntLiteral(1); + List inExprList = Lists.newArrayList(); + inExprList.add(inColumn); + inExprList.add(intLiteral); + InPredicate inPredicate = new InPredicate(child0, inExprList, false); + ExtractCommonFactorsRule extractCommonFactorsRule = new ExtractCommonFactorsRule(); + boolean result = Deencapsulation.invoke(extractCommonFactorsRule, "singleColumnPredicate", inPredicate); + Assert.assertFalse(result); + } + + // Clause1: 1k2 + // Clause2: 2 k1RangeSet1 = TreeRangeSet.create(); + k1RangeSet1.add(k1Range1); + RangeSet k2RangeSet = TreeRangeSet.create(); + k2RangeSet.add(k2Range); + Map> clause1 = Maps.newHashMap(); + clause1.put(k1SlotRef, k1RangeSet1); + clause1.put(k2SlotRef, k2RangeSet); + // Clause2 + Range k1Range2 = Range.range(new IntLiteral(2), BoundType.OPEN, new IntLiteral(4), BoundType.OPEN); + Map> clause2 = Maps.newHashMap(); + clause2.put(k1SlotRef, k1Range2); + + ExtractCommonFactorsRule extractCommonFactorsRule = new ExtractCommonFactorsRule(); + Map> result = Deencapsulation.invoke(extractCommonFactorsRule, + "mergeTwoClauseRange", clause1, clause2); + Assert.assertEquals(1, result.size()); + Assert.assertTrue(result.containsKey(k1SlotRef)); + Set> k1ResultRangeSet = result.get(k1SlotRef).asRanges(); + Assert.assertEquals(1, k1ResultRangeSet.size()); + Range k1ResultRange = k1ResultRangeSet.iterator().next(); + Assert.assertEquals(new IntLiteral(1), k1ResultRange.lowerEndpoint()); + Assert.assertEquals(new IntLiteral(4), k1ResultRange.upperEndpoint()); + } + + // Clause1: k1 in (1), k2 in (1) + // Clause2: k1 in (1, 2) + // Result: k1 in (1, 2) + @Test + public void testMergeTwoClauseIn() { + // Clause1 + SlotRef k1SlotRef = new SlotRef(new TableName("db1", "tb1"), "k1"); + SlotRef k2SlotRef = new SlotRef(new TableName("db1", "tb1"), "k2"); + IntLiteral intLiteral1 = new IntLiteral(1); + IntLiteral intLiteral2 = new IntLiteral(2); + List k1Values1 = Lists.newArrayList(); + k1Values1.add(intLiteral1); + List k2Values1 = Lists.newArrayList(); + k2Values1.add(intLiteral1); + InPredicate k1InPredicate1 = new InPredicate(k1SlotRef, k1Values1, false); + InPredicate k2InPredicate = new InPredicate(k2SlotRef, k2Values1, false); + Map clause1 = Maps.newHashMap(); + clause1.put(k1SlotRef, k1InPredicate1); + clause1.put(k2SlotRef, k2InPredicate); + // Clause2 + List k1Values2 = Lists.newArrayList(); + k1Values2.add(intLiteral1); + k1Values2.add(intLiteral2); + InPredicate k1InPredicate2 = new InPredicate(k1SlotRef, k1Values2, false); + Map clause2 = Maps.newHashMap(); + clause2.put(k1SlotRef, k1InPredicate2); + + ExtractCommonFactorsRule extractCommonFactorsRule = new ExtractCommonFactorsRule(); + Map result = Deencapsulation.invoke(extractCommonFactorsRule, + "mergeTwoClauseIn", clause1, clause2); + Assert.assertEquals(1, result.size()); + Assert.assertTrue(result.containsKey(k1SlotRef)); + InPredicate k1Result = result.get(k1SlotRef); + Assert.assertEquals(2, k1Result.getListChildren().size()); + List k1ResultValues = k1Result.getListChildren(); + k1ResultValues.contains(intLiteral1); + k1ResultValues.contains(intLiteral2); + } + + // RangeSet: {(1,3], (6,7)} + // SlotRef: k1 + // Result: (k1>1 and k1<=3) or (k1>6 and k1<7) + @Test + public void testRangeSetToCompoundPredicate() { + Range range1 = Range.range(new IntLiteral(1), BoundType.OPEN, new IntLiteral(3), BoundType.CLOSED); + Range range2 = Range.range(new IntLiteral(6), BoundType.OPEN, new IntLiteral(7), BoundType.OPEN); + RangeSet rangeSet = TreeRangeSet.create(); + rangeSet.add(range1); + rangeSet.add(range2); + SlotRef slotRef = new SlotRef(new TableName("db1", "tb1"), "k1"); + + ExtractCommonFactorsRule extractCommonFactorsRule = new ExtractCommonFactorsRule(); + Expr result = Deencapsulation.invoke(extractCommonFactorsRule, + "rangeSetToCompoundPredicate", slotRef, rangeSet); + Assert.assertTrue(result instanceof CompoundPredicate); + CompoundPredicate compoundPredicate = (CompoundPredicate) result; + Assert.assertEquals(CompoundPredicate.Operator.OR, compoundPredicate.getOp()); + Assert.assertTrue(compoundPredicate.getChild(0) instanceof CompoundPredicate); + Assert.assertTrue(compoundPredicate.getChild(1) instanceof CompoundPredicate); + CompoundPredicate clause1 = (CompoundPredicate) compoundPredicate.getChild(0); + CompoundPredicate clause2 = (CompoundPredicate) compoundPredicate.getChild(1); + Assert.assertEquals(CompoundPredicate.Operator.AND, clause1.getOp()); + Assert.assertEquals(CompoundPredicate.Operator.AND, clause2.getOp()); + } +}