Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.NestedColumnPrunable;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
Expand All @@ -66,6 +68,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Stack;

Expand Down Expand Up @@ -136,7 +139,8 @@ public Void visitAlias(Alias alias, CollectorContext context) {
public Void visitCast(Cast cast, CollectorContext context) {
if (!context.accessPathBuilder.isEmpty()
&& cast.getDataType() instanceof NestedColumnPrunable
&& cast.child().getDataType() instanceof NestedColumnPrunable) {
&& cast.child().getDataType() instanceof NestedColumnPrunable
&& !mapTypeIsChanged(cast.child().getDataType(), cast.getDataType(), false)) {

DataTypeAccessTree castTree = DataTypeAccessTree.of(cast.getDataType(), TAccessPathType.DATA);
DataTypeAccessTree originTree = DataTypeAccessTree.of(cast.child().getDataType(), TAccessPathType.DATA);
Expand Down Expand Up @@ -521,4 +525,46 @@ public int hashCode() {
return path.hashCode();
}
}

// if the map type is changed, we can not prune the type, because the map type need distinct the keys,
// e.g. select map_values(cast(map(3.0, 1, 3.1, 2) as map<int, int>));
// the result is [2] because the keys: 3.0 and 3.1 will cast to 3 and the second entry remained.
// backend will throw exception because it can not only access the values without the cast keys,
// so we should check whether the map type is changed, if not changed, we can prune the type.
private static boolean mapTypeIsChanged(DataType originType, DataType castType, boolean inMap) {
if (originType.isMapType()) {
MapType originMapType = (MapType) originType;
MapType castMapType = (MapType) castType;
if (mapTypeIsChanged(originMapType.getKeyType(), castMapType.getKeyType(), true)
|| mapTypeIsChanged(originMapType.getValueType(), castMapType.getValueType(), true)) {
return true;
}
return false;
} else if (originType.isStructType()) {
StructType originStructType = (StructType) originType;
StructType castStructType = (StructType) castType;
List<Entry<String, StructField>> originFields
= new ArrayList<>(originStructType.getNameToFields().entrySet());
List<Entry<String, StructField>> castFields
= new ArrayList<>(castStructType.getNameToFields().entrySet());

for (int i = 0; i < originFields.size(); i++) {
DataType originFieldType = originFields.get(i).getValue().getDataType();
DataType castFieldType = castFields.get(i).getValue().getDataType();
if (mapTypeIsChanged(originFieldType, castFieldType, inMap)) {
return true;
}
}
return false;
} else if (originType.isArrayType()) {
ArrayType originArrayType = (ArrayType) originType;
ArrayType castArrayType = (ArrayType) castType;
return mapTypeIsChanged(originArrayType.getItemType(), castArrayType.getItemType(), inMap);
} else if (inMap) {
return !originType.equals(castType);
} else {
// other type changed which not in map will not affect the map
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,20 @@ public void testStruct() throws Throwable {

@Test
public void testPruneCast() throws Exception {
// the map type is changed, so we can not prune type
assertColumn("select struct_element(cast(s as struct<k:text,l:array<map<int,struct<x:int,y:int>>>>), 'k') from tbl",
"struct<city:text,data:array<map<int,struct<a:int,b:double>>>>",
ImmutableList.of(path("s")),
ImmutableList.of()
);

assertColumn("select struct_element(cast(s as struct<k:text,l:array<map<int,struct<x:int,y:double>>>>), 'k') from tbl",
"struct<city:text>",
ImmutableList.of(path("s", "city")),
ImmutableList.of()
);

assertColumn("select struct_element(map_values(struct_element(cast(s as struct<k:text,l:array<map<int,struct<x:double,y:double>>>>), 'l')[0])[0], 'x') from tbl",
assertColumn("select struct_element(map_values(struct_element(cast(s as struct<k:text,l:array<map<int,struct<x:int,y:double>>>>), 'l')[0])[0], 'x') from tbl",
"struct<data:array<map<int,struct<a:int>>>>",
ImmutableList.of(path("s", "data", "*", "VALUES", "a")),
ImmutableList.of()
Expand Down
Loading