Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.DefaultValueSlot;
import org.apache.doris.nereids.trees.expressions.DereferenceExpression;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Exists;
Expand Down Expand Up @@ -3408,8 +3409,7 @@ public Expression visitDereference(DereferenceContext ctx) {
UnboundSlot slot = new UnboundSlot(nameParts, Optional.empty());
return slot;
} else {
// todo: base is an expression, may be not a table name.
throw new ParseException("Unsupported dereference expression: " + ctx.getText(), ctx);
return new DereferenceExpression(e, new StringLiteral(ctx.identifier().getText()));
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,12 @@ private LogicalHaving<Plan> bindHavingAggregate(
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build());

return (analyzer, unboundSlot) -> {
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
List<Expression> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
if (!boundInGroupBy.isEmpty()) {
return ImmutableList.of(boundInGroupBy.get(0));
}

List<Slot> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
List<Expression> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
if (!boundInAggOutput.isEmpty()) {
return ImmutableList.of(boundInAggOutput.get(0));
}
Expand Down Expand Up @@ -553,7 +553,7 @@ private LogicalHaving<Plan> bindHavingByScopes(
SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer(
having, cascadesContext, defaultScope, false, true,
(self, unboundSlot) -> {
List<Slot> slots = self.bindSlotByScope(unboundSlot, defaultScope);
List<Expression> slots = self.bindSlotByScope(unboundSlot, defaultScope);
if (!slots.isEmpty()) {
return slots;
}
Expand Down Expand Up @@ -1006,7 +1006,7 @@ private void bindQualifyByProject(LogicalProject<? extends Plan> project, Cascad
SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer(
qualify, cascadesContext, defaultScope.get(), true, true,
(self, unboundSlot) -> {
List<Slot> slots = self.bindSlotByScope(unboundSlot, defaultScope.get());
List<Expression> slots = self.bindSlotByScope(unboundSlot, defaultScope.get());
if (!slots.isEmpty()) {
return slots;
}
Expand Down Expand Up @@ -1044,11 +1044,11 @@ private void bindQualifyByAggregate(Aggregate<? extends Plan> aggregate, Cascade
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build());

return (analyzer, unboundSlot) -> {
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
List<Expression> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
if (!boundInGroupBy.isEmpty()) {
return ImmutableList.of(boundInGroupBy.get(0));
}
List<Slot> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
List<Expression> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
if (!boundInAggOutput.isEmpty()) {
return ImmutableList.of(boundInAggOutput.get(0));
}
Expand Down Expand Up @@ -1368,15 +1368,15 @@ private List<Expression> bindGroupBy(
// see: https://github.com/apache/doris/pull/15240
//
// first, try to bind by agg.child.output
List<Slot> slotsInChildren = self.bindExactSlotsByThisScope(unboundSlot, childOutputScope);
List<Expression> slotsInChildren = self.bindExactSlotsByThisScope(unboundSlot, childOutputScope);
if (slotsInChildren.size() == 1) {
// bind succeed
return slotsInChildren;
}
// second, bind failed:
// if the slot not found, or more than one candidate slots found in agg.child.output,
// then try to bind by agg.output
List<Slot> slotsInOutput = self.bindExactSlotsByThisScope(
List<Expression> slotsInOutput = self.bindExactSlotsByThisScope(
unboundSlot, aggOutputScopeWithoutAggFun.get());
if (slotsInOutput.isEmpty()) {
// if slotsInChildren.size() > 1 && slotsInOutput.isEmpty(),
Expand All @@ -1385,7 +1385,7 @@ private List<Expression> bindGroupBy(
}

Builder<Expression> useOutputExpr = ImmutableList.builderWithExpectedSize(slotsInOutput.size());
for (Slot slotInOutput : slotsInOutput) {
for (Expression slotInOutput : slotsInOutput) {
// mappingSlot is provided by aggOutputScopeWithoutAggFun
// and no non-MappingSlot slot exist in the Scope, so we
// can direct cast it safely
Expand Down Expand Up @@ -1476,7 +1476,7 @@ private Plan bindSortWithoutSetOperation(MatchingContext<LogicalSort<Plan>> ctx)
sort, cascadesContext, inputScope, true, false,
(self, unboundSlot) -> {
// first, try to bind slot in Scope(input.output)
List<Slot> slotsInInput = self.bindExactSlotsByThisScope(unboundSlot, inputScope);
List<Expression> slotsInInput = self.bindExactSlotsByThisScope(unboundSlot, inputScope);
if (!slotsInInput.isEmpty()) {
// bind succeed
return ImmutableList.of(slotsInInput.get(0));
Expand Down Expand Up @@ -1678,7 +1678,7 @@ private SimpleExprAnalyzer getAnalyzerForOrderByAggFunc(Plan finalInput, Cascade
sort, cascadesContext, inputScope, true, false,
(analyzer, unboundSlot) -> {
if (finalInput instanceof LogicalAggregate) {
List<Slot> boundInOutputWithoutAggFunc = analyzer.bindSlotByScope(unboundSlot,
List<Expression> boundInOutputWithoutAggFunc = analyzer.bindSlotByScope(unboundSlot,
outputWithoutAggFunc);
if (!boundInOutputWithoutAggFunc.isEmpty()) {
return ImmutableList.of(boundInOutputWithoutAggFunc.get(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.DereferenceExpression;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
Expand All @@ -74,7 +75,9 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf;
Expand All @@ -93,6 +96,8 @@
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
Expand Down Expand Up @@ -240,6 +245,25 @@ public Expression visitUnboundAlias(UnboundAlias unboundAlias, ExpressionRewrite
}
}

@Override
public Expression visitDereferenceExpression(DereferenceExpression dereferenceExpression,
ExpressionRewriteContext context) {
Expression expression = dereferenceExpression.child(0).accept(this, context);
DataType dataType = expression.getDataType();
if (dataType.isStructType()) {
StructType structType = (StructType) dataType;
StructField field = structType.getField(dereferenceExpression.fieldName);
if (field != null) {
return new StructElement(expression, dereferenceExpression.child(1));
}
} else if (dataType.isMapType()) {
return new ElementAt(expression, dereferenceExpression.child(1));
} else if (dataType.isVariantType()) {
return new ElementAt(expression, dereferenceExpression.child(1));
}
throw new AnalysisException("Can not dereference field: " + dereferenceExpression.fieldName);
}

@Override
public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) {
Optional<Scope> outerScope = getScope().getOuterScope();
Expand Down Expand Up @@ -913,13 +937,13 @@ protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot
return bindSlotByScope(unboundSlot, getScope());
}

protected List<Slot> bindExactSlotsByThisScope(UnboundSlot unboundSlot, Scope scope) {
List<Slot> candidates = bindSlotByScope(unboundSlot, scope);
protected List<Expression> bindExactSlotsByThisScope(UnboundSlot unboundSlot, Scope scope) {
List<Expression> candidates = bindSlotByScope(unboundSlot, scope);
if (candidates.size() == 1) {
return candidates;
}
List<Slot> extractSlots = Utils.filterImmutableList(candidates, bound ->
unboundSlot.getNameParts().size() == bound.getQualifier().size() + 1
List<Expression> extractSlots = Utils.filterImmutableList(candidates, bound ->
bound instanceof Slot && unboundSlot.getNameParts().size() == ((Slot) bound).getQualifier().size() + 1
);
// we should return origin candidates slots if extract slots is empty,
// and then throw an ambiguous exception
Expand All @@ -938,33 +962,137 @@ private List<Slot> addSqlIndexInfo(List<Slot> slots, Optional<Pair<Integer, Inte
}

/** bindSlotByScope */
public List<Slot> bindSlotByScope(UnboundSlot unboundSlot, Scope scope) {
public List<Expression> bindSlotByScope(UnboundSlot unboundSlot, Scope scope) {
List<String> nameParts = unboundSlot.getNameParts();
Optional<Pair<Integer, Integer>> idxInSql = unboundSlot.getIndexInSqlString();
int namePartSize = nameParts.size();
switch (namePartSize) {
// column
case 1: {
return addSqlIndexInfo(bindSingleSlotByName(nameParts.get(0), scope), idxInSql);
return (List<Expression>) bindExpressionByColumn(unboundSlot, nameParts, idxInSql, scope);
}
// table.column
case 2: {
return addSqlIndexInfo(bindSingleSlotByTable(nameParts.get(0), nameParts.get(1), scope), idxInSql);
return (List<Expression>) bindExpressionByTableColumn(unboundSlot, nameParts, idxInSql, scope);
}
// db.table.column
case 3: {
return addSqlIndexInfo(bindSingleSlotByDb(nameParts.get(0), nameParts.get(1), nameParts.get(2), scope),
idxInSql);
return (List<Expression>) bindExpressionByDbTableColumn(unboundSlot, nameParts, idxInSql, scope);
}
// catalog.db.table.column
case 4: {
return addSqlIndexInfo(bindSingleSlotByCatalog(
nameParts.get(0), nameParts.get(1), nameParts.get(2), nameParts.get(3), scope), idxInSql);
}
default: {
throw new AnalysisException("Not supported name: " + StringUtils.join(nameParts, "."));
return (List<Expression>) bindExpressionByCatalogDbTableColumn(unboundSlot, nameParts, idxInSql, scope);
}
}
}

private List<? extends Expression> bindExpressionByCatalogDbTableColumn(
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
List<Slot> slots = addSqlIndexInfo(bindSingleSlotByCatalog(
nameParts.get(0), nameParts.get(1), nameParts.get(2), nameParts.get(3), scope), idxInSql);
if (slots.isEmpty()) {
return bindExpressionByDbTableColumn(unboundSlot, nameParts, idxInSql, scope);
} else if (slots.size() > 1) {
return slots;
}
if (nameParts.size() == 4) {
return slots;
}

Optional<Expression> expression = bindNestedFields(
unboundSlot, slots.get(0), nameParts.subList(4, nameParts.size())
);
if (!expression.isPresent()) {
return slots;
}
return ImmutableList.of(expression.get());
}

private List<? extends Expression> bindExpressionByDbTableColumn(
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
List<Slot> slots = addSqlIndexInfo(
bindSingleSlotByDb(nameParts.get(0), nameParts.get(1), nameParts.get(2), scope), idxInSql);
if (slots.isEmpty()) {
return bindExpressionByTableColumn(unboundSlot, nameParts, idxInSql, scope);
} else if (slots.size() > 1) {
return slots;
}
if (nameParts.size() == 3) {
return slots;
}

Optional<Expression> expression = bindNestedFields(
unboundSlot, slots.get(0), nameParts.subList(3, nameParts.size())
);
if (!expression.isPresent()) {
return slots;
}
return ImmutableList.of(expression.get());
}

private List<? extends Expression> bindExpressionByTableColumn(
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
List<Slot> slots = addSqlIndexInfo(bindSingleSlotByTable(nameParts.get(0), nameParts.get(1), scope), idxInSql);
if (slots.isEmpty()) {
return bindExpressionByColumn(unboundSlot, nameParts, idxInSql, scope);
} else if (slots.size() > 1) {
return slots;
}
if (nameParts.size() == 2) {
return slots;
}

Optional<Expression> expression = bindNestedFields(
unboundSlot, slots.get(0), nameParts.subList(2, nameParts.size())
);
if (!expression.isPresent()) {
return slots;
}
return ImmutableList.of(expression.get());
}

private List<? extends Expression> bindExpressionByColumn(
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
List<Slot> slots = addSqlIndexInfo(bindSingleSlotByName(nameParts.get(0), scope), idxInSql);
if (slots.size() != 1) {
return slots;
}
if (nameParts.size() == 1) {
return slots;
}
Optional<Expression> expression = bindNestedFields(
unboundSlot, slots.get(0), nameParts.subList(1, nameParts.size())
);
if (!expression.isPresent()) {
return slots;
}
return ImmutableList.of(expression.get());
}

private Optional<Expression> bindNestedFields(UnboundSlot unboundSlot, Slot slot, List<String> fieldNames) {
Expression expression = slot;
String lastFieldName = slot.getName();
for (String fieldName : fieldNames) {
DataType dataType = expression.getDataType();
if (dataType.isStructType()) {
StructType structType = (StructType) dataType;
StructField field = structType.getField(fieldName);
if (field == null) {
throw new AnalysisException("No such struct field '" + fieldName + "' in '" + lastFieldName + "'");
}
lastFieldName = fieldName;
expression = new StructElement(expression, new StringLiteral(fieldName));
continue;
} else if (dataType.isMapType()) {
expression = new ElementAt(expression, new StringLiteral(fieldName));
continue;
} else if (dataType.isVariantType()) {
expression = new ElementAt(expression, new StringLiteral(fieldName));
continue;
}
throw new AnalysisException("No such field '" + fieldName + "' in '" + lastFieldName + "'");
}
return Optional.of(new Alias(expression));
}

public static boolean sameTableName(String boundSlot, String unboundSlot) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;

import com.google.common.collect.ImmutableList;

/** DereferenceExpression*/
public class DereferenceExpression extends Expression implements BinaryExpression, PropagateNullable, Unbound {
public final String fieldName;

public DereferenceExpression(Expression expression, StringLiteral fieldName) {
super(ImmutableList.of(expression, fieldName));
this.fieldName = fieldName.getValue();
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitDereferenceExpression(this, context);
}
}
Loading
Loading