Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,7 @@ functionCallExpression
: functionIdentifier
LEFT_PAREN (
(DISTINCT|ALL)?
(LEFT_BRACKET identifier RIGHT_BRACKET)?
arguments+=expression (COMMA arguments+=expression)*
(ORDER BY sortItem (COMMA sortItem)*)?
)? RIGHT_PAREN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
public class UnboundFunction extends Function implements Unbound, PropagateNullable {
private final String dbName;
private final boolean isDistinct;
private final boolean isSkew;
// for create view stmt, the start and end position of the function string in original sql
private final Optional<Pair<Integer, Integer>> indexInSqlString;
// the start and end position of the function string in original sql
Expand Down Expand Up @@ -69,29 +70,31 @@ public FunctionIndexInSql indexInQueryPart(int offset) {
}

public UnboundFunction(String name, List<Expression> arguments) {
this(null, name, false, arguments, Optional.empty(), Optional.empty());
this(null, name, false, arguments, false, Optional.empty(), Optional.empty());
}

public UnboundFunction(String dbName, String name, List<Expression> arguments) {
this(dbName, name, false, arguments, Optional.empty(), Optional.empty());
this(dbName, name, false, arguments, false, Optional.empty(), Optional.empty());
}

public UnboundFunction(String name, boolean isDistinct, List<Expression> arguments) {
this(null, name, isDistinct, arguments, Optional.empty(), Optional.empty());
this(null, name, isDistinct, arguments, false, Optional.empty(), Optional.empty());
}

public UnboundFunction(String dbName, String name, boolean isDistinct, List<Expression> arguments) {
this(dbName, name, isDistinct, arguments, Optional.empty(), Optional.empty());
public UnboundFunction(String dbName, String name, boolean isDistinct, List<Expression> arguments, boolean isSkew) {
this(dbName, name, isDistinct, arguments, isSkew, Optional.empty(), Optional.empty());
}

/**UnboundFunction*/
public UnboundFunction(String dbName, String name, boolean isDistinct,
List<Expression> arguments, Optional<FunctionIndexInSql> functionIndexInSql,
List<Expression> arguments, boolean isSkew, Optional<FunctionIndexInSql> functionIndexInSql,
Optional<Pair<Integer, Integer>> indexInSqlString) {
super(name, arguments);
this.dbName = dbName;
this.isDistinct = isDistinct;
this.functionIndexInSql = functionIndexInSql;
this.indexInSqlString = indexInSqlString;
this.isSkew = isSkew;
}

@Override
Expand All @@ -110,6 +113,10 @@ public boolean isDistinct() {
return isDistinct;
}

public boolean isSkew() {
return isSkew;
}

public List<Expression> getArguments() {
return children();
}
Expand All @@ -135,15 +142,17 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

@Override
public UnboundFunction withChildren(List<Expression> children) {
return new UnboundFunction(dbName, getName(), isDistinct, children, functionIndexInSql, indexInSqlString);
return new UnboundFunction(dbName, getName(), isDistinct, children, isSkew, functionIndexInSql,
indexInSqlString);
}

public Optional<FunctionIndexInSql> getFunctionIndexInSql() {
return functionIndexInSql;
}

public UnboundFunction withIndexInSqlString(Optional<FunctionIndexInSql> functionIndexInSql) {
return new UnboundFunction(dbName, getName(), isDistinct, children, functionIndexInSql, indexInSqlString);
return new UnboundFunction(dbName, getName(), isDistinct, children, isSkew, functionIndexInSql,
indexInSqlString);
}

@Override
Expand All @@ -167,7 +176,7 @@ public int computeHashCode() {
}

public UnboundFunction withIndexInSql(Pair<Integer, Integer> index) {
return new UnboundFunction(dbName, getName(), isDistinct, children, functionIndexInSql,
return new UnboundFunction(dbName, getName(), isDistinct, children, isSkew, functionIndexInSql,
Optional.ofNullable(index));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.rewrite.CostBasedRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteJob;
import org.apache.doris.nereids.rules.JoinSplitForNullSkew;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
Expand Down Expand Up @@ -325,7 +326,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
bottomUp(new EliminateEmptyRelation()),
// when union has empty relation child and constantExprsList is not empty,
// after EliminateEmptyRelation, project can be pushed into union
topDown(new PushProjectIntoUnion())
topDown(new PushProjectIntoUnion()),
costBased(topDown(new JoinSplitForNullSkew()))
),
topic("infer In-predicate from Or-predicate",
topDown(new InferInPredicateFromOr())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2692,6 +2692,7 @@ public Expression visitFunctionCallExpression(DorisParser.FunctionCallExpression
params.addAll(visit(ctx.expression(), Expression.class));
List<OrderKey> orderKeys = visit(ctx.sortItem(), OrderKey.class);
params.addAll(orderKeys.stream().map(OrderExpression::new).collect(Collectors.toList()));
boolean isSkew = ctx.identifier() != null && ctx.identifier().getText().equalsIgnoreCase("skew");

List<UnboundStar> unboundStars = ExpressionUtils.collectAll(params, UnboundStar.class::isInstance);
if (!unboundStars.isEmpty()) {
Expand All @@ -2717,7 +2718,7 @@ public Expression visitFunctionCallExpression(DorisParser.FunctionCallExpression
if (ctx.functionIdentifier().dbName != null) {
dbName = ctx.functionIdentifier().dbName.getText();
}
UnboundFunction function = new UnboundFunction(dbName, functionName, isDistinct, params);
UnboundFunction function = new UnboundFunction(dbName, functionName, isDistinct, params, isSkew);
if (ctx.windowSpec() != null) {
if (isDistinct) {
throw new ParseException("DISTINCT not allowed in analytic function: " + functionName, ctx);
Expand Down Expand Up @@ -4182,7 +4183,7 @@ public Object visitCallProcedure(CallProcedureContext ctx) {
.<Expression>map(this::typedVisit)
.collect(ImmutableList.toImmutableList());
UnboundFunction unboundFunction = new UnboundFunction(procedureName.getDbName(), procedureName.getName(),
true, arguments);
true, arguments, false);
return new CallCommand(unboundFunction, getOriginSql(ctx));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules;

import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;

/**
* LogicalLeftOuterJoin(hashConjuncts:t1.a=t2.a)
* +--Plan1(output:t1.a)
* +--Plan2(output:t2.a)
* ->
* LogicalUnion
* +--LogicalProject
* +--LogicalFilter(t1.a is null)
* +--Plan1
* +--LogicalLeftOuterJoin(t1.a=t2.a)
* +--LogicalFilter(t1.a is not null)
* +--Plan1
* +--Plan2
*
* LogicalRightOuterJoin(hashConjuncts:t1.a=t2.a)
* +--Plan1(output:t1.a)
* +--Plan2(output:t2.a)
* ->
* LogicalUnion
* +--LogicalProject
* +--LogicalFilter(t2.a is null)
* +--Plan2
* +--LogicalRightOuterJoin(t1.a=t2.a)
* +--Plan1
* +--LogicalFilter(t2.a is not null)
* +--Plan2
* */
public class JoinSplitForNullSkew extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin()
.when(join -> join.getJoinType().isOneSideOuterJoin())
.whenNot(join -> join.isMarkJoin() || !join.getMarkJoinConjuncts().isEmpty())
.when(join -> join.getHashJoinConjuncts().size() == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should check mark join conjuncts

.then(this::splitJoin)
.toRule(RuleType.JOIN_SPLIT_FOR_NULL_SKEW);
}

private Plan splitJoin(LogicalJoin<Plan, Plan> join) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not rewrite if leftExpr is not null already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

boolean isLeftJoin = join.getJoinType().isLeftOuterJoin();
Plan primarySide = isLeftJoin ? join.left() : join.right();
Plan associatedSide = isLeftJoin ? join.right() : join.left();
Expression conjunct = join.getHashJoinConjuncts().get(0);
if (!(conjunct instanceof EqualTo)) {
return null;
}
EqualTo equalTo = (EqualTo) conjunct;
Expression splitExpr;
if (primarySide.getOutputSet().containsAll(equalTo.left().getInputSlots())) {
splitExpr = equalTo.left();
} else {
splitExpr = equalTo.right();
}
if (!splitExpr.nullable()) {
return null;
}
// avoid duplicate application of rules
Expression isNotNull = new Not(new IsNull(splitExpr));
if (primarySide instanceof LogicalFilter
&& ((LogicalFilter<?>) primarySide).getConjuncts().contains(isNotNull)) {
return null;
}

// is not null side construct
LogicalFilter<Plan> isNotNullFilter = new LogicalFilter<>(ImmutableSet.of(isNotNull), primarySide);
LogicalJoin<Plan, Plan> newJoin;
if (isLeftJoin) {
newJoin = join.withChildren(ImmutableList.of(isNotNullFilter, associatedSide));
} else {
newJoin = join.withChildren(ImmutableList.of(associatedSide, isNotNullFilter));
}
Plan deepCopyJoin = LogicalPlanDeepCopier.INSTANCE.deepCopy(newJoin, new DeepCopierContext());

// is null side construct
LogicalFilter<Plan> isNullFilter = new LogicalFilter<>(ImmutableSet.of(new IsNull(splitExpr)), primarySide);
Plan deepCopyFilter = LogicalPlanDeepCopier.INSTANCE.deepCopy(isNullFilter, new DeepCopierContext());
List<NamedExpression> newProjects = new ArrayList<>(join.getOutput().size());
if (isLeftJoin) {
newProjects.addAll(deepCopyFilter.getOutput());
for (Slot slot : associatedSide.getOutput()) {
newProjects.add(new Alias(new NullLiteral(slot.getDataType())));
}
} else {
for (Slot slot : associatedSide.getOutput()) {
newProjects.add(new Alias(new NullLiteral(slot.getDataType())));
}
newProjects.addAll(deepCopyFilter.getOutput());
}
LogicalProject<Plan> isNullProject = new LogicalProject<>(newProjects, deepCopyFilter);

// regularChildrenOutputs construct
List<List<SlotReference>> regularChildrenOutputs = new ArrayList<>();
regularChildrenOutputs.add((List) isNullProject.getOutput());
regularChildrenOutputs.add((List) deepCopyJoin.getOutput());

return new LogicalUnion(Qualifier.ALL, (List) join.getOutput(), regularChildrenOutputs,
ImmutableList.of(), false, ImmutableList.of(isNullProject, deepCopyJoin));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ public enum RuleType {
CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE),
PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE),
PROJECT_OTHER_JOIN_CONDITION(RuleTypeClass.REWRITE),

JOIN_SPLIT_FOR_NULL_SKEW(RuleTypeClass.REWRITE),
// split limit
SPLIT_LIMIT(RuleTypeClass.REWRITE),
PULL_UP_JOIN_FROM_UNION_ALL(RuleTypeClass.REWRITE),
Expand Down Expand Up @@ -505,6 +505,7 @@ public enum RuleType {
LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
LOGICAL_GENERATE_TO_PHYSICAL_GENERATE(RuleTypeClass.IMPLEMENTATION),
LOGICAL_WINDOW_TO_PHYSICAL_WINDOW_RULE(RuleTypeClass.IMPLEMENTATION),
COUNT_DISTINCT_AGG_SKEW_REWRITE(RuleTypeClass.IMPLEMENTATION),
IMPLEMENTATION_SENTINEL(RuleTypeClass.IMPLEMENTATION),

// sentinel, use to count rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ public Expression visitUnboundFunction(UnboundFunction unboundFunction, Expressi
List<Object> arguments = unboundFunction.isDistinct()
? ImmutableList.builderWithExpectedSize(unboundFunction.arity() + 1)
.add(unboundFunction.isDistinct())
.add(unboundFunction.isSkew())
.addAll(unboundFunction.getArguments())
.build()
: (List) unboundFunction.getArguments();
Expand Down
Loading