Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* EquivalenceClass, this is used for equality propagation when predicate compensation
Expand All @@ -40,14 +37,19 @@ public class EquivalenceClass {
* a: [a, b],
* b: [a, b]
* }
* or column a = a,
* this would be
* {
* a: [a, a]
* }
*/
private Map<SlotReference, Set<SlotReference>> equivalenceSlotMap = new LinkedHashMap<>();
private List<Set<SlotReference>> equivalenceSlotList;
private Map<SlotReference, List<SlotReference>> equivalenceSlotMap = new LinkedHashMap<>();
private List<List<SlotReference>> equivalenceSlotList;

public EquivalenceClass() {
}

public EquivalenceClass(Map<SlotReference, Set<SlotReference>> equivalenceSlotMap) {
public EquivalenceClass(Map<SlotReference, List<SlotReference>> equivalenceSlotMap) {
this.equivalenceSlotMap = equivalenceSlotMap;
}

Expand All @@ -56,13 +58,13 @@ public EquivalenceClass(Map<SlotReference, Set<SlotReference>> equivalenceSlotMa
*/
public void addEquivalenceClass(SlotReference leftSlot, SlotReference rightSlot) {

Set<SlotReference> leftSlotSet = equivalenceSlotMap.get(leftSlot);
Set<SlotReference> rightSlotSet = equivalenceSlotMap.get(rightSlot);
List<SlotReference> leftSlotSet = equivalenceSlotMap.get(leftSlot);
List<SlotReference> rightSlotSet = equivalenceSlotMap.get(rightSlot);
if (leftSlotSet != null && rightSlotSet != null) {
// Both present, we need to merge
if (leftSlotSet.size() < rightSlotSet.size()) {
// We swap them to merge
Set<SlotReference> tmp = rightSlotSet;
List<SlotReference> tmp = rightSlotSet;
rightSlotSet = leftSlotSet;
leftSlotSet = tmp;
}
Expand All @@ -80,15 +82,15 @@ public void addEquivalenceClass(SlotReference leftSlot, SlotReference rightSlot)
equivalenceSlotMap.put(leftSlot, rightSlotSet);
} else {
// None are present, add to same equivalence class
Set<SlotReference> equivalenceClass = new LinkedHashSet<>();
List<SlotReference> equivalenceClass = new ArrayList<>();
equivalenceClass.add(leftSlot);
equivalenceClass.add(rightSlot);
equivalenceSlotMap.put(leftSlot, equivalenceClass);
equivalenceSlotMap.put(rightSlot, equivalenceClass);
}
}

public Map<SlotReference, Set<SlotReference>> getEquivalenceSlotMap() {
public Map<SlotReference, List<SlotReference>> getEquivalenceSlotMap() {
return equivalenceSlotMap;
}

Expand All @@ -101,15 +103,15 @@ public boolean isEmpty() {
*/
public EquivalenceClass permute(Map<SlotReference, SlotReference> mapping) {

Map<SlotReference, Set<SlotReference>> permutedEquivalenceSlotMap = new HashMap<>();
for (Map.Entry<SlotReference, Set<SlotReference>> slotReferenceSetEntry : equivalenceSlotMap.entrySet()) {
Map<SlotReference, List<SlotReference>> permutedEquivalenceSlotMap = new HashMap<>();
for (Map.Entry<SlotReference, List<SlotReference>> slotReferenceSetEntry : equivalenceSlotMap.entrySet()) {
SlotReference mappedSlotReferenceKey = mapping.get(slotReferenceSetEntry.getKey());
if (mappedSlotReferenceKey == null) {
// can not permute then need to return null
return null;
}
Set<SlotReference> equivalenceValueSet = slotReferenceSetEntry.getValue();
final Set<SlotReference> mappedSlotReferenceSet = new HashSet<>();
List<SlotReference> equivalenceValueSet = slotReferenceSetEntry.getValue();
final List<SlotReference> mappedSlotReferenceSet = new ArrayList<>();
for (SlotReference target : equivalenceValueSet) {
SlotReference mappedSlotReferenceValue = mapping.get(target);
if (mappedSlotReferenceValue == null) {
Expand All @@ -123,15 +125,14 @@ public EquivalenceClass permute(Map<SlotReference, SlotReference> mapping) {
}

/**
* Return the list of equivalence set, remove duplicate
* Return the list of equivalence list, remove duplicate
*/
public List<Set<SlotReference>> getEquivalenceSetList() {

public List<List<SlotReference>> getEquivalenceSetList() {
if (equivalenceSlotList != null) {
return equivalenceSlotList;
}
List<Set<SlotReference>> equivalenceSets = new ArrayList<>();
Set<Set<SlotReference>> visited = new HashSet<>();
List<List<SlotReference>> equivalenceSets = new ArrayList<>();
List<List<SlotReference>> visited = new ArrayList<>();
equivalenceSlotMap.values().forEach(slotSet -> {
if (!visited.contains(slotSet)) {
equivalenceSets.add(slotSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.EquivalenceClassSetMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.EquivalenceClassMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
Expand All @@ -33,6 +33,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
Expand Down Expand Up @@ -98,15 +99,15 @@ public static Set<Expression> compensateEquivalence(StructInfo queryStructInfo,
if (queryEquivalenceClass.isEmpty() && !viewEquivalenceClass.isEmpty()) {
return null;
}
EquivalenceClassSetMapping queryToViewEquivalenceMapping =
EquivalenceClassSetMapping.generate(queryEquivalenceClass, viewEquivalenceClassQueryBased);
EquivalenceClassMapping queryToViewEquivalenceMapping =
EquivalenceClassMapping.generate(queryEquivalenceClass, viewEquivalenceClassQueryBased);
// can not map all target equivalence class, can not compensate
if (queryToViewEquivalenceMapping.getEquivalenceClassSetMap().size()
< viewEquivalenceClass.getEquivalenceSetList().size()) {
return null;
}
// do equal compensate
Set<Set<SlotReference>> mappedQueryEquivalenceSet =
Set<List<SlotReference>> mappedQueryEquivalenceSet =
queryToViewEquivalenceMapping.getEquivalenceClassSetMap().keySet();
queryEquivalenceClass.getEquivalenceSetList().forEach(
queryEquivalenceSet -> {
Expand All @@ -120,9 +121,9 @@ public static Set<Expression> compensateEquivalence(StructInfo queryStructInfo,
}
} else {
// compensate the equivalence both in query and view, but query has more equivalence
Set<SlotReference> viewEquivalenceSet =
List<SlotReference> viewEquivalenceSet =
queryToViewEquivalenceMapping.getEquivalenceClassSetMap().get(queryEquivalenceSet);
Set<SlotReference> copiedQueryEquivalenceSet = new HashSet<>(queryEquivalenceSet);
List<SlotReference> copiedQueryEquivalenceSet = new ArrayList<>(queryEquivalenceSet);
copiedQueryEquivalenceSet.removeAll(viewEquivalenceSet);
SlotReference first = viewEquivalenceSet.iterator().next();
for (SlotReference slotReference : copiedQueryEquivalenceSet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -30,39 +31,41 @@
* This will extract the equivalence class set in EquivalenceClass and mapping set in
* two different EquivalenceClass.
*/
public class EquivalenceClassSetMapping extends Mapping {
public class EquivalenceClassMapping extends Mapping {

private final Map<Set<SlotReference>, Set<SlotReference>> equivalenceClassSetMap;
private final Map<List<SlotReference>, List<SlotReference>> equivalenceClassSetMap;

public EquivalenceClassSetMapping(Map<Set<SlotReference>,
Set<SlotReference>> equivalenceClassSetMap) {
public EquivalenceClassMapping(Map<List<SlotReference>,
List<SlotReference>> equivalenceClassSetMap) {
this.equivalenceClassSetMap = equivalenceClassSetMap;
}

public static EquivalenceClassSetMapping of(Map<Set<SlotReference>, Set<SlotReference>> equivalenceClassSetMap) {
return new EquivalenceClassSetMapping(equivalenceClassSetMap);
public static EquivalenceClassMapping of(Map<List<SlotReference>, List<SlotReference>> equivalenceClassSetMap) {
return new EquivalenceClassMapping(equivalenceClassSetMap);
}

/**
* Generate source equivalence set map to target equivalence set
*/
public static EquivalenceClassSetMapping generate(EquivalenceClass source, EquivalenceClass target) {
public static EquivalenceClassMapping generate(EquivalenceClass source, EquivalenceClass target) {

Map<Set<SlotReference>, Set<SlotReference>> equivalenceClassSetMap = new HashMap<>();
List<Set<SlotReference>> sourceSets = source.getEquivalenceSetList();
List<Set<SlotReference>> targetSets = target.getEquivalenceSetList();
Map<List<SlotReference>, List<SlotReference>> equivalenceClassSetMap = new HashMap<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change set to list could fix this problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

such as the expression o_orderstatus = o_orderstatus we should compensate o_orderstatus = o_orderstatus on materialized view.
If we record the slot equal expression in set, only get {o_orderstatus}, in Predicates#compensateEquivalence we couldn't compensate filter.
Change set to list. we get {o_orderstatus, o_orderstatus}, then we can compensate filter by Predicates#compensateEquivalence .

List<List<SlotReference>> sourceSets = source.getEquivalenceSetList();
List<List<SlotReference>> targetSets = target.getEquivalenceSetList();

for (Set<SlotReference> sourceSet : sourceSets) {
for (Set<SlotReference> targetSet : targetSets) {
for (List<SlotReference> sourceList : sourceSets) {
Set<SlotReference> sourceSet = new HashSet<>(sourceList);
for (List<SlotReference> targetList : targetSets) {
Set<SlotReference> targetSet = new HashSet<>(targetList);
if (sourceSet.containsAll(targetSet)) {
equivalenceClassSetMap.put(sourceSet, targetSet);
equivalenceClassSetMap.put(sourceList, targetList);
}
}
}
return EquivalenceClassSetMapping.of(equivalenceClassSetMap);
return EquivalenceClassMapping.of(equivalenceClassSetMap);
}

public Map<Set<SlotReference>, Set<SlotReference>> getEquivalenceClassSetMap() {
public Map<List<SlotReference>, List<SlotReference>> getEquivalenceClassSetMap() {
return equivalenceClassSetMap;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !query1_0_before --
1 o mm
2 o mi
4 o yy

-- !query1_0_after --
1 o mm
2 o mi
4 o yy

Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package mv.unsafe_equals
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("null_unsafe_equals") {
String db = context.config.getDbNameByFile(context.file)
sql "use ${db}"
sql "set runtime_filter_mode=OFF";
sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"

sql """
drop table if exists orders
"""

sql """
CREATE TABLE IF NOT EXISTS orders (
o_orderkey INTEGER NULL,
o_custkey INTEGER NULL,
o_orderstatus CHAR(1) NULL,
o_totalprice DECIMALV3(15,2) NULL,
o_orderdate DATE NULL,
o_orderpriority CHAR(15) NULL,
o_clerk CHAR(15) NULL,
o_shippriority INTEGER NULL,
O_COMMENT VARCHAR(79) NULL
)
DUPLICATE KEY(o_orderkey, o_custkey)
PARTITION BY RANGE(o_orderdate) (
PARTITION `day_2` VALUES LESS THAN ('2023-12-9'),
PARTITION `day_3` VALUES LESS THAN ("2023-12-11"),
PARTITION `day_4` VALUES LESS THAN ("2023-12-30")
)
DISTRIBUTED BY HASH(o_orderkey) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
"""

sql """
insert into orders values
(null, 1, 'o', 9.5, '2023-12-08', 'a', 'b', 1, 'yy'),
(1, null, 'o', 10.5, '2023-12-08', 'a', 'b', 1, 'yy'),
(2, 1, null, 11.5, '2023-12-09', 'a', 'b', 1, 'yy'),
(3, 1, 'o', null, '2023-12-10', 'a', 'b', 1, 'yy'),
(3, 1, 'o', 33.5, null, 'a', 'b', 1, 'yy'),
(4, 2, 'o', 43.2, '2023-12-11', null,'d',2, 'mm'),
(5, 2, 'o', 56.2, '2023-12-12', 'c',null, 2, 'mi'),
(5, 2, 'o', 1.2, '2023-12-12', 'c','d', null, 'mi');
"""

def mv1_0 =
"""
select count(*), o_orderstatus, o_comment
from orders
group by
o_orderstatus, o_comment;
"""
// query contains the filter which is 'o_orderstatus = o_orderstatus' should reject null
def query1_0 =
"""
select count(*), o_orderstatus, o_comment
from orders
where o_orderstatus = o_orderstatus
group by
o_orderstatus, o_comment;
"""
order_qt_query1_0_before "${query1_0}"
async_mv_rewrite_success(db, mv1_0, query1_0, "mv1_0")
order_qt_query1_0_after "${query1_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv1_0"""
}