From 8a291efb0c50302d582844a948f489698f76ccbb Mon Sep 17 00:00:00 2001 From: TengJianPing <18241664+jacktengg@users.noreply.github.com> Date: Mon, 1 Apr 2024 16:14:10 +0800 Subject: [PATCH 01/12] [improvement](spill) avoid spill if memory is enough (#33075) --- .../partitioned_aggregation_sink_operator.cpp | 30 +++++++------ .../partitioned_aggregation_sink_operator.h | 3 -- ...artitioned_aggregation_source_operator.cpp | 29 ++++++------- .../partitioned_aggregation_source_operator.h | 2 - .../exec/spill_sort_sink_operator.cpp | 43 ++++++++++--------- .../pipeline/exec/spill_sort_sink_operator.h | 3 -- .../exec/spill_sort_source_operator.cpp | 29 +++---------- .../exec/spill_sort_source_operator.h | 3 -- be/src/pipeline/pipeline_x/dependency.h | 2 + be/src/vec/common/sort/sorter.cpp | 3 +- be/src/vec/spill/spill_stream.cpp | 17 ++++++-- 11 files changed, 76 insertions(+), 88 deletions(-) diff --git a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp index 3207c109589c1d..d44c35a76a9478 100644 --- a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp +++ b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp @@ -70,12 +70,16 @@ Status PartitionedAggSinkLocalState::close(RuntimeState* state, Status exec_stat if (Base::_closed) { return Status::OK(); } +<<<<<<< HEAD { std::unique_lock lk(_spill_lock); if (_is_spilling) { _spill_cv.wait(lk); } } +======= + dec_running_big_mem_op_num(state); +>>>>>>> bb11955709 ([improvement](spill) avoid spill if memory is enough (#33075)) return Base::close(state, exec_status); } @@ -166,13 +170,17 @@ Status PartitionedAggSinkOperatorX::sink(doris::RuntimeState* state, vectorized: auto* runtime_state = local_state._runtime_state.get(); RETURN_IF_ERROR(_agg_sink_operator->sink(runtime_state, in_block, false)); if (eos) { - LOG(INFO) << "agg node " << id() << " sink eos"; - if (revocable_mem_size(state) > 0) { - RETURN_IF_ERROR(revoke_memory(state)); - } else { - for (auto& partition : local_state._shared_state->spill_partitions) { - RETURN_IF_ERROR(partition->finish_current_spilling(eos)); + if (local_state._shared_state->is_spilled) { + if (revocable_mem_size(state) > 0) { + RETURN_IF_ERROR(revoke_memory(state)); + } else { + for (auto& partition : local_state._shared_state->spill_partitions) { + RETURN_IF_ERROR(partition->finish_current_spilling(eos)); + } + local_state._dependency->set_ready_to_read(); + local_state._finish_dependency->set_ready(); } + } else { local_state._dependency->set_ready_to_read(); } } @@ -229,8 +237,10 @@ Status PartitionedAggSinkLocalState::revoke_memory(RuntimeState* state) { LOG(INFO) << "agg node " << Base::_parent->id() << " revoke_memory" << ", eos: " << _eos; RETURN_IF_ERROR(Base::_shared_state->sink_status); - DCHECK(!_is_spilling); - _is_spilling = true; + if (!_shared_state->is_spilled) { + _shared_state->is_spilled = true; + profile()->add_info_string("Spilled", "true"); + } // TODO: spill thread may set_ready before the task::execute thread put the task to blocked state if (!_eos) { @@ -240,7 +250,6 @@ Status PartitionedAggSinkLocalState::revoke_memory(RuntimeState* state) { Status status; Defer defer {[&]() { if (!status.ok()) { - _is_spilling = false; if (!_eos) { Base::_dependency->Dependency::set_ready(); } @@ -269,15 +278,12 @@ Status PartitionedAggSinkLocalState::revoke_memory(RuntimeState* state) { << ", eos: " << _eos; } { - std::unique_lock lk(_spill_lock); - _is_spilling = false; if (_eos) { Base::_dependency->set_ready_to_read(); _finish_dependency->set_ready(); } else { Base::_dependency->Dependency::set_ready(); } - _spill_cv.notify_one(); } }}; auto* runtime_state = _runtime_state.get(); diff --git a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h index 5e61738681235d..542046556ec9e3 100644 --- a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h +++ b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.h @@ -272,9 +272,6 @@ class PartitionedAggSinkLocalState bool _eos = false; std::shared_ptr _finish_dependency; - bool _is_spilling = false; - std::mutex _spill_lock; - std::condition_variable _spill_cv; /// Resources in shared state will be released when the operator is closed, /// but there may be asynchronous spilling tasks at this time, which can lead to conflicts. diff --git a/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp b/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp index 960decdb9515ce..f5eceac338c76c 100644 --- a/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp +++ b/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp @@ -88,12 +88,7 @@ Status PartitionedAggLocalState::close(RuntimeState* state) { if (_closed) { return Status::OK(); } - { - std::unique_lock lk(_merge_spill_lock); - if (_is_merging) { - _merge_spill_cv.wait(lk); - } - } + dec_running_big_mem_op_num(state); return Base::close(state); } PartitionedAggSourceOperatorX::PartitionedAggSourceOperatorX(ObjectPool* pool, @@ -131,13 +126,16 @@ Status PartitionedAggSourceOperatorX::get_block(RuntimeState* state, vectorized: SCOPED_TIMER(local_state.exec_time_counter()); RETURN_IF_ERROR(local_state._status); - RETURN_IF_ERROR(local_state.initiate_merge_spill_partition_agg_data(state)); + if (local_state._shared_state->is_spilled) { + RETURN_IF_ERROR(local_state.initiate_merge_spill_partition_agg_data(state)); - /// When `_is_merging` is true means we are reading spilled data and merging the data into hash table. - if (local_state._is_merging) { - return Status::OK(); + /// When `_is_merging` is true means we are reading spilled data and merging the data into hash table. + if (local_state._is_merging) { + return Status::OK(); + } } + // not spilled in sink or current partition still has data auto* runtime_state = local_state._runtime_state.get(); RETURN_IF_ERROR(_agg_source_operator->get_block(runtime_state, block, eos)); if (local_state._runtime_state) { @@ -146,7 +144,8 @@ Status PartitionedAggSourceOperatorX::get_block(RuntimeState* state, vectorized: local_state.update_profile(source_local_state->profile()); } if (*eos) { - if (!local_state._shared_state->spill_partitions.empty()) { + if (local_state._shared_state->is_spilled && + !local_state._shared_state->spill_partitions.empty()) { *eos = false; } } @@ -218,12 +217,8 @@ Status PartitionedAggLocalState::initiate_merge_spill_partition_agg_data(Runtime } Base::_shared_state->in_mem_shared_state->aggregate_data_container ->init_once(); - { - std::unique_lock lk(_merge_spill_lock); - _is_merging = false; - _dependency->Dependency::set_ready(); - _merge_spill_cv.notify_one(); - } + _is_merging = false; + _dependency->Dependency::set_ready(); }}; bool has_agg_data = false; auto& parent = Base::_parent->template cast(); diff --git a/be/src/pipeline/exec/partitioned_aggregation_source_operator.h b/be/src/pipeline/exec/partitioned_aggregation_source_operator.h index ac63402f227023..eff1e7179c8d0d 100644 --- a/be/src/pipeline/exec/partitioned_aggregation_source_operator.h +++ b/be/src/pipeline/exec/partitioned_aggregation_source_operator.h @@ -60,8 +60,6 @@ class PartitionedAggLocalState final : public PipelineXSpillLocalState _spill_merge_future; bool _current_partition_eos = true; bool _is_merging = false; - std::mutex _merge_spill_lock; - std::condition_variable _merge_spill_cv; /// Resources in shared state will be released when the operator is closed, /// but there may be asynchronous spilling tasks at this time, which can lead to conflicts. diff --git a/be/src/pipeline/exec/spill_sort_sink_operator.cpp b/be/src/pipeline/exec/spill_sort_sink_operator.cpp index c586a8e5e56012..7764aa948b9258 100644 --- a/be/src/pipeline/exec/spill_sort_sink_operator.cpp +++ b/be/src/pipeline/exec/spill_sort_sink_operator.cpp @@ -74,11 +74,9 @@ Status SpillSortSinkLocalState::open(RuntimeState* state) { return Status::OK(); } Status SpillSortSinkLocalState::close(RuntimeState* state, Status execsink_status) { - { - std::unique_lock lk(_spill_lock); - if (_is_spilling) { - _spill_cv.wait(lk); - } + auto& parent = Base::_parent->template cast(); + if (parent._enable_spill) { + dec_running_big_mem_op_num(state); } return Status::OK(); } @@ -172,9 +170,16 @@ Status SpillSortSinkOperatorX::sink(doris::RuntimeState* state, vectorized::Bloc local_state._shared_state->in_mem_shared_state->sorter->data_size()); if (eos) { if (_enable_spill) { - if (revocable_mem_size(state) > 0) { - RETURN_IF_ERROR(revoke_memory(state)); + if (local_state._shared_state->is_spilled) { + if (revocable_mem_size(state) > 0) { + RETURN_IF_ERROR(revoke_memory(state)); + } else { + local_state._dependency->set_ready_to_read(); + local_state._finish_dependency->set_ready(); + } } else { + RETURN_IF_ERROR( + local_state._shared_state->in_mem_shared_state->sorter->prepare_for_read()); local_state._dependency->set_ready_to_read(); } } else { @@ -186,8 +191,10 @@ Status SpillSortSinkOperatorX::sink(doris::RuntimeState* state, vectorized::Bloc return Status::OK(); } Status SpillSortSinkLocalState::revoke_memory(RuntimeState* state) { - DCHECK(!_is_spilling); - _is_spilling = true; + if (!_shared_state->is_spilled) { + _shared_state->is_spilled = true; + profile()->add_info_string("Spilled", "true"); + } LOG(INFO) << "sort node " << Base::_parent->id() << " revoke_memory" << ", eos: " << _eos; @@ -243,17 +250,12 @@ Status SpillSortSinkLocalState::revoke_memory(RuntimeState* state) { _shared_state->clear(); } - { - std::unique_lock lk(_spill_lock); - _spilling_stream.reset(); - _is_spilling = false; - if (_eos) { - _dependency->set_ready_to_read(); - _finish_dependency->set_ready(); - } else { - _dependency->Dependency::set_ready(); - } - _spill_cv.notify_one(); + _spilling_stream.reset(); + if (_eos) { + _dependency->set_ready_to_read(); + _finish_dependency->set_ready(); + } else { + _dependency->Dependency::set_ready(); } }}; @@ -288,7 +290,6 @@ Status SpillSortSinkLocalState::revoke_memory(RuntimeState* state) { return Status::OK(); }); if (!status.ok()) { - _is_spilling = false; _spilling_stream->end_spill(status); if (!_eos) { diff --git a/be/src/pipeline/exec/spill_sort_sink_operator.h b/be/src/pipeline/exec/spill_sort_sink_operator.h index ae5a3bcb8c7d83..d66215411aae12 100644 --- a/be/src/pipeline/exec/spill_sort_sink_operator.h +++ b/be/src/pipeline/exec/spill_sort_sink_operator.h @@ -62,11 +62,8 @@ class SpillSortSinkLocalState : public PipelineXSpillSinkLocalState _finish_dependency; - std::mutex _spill_lock; - std::condition_variable _spill_cv; }; class SpillSortSinkOperatorX final : public DataSinkOperatorX { diff --git a/be/src/pipeline/exec/spill_sort_source_operator.cpp b/be/src/pipeline/exec/spill_sort_source_operator.cpp index d249b3be56e23c..417ff704bc6abb 100644 --- a/be/src/pipeline/exec/spill_sort_source_operator.cpp +++ b/be/src/pipeline/exec/spill_sort_source_operator.cpp @@ -58,17 +58,10 @@ Status SpillSortLocalState::close(RuntimeState* state) { if (_closed) { return Status::OK(); } - { - std::unique_lock lk(_merge_spill_lock); - if (_is_merging) { - _merge_spill_cv.wait(lk); - } + if (Base::_shared_state->enable_spill) { + dec_running_big_mem_op_num(state); } RETURN_IF_ERROR(Base::close(state)); - for (auto& stream : _current_merging_streams) { - (void)ExecEnv::GetInstance()->spill_stream_mgr()->delete_spill_stream(stream); - } - _current_merging_streams.clear(); return Status::OK(); } int SpillSortLocalState::_calc_spill_blocks_to_merge() const { @@ -78,14 +71,11 @@ int SpillSortLocalState::_calc_spill_blocks_to_merge() const { Status SpillSortLocalState::initiate_merge_sort_spill_streams(RuntimeState* state) { auto& parent = Base::_parent->template cast(); LOG(INFO) << "sort node " << _parent->node_id() << " merge spill data"; - DCHECK(!_is_merging); - _is_merging = true; _dependency->Dependency::block(); Status status; Defer defer {[&]() { if (!status.ok()) { - _is_merging = false; _dependency->Dependency::set_ready(); } }}; @@ -108,12 +98,7 @@ Status SpillSortLocalState::initiate_merge_sort_spill_streams(RuntimeState* stat } else { LOG(INFO) << "sort node " << _parent->node_id() << " merge spill data finish"; } - { - std::unique_lock lk(_merge_spill_lock); - _is_merging = false; - _dependency->Dependency::set_ready(); - _merge_spill_cv.notify_one(); - } + _dependency->Dependency::set_ready(); }}; vectorized::Block merge_sorted_block; vectorized::SpillStreamSPtr tmp_stream; @@ -252,15 +237,15 @@ Status SpillSortSourceOperatorX::get_block(RuntimeState* state, vectorized::Bloc SCOPED_TIMER(local_state.exec_time_counter()); RETURN_IF_ERROR(local_state._status); - if (!local_state.Base::_shared_state->enable_spill) { - RETURN_IF_ERROR( - _sort_source_operator->get_block(local_state._runtime_state.get(), block, eos)); - } else { + if (local_state.Base::_shared_state->enable_spill && local_state._shared_state->is_spilled) { if (!local_state._merger) { return local_state.initiate_merge_sort_spill_streams(state); } else { RETURN_IF_ERROR(local_state._merger->get_next(block, eos)); } + } else { + RETURN_IF_ERROR( + _sort_source_operator->get_block(local_state._runtime_state.get(), block, eos)); } local_state.reached_limit(block, eos); return Status::OK(); diff --git a/be/src/pipeline/exec/spill_sort_source_operator.h b/be/src/pipeline/exec/spill_sort_source_operator.h index 8132dd5a56c9e3..a20eb57889bc83 100644 --- a/be/src/pipeline/exec/spill_sort_source_operator.h +++ b/be/src/pipeline/exec/spill_sort_source_operator.h @@ -65,9 +65,6 @@ class SpillSortLocalState final : public PipelineXSpillLocalState _current_merging_streams; std::unique_ptr _merger; - bool _is_merging = false; - std::mutex _merge_spill_lock; - std::condition_variable _merge_spill_cv; std::unique_ptr _internal_runtime_profile; // counters for spill merge sort diff --git a/be/src/pipeline/pipeline_x/dependency.h b/be/src/pipeline/pipeline_x/dependency.h index 77cd121be5e413..1af48748d6c298 100644 --- a/be/src/pipeline/pipeline_x/dependency.h +++ b/be/src/pipeline/pipeline_x/dependency.h @@ -397,6 +397,7 @@ struct PartitionedAggSharedState : public BasicSharedState, size_t partition_count; size_t max_partition_index; Status sink_status; + bool is_spilled = false; std::deque> spill_partitions; size_t get_partition_index(size_t hash_value) const { @@ -470,6 +471,7 @@ struct SpillSortSharedState : public BasicSharedState, SortSharedState* in_mem_shared_state = nullptr; bool enable_spill = false; + bool is_spilled = false; Status sink_status; std::shared_ptr in_mem_shared_state_sptr; diff --git a/be/src/vec/common/sort/sorter.cpp b/be/src/vec/common/sort/sorter.cpp index 53fc20112326c4..db3cca8bf09098 100644 --- a/be/src/vec/common/sort/sorter.cpp +++ b/be/src/vec/common/sort/sorter.cpp @@ -63,6 +63,7 @@ void MergeSorterState::reset() { cursors_.swap(empty_cursors); std::vector empty_blocks(0); sorted_blocks_.swap(empty_blocks); + unsorted_block_ = Block::create_unique(unsorted_block_->clone_empty()); in_mem_sorted_bocks_size_ = 0; } Status MergeSorterState::add_sorted_block(Block& block) { @@ -70,8 +71,8 @@ Status MergeSorterState::add_sorted_block(Block& block) { if (0 == rows) { return Status::OK(); } - sorted_blocks_.emplace_back(std::move(block)); in_mem_sorted_bocks_size_ += block.bytes(); + sorted_blocks_.emplace_back(std::move(block)); num_rows_ += rows; return Status::OK(); } diff --git a/be/src/vec/spill/spill_stream.cpp b/be/src/vec/spill/spill_stream.cpp index d08b63df40b0bd..f245f8fa3095a3 100644 --- a/be/src/vec/spill/spill_stream.cpp +++ b/be/src/vec/spill/spill_stream.cpp @@ -68,8 +68,14 @@ void SpillStream::close() { read_promise_.reset(); } - (void)writer_->close(); - (void)reader_->close(); + if (writer_) { + (void)writer_->close(); + writer_.reset(); + } + if (reader_) { + (void)reader_->close(); + reader_.reset(); + } } const std::string& SpillStream::get_spill_root_dir() const { @@ -100,13 +106,16 @@ Status SpillStream::spill_block(const Block& block, bool eof) { size_t written_bytes = 0; RETURN_IF_ERROR(writer_->write(block, written_bytes)); if (eof) { - return writer_->close(); + RETURN_IF_ERROR(writer_->close()); + writer_.reset(); } return Status::OK(); } Status SpillStream::spill_eof() { - return writer_->close(); + RETURN_IF_ERROR(writer_->close()); + writer_.reset(); + return Status::OK(); } Status SpillStream::read_next_block_sync(Block* block, bool* eos) { From 5aba0524acff15ab9b8746888c53154b71b02a3e Mon Sep 17 00:00:00 2001 From: abmdocrt Date: Mon, 1 Apr 2024 22:53:21 +0800 Subject: [PATCH 02/12] [Enhancement](merge-on-write) Support dynamic delete bitmap cache (#32991) * The default delete bitmap cache is set to 100MB, which can be insufficient and cause performance issues when the amount of user data is large. To mitigate the problem of an inadequate cache, we will take the larger of 5% of the total memory and 100MB as the delete bitmap cache size. --- be/src/common/config.cpp | 5 +++++ be/src/common/config.h | 1 + be/src/olap/tablet_meta.cpp | 15 ++++++++++++++- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/be/src/common/config.cpp b/be/src/common/config.cpp index 9364303a3ea716..2c228f7cf63ef4 100644 --- a/be/src/common/config.cpp +++ b/be/src/common/config.cpp @@ -794,6 +794,11 @@ DEFINE_mInt32(jdbc_connection_pool_cache_clear_time_sec, "28800"); // Global bitmap cache capacity for aggregation cache, size in bytes DEFINE_Int64(delete_bitmap_agg_cache_capacity, "104857600"); +// The default delete bitmap cache is set to 100MB, +// which can be insufficient and cause performance issues when the amount of user data is large. +// To mitigate the problem of an inadequate cache, +// we will take the larger of 0.5% of the total memory and 100MB as the delete bitmap cache size. +DEFINE_String(delete_bitmap_dynamic_agg_cache_limit, "0.5%"); DEFINE_mInt32(delete_bitmap_agg_cache_stale_sweep_time_sec, "1800"); // reference https://github.com/edenhill/librdkafka/blob/master/INTRODUCTION.md#broker-version-compatibility diff --git a/be/src/common/config.h b/be/src/common/config.h index 01d8a123a45b47..9ce6b242248a74 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -848,6 +848,7 @@ DECLARE_mInt32(jdbc_connection_pool_cache_clear_time_sec); // Global bitmap cache capacity for aggregation cache, size in bytes DECLARE_Int64(delete_bitmap_agg_cache_capacity); +DECLARE_String(delete_bitmap_dynamic_agg_cache_limit); DECLARE_mInt32(delete_bitmap_agg_cache_stale_sweep_time_sec); // A common object cache depends on an Sharded LRU Cache. diff --git a/be/src/olap/tablet_meta.cpp b/be/src/olap/tablet_meta.cpp index d8235191b93df5..f8a3bb9e97a1bd 100644 --- a/be/src/olap/tablet_meta.cpp +++ b/be/src/olap/tablet_meta.cpp @@ -38,6 +38,8 @@ #include "olap/tablet_meta_manager.h" #include "olap/utils.h" #include "util/debug_points.h" +#include "util/mem_info.h" +#include "util/parse_util.h" #include "util/string_util.h" #include "util/time.h" #include "util/uid_util.h" @@ -924,7 +926,18 @@ bool operator!=(const TabletMeta& a, const TabletMeta& b) { } DeleteBitmap::DeleteBitmap(int64_t tablet_id) : _tablet_id(tablet_id) { - _agg_cache.reset(new AggCache(config::delete_bitmap_agg_cache_capacity)); + // The default delete bitmap cache is set to 100MB, + // which can be insufficient and cause performance issues when the amount of user data is large. + // To mitigate the problem of an inadequate cache, + // we will take the larger of 0.5% of the total memory and 100MB as the delete bitmap cache size. + bool is_percent = false; + int64_t delete_bitmap_agg_cache_cache_limit = + ParseUtil::parse_mem_spec(config::delete_bitmap_dynamic_agg_cache_limit, + MemInfo::mem_limit(), MemInfo::physical_mem(), &is_percent); + _agg_cache.reset(new AggCache(delete_bitmap_agg_cache_cache_limit > + config::delete_bitmap_agg_cache_capacity + ? delete_bitmap_agg_cache_cache_limit + : config::delete_bitmap_agg_cache_capacity)); } DeleteBitmap::DeleteBitmap(const DeleteBitmap& o) { From c3ffbdb799b1a1f62e8a03f1657782e827441acd Mon Sep 17 00:00:00 2001 From: minghong Date: Tue, 2 Apr 2024 08:52:03 +0800 Subject: [PATCH 03/12] [feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087) * cse fe part --- .../translator/PhysicalPlanTranslator.java | 50 +++++-- .../post/CommonSubExpressionCollector.java | 59 ++++++++ .../post/CommonSubExpressionOpt.java | 125 +++++++++++++++++ .../processor/post/PlanPostProcessors.java | 3 +- .../trees/plans/physical/PhysicalProject.java | 81 ++++++++++- .../org/apache/doris/planner/PlanNode.java | 38 ++++- .../doris/catalog/CreateFunctionTest.java | 41 ++++-- .../postprocess/CommonSubExpressionTest.java | 131 ++++++++++++++++++ .../data/tpch_sf0.1_p1/sql/cse.out | 30 ++++ .../regression/action/ExplainAction.groovy | 15 ++ .../suites/tpch_sf0.1_p1/sql/cse.groovy | 49 +++++++ 11 files changed, 591 insertions(+), 31 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java create mode 100644 regression-test/data/tpch_sf0.1_p1/sql/cse.out create mode 100644 regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 15e149cdd4d886..ab72d995573c18 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1833,15 +1833,38 @@ public PlanFragment visitPhysicalProject(PhysicalProject project registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot()); } - List projectionExprs = project.getProjects() - .stream() - .map(e -> ExpressionTranslator.translate(e, context)) - .collect(Collectors.toList()); - List slots = project.getProjects() - .stream() - .map(NamedExpression::toSlot) - .collect(Collectors.toList()); - + PlanNode inputPlanNode = inputFragment.getPlanRoot(); + List projectionExprs = null; + List allProjectionExprs = Lists.newArrayList(); + List slots = null; + if (project.hasMultiLayerProjection()) { + int layerCount = project.getMultiLayerProjects().size(); + for (int i = 0; i < layerCount; i++) { + List layer = project.getMultiLayerProjects().get(i); + projectionExprs = layer.stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = layer.stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + if (i < layerCount - 1) { + inputPlanNode.addIntermediateProjectList(projectionExprs); + TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context); + inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple); + } + allProjectionExprs.addAll(projectionExprs); + } + } else { + projectionExprs = project.getProjects() + .stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = project.getProjects() + .stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + allProjectionExprs.addAll(projectionExprs); + } // process multicast sink if (inputFragment instanceof MultiCastPlanFragment) { MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink(); @@ -1853,10 +1876,9 @@ public PlanFragment visitPhysicalProject(PhysicalProject project return inputFragment; } - PlanNode inputPlanNode = inputFragment.getPlanRoot(); List conjuncts = inputPlanNode.getConjuncts(); Set requiredSlotIdSet = Sets.newHashSet(); - for (Expr expr : projectionExprs) { + for (Expr expr : allProjectionExprs) { Expr.extractSlots(expr, requiredSlotIdSet); } Set requiredByProjectSlotIdSet = Sets.newHashSet(requiredSlotIdSet); @@ -1891,8 +1913,10 @@ public PlanFragment visitPhysicalProject(PhysicalProject project requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e))); for (ExprId exprId : requiredExprIds) { SlotId slotId = ((HashJoinNode) joinNode).getHashOutputExprSlotIdMap().get(exprId); - Preconditions.checkState(slotId != null); - ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + // Preconditions.checkState(slotId != null); + if (slotId != null) { + ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + } } } return inputFragment; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java new file mode 100644 index 00000000000000..5abc5f6f60ffa2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +/** + * collect common expr + */ +public class CommonSubExpressionCollector extends ExpressionVisitor { + public final Map> commonExprByDepth = new HashMap<>(); + private final Map> expressionsByDepth = new HashMap<>(); + + @Override + public Integer visit(Expression expr, Void context) { + if (expr.children().isEmpty()) { + return 0; + } + return collectCommonExpressionByDepth(expr.children().stream().map(child -> + child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr); + } + + private int collectCommonExpressionByDepth(int depth, Expression expr) { + Set expressions = getExpressionsFromDepthMap(depth, expressionsByDepth); + if (expressions.contains(expr)) { + Set commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth); + commonExpression.add(expr); + } + expressions.add(expr); + return depth; + } + + public static Set getExpressionsFromDepthMap( + int depth, Map> depthMap) { + depthMap.putIfAbsent(depth, new LinkedHashSet<>()); + return depthMap.get(depth); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java new file mode 100644 index 00000000000000..dfaf2de757e45e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; + +import com.google.common.collect.Lists; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Select A+B, (A+B+C)*2, (A+B+C)*3, D from T + * + * before optimize + * projection: + * Proj: A+B, (A+B+C)*2, (A+B+C)*3, D + * + * --- + * after optimize: + * Projection: List < List < Expression > > + * A+B, C, D + * A+B, A+B+C, D + * A+B, (A+B+C)*2, (A+B+C)*3, D + */ +public class CommonSubExpressionOpt extends PlanPostProcessor { + @Override + public PhysicalProject visitPhysicalProject(PhysicalProject project, CascadesContext ctx) { + + List> multiLayers = computeMultiLayerProjections( + project.getInputSlots(), project.getProjects()); + project.setMultiLayerProjects(multiLayers); + return project; + } + + private List> computeMultiLayerProjections( + Set inputSlots, List projects) { + + List> multiLayers = Lists.newArrayList(); + CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); + for (Expression expr : projects) { + expr.accept(collector, null); + } + Map commonExprToAliasMap = new HashMap<>(); + collector.commonExprByDepth.values().stream().flatMap(expressions -> expressions.stream()) + .forEach(expression -> { + if (expression instanceof Alias) { + commonExprToAliasMap.put(expression, (Alias) expression); + } else { + commonExprToAliasMap.put(expression, new Alias(expression)); + } + }); + Map aliasMap = new HashMap<>(); + if (!collector.commonExprByDepth.isEmpty()) { + for (int i = 1; i <= collector.commonExprByDepth.size(); i++) { + List layer = Lists.newArrayList(); + layer.addAll(inputSlots); + Set exprsInDepth = CommonSubExpressionCollector + .getExpressionsFromDepthMap(i, collector.commonExprByDepth); + exprsInDepth.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + Alias alias = new Alias(rewritten); + aliasMap.put(expr, alias); + }); + layer.addAll(aliasMap.values()); + multiLayers.add(layer); + } + // final layer + List finalLayer = Lists.newArrayList(); + projects.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + if (rewritten instanceof Slot) { + finalLayer.add((NamedExpression) rewritten); + } else if (rewritten instanceof Alias) { + finalLayer.add(new Alias(expr.getExprId(), ((Alias) rewritten).child(), expr.getName())); + } + }); + multiLayers.add(finalLayer); + } + return multiLayers; + } + + /** + * replace sub expr by aliasMap + */ + public static class ExpressionReplacer + extends DefaultExpressionRewriter> { + public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); + + private ExpressionReplacer() { + } + + @Override + public Expression visit(Expression expr, Map replaceMap) { + if (replaceMap.containsKey(expr)) { + return replaceMap.get(expr).toSlot(); + } + return super.visit(expr, replaceMap); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java index 60c1a74445e1ff..86c8486ef45710 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java @@ -63,8 +63,9 @@ public List getProcessors() { builder.add(new MergeProjectPostProcessor()); builder.add(new RecomputeLogicalPropertiesProcessor()); builder.add(new AddOffsetIntoDistribute()); + builder.add(new CommonSubExpressionOpt()); + // DO NOT replace PLAN NODE from here builder.add(new TopNScanOpt()); - // after generate rf, DO NOT replace PLAN NODE builder.add(new FragmentProcessor()); if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode() .toUpperCase().equals(TRuntimeFilterMode.OFF.name())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java index af7bb950a97d96..e8472b6af23a6e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -41,6 +42,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.util.List; import java.util.Objects; @@ -52,6 +54,12 @@ public class PhysicalProject extends PhysicalUnary implements Project { private final List projects; + //multiLayerProjects is used to extract common expressions + // projects: (A+B) * 2, (A+B) * 3 + // multiLayerProjects: + // L1: A+B as x + // L2: x*2, x*3 + private List> multiLayerProjects = Lists.newArrayList(); public PhysicalProject(List projects, LogicalProperties logicalProperties, CHILD_TYPE child) { this(projects, Optional.empty(), logicalProperties, child); @@ -227,7 +235,12 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator computeOutput() { - return projects.stream() + List output = projects; + if (! multiLayerProjects.isEmpty()) { + int layers = multiLayerProjects.size(); + output = multiLayerProjects.get(layers - 1); + } + return output.stream() .map(NamedExpression::toSlot) .collect(ImmutableList.toImmutableList()); } @@ -237,4 +250,70 @@ public PhysicalProject resetLogicalProperties() { return new PhysicalProject<>(projects, groupExpression, null, physicalProperties, statistics, child()); } + + /** + * extract common expr, set multi layer projects + */ + public void computeMultiLayerProjectsForCommonExpress() { + // hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + s_nationkey), s_name from supplier; + if (projects.size() == 3) { + if (projects.get(2) instanceof SlotReference) { + SlotReference sName = (SlotReference) projects.get(2); + if (sName.getName().equals("s_name")) { + Alias a1 = (Alias) projects.get(0); // (s_suppkey + s_nationkey) + Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + s_nationkey) + // L1: (s_suppkey + s_nationkey) as x, s_name + multiLayerProjects.add(Lists.newArrayList(projects.get(0), projects.get(2))); + List l2 = Lists.newArrayList(); + l2.add(a1.toSlot()); + Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), a2.child().child(1)), a2.getName()); + l2.add(a3); + l2.add(sName); + // L2: x, (1+x) as y, s_name + multiLayerProjects.add(l2); + } + } + } + // hard code: + // select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y + // from supplier join nation on s_nationkey=n_nationkey + // projects: x, y + // multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z + // L2: z +1 as x, z+2 as y + if (projects.size() == 2 && projects.get(0) instanceof Alias && projects.get(1) instanceof Alias + && ((Alias) projects.get(0)).getName().equals("x") + && ((Alias) projects.get(1)).getName().equals("y")) { + Alias a0 = (Alias) projects.get(0); + Alias a1 = (Alias) projects.get(1); + Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey + List l1 = Lists.newArrayList(); + common.children().stream().forEach(child -> l1.add((SlotReference) child)); + Alias aliasOfCommon = new Alias(common); + l1.add(aliasOfCommon); + multiLayerProjects.add(l1); + Add add1 = new Add(common, a0.child().child(0).child(1)); + Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName()); + Add add2 = new Add(common, a1.child().child(0).child(1)); + Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName()); + List l2 = Lists.newArrayList(aliasOfAdd1, aliasOfAdd2); + multiLayerProjects.add(l2); + } + } + + public boolean hasMultiLayerProjection() { + return !multiLayerProjects.isEmpty(); + } + + public List> getMultiLayerProjects() { + return multiLayerProjects; + } + + public void setMultiLayerProjects(List> multiLayers) { + this.multiLayerProjects = multiLayers; + } + + @Override + public List getOutput() { + return computeOutput(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java index b404bc4ad3545c..8cc18a527a86d6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java @@ -59,6 +59,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * Each PlanNode represents a single relational operator @@ -155,6 +156,8 @@ public abstract class PlanNode extends TreeNode implements PlanStats { protected int nereidsId = -1; private List> childrenDistributeExprLists = new ArrayList<>(); + private List intermediateOutputTupleDescList = Lists.newArrayList(); + private List> intermediateProjectListList = Lists.newArrayList(); protected PlanNode(PlanNodeId id, ArrayList tupleIds, String planNodeName, StatisticalType statisticalType) { @@ -536,10 +539,20 @@ protected final String getExplainString(String rootPrefix, String prefix, TExpla expBuilder.append(detailPrefix + "limit: " + limit + "\n"); } if (!CollectionUtils.isEmpty(projectList)) { - expBuilder.append(detailPrefix).append("projections: ").append(getExplainString(projectList)).append("\n"); - expBuilder.append(detailPrefix).append("project output tuple id: ") + expBuilder.append(detailPrefix).append("final projections: ") + .append(getExplainString(projectList)).append("\n"); + expBuilder.append(detailPrefix).append("final project output tuple id: ") .append(outputTupleDesc.getId().asInt()).append("\n"); } + if (!intermediateProjectListList.isEmpty()) { + int layers = intermediateProjectListList.size(); + for (int i = layers - 1; i >= 0; i--) { + expBuilder.append(detailPrefix).append("intermediate projections: ") + .append(getExplainString(intermediateProjectListList.get(i))).append("\n"); + expBuilder.append(detailPrefix).append("intermediate tuple id: ") + .append(intermediateOutputTupleDescList.get(i).getId().asInt()).append("\n"); + } + } if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) { for (List distributeExprList : childrenDistributeExprLists) { expBuilder.append(detailPrefix).append("distribute expr lists: ") @@ -660,6 +673,19 @@ private void treeToThriftHelper(TPlan container) { } } } + + if (!intermediateOutputTupleDescList.isEmpty()) { + intermediateOutputTupleDescList + .forEach( + tupleDescriptor -> msg.addToIntermediateOutputTupleIdList(tupleDescriptor.getId().asInt())); + } + + if (!intermediateProjectListList.isEmpty()) { + intermediateProjectListList.forEach( + projectList -> msg.addToIntermediateProjectionsList( + projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList()))); + } + if (this instanceof ExchangeNode) { msg.num_children = 0; return; @@ -1221,4 +1247,12 @@ public boolean pushDownAggNoGroupingCheckCol(FunctionCallExpr aggExpr, Column co public void setNereidsId(int nereidsId) { this.nereidsId = nereidsId; } + + public void addIntermediateOutputTupleDescList(TupleDescriptor tupleDescriptor) { + intermediateOutputTupleDescList.add(tupleDescriptor); + } + + public void addIntermediateProjectList(List exprs) { + intermediateProjectListList.add(exprs); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java index 0f464ba2946b7d..c342d858fe1fb5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java @@ -74,6 +74,7 @@ public static void teardown() { public void test() throws Exception { ConnectContext ctx = UtFrameUtils.createDefaultCtx(); ctx.getSessionVariable().setEnableNereidsPlanner(false); + ctx.getSessionVariable().enableFallbackToOriginalPlanner = true; ctx.getSessionVariable().setEnableFoldConstantByBe(false); // create database db1 createDatabase(ctx, "create database db1;"); @@ -113,8 +114,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr); queryStr = "select db1.id_masking(k1) from db1.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // create alias function with cast // cast any type to decimal with specific precision and scale @@ -142,14 +143,16 @@ public void test() throws Exception { queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // cast any type to varchar with fixed length - createFuncStr = "create alias function db1.varchar(all) with parameter(text) as " - + "cast(text as varchar(65533));"; + createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as " + + "cast(text as varchar(length));"; createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx); Env.getCurrentEnv().createFunction(createFunctionStmt); @@ -172,7 +175,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.varchar(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // cast any type to char with fixed length createFuncStr = "create alias function db1.to_char(all, int) with parameter(text, length) as " @@ -199,7 +203,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.to_char(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER")); } @Test @@ -235,8 +240,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, false); queryStr = "select id_masking(k1) from db2.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // 4. create alias function with cast // cast any type to decimal with specific precision and scale @@ -253,9 +258,11 @@ public void testCreateGlobalFunction() throws Exception { queryStr = "select decimal(k3, 4, 1) from db2.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // 5. cast any type to varchar with fixed length @@ -271,7 +278,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, true); queryStr = "select varchar(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // 6. cast any type to char with fixed length createFuncStr = "create global alias function db2.to_char(all, int) with parameter(text, length) as " @@ -286,7 +294,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, true); queryStr = "select to_char(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER)")); } private void testFunctionQuery(ConnectContext ctx, String queryStr, Boolean isStringLiteral) throws Exception { @@ -320,4 +329,8 @@ private void createDatabase(ConnectContext ctx, String createDbStmtStr) throws E Env.getCurrentEnv().createDb(createDbStmt); System.out.println(Env.getCurrentInternalCatalog().getDbNames()); } + + private boolean containsIgnoreCase(String str, String sub) { + return str.toLowerCase().contains(sub.toLowerCase()); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java new file mode 100644 index 00000000000000..56b67e087d59ab --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.postprocess; + +import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector; +import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.IntegerType; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class CommonSubExpressionTest extends ExpressionRewriteTestHelper { + @Test + public void testExtractCommonExpr() { + List exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a"); + CommonSubExpressionCollector collector = + new CommonSubExpressionCollector(); + exprs.forEach(expr -> collector.visit(expr, null)); + System.out.println(collector.commonExprByDepth); + Assertions.assertEquals(2, collector.commonExprByDepth.size()); + List l1 = collector.commonExprByDepth.get(Integer.valueOf(1)) + .stream().collect(Collectors.toList()); + List l2 = collector.commonExprByDepth.get(Integer.valueOf(2)) + .stream().collect(Collectors.toList()); + Assertions.assertEquals(1, l1.size()); + assertExpression(l1.get(0), "a+b"); + Assertions.assertEquals(1, l2.size()); + assertExpression(l2.get(0), "a+b+1"); + } + + @Test + public void testMultiLayers() throws Exception { + List exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a"); + Set inputSlots = exprs.get(0).getInputSlots(); + CommonSubExpressionOpt opt = new CommonSubExpressionOpt(); + Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class + .getDeclaredMethod("computeMultiLayerProjections", Set.class, List.class); + computeMultLayerProjectionsMethod.setAccessible(true); + List> multiLayers = (List>) computeMultLayerProjectionsMethod + .invoke(opt, inputSlots, exprs); + System.out.println(multiLayers); + Assertions.assertEquals(3, multiLayers.size()); + List l0 = multiLayers.get(0); + Assertions.assertEquals(2, l0.size()); + Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a"))); + Assertions.assertTrue(l0.get(1) instanceof Alias); + assertExpression(l0.get(1).child(0), "a+b"); + Assertions.assertEquals(multiLayers.get(1).size(), 3); + Assertions.assertEquals(multiLayers.get(2).size(), 5); + List l2 = multiLayers.get(2); + for (int i = 0; i < 5; i++) { + Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt()); + } + + } + + private void assertExpression(Expression expr, String str) { + Assertions.assertEquals(ExprParser.INSTANCE.parseExpression(str), expr); + } + + private List parseProjections(String exprList) { + List result = new ArrayList<>(); + String[] exprArray = exprList.split(","); + for (String item : exprArray) { + Expression expr = ExprParser.INSTANCE.parseExpression(item); + if (expr instanceof NamedExpression) { + result.add((NamedExpression) expr); + } else { + result.add(new Alias(expr)); + } + } + return result; + } + + public static class ExprParser { + public static ExprParser INSTANCE = new ExprParser(); + HashMap slotMap = new HashMap<>(); + + public Expression parseExpression(String str) { + Expression expr = PARSER.parseExpression(str); + return expr.accept(DataTypeAssignor.INSTANCE, slotMap); + } + } + + public static class DataTypeAssignor extends DefaultExpressionRewriter> { + public static DataTypeAssignor INSTANCE = new DataTypeAssignor(); + + @Override + public Expression visitSlot(Slot slot, Map slotMap) { + SlotReference exitsSlot = slotMap.get(slot.getName()); + if (exitsSlot != null) { + return exitsSlot; + } else { + SlotReference slotReference = new SlotReference(slot.getName(), IntegerType.INSTANCE); + slotMap.put(slot.getName(), slotReference); + return slotReference; + } + } + } + +} diff --git a/regression-test/data/tpch_sf0.1_p1/sql/cse.out b/regression-test/data/tpch_sf0.1_p1/sql/cse.out new file mode 100644 index 00000000000000..5ab44655661dbf --- /dev/null +++ b/regression-test/data/tpch_sf0.1_p1/sql/cse.out @@ -0,0 +1,30 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !cse -- +1 1 3 4 +2 0 3 4 +3 1 5 6 +4 0 5 6 +5 4 10 11 +6 0 7 8 +7 3 11 12 +8 1 10 11 +9 4 14 15 +10 1 12 13 + +-- !cse_2 -- +17 1 18 19 19 +5 2 7 8 8 +1 3 4 5 5 +15 4 19 20 20 +11 5 16 17 17 +14 6 20 21 21 +23 7 30 31 31 +17 8 25 26 26 +10 9 19 20 20 +24 10 34 35 35 + +-- !cse_3 -- +12093 13093 14093 15093 + +-- !cse_4 -- +12093 13093 14093 15093 \ No newline at end of file diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy index e6f05c6c7655f9..cf0c03fc3bd73b 100644 --- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy +++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy @@ -32,6 +32,7 @@ class ExplainAction implements SuiteAction { private SuiteContext context private Set containsStrings = new LinkedHashSet<>() private Set notContainsStrings = new LinkedHashSet<>() + private Map multiContainsStrings = new HashMap<>() private String coonType private Closure checkFunction @@ -56,6 +57,10 @@ class ExplainAction implements SuiteAction { containsStrings.add(subString) } + void multiContains(String subString, int n) { + multiContainsStrings.put(subString, n); + } + void notContains(String subString) { notContainsStrings.add(subString) } @@ -112,6 +117,16 @@ class ExplainAction implements SuiteAction { throw t } } + for (Map.Entry entry : multiContainsStrings) { + int count = explainString.count(entry.key); + if (count != entry.value) { + String msg = ("Explain and check failed, expect multiContains '${string}' , '${entry.value}' times, actural '${count}' times." + + "Actual explain string is:\n${explainString}").toString() + log.info(msg) + def t = new IllegalStateException(msg) + throw t + } + } } } diff --git a/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy new file mode 100644 index 00000000000000..698dbd3e5d0093 --- /dev/null +++ b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// The cases is copied from https://github.com/trinodb/trino/tree/master +// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/tpcds +// and modified by Doris. + +suite('cse') { + def q1 = """select s_suppkey,n_regionkey,(s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y + from supplier join nation on s_nationkey=n_nationkey order by s_suppkey , n_regionkey limit 10 ; + """ + + def q2 = """select s_nationkey,s_suppkey ,(s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1, abs((s_nationkey + s_suppkey) + 1) + from supplier order by s_suppkey , s_suppkey limit 10 ;""" + + qt_cse "${q1}" + + explain { + sql "${q1}" + contains "intermediate projections:" + } + + qt_cse_2 "${q2}" + + explain { + sql "${q2}" + multiContains("intermediate projections:", 2) + } + + qt_cse_3 """ select sum(s_nationkey),sum(s_nationkey +1 ) ,sum(s_nationkey +2 ) , sum(s_nationkey + 3 ) from supplier ;""" + + qt_cse_4 """select sum(s_nationkey),sum(s_nationkey) + count(1) ,sum(s_nationkey) + 2 * count(1) , sum(s_nationkey) + 3 * count(1) from supplier ;""" + + +} From 6a43093d5523be95e3f339032cbbc87006537492 Mon Sep 17 00:00:00 2001 From: LiBinfeng <46676950+LiBinfeng-01@users.noreply.github.com> Date: Tue, 2 Apr 2024 10:42:09 +0800 Subject: [PATCH 04/12] [Fix](Nereids) ntile function should check argument (#32994) Problem: when ntile using 0 as parameter, be would core because no checking of parameter Solved: check parameter in fe analyze --- .../expressions/functions/window/Ntile.java | 24 +++++++++++++++++++ .../test_ntile_function.groovy | 21 +++++++++------- .../test_ntile_function.groovy | 10 ++++++++ 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/window/Ntile.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/window/Ntile.java index 16321d6828004b..d1d2ee57736fed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/window/Ntile.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/window/Ntile.java @@ -18,12 +18,15 @@ package org.apache.doris.nereids.trees.expressions.functions.window; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.shape.LeafExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.LargeIntType; import org.apache.doris.nereids.types.SmallIntType; @@ -64,6 +67,27 @@ public Ntile withChildren(List children) { return new Ntile(children.get(0)); } + @Override + public void checkLegalityBeforeTypeCoercion() { + DataType type = getBuckets().getDataType(); + if (!type.isIntegralType()) { + throw new AnalysisException("The bucket of NTILE must be a integer: " + this.toSql()); + } + if (!getBuckets().isConstant()) { + throw new AnalysisException( + "The bucket of NTILE must be a constant value: " + this.toSql()); + } + if (getBuckets() instanceof Literal) { + if (((Literal) getBuckets()).getDouble() <= 0) { + throw new AnalysisException( + "The bucket parameter of NTILE must be a constant positive integer: " + this.toSql()); + } + } else { + throw new AnalysisException( + "The bucket parameter of NTILE must be a constant positive integer: " + this.toSql()); + } + } + @Override public List getSignatures() { return SIGNATURES; diff --git a/regression-test/suites/nereids_p0/sql_functions/window_functions/test_ntile_function.groovy b/regression-test/suites/nereids_p0/sql_functions/window_functions/test_ntile_function.groovy index 0767be9579fdf3..9532fea2d29675 100644 --- a/regression-test/suites/nereids_p0/sql_functions/window_functions/test_ntile_function.groovy +++ b/regression-test/suites/nereids_p0/sql_functions/window_functions/test_ntile_function.groovy @@ -68,14 +68,19 @@ suite("test_ntile_function") { } sql "sync" - // Nereids does't support window function - // qt_select "select k1, k2, k3, ntile(3) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" - // Nereids does't support window function - // qt_select "select k1, k2, k3, ntile(5) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" - // Nereids does't support window function - // qt_select "select k2, k1, k3, ntile(3) over (order by k2) as ntile from ${tableName} order by k2, k1, k3 desc;" - // Nereids does't support window function - // qt_select "select k3, k2, k1, ntile(3) over (partition by k3 order by k2) as ntile from ${tableName} order by k3, k2, k1;" + qt_select "select k1, k2, k3, ntile(3) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" + qt_select "select k1, k2, k3, ntile(5) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" + qt_select "select k2, k1, k3, ntile(3) over (order by k2) as ntile from ${tableName} order by k2, k1, k3 desc;" + qt_select "select k3, k2, k1, ntile(3) over (partition by k3 order by k2) as ntile from ${tableName} order by k3, k2, k1;" + test { + sql "select k1, k2, k3, ntile(0) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" + exception "The bucket parameter of NTILE must be a constant positive integer: ntile(0)" + } + test { + sql "select k1, k2, k3, ntile(k1) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" + exception "The bucket of NTILE must be a constant value: ntile(k1)" + } + } diff --git a/regression-test/suites/query_p0/sql_functions/window_functions/test_ntile_function.groovy b/regression-test/suites/query_p0/sql_functions/window_functions/test_ntile_function.groovy index 9eedc243c18520..7a92095af2967f 100644 --- a/regression-test/suites/query_p0/sql_functions/window_functions/test_ntile_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/window_functions/test_ntile_function.groovy @@ -70,6 +70,16 @@ suite("test_ntile_function") { qt_select "select k1, k2, k3, ntile(5) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" qt_select "select k2, k1, k3, ntile(3) over (order by k2) as ntile from ${tableName} order by k2, k1, k3 desc;" qt_select "select k3, k2, k1, ntile(3) over (partition by k3 order by k2) as ntile from ${tableName} order by k3, k2, k1;" + + test { + sql "select k1, k2, k3, ntile(0) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" + exception "Parameter n in ntile(n) should be positive." + } + + test { + sql "select k1, k2, k3, ntile(k1) over (partition by k1 order by k2) as ntile from ${tableName} order by k1, k2, k3 desc;" + exception "Parameter n in ntile(n) should be constant positive integer." + } } From 45707f3019e7f8e74a838a93960e3da5b8a90dd3 Mon Sep 17 00:00:00 2001 From: lihangyu <15605149486@163.com> Date: Tue, 2 Apr 2024 11:15:28 +0800 Subject: [PATCH 05/12] [Optimize] Move strings_pool from individual tree nodes to the tree itself (#33089) Previously, strings_pool was allocated within each tree node. However, due to the Arena's alignment of allocated chunks to at least 4K, this allocation size was excessively large for a single tree node. Consequently, when there are numerous nodes within the SubcolumnTree, a significant portion of memory was wasted. Moving strings_pool to the tree itself optimizes memory usage and reduces wastage, improving overall efficiency. --- be/src/vec/columns/subcolumn_tree.h | 48 +++++++++++++---------------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/be/src/vec/columns/subcolumn_tree.h b/be/src/vec/columns/subcolumn_tree.h index 86baa46e4c851a..75893f7a9f7d18 100644 --- a/be/src/vec/columns/subcolumn_tree.h +++ b/be/src/vec/columns/subcolumn_tree.h @@ -24,6 +24,7 @@ #include "runtime/exec_env.h" #include "runtime/thread_context.h" #include "vec/columns/column.h" +#include "vec/common/arena.h" #include "vec/common/hash_table/hash_map.h" #include "vec/common/string_ref.h" #include "vec/data_types/data_type.h" @@ -38,28 +39,17 @@ class SubcolumnsTree { struct Node { enum Kind { TUPLE, NESTED, SCALAR }; - explicit Node(Kind kind_) : kind(kind_) { init_memory(); } - Node(Kind kind_, const NodeData& data_) : kind(kind_), data(data_) { init_memory(); } + explicit Node(Kind kind_) : kind(kind_) {} + Node(Kind kind_, const NodeData& data_) : kind(kind_), data(data_) {} Node(Kind kind_, const NodeData& data_, const PathInData& path_) - : kind(kind_), data(data_), path(path_) { - init_memory(); - } - Node(Kind kind_, NodeData&& data_) : kind(kind_), data(std::move(data_)) { init_memory(); } + : kind(kind_), data(data_), path(path_) {} + Node(Kind kind_, NodeData&& data_) : kind(kind_), data(std::move(data_)) {} Node(Kind kind_, NodeData&& data_, const PathInData& path_) - : kind(kind_), data(std::move(data_)), path(path_) { - init_memory(); - } - - ~Node() { - SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER( - ExecEnv::GetInstance()->subcolumns_tree_tracker()); - strings_pool.reset(); - } + : kind(kind_), data(std::move(data_)), path(path_) {} Kind kind = TUPLE; const Node* parent = nullptr; - std::unique_ptr strings_pool; std::unordered_map, StringRefHash> children; NodeData data; @@ -70,12 +60,6 @@ class SubcolumnsTree { bool is_leaf_node() const { return kind == SCALAR && children.empty(); } - void init_memory() { - SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER( - ExecEnv::GetInstance()->subcolumns_tree_tracker()); - strings_pool = std::make_unique(); - } - // Only modify data and kind void modify(std::shared_ptr&& other) { data = std::move(other->data); @@ -89,13 +73,13 @@ class SubcolumnsTree { kind = Kind::SCALAR; } - void add_child(std::string_view key, std::shared_ptr next_node) { + void add_child(std::string_view key, std::shared_ptr next_node, Arena& strings_pool) { next_node->parent = this; StringRef key_ref; { SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER( ExecEnv::GetInstance()->subcolumns_tree_tracker()); - key_ref = {strings_pool->insert(key.data(), key.length()), key.length()}; + key_ref = {strings_pool.insert(key.data(), key.length()), key.length()}; } children[key_ref] = std::move(next_node); } @@ -186,7 +170,7 @@ class SubcolumnsTree { } else { auto next_kind = parts[i].is_nested ? Node::NESTED : Node::TUPLE; auto next_node = node_creator(next_kind, false); - current_node->add_child(String(parts[i].key), next_node); + current_node->add_child(String(parts[i].key), next_node, *strings_pool); current_node = next_node.get(); } } @@ -202,7 +186,7 @@ class SubcolumnsTree { } auto next_node = node_creator(Node::SCALAR, false); - current_node->add_child(String(parts.back().key), next_node); + current_node->add_child(String(parts.back().key), next_node, *strings_pool); leaves.push_back(std::move(next_node)); return true; @@ -287,6 +271,16 @@ class SubcolumnsTree { const_iterator begin() const { return leaves.begin(); } const_iterator end() const { return leaves.end(); } + ~SubcolumnsTree() { + SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(ExecEnv::GetInstance()->subcolumns_tree_tracker()); + strings_pool.reset(); + } + + SubcolumnsTree() { + SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(ExecEnv::GetInstance()->subcolumns_tree_tracker()); + strings_pool = std::make_shared(); + } + private: const Node* find_impl(const PathInData& path, bool find_exact) const { if (!root) { @@ -307,7 +301,7 @@ class SubcolumnsTree { return current_node; } - + std::shared_ptr strings_pool; NodePtr root; Nodes leaves; }; From f47e0591c171a93e66bab428c527474096ff905f Mon Sep 17 00:00:00 2001 From: Lei Zhang <27994433+SWJTU-ZhangLei@users.noreply.github.com> Date: Tue, 2 Apr 2024 11:25:45 +0800 Subject: [PATCH 06/12] [fix](fe) partitionInfo is null, fe can not start (#33108) --- .../main/java/org/apache/doris/catalog/CatalogRecycleBin.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/CatalogRecycleBin.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/CatalogRecycleBin.java index 5a6703f6e5c616..46a5ad26ed0551 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/CatalogRecycleBin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/CatalogRecycleBin.java @@ -569,7 +569,8 @@ public synchronized void replayErasePartition(long partitionId) { idToRecycleTime.remove(partitionId); if (partitionInfo == null) { - LOG.error("replayErasePartition: partitionInfo is null for partitionId[{}]", partitionId); + LOG.warn("replayErasePartition: partitionInfo is null for partitionId[{}]", partitionId); + return; } Partition partition = partitionInfo.getPartition(); From 7b63e812ffa8f24c6ab1fc897f4a5ac6fd370367 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Tue, 2 Apr 2024 14:56:26 +0800 Subject: [PATCH 07/12] [enhancement](function truncate) truncate can use column as scale argument (#32746) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- be/src/vec/functions/function_truncate.h | 245 ++++++++++++ be/src/vec/functions/math.cpp | 23 +- be/src/vec/functions/round.h | 224 ++++++++++- .../function_truncate_decimal_test.cpp | 370 ++++++++++++++++++ .../doris/analysis/FunctionCallExpr.java | 32 +- .../functions/ComputePrecisionForRound.java | 40 +- .../math_functions/test_function_truncate.out | 101 +++++ .../test_function_truncate.groovy | 132 +++++++ 8 files changed, 1136 insertions(+), 31 deletions(-) create mode 100644 be/src/vec/functions/function_truncate.h create mode 100644 be/test/vec/function/function_truncate_decimal_test.cpp create mode 100644 regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out create mode 100644 regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy diff --git a/be/src/vec/functions/function_truncate.h b/be/src/vec/functions/function_truncate.h new file mode 100644 index 00000000000000..e29bc99c0417dc --- /dev/null +++ b/be/src/vec/functions/function_truncate.h @@ -0,0 +1,245 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include "common/exception.h" +#include "common/status.h" +#include "olap/olap_common.h" +#include "round.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/call_on_type_index.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" + +namespace doris::vectorized { + +struct TruncateFloatOneArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } +}; + +struct TruncateFloatTwoArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared(), std::make_shared()}; + } +}; + +struct TruncateDecimalOneArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function + // So we can just register Decimal32 here + return {std::make_shared>(9, 0)}; + } +}; + +struct TruncateDecimalTwoArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared>(9, 0), + std::make_shared()}; + } +}; + +template +class FunctionTruncate : public FunctionRounding { +public: + static FunctionPtr create() { return std::make_shared(); } + + ColumnNumbers get_arguments_that_are_always_constant() const override { return {}; } + // SELECT number, truncate(123.345, 1) FROM number("numbers"="10") + // should NOT behave like two column arguments, so we can not use const column default implementation + bool use_default_implementation_for_constants() const override { return false; } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); + ColumnPtr res; + + // potential argument types: + // 0. truncate(ColumnConst, ColumnConst) + // 1. truncate(Column), truncate(Column, ColumnConst) + // 2. truncate(Column, Column) + // 3. truncate(ColumnConst, Column) + + if (arguments.size() == 2 && is_column_const(*block.get_by_position(arguments[0]).column) && + is_column_const(*block.get_by_position(arguments[1]).column)) { + // truncate(ColumnConst, ColumnConst) + auto col_general = + assert_cast(*column_general.column).get_data_column_ptr(); + Int16 scale_arg = 0; + RETURN_IF_ERROR(FunctionTruncate::get_scale_arg( + block.get_by_position(arguments[1]), &scale_arg)); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher::apply_vec_const(col_general, + scale_arg); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + // Important, make sure the result column has the same size as the input column + res = ColumnConst::create(std::move(res), input_rows_count); + } else if (arguments.size() == 1 || + (arguments.size() == 2 && + is_column_const(*block.get_by_position(arguments[1]).column))) { + // truncate(Column) or truncate(Column, ColumnConst) + Int16 scale_arg = 0; + if (arguments.size() == 2) { + RETURN_IF_ERROR(FunctionTruncate::get_scale_arg( + block.get_by_position(arguments[1]), &scale_arg)); + } + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher:: + apply_vec_const(column_general.column.get(), scale_arg); + return true; + } + + return false; + }; +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + + } else if (is_column_const(*block.get_by_position(arguments[0]).column)) { + // truncate(ColumnConst, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + const ColumnConst& const_col_general = + assert_cast(*column_general.column); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher:: + apply_const_vec(&const_col_general, column_scale.column.get()); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + } else { + // truncate(Column, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher:: + apply_vec_vec(column_general.column.get(), column_scale.column.get()); + return true; + } + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + } + + block.replace_by_position(result, std::move(res)); + return Status::OK(); + } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp index dc815cf74e5a49..c0dfe7615764ba 100644 --- a/be/src/vec/functions/math.cpp +++ b/be/src/vec/functions/math.cpp @@ -46,6 +46,7 @@ #include "vec/functions/function_math_unary.h" #include "vec/functions/function_string.h" #include "vec/functions/function_totype.h" +#include "vec/functions/function_truncate.h" #include "vec/functions/function_unary_arithmetic.h" #include "vec/functions/round.h" #include "vec/functions/simple_function_factory.h" @@ -392,16 +393,14 @@ struct DecimalRoundOneImpl { // TODO: Now math may cause one thread compile time too long, because the function in math // so mush. Split it to speed up compile time in the future void register_function_math(SimpleFunctionFactory& factory) { -#define REGISTER_ROUND_FUNCTIONS(IMPL) \ - factory.register_function< \ - FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Trunc, TieBreakingMode::Auto>>(); \ - factory.register_function, RoundingMode::Round, \ +#define REGISTER_ROUND_FUNCTIONS(IMPL) \ + factory.register_function< \ + FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ + factory.register_function, RoundingMode::Round, \ TieBreakingMode::Bankers>>(); REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl) @@ -445,5 +444,9 @@ void register_function_math(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function(); factory.register_function(); + factory.register_function>(); + factory.register_function>(); + factory.register_function>(); + factory.register_function>(); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index 7e48b8e93064d8..a9d1e7a019c0a4 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -20,8 +20,15 @@ #pragma once +#include +#include + +#include "common/exception.h" +#include "common/status.h" #include "vec/columns/column_const.h" #include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/types.h" #include "vec/functions/function.h" #if defined(__SSE4_1__) || defined(__aarch64__) #include "util/sse_util.hpp" @@ -176,6 +183,23 @@ class DecimalRoundingImpl { memcpy(out.data(), in.data(), in.size() * sizeof(T)); } } + + // NOTE: This function is only tested for truncate + // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! + static NO_INLINE void apply(const NativeType& in, UInt32 in_scale, NativeType& out, + Int16 out_scale) { + Int16 scale_arg = in_scale - out_scale; + if (scale_arg > 0) { + size_t scale = int_exp10(scale_arg); + if (out_scale < 0) { + Op::compute(&in, scale, &out, int_exp10(-out_scale)); + } else { + Op::compute(&in, scale, &out, 1); + } + } else { + memcpy(&out, &in, sizeof(NativeType)); + } + } }; template @@ -314,6 +338,11 @@ struct FloatRoundingImpl { memcpy(p_out, &tmp_dst, tail_size_bytes); } } + + static NO_INLINE void apply(const T& in, size_t scale, T& out) { + auto mm_scale = Op::prepare(scale); + Op::compute(&in, mm_scale, &out); + } }; template , IntegerRoundingImpl>>; - static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) { + static ColumnPtr apply_vec_const(const IColumn* col_general, Int16 scale_arg) { if constexpr (IsNumber) { const auto* const col = check_and_get_column>(col_general); auto col_res = ColumnVector::create(); @@ -446,6 +479,179 @@ struct Dispatcher { return nullptr; } } + + // NOTE: This function is only tested for truncate + // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! + static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale) { + if constexpr (rounding_mode != RoundingMode::Trunc) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Using column as scale is only supported for function truncate"); + } + + const ColumnInt32& col_scale_i32 = assert_cast(*col_scale); + const size_t input_row_count = col_scale_i32.size(); + for (size_t i = 0; i < input_row_count; ++i) { + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg > std::numeric_limits::max() || + scale_arg < std::numeric_limits::min()) { + throw doris::Exception(ErrorCode::OUT_OF_BOUND, + "Scale argument for function is out of bound: {}", + scale_arg); + } + } + + if constexpr (IsNumber) { + const auto* col = assert_cast*>(col_general); + auto col_res = ColumnVector::create(); + typename ColumnVector::Container& vec_res = col_res->get_data(); + vec_res.resize(input_row_count); + + for (size_t i = 0; i < input_row_count; ++i) { + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg == 0) { + size_t scale = 1; + FunctionRoundingImpl::apply(col->get_data()[i], scale, + vec_res[i]); + } else if (scale_arg > 0) { + size_t scale = int_exp10(scale_arg); + FunctionRoundingImpl::apply(col->get_data()[i], scale, + vec_res[i]); + } else { + size_t scale = int_exp10(-scale_arg); + FunctionRoundingImpl::apply(col->get_data()[i], scale, + vec_res[i]); + } + } + return col_res; + } else if constexpr (IsDecimalNumber) { + const auto* decimal_col = assert_cast*>(col_general); + + // For truncate, ALWAYS use SAME scale with source Decimal column + const Int32 input_scale = decimal_col->get_scale(); + auto col_res = ColumnDecimal::create(input_row_count, input_scale); + + for (size_t i = 0; i < input_row_count; ++i) { + DecimalRoundingImpl::apply( + decimal_col->get_element(i).value, input_scale, + col_res->get_element(i).value, col_scale_i32.get_data()[i]); + } + + for (size_t i = 0; i < input_row_count; ++i) { + // For truncate(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column + // So we need this check to make sure the result have correct digits count + // + // Case 0: scale_arg <= -(integer part digits count) + // do nothing, because result is 0 + // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count) + // decimal parts has been erased, so add them back by multiply 10^(scale_arg) + // Case 2: scale_arg > 0 && scale_arg < decimal part digits count + // decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg) + // Case 3: scale_arg >= input_scale + // do nothing + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg <= 0) { + col_res->get_element(i).value *= int_exp10(input_scale); + } else if (scale_arg > 0 && scale_arg < input_scale) { + col_res->get_element(i).value *= int_exp10(input_scale - scale_arg); + } + } + + return col_res; + } else { + LOG(FATAL) << "__builtin_unreachable"; + __builtin_unreachable(); + return nullptr; + } + } + + // NOTE: This function is only tested for truncate + // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! only test for truncate + static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, + const IColumn* col_scale) { + if constexpr (rounding_mode != RoundingMode::Trunc) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Using column as scale is only supported for function truncate"); + } + + const ColumnInt32& col_scale_i32 = assert_cast(*col_scale); + const size_t input_rows_count = col_scale->size(); + + for (size_t i = 0; i < input_rows_count; ++i) { + const Int32 scale_arg = col_scale_i32.get_data()[i]; + + if (scale_arg > std::numeric_limits::max() || + scale_arg < std::numeric_limits::min()) { + throw doris::Exception(ErrorCode::OUT_OF_BOUND, + "Scale argument for function is out of bound: {}", + scale_arg); + } + } + + if constexpr (IsDecimalNumber) { + const ColumnDecimal& data_col_general = + assert_cast&>(const_col_general->get_data_column()); + const T& general_val = data_col_general.get_data()[0]; + Int32 input_scale = data_col_general.get_scale(); + + auto col_res = ColumnDecimal::create(input_rows_count, input_scale); + + for (size_t i = 0; i < input_rows_count; ++i) { + DecimalRoundingImpl::apply( + general_val, input_scale, col_res->get_element(i).value, + col_scale_i32.get_data()[i]); + } + + for (size_t i = 0; i < input_rows_count; ++i) { + // For truncate(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column + // So we need this check to make sure the result have correct digits count + // + // Case 0: scale_arg <= -(integer part digits count) + // do nothing, because result is 0 + // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count) + // decimal parts has been erased, so add them back by multiply 10^(scale_arg) + // Case 2: scale_arg > 0 && scale_arg < decimal part digits count + // decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg) + // Case 3: scale_arg >= input_scale + // do nothing + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg <= 0) { + col_res->get_element(i).value *= int_exp10(input_scale); + } else if (scale_arg > 0 && scale_arg < input_scale) { + col_res->get_element(i).value *= int_exp10(input_scale - scale_arg); + } + } + + return col_res; + } else if constexpr (IsNumber) { + const ColumnVector& data_col_general = + assert_cast&>(const_col_general->get_data_column()); + const T& general_val = data_col_general.get_data()[0]; + auto col_res = ColumnVector::create(input_rows_count); + typename ColumnVector::Container& vec_res = col_res->get_data(); + + for (size_t i = 0; i < input_rows_count; ++i) { + const Int16 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg == 0) { + size_t scale = 1; + FunctionRoundingImpl::apply(general_val, scale, vec_res[i]); + } else if (scale_arg > 0) { + size_t scale = int_exp10(col_scale_i32.get_data()[i]); + FunctionRoundingImpl::apply(general_val, scale, + vec_res[i]); + } else { + size_t scale = int_exp10(-col_scale_i32.get_data()[i]); + FunctionRoundingImpl::apply(general_val, scale, + vec_res[i]); + } + } + + return col_res; + } else { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Unsupported column {} for function truncate", + const_col_general->get_name()); + } + } }; template @@ -476,17 +682,17 @@ class FunctionRounding : public IFunction { static Status get_scale_arg(const ColumnWithTypeAndName& arguments, Int16* scale) { const IColumn& scale_column = *arguments.column; - Int32 scale64 = static_cast( - static_cast(&scale_column)->get_data_column()) - .get_element(0); + Int32 scale_arg = assert_cast( + assert_cast(&scale_column)->get_data_column()) + .get_element(0); - if (scale64 > std::numeric_limits::max() || - scale64 < std::numeric_limits::min()) { + if (scale_arg > std::numeric_limits::max() || + scale_arg < std::numeric_limits::min()) { return Status::InvalidArgument("Scale argument for function {} is out of bound: {}", - name, scale64); + name, scale_arg); } - *scale = scale64; + *scale = scale_arg; return Status::OK(); } @@ -507,7 +713,7 @@ class FunctionRounding : public IFunction { if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { using FieldType = typename DataType::FieldType; - res = Dispatcher::apply( + res = Dispatcher::apply_vec_const( column.column.get(), scale_arg); return true; } diff --git a/be/test/vec/function/function_truncate_decimal_test.cpp b/be/test/vec/function/function_truncate_decimal_test.cpp new file mode 100644 index 00000000000000..36fcaa14e67fa6 --- /dev/null +++ b/be/test/vec/function/function_truncate_decimal_test.cpp @@ -0,0 +1,370 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "function_test_util.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/column_numbers.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/function_truncate.h" + +namespace doris::vectorized { +// {precision, scale} -> {input, scale_arg, expectation} +using TestDataSet = std::map, std::vector>>; + +const static TestDataSet truncate_decimal32_cases = { + {{1, 0}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 1}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{1, 1}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 0}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{2, 0}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 10}, + {12, 0, 12}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 1}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 10}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 2}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 0}, + {12, 1, 10}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{9, 0}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 100000000}, + {123456789, -7, 120000000}, {123456789, -6, 123000000}, {123456789, -5, 123400000}, + {123456789, -4, 123450000}, {123456789, -3, 123456000}, {123456789, -2, 123456700}, + {123456789, -1, 123456780}, {123456789, 0, 123456789}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 1}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 100000000}, {123456789, -6, 120000000}, {123456789, -5, 123000000}, + {123456789, -4, 123400000}, {123456789, -3, 123450000}, {123456789, -2, 123456000}, + {123456789, -1, 123456700}, {123456789, 0, 123456780}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 2}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 100000000}, {123456789, -5, 120000000}, + {123456789, -4, 123000000}, {123456789, -3, 123400000}, {123456789, -2, 123450000}, + {123456789, -1, 123456000}, {123456789, 0, 123456700}, {123456789, 1, 123456780}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 3}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 100000000}, + {123456789, -4, 120000000}, {123456789, -3, 123000000}, {123456789, -2, 123400000}, + {123456789, -1, 123450000}, {123456789, 0, 123456000}, {123456789, 1, 123456700}, + {123456789, 2, 123456780}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 4}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 100000000}, {123456789, -3, 120000000}, {123456789, -2, 123000000}, + {123456789, -1, 123400000}, {123456789, 0, 123450000}, {123456789, 1, 123456000}, + {123456789, 2, 123456700}, {123456789, 3, 123456780}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 5}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 100000000}, {123456789, -2, 120000000}, + {123456789, -1, 123000000}, {123456789, 0, 123400000}, {123456789, 1, 123450000}, + {123456789, 2, 123456000}, {123456789, 3, 123456700}, {123456789, 4, 123456780}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 6}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 100000000}, + {123456789, -1, 120000000}, {123456789, 0, 123000000}, {123456789, 1, 123400000}, + {123456789, 2, 123450000}, {123456789, 3, 123456000}, {123456789, 4, 123456700}, + {123456789, 5, 123456780}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 7}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 100000000}, {123456789, 0, 120000000}, {123456789, 1, 123000000}, + {123456789, 2, 123400000}, {123456789, 3, 123450000}, {123456789, 4, 123456000}, + {123456789, 5, 123456700}, {123456789, 6, 123456780}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 8}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 100000000}, {123456789, 1, 120000000}, + {123456789, 2, 123000000}, {123456789, 3, 123400000}, {123456789, 4, 123450000}, + {123456789, 5, 123456000}, {123456789, 6, 123456700}, {123456789, 7, 123456780}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 9}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 0}, {123456789, 1, 100000000}, + {123456789, 2, 120000000}, {123456789, 3, 123000000}, {123456789, 4, 123400000}, + {123456789, 5, 123450000}, {123456789, 6, 123456000}, {123456789, 7, 123456700}, + {123456789, 8, 123456780}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}}; + +const static TestDataSet truncate_decimal64_cases = { + {{10, 0}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 1000000000}, + {1234567891, -8, 1200000000}, {1234567891, -7, 1230000000}, {1234567891, -6, 1234000000}, + {1234567891, -5, 1234500000}, {1234567891, -4, 1234560000}, {1234567891, -3, 1234567000}, + {1234567891, -2, 1234567800}, {1234567891, -1, 1234567890}, {1234567891, 0, 1234567891}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 1}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 1000000000}, {1234567891, -7, 1200000000}, {1234567891, -6, 1230000000}, + {1234567891, -5, 1234000000}, {1234567891, -4, 1234500000}, {1234567891, -3, 1234560000}, + {1234567891, -2, 1234567000}, {1234567891, -1, 1234567800}, {1234567891, 0, 1234567890}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891} + + }}, + {{10, 2}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 1000000000}, {1234567891, -6, 1200000000}, + {1234567891, -5, 1230000000}, {1234567891, -4, 1234000000}, {1234567891, -3, 1234500000}, + {1234567891, -2, 1234560000}, {1234567891, -1, 1234567000}, {1234567891, 0, 1234567800}, + {1234567891, 1, 1234567890}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 9}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 0}, {1234567891, -6, 0}, + {1234567891, -5, 0}, {1234567891, -4, 0}, {1234567891, -3, 0}, + {1234567891, -2, 0}, {1234567891, -1, 0}, {1234567891, 0, 1000000000}, + {1234567891, 1, 1200000000}, {1234567891, 2, 1230000000}, {1234567891, 3, 1234000000}, + {1234567891, 4, 1234500000}, {1234567891, 5, 1234560000}, {1234567891, 6, 1234567000}, + {1234567891, 7, 1234567800}, {1234567891, 8, 1234567890}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{18, 0}, + {{123456789123456789, -19, 0}, + {123456789123456789, -18, 0}, + {123456789123456789, -17, 100000000000000000}, + {123456789123456789, -16, 120000000000000000}, + {123456789123456789, -15, 123000000000000000}, + {123456789123456789, -14, 123400000000000000}, + {123456789123456789, -13, 123450000000000000}, + {123456789123456789, -12, 123456000000000000}, + {123456789123456789, -11, 123456700000000000}, + {123456789123456789, -10, 123456780000000000}, + {123456789123456789, -9, 123456789000000000}, + {123456789123456789, -8, 123456789100000000}, + {123456789123456789, -7, 123456789120000000}, + {123456789123456789, -6, 123456789123000000}, + {123456789123456789, -5, 123456789123400000}, + {123456789123456789, -4, 123456789123450000}, + {123456789123456789, -3, 123456789123456000}, + {123456789123456789, -2, 123456789123456700}, + {123456789123456789, -1, 123456789123456780}, + {123456789123456789, 0, 123456789123456789}, + {123456789123456789, 1, 123456789123456789}, + {123456789123456789, 2, 123456789123456789}, + {123456789123456789, 3, 123456789123456789}, + {123456789123456789, 4, 123456789123456789}, + {123456789123456789, 5, 123456789123456789}, + {123456789123456789, 6, 123456789123456789}, + {123456789123456789, 7, 123456789123456789}, + {123456789123456789, 8, 123456789123456789}, + {123456789123456789, 18, 123456789123456789}}}, + {{18, 18}, + {{123456789123456789, -1, 0}, + {123456789123456789, 0, 0}, + {123456789123456789, 1, 100000000000000000}, + {123456789123456789, 2, 120000000000000000}, + {123456789123456789, 3, 123000000000000000}, + {123456789123456789, 4, 123400000000000000}, + {123456789123456789, 5, 123450000000000000}, + {123456789123456789, 6, 123456000000000000}, + {123456789123456789, 7, 123456700000000000}, + {123456789123456789, 8, 123456780000000000}, + {123456789123456789, 9, 123456789000000000}, + {123456789123456789, 10, 123456789100000000}, + {123456789123456789, 11, 123456789120000000}, + {123456789123456789, 12, 123456789123000000}, + {123456789123456789, 13, 123456789123400000}, + {123456789123456789, 14, 123456789123450000}, + {123456789123456789, 15, 123456789123456000}, + {123456789123456789, 16, 123456789123456700}, + {123456789123456789, 17, 123456789123456780}, + {123456789123456789, 18, 123456789123456789}, + {123456789123456789, 19, 123456789123456789}, + {123456789123456789, 20, 123456789123456789}, + {123456789123456789, 21, 123456789123456789}, + {123456789123456789, 22, 123456789123456789}, + {123456789123456789, 23, 123456789123456789}, + {123456789123456789, 24, 123456789123456789}, + {123456789123456789, 25, 123456789123456789}, + {123456789123456789, 26, 123456789123456789}}}}; + +template +static void checker(const TestDataSet& truncate_test_cases, bool decimal_col_is_const) { + static_assert(IsDecimalNumber); + auto func = std::dynamic_pointer_cast(FuncType::create()); + FunctionContext* context = nullptr; + + for (const auto& test_case : truncate_test_cases) { + Block block; + size_t res_idx = 2; + ColumnNumbers arguments = {0, 1, 2}; + const int precision = test_case.first.first; + const int scale = test_case.first.second; + const size_t input_rows_count = test_case.second.size(); + auto col_general = ColumnDecimal::create(input_rows_count, scale); + auto col_scale = ColumnInt32::create(); + auto col_res_expected = ColumnDecimal::create(input_rows_count, scale); + size_t rid = 0; + + for (const auto& test_date : test_case.second) { + auto input = std::get<0>(test_date); + auto scale_arg = std::get<1>(test_date); + auto expectation = std::get<2>(test_date); + col_general->get_element(rid) = DecimalType(input); + col_scale->insert(scale_arg); + col_res_expected->get_element(rid) = DecimalType(expectation); + rid++; + } + + if (decimal_col_is_const) { + block.insert({ColumnConst::create(col_general->clone_resized(1), 1), + std::make_shared>(precision, scale), + "col_general_const"}); + } else { + block.insert({col_general->clone(), + std::make_shared>(precision, scale), + "col_general"}); + } + + block.insert({col_scale->clone(), std::make_shared(), "col_scale"}); + block.insert({nullptr, std::make_shared>(precision, scale), + "col_res"}); + + auto status = func->execute_impl(context, block, arguments, res_idx, input_rows_count); + auto col_res = assert_cast&>( + *(block.get_by_position(res_idx).column)); + EXPECT_TRUE(status.ok()); + + for (size_t i = 0; i < input_rows_count; ++i) { + auto res = col_res.get_element(i); + auto res_expected = col_res_expected->get_element(i); + EXPECT_EQ(res, res_expected) + << "precision " << precision << " input_scale " << scale << " input " + << col_general->get_element(i) << " scale_arg " << col_scale->get_element(i) + << " res " << res << " res_expected " << res_expected; + } + } +} +TEST(TruncateFunctionTest, normal_decimal) { + checker, Decimal32>(truncate_decimal32_cases, + false); + checker, Decimal64>(truncate_decimal64_cases, + false); +} + +TEST(TruncateFunctionTest, normal_decimal_const) { + checker, Decimal32>(truncate_decimal32_cases, true); + checker, Decimal64>(truncate_decimal64_cases, true); +} + +} // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index b5184c33fcd546..9bc857bacef1a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -122,7 +122,7 @@ public class FunctionCallExpr extends Expr { Preconditions.checkArgument(children.get(1) instanceof IntLiteral || (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral), - "2nd argument of function round/floor/ceil/truncate must be literal"); + "2nd argument of function round/floor/ceil must be literal"); if (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral) { children.get(1).getChild(0).setType(children.get(1).getType()); children.set(1, children.get(1).getChild(0)); @@ -136,6 +136,34 @@ public class FunctionCallExpr extends Expr { return returnType; } }; + + java.util.function.BiFunction, Type, Type> truncateRule = (children, returnType) -> { + Preconditions.checkArgument(children != null && children.size() > 0); + if (children.size() == 1 && children.get(0).getType().isDecimalV3()) { + return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0); + } else if (children.size() == 2) { + Expr scaleExpr = children.get(1); + if (scaleExpr instanceof IntLiteral + || (scaleExpr instanceof CastExpr && scaleExpr.getChild(0) instanceof IntLiteral)) { + if (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral) { + children.get(1).getChild(0).setType(children.get(1).getType()); + children.set(1, children.get(1).getChild(0)); + } else { + children.get(1).setType(Type.INT); + } + int scaleArg = (int) (((IntLiteral) children.get(1)).getValue()); + return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), + Math.min(Math.max(scaleArg, 0), ((ScalarType) children.get(0).getType()).decimalScale())); + } else { + // Scale argument is a Column, always use same scale with input decimal + return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), + ((ScalarType) children.get(0).getType()).decimalScale()); + } + } else { + return returnType; + } + }; + java.util.function.BiFunction, Type, Type> arrayDateTimeV2OrDecimalV3Rule = (children, returnType) -> { Preconditions.checkArgument(children != null && children.size() > 0); @@ -239,7 +267,7 @@ public class FunctionCallExpr extends Expr { PRECISION_INFER_RULE.put("dround", roundRule); PRECISION_INFER_RULE.put("dceil", roundRule); PRECISION_INFER_RULE.put("dfloor", roundRule); - PRECISION_INFER_RULE.put("truncate", roundRule); + PRECISION_INFER_RULE.put("truncate", truncateRule); } public static final ImmutableSet TIME_FUNCTIONS_WITH_PRECISION = new ImmutableSortedSet.Builder( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java index 4b57772ed23ce4..6b6308c516ce58 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java @@ -20,6 +20,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.Int32OrLessType; @@ -37,19 +38,38 @@ default FunctionSignature computePrecision(FunctionSignature signature) { } else if (arity() == 2 && signature.getArgType(0) instanceof DecimalV3Type) { DecimalV3Type decimalV3Type = DecimalV3Type.forType(getArgumentType(0)); Expression floatLength = getArgument(1); - Preconditions.checkArgument(floatLength.getDataType() instanceof Int32OrLessType - && (floatLength.isLiteral() || ( - floatLength instanceof Cast && floatLength.child(0).isLiteral() - && floatLength.child(0).getDataType() instanceof Int32OrLessType)), - "2nd argument of function round/floor/ceil/truncate must be literal"); - int scale; - if (floatLength instanceof Cast) { - scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); + + if (this instanceof Truncate) { + if (floatLength.isLiteral() || ( + floatLength instanceof Cast && floatLength.child(0).isLiteral() + && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { + // Scale argument is a literal or cast from other literal + if (floatLength instanceof Cast) { + scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); + } else { + scale = ((IntegerLikeLiteral) floatLength).getIntValue(); + } + scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); + } else { + // Truncate could use Column as its scale argument. + // Result scale will always same with input Decimal in this situation. + scale = decimalV3Type.getScale(); + } } else { - scale = ((IntegerLikeLiteral) floatLength).getIntValue(); + Preconditions.checkArgument(floatLength.getDataType() instanceof Int32OrLessType + && (floatLength.isLiteral() || ( + floatLength instanceof Cast && floatLength.child(0).isLiteral() + && floatLength.child(0).getDataType() instanceof Int32OrLessType)), + "2nd argument of function round/floor/ceil must be literal"); + if (floatLength instanceof Cast) { + scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); + } else { + scale = ((IntegerLikeLiteral) floatLength).getIntValue(); + } + scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); } - scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); + return signature.withArgumentType(0, decimalV3Type) .withReturnType(DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision(), scale)); } else { diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out new file mode 100644 index 00000000000000..24f675ffbe29a2 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out @@ -0,0 +1,101 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +0 123.3 +1 123.3 +2 123.3 +3 123.3 +4 123.3 +5 123.3 +6 123.3 +7 123.3 +8 123.3 +9 123.3 + +-- !sql -- +0 120 +1 120 +2 120 +3 120 +4 120 +5 120 +6 120 +7 120 +8 120 +9 120 + +-- !sql -- +0 123 +1 123 +2 123 +3 123 +4 123 +5 123 +6 123 +7 123 +8 123 +9 123 + +-- !sql -- +0E-8 + +-- !sql -- +0 0.0 +1 0.0 +2 0.0 +3 0.0 +4 0.0 + +-- !vec_const0 -- +1 12345.0 1.23456789E8 +2 12345.0 1.23456789E8 +3 12345.0 1.23456789E8 +4 0.0 0.0 + +-- !vec_const0 -- +1 12345.1 1.234567891E8 +2 12345.1 1.234567891E8 +3 12345.1 1.234567891E8 +4 0.0 0.0 + +-- !vec_const0 -- +1 12340.0 1.2345678E8 +2 12340.0 1.2345678E8 +3 12340.0 1.2345678E8 +4 0.0 0.0 + +-- !vec_const1 -- +1 123456789 123456789 12345678.1 12345678 0.123456789 0 +2 123456789 123456789 12345678.1 12345678 0.123456789 0 +3 123456789 123456789 12345678.1 12345678 0.123456789 0 +4 0 0 0.0 0 0E-9 0 + +-- !vec_const2 -- +1 123456789 123456789 1.123456789 1 0.1234567890 0 +2 123456789 123456789 1.123456789 1 0.1234567890 0 +3 123456789 123456789 1.123456789 1 0.1234567890 0 +4 0 0 0E-9 0 0E-10 0 + +-- !const_vec1 -- +123456789.123456789 1 123456789.100000000 +123456789.123456789 1 123456789.100000000 +123456789.123456789 1 123456789.100000000 +123456789.123456789 1 123456789.100000000 + +-- !const_vec2 -- +123456789.123456789 -1 123456780.000000000 +123456789.123456789 -1 123456780.000000000 +123456789.123456789 -1 123456780.000000000 +123456789.123456789 -1 123456780.000000000 + +-- !vec_vec0 -- +1 1 12345.1 1.234567891E8 +2 1 12345.1 1.234567891E8 +3 1 12345.1 1.234567891E8 +4 1 0.0 0.0 + +-- !truncate_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891.123456789 1234567891 0.1234567891234567891 0 + +-- !truncate_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891.123456789 1234567891.100000000 0.1234567891234567891 0.1000000000000000000 + diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy new file mode 100644 index 00000000000000..767140e7a6ff85 --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_function_truncate") { + qt_sql """ + SELECT number, truncate(123.345 , 1) FROM numbers("number"="10"); + """ + qt_sql """ + SELECT number, truncate(123.123, -1) FROM numbers("number"="10"); + """ + qt_sql """ + SELECT number, truncate(123.123, 0) FROM numbers("number"="10"); + """ + + // const_const, result scale should be 10 + qt_sql """ + SELECT truncate(cast(0 as Decimal(9,8)), 10); + """ + + // const_const, result scale should be 1 + qt_sql """ + SELECT number, truncate(cast(0 as Decimal(9,4)), 1) FROM numbers("number"="5") + """ + + sql """DROP TABLE IF EXISTS test_function_truncate;""" + sql """DROP TABLE IF EXISTS test_function_truncate_dec128;""" + sql """ + CREATE TABLE test_function_truncate ( + rid int, flo float, dou double, + dec90 decimal(9, 0), dec91 decimal(9, 1), dec99 decimal(9, 9), + dec100 decimal(10,0), dec109 decimal(10,9), dec1010 decimal(10,10), + number int DEFAULT 1) + DISTRIBUTED BY HASH(rid) + PROPERTIES("replication_num" = "1" ); + """ + + sql """ + INSERT INTO test_function_truncate + VALUES + (1, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + sql """ + INSERT INTO test_function_truncate + VALUES + (2, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + sql """ + INSERT INTO test_function_truncate + VALUES + (3, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + sql """ + INSERT INTO test_function_truncate + VALUES + (4, 0, 0, 0, 0.0, 0, 0, 0, 0, 1); + """ + qt_vec_const0 """ + SELECT rid, truncate(flo, 0), truncate(dou, 0) FROM test_function_truncate order by rid; + """ + qt_vec_const0 """ + SELECT rid, truncate(flo, 1), truncate(dou, 1) FROM test_function_truncate order by rid; + """ + qt_vec_const0 """ + SELECT rid, truncate(flo, -1), truncate(dou, -1) FROM test_function_truncate order by rid; + """ + qt_vec_const1 """ + SELECT rid, dec90, truncate(dec90, 0), dec91, truncate(dec91, 0), dec99, truncate(dec99, 0) FROM test_function_truncate order by rid + """ + qt_vec_const2 """ + SELECT rid, dec100, truncate(dec100, 0), dec109, truncate(dec109, 0), dec1010, truncate(dec1010, 0) FROM test_function_truncate order by rid + """ + + + + qt_const_vec1 """ + SELECT 123456789.123456789, number, truncate(123456789.123456789, number) from test_function_truncate; + """ + qt_const_vec2 """ + SELECT 123456789.123456789, -number, truncate(123456789.123456789, -number) from test_function_truncate; + """ + qt_vec_vec0 """ + SELECT rid,number, truncate(flo, number), truncate(dou, number) FROM test_function_truncate order by rid; + """ + + sql """ + CREATE TABLE test_function_truncate_dec128 ( + rid int, dec190 decimal(19,0), dec199 decimal(19,9), dec1919 decimal(19,19), + dec380 decimal(38,0), dec3819 decimal(38,19), dec3838 decimal(38,38), + number int DEFAULT 1 + ) + DISTRIBUTED BY HASH(rid) + PROPERTIES("replication_num" = "1" ); + """ + sql """ + INSERT INTO test_function_truncate_dec128 + VALUES + (1, 1234567891234567891.0, 1234567891.123456789, 0.1234567891234567891, + 12345678912345678912345678912345678912.0, + 1234567891234567891.1234567891234567891, + 0.12345678912345678912345678912345678912345678912345678912345678912345678912, 1); + """ + qt_truncate_dec128 """ + SELECT rid, dec190, truncate(dec190, 0), dec199, truncate(dec199, 0), dec1919, truncate(dec1919, 0) + FROM test_function_truncate_dec128 order by rid + """ + + qt_truncate_dec128 """ + SELECT rid, dec190, truncate(dec190, number), dec199, truncate(dec199, number), dec1919, truncate(dec1919, number) + FROM test_function_truncate_dec128 order by rid + """ + +} \ No newline at end of file From 53b1d44a74ce5733a5f4d1a1d6df56163a021e9f Mon Sep 17 00:00:00 2001 From: zclllyybb Date: Tue, 2 Apr 2024 14:55:07 +0800 Subject: [PATCH 08/12] [Feature] support function uuid_to_int and int_to_uuid #33005 --- be/src/vec/functions/function_uuid.cpp | 213 ++++++++++++++++++ .../vec/functions/simple_function_factory.h | 4 +- be/test/vec/function/function_string_test.cpp | 30 +++ .../doris/catalog/BuiltinScalarFunctions.java | 4 + .../functions/scalar/InttoUuid.java | 68 ++++++ .../functions/scalar/UuidtoInt.java | 70 ++++++ .../visitor/ScalarFunctionVisitor.java | 10 + gensrc/script/doris_builtins_functions.py | 5 +- 8 files changed, 401 insertions(+), 3 deletions(-) create mode 100644 be/src/vec/functions/function_uuid.cpp create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InttoUuid.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/UuidtoInt.java diff --git a/be/src/vec/functions/function_uuid.cpp b/be/src/vec/functions/function_uuid.cpp new file mode 100644 index 00000000000000..cee5fd7a363503 --- /dev/null +++ b/be/src/vec/functions/function_uuid.cpp @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_string.h" +#include "vec/columns/column_vector.h" +#include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/block.h" +#include "vec/core/column_numbers.h" +#include "vec/core/column_with_type_and_name.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" +#include "vec/functions/function.h" +#include "vec/functions/simple_function_factory.h" + +namespace doris { +class FunctionContext; +} // namespace doris + +namespace doris::vectorized { +constexpr static std::array SPLIT_POS = {8, 13, 18, 23, 36}; // 8-4-4-4-12 +constexpr static char DELIMITER = '-'; + +class FunctionUuidtoInt : public IFunction { +public: + static constexpr auto name = "uuid_to_int"; + + static FunctionPtr create() { return std::make_shared(); } + + String get_name() const override { return name; } + + size_t get_number_of_arguments() const override { return 1; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return make_nullable(std::make_shared()); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + const auto& arg_column = + assert_cast(*block.get_by_position(arguments[0]).column); + + auto result_column = ColumnInt128::create(input_rows_count); + auto& result_data = result_column->get_data(); + auto null_column = ColumnUInt8::create(input_rows_count); + auto& null_map = null_column->get_data(); + + for (int row = 0; row < input_rows_count; row++) { + auto str = arg_column.get_data_at(row); + const auto* data = str.data; + Int128* result_cell = &result_data[row]; + *result_cell = 0; + null_map[row] = false; + + if (str.size == 36) { + if (data[SPLIT_POS[0]] != DELIMITER || data[SPLIT_POS[1]] != DELIMITER || + data[SPLIT_POS[2]] != DELIMITER || data[SPLIT_POS[3]] != DELIMITER) { + null_map[row] = true; + continue; + } + char new_data[32]; + memset(new_data, 0, sizeof(new_data)); + // ignore '-' + memcpy(new_data, data, 8); + memcpy(new_data + 8, data + SPLIT_POS[0] + 1, 4); + memcpy(new_data + 12, data + SPLIT_POS[1] + 1, 4); + memcpy(new_data + 16, data + SPLIT_POS[2] + 1, 4); + memcpy(new_data + 20, data + SPLIT_POS[3] + 1, 12); + + if (!serialize(new_data, (char*)result_cell, 32)) { + null_map[row] = true; + continue; + } + } else if (str.size == 32) { + if (!serialize(data, (char*)result_cell, 32)) { + null_map[row] = true; + continue; + } + } else { + null_map[row] = true; + continue; + } + } + + block.replace_by_position( + result, ColumnNullable::create(std::move(result_column), std::move(null_column))); + return Status::OK(); + } + + // use char* to write dst is the only legal way by 'restrict aliasing rule' + static bool serialize(const char* __restrict src, char* __restrict dst, size_t length) { + char target; // 8bit, contains 2 char input + auto translate = [&target](const char ch) { + if (isdigit(ch)) { + target += ch - '0'; + } else if (ch >= 'a' && ch <= 'f') { + target += ch - 'a' + 10; + } else if (ch >= 'A' && ch <= 'F') { + target += ch - 'A' + 10; + } else { + return false; + } + return true; + }; + + bool ok = true; + for (size_t i = 0; i < length; i += 2, src++, dst++) { + target = 0; + if (!translate(*src)) { + ok = false; // dont break for auto-simd + } + + src++; + target <<= 4; + if (!translate(*src)) { + ok = false; + } + *dst = target; + } + + return ok; + } +}; + +class FunctionInttoUuid : public IFunction { +public: + static constexpr auto name = "int_to_uuid"; + + static FunctionPtr create() { return std::make_shared(); } + + String get_name() const override { return name; } + + size_t get_number_of_arguments() const override { return 1; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return std::make_shared(); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + const auto& arg_column = + assert_cast(*block.get_by_position(arguments[0]).column); + auto result_column = ColumnString::create(); + constexpr int str_length = 36; + auto& col_data = result_column->get_chars(); + auto& col_offset = result_column->get_offsets(); + col_data.resize(str_length * input_rows_count + + 1); // for branchless deserialize, we occupy one more byte for the last '-' + col_offset.resize(input_rows_count); + + for (int row = 0; row < input_rows_count; row++) { + const Int128* arg = &arg_column.get_data()[row]; + col_offset[row] = col_offset[row - 1] + str_length; + deserialize((char*)arg, col_data.data() + str_length * row); + } + block.replace_by_position(result, std::move(result_column)); + return Status::OK(); + } + + // use char* to read src is the only legal way by 'restrict aliasing rule' + static void deserialize(const char* __restrict src, unsigned char* __restrict dst) { + auto transform = [](char ch) -> unsigned char { + if (ch < 10) { + return ch + '0'; + } else { + return ch - 10 + 'a'; + } + }; + + int j = 0; + for (int i : SPLIT_POS) { + for (; j < i; src++, j += 2) { // input 16 chars, 2 data per char + dst[j] = transform(((*src) >> 4) & 0x0F); + dst[j + 1] = transform(*src & 0x0F); + } + dst[j++] = DELIMITER; // we resized one more byte. + } + } +}; + +void register_function_uuid_transforms(SimpleFunctionFactory& factory) { + factory.register_function(); + factory.register_function(); +} + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/simple_function_factory.h b/be/src/vec/functions/simple_function_factory.h index a18b0beb8dbe2b..649db732093934 100644 --- a/be/src/vec/functions/simple_function_factory.h +++ b/be/src/vec/functions/simple_function_factory.h @@ -24,8 +24,6 @@ #include #include "agent/be_exec_version_manager.h" -#include "udf/udf.h" -#include "vec/exprs/table_function/table_function.h" #include "vec/functions/function.h" namespace doris::vectorized { @@ -81,6 +79,7 @@ void register_function_regexp(SimpleFunctionFactory& factory); void register_function_random(SimpleFunctionFactory& factory); void register_function_uuid(SimpleFunctionFactory& factory); void register_function_uuid_numeric(SimpleFunctionFactory& factory); +void register_function_uuid_transforms(SimpleFunctionFactory& factory); void register_function_coalesce(SimpleFunctionFactory& factory); void register_function_grouping(SimpleFunctionFactory& factory); void register_function_datetime_floor_ceil(SimpleFunctionFactory& factory); @@ -265,6 +264,7 @@ class SimpleFunctionFactory { register_function_random(instance); register_function_uuid(instance); register_function_uuid_numeric(instance); + register_function_uuid_transforms(instance); register_function_coalesce(instance); register_function_grouping(instance); register_function_datetime_floor_ceil(instance); diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index 612a6fff0ccd6f..d8d1a57b8eb986 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -17,6 +17,7 @@ #include +#include #include #include #include @@ -1157,4 +1158,33 @@ TEST(function_string_test, function_bit_length_test) { static_cast(check_function(func_name, input_types, data_set)); } +TEST(function_string_test, function_uuid_test) { + { + std::string func_name = "uuid_to_int"; + InputTypeSet input_types = {TypeIndex::String}; + uint64_t high = 9572195551486940809ULL; + uint64_t low = 1759290071393952876ULL; + __int128 result = (__int128)high * (__int128)10000000000000000000ULL + (__int128)low; + DataSet data_set = {{{Null()}, Null()}, + {{std::string("6ce4766f-6783-4b30-b357-bba1c7600348")}, result}, + {{std::string("6ce4766f67834b30b357bba1c7600348")}, result}, + {{std::string("ffffffff-ffff-ffff-ffff-ffffffffffff")}, (__int128)-1}, + {{std::string("00000000-0000-0000-0000-000000000000")}, (__int128)0}, + {{std::string("123")}, Null()}}; + static_cast(check_function(func_name, input_types, data_set)); + } + { + std::string func_name = "int_to_uuid"; + InputTypeSet input_types = {TypeIndex::Int128}; + uint64_t high = 9572195551486940809ULL; + uint64_t low = 1759290071393952876ULL; + __int128 value = (__int128)high * (__int128)10000000000000000000ULL + (__int128)low; + DataSet data_set = {{{Null()}, Null()}, + {{value}, std::string("6ce4766f-6783-4b30-b357-bba1c7600348")}, + {{(__int128)-1}, std::string("ffffffff-ffff-ffff-ffff-ffffffffffff")}, + {{(__int128)0}, std::string("00000000-0000-0000-0000-000000000000")}}; + static_cast(check_function(func_name, input_types, data_set)); + } +} + } // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index 9d19ea9e2e9923..d28cb751eaa03d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -198,6 +198,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Initcap; import org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Instr; +import org.apache.doris.nereids.trees.expressions.functions.scalar.InttoUuid; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ipv4CIDRToRange; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ipv4NumToString; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ipv4StringToNum; @@ -417,6 +418,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.UtcTimestamp; import org.apache.doris.nereids.trees.expressions.functions.scalar.Uuid; import org.apache.doris.nereids.trees.expressions.functions.scalar.UuidNumeric; +import org.apache.doris.nereids.trees.expressions.functions.scalar.UuidtoInt; import org.apache.doris.nereids.trees.expressions.functions.scalar.Version; import org.apache.doris.nereids.trees.expressions.functions.scalar.Week; import org.apache.doris.nereids.trees.expressions.functions.scalar.WeekCeil; @@ -625,6 +627,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(Initcap.class, "initcap"), scalar(InnerProduct.class, "inner_product"), scalar(Instr.class, "instr"), + scalar(InttoUuid.class, "int_to_uuid"), scalar(Ipv4NumToString.class, "ipv4_num_to_string", "inet_ntoa"), scalar(Ipv4StringToNum.class, "ipv4_string_to_num"), scalar(Ipv4StringToNumOrDefault.class, "ipv4_string_to_num_or_default"), @@ -870,6 +873,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(UtcTimestamp.class, "utc_timestamp"), scalar(Uuid.class, "uuid"), scalar(UuidNumeric.class, "uuid_numeric"), + scalar(UuidtoInt.class, "uuid_to_int"), scalar(Version.class, "version"), scalar(Week.class, "week"), scalar(WeekCeil.class, "week_ceil"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InttoUuid.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InttoUuid.java new file mode 100644 index 00000000000000..d3434eff35b3c0 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InttoUuid.java @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.VarcharType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * ScalarFunction 'int_to_uuid'. + */ +public class InttoUuid extends ScalarFunction + implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(LargeIntType.INSTANCE)); + + /** + * constructor with 1 argument. + */ + public InttoUuid(Expression arg) { + super("int_to_uuid", arg); + } + + /** + * withChildren. + */ + @Override + public InttoUuid withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new InttoUuid(children.get(0)); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitInttoUuid(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/UuidtoInt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/UuidtoInt.java new file mode 100644 index 00000000000000..987a8b1d06e82f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/UuidtoInt.java @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarcharType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * ScalarFunction 'uuid_to_int'. + */ +public class UuidtoInt extends ScalarFunction + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(LargeIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(LargeIntType.INSTANCE).args(StringType.INSTANCE)); + + /** + * constructor with 1 argument. + */ + public UuidtoInt(Expression arg) { + super("uuid_to_int", arg); + } + + /** + * withChildren. + */ + @Override + public UuidtoInt withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new UuidtoInt(children.get(0)); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitUuidtoInt(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 7cef47557cc525..83a4a2aa027520 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -201,6 +201,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Initcap; import org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Instr; +import org.apache.doris.nereids.trees.expressions.functions.scalar.InttoUuid; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ipv4CIDRToRange; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ipv4NumToString; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ipv4StringToNum; @@ -414,6 +415,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.UtcTimestamp; import org.apache.doris.nereids.trees.expressions.functions.scalar.Uuid; import org.apache.doris.nereids.trees.expressions.functions.scalar.UuidNumeric; +import org.apache.doris.nereids.trees.expressions.functions.scalar.UuidtoInt; import org.apache.doris.nereids.trees.expressions.functions.scalar.Version; import org.apache.doris.nereids.trees.expressions.functions.scalar.Week; import org.apache.doris.nereids.trees.expressions.functions.scalar.WeekCeil; @@ -2001,6 +2003,14 @@ default R visitUuidNumeric(UuidNumeric uuidNumeric, C context) { return visitScalarFunction(uuidNumeric, context); } + default R visitUuidtoInt(UuidtoInt uuidtoInt, C context) { + return visitScalarFunction(uuidtoInt, context); + } + + default R visitInttoUuid(InttoUuid inttoUuid, C context) { + return visitScalarFunction(inttoUuid, context); + } + default R visitVersion(Version version, C context) { return visitScalarFunction(version, context); } diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index d97c499f2da30b..3d87ab86fd259f 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -2012,7 +2012,10 @@ "UUID": [ [['uuid'], 'VARCHAR', [], 'ALWAYS_NOT_NULLABLE'], - [['uuid_numeric'], 'LARGEINT', [], 'ALWAYS_NOT_NULLABLE'] + [['uuid_numeric'], 'LARGEINT', [], 'ALWAYS_NOT_NULLABLE'], + [['uuid_to_int'], 'LARGEINT', ['VARCHAR'], 'ALWAYS_NULLABLE'], + [['uuid_to_int'], 'LARGEINT', ['STRING'], 'ALWAYS_NULLABLE'], + [['int_to_uuid'], 'VARCHAR', ['LARGEINT'], 'DEPEND_ON_ARGUMENT'] ], #ip functions From 0c91388c19b21d56ce3eca43ed60e4cc94c38ee0 Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Wed, 10 Apr 2024 13:30:00 +0800 Subject: [PATCH 09/12] [enhancement](Nereids) refactor expression rewriter to pattern match (#32617) this pr can improve the performance of the nereids planner, in plan stage. 1. refactor expression rewriter to pattern match, so the lots of expression rewrite rules can criss-crossed apply in a big bottom-up iteration, and rewrite until the expression became stable. now we can process more cases because original there has no loop, and sometimes only process the top expression, like `SimplifyArithmeticRule`. 2. replace `Collection.stream()` to `ImmutableXxx.Builder` to avoid useless method call 3. loop unrolling some codes, like `Expression.`, `PlanTreeRewriteBottomUpJob.pushChildrenJobs` 4. use type/arity specified-code, like `OneRangePartitionEvaluator.toNereidsLiterals()`, `PartitionRangeExpander.tryExpandRange()`, `PartitionRangeExpander.enumerableCount()` 5. refactor `ExtractCommonFactorRule`, now we can extract more cases, and I fix the deed loop when use `ExtractCommonFactorRule` and `SimplifyRange` in one iterative, because `SimplifyRange` generate right deep tree, but `ExtractCommonFactorRule` generate left deep tree 6. refactor `FoldConstantRuleOnFE`, support visitor/pattern match mode, in ExpressionNormalization, pattern match can criss-crossed apply with other rules; in PartitionPruner, visitor can evaluate expression faster 7. lazy compute and cache some operation 8. use int field to compare date 9. use BitSet to find disableNereidsRules 10. two level loop usually faster then build Multimap when bind slot in Scope, so I revert the code 11. `PlanTreeRewriteBottomUpJob` don't need to clearStatePhase any more ### test case 100 threads parallel continuous send this sql which query an empty table, test in my mac machine(m2 chip, 8 core), enable sql cache ```sql select count(1),date_format(time_col,'%Y%m%d'),varchar_col1 from tbl where partition_date>'2024-02-15' and (varchar_col2 ='73130' or varchar_col3='73130') and time_col>'2024-03-04' and time_col<'2024-03-05' group by date_format(time_col,'%Y%m%d'),varchar_col1 order by date_format(time_col,'%Y%m%d') desc, varchar_col1 desc,count(1) asc limit 1000 ``` before this pr: 3100 peak QPS, about 2700 avg QPS after this pr: 4800 peak QPS, about 4400 avg QPS (cherry picked from commit 7338683fdbdf77711f2ce61e580c19f4ea100723) --- fe/fe-core/pom.xml | 2 +- .../apache/doris/analysis/DateLiteral.java | 9 +- .../org/apache/doris/catalog/OlapTable.java | 5 +- .../doris/mtmv/MTMVRelationManager.java | 4 +- .../apache/doris/mysql/privilege/Role.java | 4 +- .../apache/doris/nereids/CascadesContext.java | 49 +++- .../apache/doris/nereids/NereidsPlanner.java | 2 +- .../doris/nereids/StatementContext.java | 15 ++ .../apache/doris/nereids/analyzer/Scope.java | 19 +- .../org/apache/doris/nereids/jobs/Job.java | 11 +- .../doris/nereids/jobs/executor/Rewriter.java | 7 +- .../jobs/joinorder/hypergraph/HyperGraph.java | 11 +- .../jobs/rewrite/CustomRewriteJob.java | 6 +- .../rewrite/PlanTreeRewriteBottomUpJob.java | 113 ++++---- .../jobs/rewrite/PlanTreeRewriteJob.java | 64 +++-- .../rewrite/PlanTreeRewriteTopDownJob.java | 41 ++- .../jobs/rewrite/RewriteJobContext.java | 10 +- .../jobs/rewrite/RootPlanTreeRewriteJob.java | 16 +- .../pattern/ExpressionPatternRules.java | 112 ++++++++ .../ExpressionPatternTraverseListeners.java | 112 ++++++++ .../nereids/pattern/ParentTypeIdMapping.java | 59 +++++ .../apache/doris/nereids/pattern/Pattern.java | 4 + .../doris/nereids/pattern/TypeMappings.java | 133 ++++++++++ .../ExpressionTypeMappingGenerator.java | 159 +++++++++++ ...atorAnalyzer.java => JavaAstAnalyzer.java} | 93 +++---- .../LogicalBinaryPatternGenerator.java | 4 +- .../LogicalLeafPatternGenerator.java | 4 +- .../LogicalUnaryPatternGenerator.java | 4 +- .../PatternDescribableProcessor.java | 34 ++- .../PhysicalBinaryPatternGenerator.java | 4 +- .../PhysicalLeafPatternGenerator.java | 4 +- .../PhysicalUnaryPatternGenerator.java | 4 +- ...nerator.java => PlanPatternGenerator.java} | 18 +- .../PlanPatternGeneratorAnalyzer.java | 73 +++++ .../generator/PlanTypeMappingGenerator.java | 159 +++++++++++ .../processor/post/RuntimeFilterPruner.java | 17 +- .../nereids/processor/post/Validator.java | 10 +- .../properties/FunctionalDependencies.java | 24 +- .../nereids/properties/LogicalProperties.java | 50 ++-- .../org/apache/doris/nereids/rules/Rule.java | 6 +- .../apache/doris/nereids/rules/RuleSet.java | 4 +- .../AdjustAggregateNullableForEmptySet.java | 29 +- .../rules/analysis/BindExpression.java | 28 +- .../rules/analysis/BindSlotWithPaths.java | 29 +- .../rules/analysis/CheckAfterRewrite.java | 85 +++--- .../nereids/rules/analysis/CheckAnalysis.java | 36 +-- .../analysis/EliminateGroupByConstant.java | 2 +- .../rules/analysis/ExpressionAnalyzer.java | 2 +- .../rules/analysis/FillUpMissingSlots.java | 21 +- .../rules/analysis/NormalizeAggregate.java | 41 +-- .../ReplaceExpressionByChildOutput.java | 48 ++-- .../rules/analysis/SubqueryToApply.java | 77 ++++-- .../mv/AbstractMaterializedViewRule.java | 15 +- .../mv/InitMaterializationContextHook.java | 4 +- .../mv/MaterializationContext.java | 8 +- .../rules/exploration/mv/StructInfo.java | 20 +- .../ExpressionBottomUpRewriter.java | 124 +++++++++ .../expression/ExpressionListenerMatcher.java | 41 +++ .../expression/ExpressionMatchingAction.java | 25 ++ .../expression/ExpressionMatchingContext.java | 46 ++++ .../expression/ExpressionNormalization.java | 29 +- ...xpressionNormalizationAndOptimization.java | 33 +++ .../expression/ExpressionOptimization.java | 26 +- .../ExpressionPatternMatchRule.java | 64 +++++ .../expression/ExpressionPatternMatcher.java | 41 +++ .../ExpressionPatternRuleFactory.java | 84 ++++++ .../rules/expression/ExpressionRewrite.java | 51 +++- .../expression/ExpressionRewriteContext.java | 4 +- .../expression/ExpressionRuleExecutor.java | 16 +- .../ExpressionTraverseListener.java | 31 +++ .../ExpressionTraverseListenerFactory.java | 79 ++++++ .../ExpressionTraverseListenerMapping.java | 59 +++++ .../rules/expression/check/CheckCast.java | 24 +- .../rules/ArrayContainToArrayOverlap.java | 94 ++++--- .../rules/expression/rules/CaseWhenToIf.java | 18 +- .../expression/rules/ConvertAggStateCast.java | 33 +-- .../expression/rules/DateFunctionRewrite.java | 34 ++- .../rules/DigitalMaskingConvert.java | 23 +- .../rules/DistinctPredicatesRule.java | 18 +- .../rules/ExtractCommonFactorRule.java | 222 +++++++++++++--- .../expression/rules/FoldConstantRule.java | 32 ++- .../rules/FoldConstantRuleOnBE.java | 46 +++- .../rules/FoldConstantRuleOnFE.java | 170 ++++++++++-- .../expression/rules/InPredicateDedup.java | 40 +-- .../rules/InPredicateToEqualToRule.java | 25 +- .../rules/NormalizeBinaryPredicatesRule.java | 21 +- .../rules/NullSafeEqualToEqual.java | 21 +- .../rules/OneListPartitionEvaluator.java | 2 +- .../rules/OneRangePartitionEvaluator.java | 120 ++++++--- .../rules/expression/rules/OrToIn.java | 36 ++- .../expression/rules/PartitionPruner.java | 23 +- .../rules/PartitionRangeExpander.java | 115 ++++---- .../PredicateRewriteForPartitionPrune.java | 4 +- .../rules/RangePartitionValueIterator.java | 64 +++++ .../rules/ReplaceVariableByLiteral.java | 17 +- .../SimplifyArithmeticComparisonRule.java | 105 ++++---- .../rules/SimplifyArithmeticRule.java | 70 ++--- .../expression/rules/SimplifyCastRule.java | 21 +- .../rules/SimplifyComparisonPredicate.java | 37 ++- .../rules/SimplifyDecimalV3Comparison.java | 24 +- .../expression/rules/SimplifyInPredicate.java | 20 +- .../expression/rules/SimplifyNotExprRule.java | 34 ++- .../rules/expression/rules/SimplifyRange.java | 73 ++--- .../rules/SupportJavaDateFormatter.java | 44 ++- .../rules/expression/rules/TopnToMax.java | 29 +- .../TryEliminateUninterestedPredicates.java | 14 +- .../implementation/AggregateStrategies.java | 2 +- .../rewrite/AdjustConjunctsReturnType.java | 4 +- .../nereids/rules/rewrite/AdjustNullable.java | 12 +- .../rules/rewrite/CheckMatchExpression.java | 7 +- .../rules/rewrite/CheckPrivileges.java | 29 +- .../nereids/rules/rewrite/ColumnPruning.java | 98 +++---- .../rules/rewrite/CountDistinctRewrite.java | 60 +++-- .../rules/rewrite/CountLiteralRewrite.java | 37 ++- .../rules/rewrite/EliminateFilter.java | 7 +- .../rules/rewrite/EliminateGroupBy.java | 56 ++-- .../rules/rewrite/EliminateMarkJoin.java | 17 +- .../rules/rewrite/EliminateNotNull.java | 39 +-- .../rewrite/EliminateOrderByConstant.java | 16 +- .../ExtractAndNormalizeWindowExpression.java | 161 ++++++----- ...tSingleTableExpressionFromDisjunction.java | 9 +- .../rules/rewrite/InferJoinNotNull.java | 4 +- .../nereids/rules/rewrite/MergeAggregate.java | 2 +- .../nereids/rules/rewrite/MergeProjects.java | 10 +- .../nereids/rules/rewrite/NormalizeSort.java | 59 +++-- .../rules/rewrite/NormalizeToSlot.java | 43 +-- .../rules/rewrite/PruneOlapScanPartition.java | 51 ++-- .../rules/rewrite/PullUpPredicates.java | 75 +++--- .../PushDownFilterThroughAggregation.java | 12 +- .../rewrite/PushDownFilterThroughProject.java | 13 +- .../nereids/rules/rewrite/ReorderJoin.java | 6 +- .../rules/rewrite/SimplifyAggGroupBy.java | 23 +- .../AbstractSelectMaterializedIndexRule.java | 15 +- .../SelectMaterializedIndexWithAggregate.java | 9 +- ...lectMaterializedIndexWithoutAggregate.java | 45 ++-- .../doris/nereids/stats/StatsCalculator.java | 11 +- .../doris/nereids/trees/AbstractTreeNode.java | 22 +- .../apache/doris/nereids/trees/TreeNode.java | 17 ++ .../trees/expressions/BinaryOperator.java | 6 - .../expressions/ComparisonPredicate.java | 4 +- .../nereids/trees/expressions/Expression.java | 99 +++++-- .../trees/expressions/InPredicate.java | 5 +- .../trees/expressions/SlotReference.java | 7 +- .../functions/ComputeSignatureHelper.java | 11 +- .../functions/agg/AggregateFunction.java | 17 +- .../scalar/PushDownToProjectionFunction.java | 7 +- .../expressions/literal/DateLiteral.java | 39 ++- .../visitor/DefaultExpressionRewriter.java | 10 +- .../nereids/trees/plans/AbstractPlan.java | 28 +- .../doris/nereids/trees/plans/Plan.java | 61 +++-- .../trees/plans/algebra/Aggregate.java | 17 +- .../nereids/trees/plans/algebra/Project.java | 27 +- .../trees/plans/logical/LogicalAggregate.java | 8 +- .../plans/logical/LogicalCatalogRelation.java | 132 +++++---- .../trees/plans/logical/LogicalOlapScan.java | 65 +++-- .../trees/plans/logical/LogicalProject.java | 8 +- .../trees/plans/logical/LogicalSort.java | 19 +- .../trees/plans/logical/LogicalTopN.java | 13 +- .../plans/physical/PhysicalHashJoin.java | 3 +- .../doris/nereids/util/ExpressionUtils.java | 250 +++++++++++++----- .../doris/nereids/util/ImmutableEqualSet.java | 6 +- .../apache/doris/nereids/util/JoinUtils.java | 5 +- .../apache/doris/nereids/util/PlanUtils.java | 24 ++ .../doris/nereids/util/TypeCoercionUtils.java | 19 +- .../org/apache/doris/nereids/util/Utils.java | 65 ++++- .../org/apache/doris/qe/SessionVariable.java | 38 ++- .../joinorder/hypergraph/HyperGraphTest.java | 12 +- .../expression/ExpressionRewriteTest.java | 80 ++++-- .../ExpressionRewriteTestHelper.java | 2 +- .../rules/expression/FoldConstantTest.java | 36 ++- .../expression/PredicatesSplitterTest.java | 2 +- .../SimplifyArithmeticRuleTest.java | 56 ++-- .../expression/SimplifyInPredicateTest.java | 8 +- .../rules/expression/SimplifyRangeTest.java | 26 +- .../rules/NullSafeEqualToEqualTest.java | 20 +- .../SimplifyArithmeticComparisonRuleTest.java | 7 +- .../rules/SimplifyCastRuleTest.java | 7 +- .../SimplifyComparisonPredicateTest.java | 35 ++- .../SimplifyDecimalV3ComparisonTest.java | 6 +- .../rules/expression/rules/TopnToMaxTest.java | 4 +- .../rules/rewrite/EliminateJoinByFkTest.java | 1 + .../nereids/rules/rewrite/OrToInTest.java | 19 +- .../PushDownFilterThroughAggregationTest.java | 4 +- .../functions/ComputeSignatureHelperTest.java | 11 + .../doris/nereids/util/PlanChecker.java | 20 ++ .../nereids_hint_tpcds_p0/shape/query24.out | 10 +- .../nereids_hint_tpcds_p0/shape/query64.out | 2 +- .../filter_push_down/push_filter_through.out | 28 +- .../shape/query13.out | 2 +- .../shape/query14.out | 2 +- .../shape/query24.out | 4 +- .../shape/query41.out | 2 +- .../shape/query50.out | 2 +- .../shape/query64.out | 2 +- .../shape/query85.out | 2 +- .../shape/query95.out | 2 +- .../noStatsRfPrune/query13.out | 2 +- .../noStatsRfPrune/query17.out | 2 +- .../noStatsRfPrune/query41.out | 2 +- .../noStatsRfPrune/query47.out | 7 +- .../noStatsRfPrune/query50.out | 2 +- .../noStatsRfPrune/query57.out | 7 +- .../noStatsRfPrune/query6.out | 57 ++-- .../noStatsRfPrune/query65.out | 2 +- .../no_stats_shape/query13.out | 2 +- .../no_stats_shape/query14.out | 2 +- .../no_stats_shape/query17.out | 2 +- .../no_stats_shape/query24.out | 2 +- .../no_stats_shape/query41.out | 2 +- .../no_stats_shape/query47.out | 7 +- .../no_stats_shape/query50.out | 2 +- .../no_stats_shape/query57.out | 7 +- .../no_stats_shape/query6.out | 57 ++-- .../no_stats_shape/query64.out | 2 +- .../no_stats_shape/query65.out | 2 +- .../no_stats_shape/query85.out | 6 +- .../rf_prune/query13.out | 2 +- .../rf_prune/query14.out | 2 +- .../rf_prune/query41.out | 2 +- .../rf_prune/query50.out | 2 +- .../rf_prune/query85.out | 2 +- .../rf_prune/query95.out | 2 +- .../shape/query13.out | 2 +- .../shape/query14.out | 2 +- .../shape/query24.out | 4 +- .../shape/query41.out | 2 +- .../shape/query50.out | 2 +- .../shape/query64.out | 2 +- .../shape/query85.out | 2 +- .../shape/query95.out | 2 +- .../nereids_tpch_shape_sf1000_p0/shape/q9.out | 2 +- .../shape_no_stats/q9.out | 2 +- .../doris/regression/suite/Suite.groovy | 7 +- .../doris/regression/util/OutputUtils.groovy | 28 +- .../regression/util/ReusableIterator.groovy | 7 + 235 files changed, 5053 insertions(+), 1925 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java rename fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/{PatternGeneratorAnalyzer.java => JavaAstAnalyzer.java} (75%) rename fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/{PatternGenerator.java => PlanPatternGenerator.java} (96%) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java diff --git a/fe/fe-core/pom.xml b/fe/fe-core/pom.xml index dae3740eb71d0b..83deec7650c750 100644 --- a/fe/fe-core/pom.xml +++ b/fe/fe-core/pom.xml @@ -1016,7 +1016,7 @@ under the License. only - -AplanPath=${basedir}/src/main/java/org/apache/doris/nereids + -Apath=${basedir}/src/main/java/org/apache/doris/nereids org/apache/doris/nereids/pattern/generator/PatternDescribableProcessPoint.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java index 28ed98df0cb96d..125f1a56c9c057 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java @@ -570,11 +570,14 @@ public boolean isMinValue() { switch (type.getPrimitiveType()) { case DATE: case DATEV2: - return this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0; + return year == 0 && month == 1 && day == 1 + && this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0; case DATETIME: - return this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0; + return year == 0 && month == 1 && day == 1 + && this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0; case DATETIMEV2: - return this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0; + return year == 0 && month == 1 && day == 1 + && this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0; default: return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java index 459e5f360af97f..009b4c3b362bab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java @@ -2454,9 +2454,8 @@ public Set getPartitionKeys() { } public boolean isDupKeysOrMergeOnWrite() { - return getKeysType() == KeysType.DUP_KEYS - || (getKeysType() == KeysType.UNIQUE_KEYS - && getEnableUniqueKeyMergeOnWrite()); + return keysType == KeysType.DUP_KEYS + || (keysType == KeysType.UNIQUE_KEYS && getEnableUniqueKeyMergeOnWrite()); } public void initAutoIncrementGenerator(long dbId) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVRelationManager.java b/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVRelationManager.java index 723deaff7403fe..693bb4b19dedb6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVRelationManager.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVRelationManager.java @@ -67,7 +67,7 @@ private Set getMtmvsByBaseTable(BaseTableInfo table) { * @return */ public Set getAvailableMTMVs(List tableInfos, ConnectContext ctx) { - Set res = Sets.newHashSet(); + Set res = Sets.newLinkedHashSet(); Set mvInfos = getMTMVInfos(tableInfos); for (BaseTableInfo tableInfo : mvInfos) { try { @@ -90,7 +90,7 @@ public boolean isMVPartitionValid(MTMV mtmv, ConnectContext ctx) { } private Set getMTMVInfos(List tableInfos) { - Set mvInfos = Sets.newHashSet(); + Set mvInfos = Sets.newLinkedHashSet(); for (BaseTableInfo tableInfo : tableInfos) { mvInfos.addAll(getMtmvsByBaseTable(tableInfo)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java index f59cfd699f0c80..8724331eb0fa7b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java @@ -380,7 +380,9 @@ public boolean checkTblPriv(String ctl, String db, String tbl, PrivPredicate wan public boolean checkColPriv(String ctl, String db, String tbl, String col, PrivPredicate wanted) { Optional colPrivilege = wanted.getColPrivilege(); - Preconditions.checkState(colPrivilege.isPresent(), "this privPredicate should not use checkColPriv:" + wanted); + if (!colPrivilege.isPresent()) { + throw new IllegalStateException("this privPredicate should not use checkColPriv:" + wanted); + } return checkTblPriv(ctl, db, tbl, wanted) || onlyCheckColPriv(ctl, db, tbl, col, colPrivilege.get()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index 8e4a47938e49c4..60b7c0343a7001 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -76,6 +76,7 @@ import org.apache.logging.log4j.Logger; import java.util.ArrayList; +import java.util.BitSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -134,6 +135,11 @@ public class CascadesContext implements ScheduleContext { // trigger by rule and show by `explain plan process` statement private final List planProcesses = new ArrayList<>(); + // this field is modified by FoldConstantRuleOnFE, it matters current traverse + // into AggregateFunction with distinct, we can not fold constant in this case + private int distinctAggLevel; + private final boolean isEnableExprTrace; + /** * Constructor of OptimizerContext. * @@ -156,6 +162,13 @@ private CascadesContext(Optional parent, Optional curren this.subqueryExprIsAnalyzed = new HashMap<>(); this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable()); this.materializationContexts = new ArrayList<>(); + if (statementContext.getConnectContext() != null) { + ConnectContext connectContext = statementContext.getConnectContext(); + SessionVariable sessionVariable = connectContext.getSessionVariable(); + this.isEnableExprTrace = sessionVariable != null && sessionVariable.isEnableExprTrace(); + } else { + this.isEnableExprTrace = false; + } } /** @@ -256,7 +269,7 @@ public void setTables(List tables) { this.tables = tables.stream().collect(Collectors.toMap(TableIf::getId, t -> t, (t1, t2) -> t1)); } - public ConnectContext getConnectContext() { + public final ConnectContext getConnectContext() { return statementContext.getConnectContext(); } @@ -366,12 +379,18 @@ public T getAndCacheSessionVariable(String cacheName, return defaultValue; } + return getStatementContext().getOrRegisterCache(cacheName, + () -> variableSupplier.apply(connectContext.getSessionVariable())); + } + + /** getAndCacheDisableRules */ + public final BitSet getAndCacheDisableRules() { + ConnectContext connectContext = getConnectContext(); StatementContext statementContext = getStatementContext(); - if (statementContext == null) { - return defaultValue; + if (connectContext == null || statementContext == null) { + return new BitSet(); } - return statementContext.getOrRegisterCache(cacheName, - () -> variableSupplier.apply(connectContext.getSessionVariable())); + return statementContext.getOrCacheDisableRules(connectContext.getSessionVariable()); } private CascadesContext execute(Job job) { @@ -718,8 +737,28 @@ public void keepOrShowPlanProcess(boolean showPlanProcess, Runnable task) { } public void printPlanProcess() { + printPlanProcess(this.planProcesses); + } + + public static void printPlanProcess(List planProcesses) { for (PlanProcess row : planProcesses) { LOG.info("RULE: " + row.ruleName + "\nBEFORE:\n" + row.beforeShape + "\nafter:\n" + row.afterShape); } } + + public void incrementDistinctAggLevel() { + this.distinctAggLevel++; + } + + public void decrementDistinctAggLevel() { + this.distinctAggLevel--; + } + + public int getDistinctAggLevel() { + return distinctAggLevel; + } + + public boolean isEnableExprTrace() { + return isEnableExprTrace; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index eedc77e9df7b52..7457e4de04a4ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -387,7 +387,7 @@ public String getHintExplainString(List hints) { if (hint instanceof DistributeHint) { distributeHintIndex++; if (!hint.getExplainString().equals("")) { - distributeIndex = "_" + String.valueOf(distributeHintIndex); + distributeIndex = "_" + distributeHintIndex; } } switch (hint.getStatus()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index 5c894fd46ef2b9..7b444995120cab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.OriginStatement; +import org.apache.doris.qe.SessionVariable; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; @@ -45,6 +46,7 @@ import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.BitSet; import java.util.Collection; import java.util.Comparator; import java.util.HashMap; @@ -117,6 +119,8 @@ public class StatementContext { // Relation for example LogicalOlapScan private final Map slotToRelation = Maps.newHashMap(); + private BitSet disableRules; + public StatementContext() { this.connectContext = ConnectContext.get(); } @@ -259,11 +263,22 @@ public synchronized T getOrRegisterCache(String key, Supplier cacheSuppli return supplier.get(); } + public synchronized BitSet getOrCacheDisableRules(SessionVariable sessionVariable) { + if (this.disableRules != null) { + return this.disableRules; + } + this.disableRules = sessionVariable.getDisableNereidsRules(); + return this.disableRules; + } + /** * Some value of the cacheKey may change, invalid cache when value change */ public synchronized void invalidCache(String cacheKey) { contextCacheMap.remove(cacheKey); + if (cacheKey.equalsIgnoreCase(SessionVariable.DISABLE_NEREIDS_RULES)) { + this.disableRules = null; + } } public ColumnAliasGenerator getColumnAliasGenerator() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java index a95e562f7e029c..dbcbea7c104b5a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java @@ -26,6 +26,7 @@ import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Objects; @@ -63,6 +64,7 @@ public class Scope { private final List slots; private final Optional ownerSubquery; private final Set correlatedSlots; + private final boolean buildNameToSlot; private final Supplier> nameToSlot; public Scope(List slots) { @@ -75,7 +77,8 @@ public Scope(Optional outerScope, List slots, Optional 500; + this.nameToSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToSlot) : null; } public List getSlots() { @@ -96,7 +99,19 @@ public Set getCorrelatedSlots() { /** findSlotIgnoreCase */ public List findSlotIgnoreCase(String slotName) { - return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT)); + if (!buildNameToSlot) { + Object[] array = new Object[slots.size()]; + int filterIndex = 0; + for (int i = 0; i < slots.size(); i++) { + Slot slot = slots.get(i); + if (slot.getName().equalsIgnoreCase(slotName)) { + array[filterIndex++] = slot; + } + } + return (List) Arrays.asList(array).subList(0, filterIndex); + } else { + return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT)); + } } private ListMultimap buildNameToSlot() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java index a9739cbb9e22ff..41e5e1b8d7e75e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java @@ -34,16 +34,14 @@ import org.apache.doris.nereids.trees.expressions.CTEId; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.qe.ConnectContext; -import org.apache.doris.qe.SessionVariable; import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableSet; +import java.util.BitSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; /** * Abstract class for all job using for analyze and optimize query plan in Nereids. @@ -57,7 +55,7 @@ public abstract class Job implements TracerSupplier { protected JobType type; protected JobContext context; protected boolean once; - protected final Set disableRules; + protected final BitSet disableRules; protected Map cteIdToStats; @@ -129,8 +127,7 @@ protected void countJobExecutionTimesOfGroupExpressions(GroupExpression groupExp groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan())); } - public static Set getDisableRules(JobContext context) { - return context.getCascadesContext().getAndCacheSessionVariable( - SessionVariable.DISABLE_NEREIDS_RULES, ImmutableSet.of(), SessionVariable::getDisableNereidsRules); + public static BitSet getDisableRules(JobContext context) { + return context.getCascadesContext().getAndCacheDisableRules(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 65998416fb0973..a68c7510965b26 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -30,7 +30,7 @@ import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; import org.apache.doris.nereids.rules.expression.ExpressionNormalization; -import org.apache.doris.nereids.rules.expression.ExpressionOptimization; +import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization; import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit; import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType; @@ -152,8 +152,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // such as group by key matching and replaced // but we need to do some normalization before subquery unnesting, // such as extract common expression. - new ExpressionNormalization(), - new ExpressionOptimization(), + new ExpressionNormalizationAndOptimization(), new AvgDistinctToSumDivCount(), new CountDistinctRewrite(), new ExtractFilterFromCrossJoin() @@ -240,7 +239,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // efficient because it can find the new plans and apply transform wherever it is bottomUp(RuleSet.PUSH_DOWN_FILTERS), // after push down, some new filters are generated, which needs to be optimized. (example: tpch q19) - topDown(new ExpressionOptimization()), + // topDown(new ExpressionOptimization()), topDown( new MergeFilters(), new ReorderJoin(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 3b3c2f410c4ce0..5e45fc0bdb8358 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -51,6 +51,7 @@ import java.util.ArrayList; import java.util.BitSet; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -64,7 +65,7 @@ */ public class HyperGraph { // record all edges that can be placed on the subgraph - private final Map treeEdgesCache = new HashMap<>(); + private final Map treeEdgesCache = new LinkedHashMap<>(); private final List joinEdges; private final List filterEdges; private final List nodes; @@ -330,9 +331,9 @@ public static class Builder { private final List nodes = new ArrayList<>(); // These hyperGraphs should be replaced nodes when building all - private final Map> replacedHyperGraphs = new HashMap<>(); - private final HashMap slotToNodeMap = new HashMap<>(); - private final Map> complexProject = new HashMap<>(); + private final Map> replacedHyperGraphs = new LinkedHashMap<>(); + private final HashMap slotToNodeMap = new LinkedHashMap<>(); + private final Map> complexProject = new LinkedHashMap<>(); private Set finalOutputs; public List getNodes() { @@ -522,7 +523,7 @@ private long calNodeMap(Set slots) { */ private BitSet addJoin(LogicalJoin join, Pair leftEdgeNodes, Pair rightEdgeNodes) { - HashMap, Pair, List>> conjuncts = new HashMap<>(); + Map, Pair, List>> conjuncts = new LinkedHashMap<>(); for (Expression expression : join.getHashJoinConjuncts()) { // TODO: avoid calling calculateEnds if calNodeMap's results are same Pair ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java index 35e04b9f33fd85..0e58f1bc976ba5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java @@ -25,8 +25,8 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import java.util.BitSet; import java.util.Objects; -import java.util.Set; import java.util.function.Supplier; /** @@ -50,8 +50,8 @@ public CustomRewriteJob(Supplier rewriter, RuleType ruleType) { @Override public void execute(JobContext context) { - Set disableRules = Job.getDisableRules(context); - if (disableRules.contains(ruleType.type())) { + BitSet disableRules = Job.getDisableRules(context); + if (disableRules.get(ruleType.type())) { return; } CascadesContext cascadesContext = context.getCascadesContext(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java index 4f623e5450060f..60555a9cc04ad6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java @@ -39,9 +39,9 @@ public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob { // Different 'RewriteState' has different actions, // so we will do specified action for each node based on their 'RewriteState'. private static final String REWRITE_STATE_KEY = "rewrite_state"; - private final RewriteJobContext rewriteJobContext; private final List rules; + private final int batchId; enum RewriteState { // 'REWRITE_THIS' means the current plan node can be handled immediately. If the plan state is 'REWRITE_THIS', @@ -59,22 +59,15 @@ public PlanTreeRewriteBottomUpJob(RewriteJobContext rewriteJobContext, JobContex super(JobType.BOTTOM_UP_REWRITE, context); this.rewriteJobContext = Objects.requireNonNull(rewriteJobContext, "rewriteContext cannot be null"); this.rules = Objects.requireNonNull(rules, "rules cannot be null"); + this.batchId = rewriteJobContext.batchId; } @Override public void execute() { - // For the bottom-up rewrite job, we need to reset the state of its children - // if the plan has changed after the rewrite. So we use the 'childrenVisited' to check this situation. - boolean clearStatePhase = !rewriteJobContext.childrenVisited; - if (clearStatePhase) { - traverseClearState(); - return; - } - // We'll do different actions based on their different states. // You can check the comment in 'RewriteState' structure for more details. Plan plan = rewriteJobContext.plan; - RewriteState state = getState(plan); + RewriteState state = getState(plan, batchId); switch (state) { case REWRITE_THIS: rewriteThis(); @@ -90,33 +83,13 @@ public void execute() { } } - private void traverseClearState() { - // Reset the state for current node. - RewriteJobContext clearedStateContext = rewriteJobContext.withChildrenVisited(true); - setState(clearedStateContext.plan, RewriteState.REWRITE_THIS); - pushJob(new PlanTreeRewriteBottomUpJob(clearedStateContext, context, rules)); - - // Generate the new rewrite job for its children. Because the character of stack is 'first in, last out', - // so we can traverse reset the state for the plan node until the leaf node. - List children = clearedStateContext.plan.children(); - for (int i = children.size() - 1; i >= 0; i--) { - Plan child = children.get(i); - RewriteJobContext childRewriteJobContext = new RewriteJobContext( - child, clearedStateContext, i, false); - // NOTICE: this relay on pull up cte anchor - if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { - pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules)); - } - } - } - private void rewriteThis() { // Link the current node with the sub-plan to get the current plan which is used in the rewrite phase later. Plan plan = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext); RewriteResult rewriteResult = rewrite(plan, rules, rewriteJobContext); if (rewriteResult.hasNewPlan) { RewriteJobContext newJobContext = rewriteJobContext.withPlan(rewriteResult.plan); - RewriteState state = getState(rewriteResult.plan); + RewriteState state = getState(rewriteResult.plan, batchId); // Some eliminate rule will return a rewritten plan, for example the current node is eliminated // and return the child plan. So we don't need to handle it again. if (state == RewriteState.REWRITTEN) { @@ -125,40 +98,82 @@ private void rewriteThis() { } // After the rewrite take effect, we should handle the children part again. pushJob(new PlanTreeRewriteBottomUpJob(newJobContext, context, rules)); - setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN); + setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN, batchId); } else { // No new plan is generated, so just set the state of the current plan to 'REWRITTEN'. - setState(rewriteResult.plan, RewriteState.REWRITTEN); + setState(rewriteResult.plan, RewriteState.REWRITTEN, batchId); rewriteJobContext.setResult(rewriteResult.plan); } } private void ensureChildrenRewritten() { - // Similar to the function 'traverseClearState'. Plan plan = rewriteJobContext.plan; - setState(plan, RewriteState.REWRITE_THIS); + int batchId = rewriteJobContext.batchId; + setState(plan, RewriteState.REWRITE_THIS, batchId); pushJob(new PlanTreeRewriteBottomUpJob(rewriteJobContext, context, rules)); + // some rule return new plan tree, which the number of new plan node > 1, + // we should transform this new plan nodes too. + // NOTICE: this relay on pull up cte anchor + if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { + pushChildrenJobs(plan); + } + } + + private void pushChildrenJobs(Plan plan) { List children = plan.children(); - for (int i = children.size() - 1; i >= 0; i--) { - Plan child = children.get(i); - // some rule return new plan tree, which the number of new plan node > 1, - // we should transform this new plan nodes too. - RewriteJobContext childRewriteJobContext = new RewriteJobContext( - child, rewriteJobContext, i, false); - // NOTICE: this relay on pull up cte anchor - if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { + switch (children.size()) { + case 0: return; + case 1: + Plan child = children.get(0); + RewriteJobContext childRewriteJobContext = new RewriteJobContext( + child, rewriteJobContext, 0, false, batchId); pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules)); - } + return; + case 2: + Plan right = children.get(1); + RewriteJobContext rightRewriteJobContext = new RewriteJobContext( + right, rewriteJobContext, 1, false, batchId); + pushJob(new PlanTreeRewriteBottomUpJob(rightRewriteJobContext, context, rules)); + + Plan left = children.get(0); + RewriteJobContext leftRewriteJobContext = new RewriteJobContext( + left, rewriteJobContext, 0, false, batchId); + pushJob(new PlanTreeRewriteBottomUpJob(leftRewriteJobContext, context, rules)); + return; + default: + for (int i = children.size() - 1; i >= 0; i--) { + child = children.get(i); + childRewriteJobContext = new RewriteJobContext( + child, rewriteJobContext, i, false, batchId); + pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules)); + } + } + } + + private static RewriteState getState(Plan plan, int currentBatchId) { + Optional state = plan.getMutableState(REWRITE_STATE_KEY); + if (!state.isPresent()) { + return RewriteState.ENSURE_CHILDREN_REWRITTEN; + } + RewriteStateContext context = state.get(); + if (context.batchId != currentBatchId) { + return RewriteState.ENSURE_CHILDREN_REWRITTEN; } + return context.rewriteState; } - private static final RewriteState getState(Plan plan) { - Optional state = plan.getMutableState(REWRITE_STATE_KEY); - return state.orElse(RewriteState.ENSURE_CHILDREN_REWRITTEN); + private static void setState(Plan plan, RewriteState state, int batchId) { + plan.setMutableState(REWRITE_STATE_KEY, new RewriteStateContext(state, batchId)); } - private static final void setState(Plan plan, RewriteState state) { - plan.setMutableState(REWRITE_STATE_KEY, state); + private static class RewriteStateContext { + private final RewriteState rewriteState; + private final int batchId; + + public RewriteStateContext(RewriteState rewriteState, int batchId) { + this.rewriteState = rewriteState; + this.batchId = batchId; + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java index affbb9196cc3d5..5e5acc29f66edb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java @@ -28,6 +28,8 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.trees.plans.Plan; +import com.google.common.collect.ImmutableList; + import java.util.List; /** PlanTreeRewriteJob */ @@ -43,7 +45,7 @@ protected final RewriteResult rewrite(Plan plan, List rules, RewriteJobCon boolean showPlanProcess = cascadesContext.showPlanProcess(); for (Rule rule : rules) { - if (disableRules.contains(rule.getRuleType().type())) { + if (disableRules.get(rule.getRuleType().type())) { continue; } Pattern pattern = (Pattern) rule.getPattern(); @@ -76,26 +78,50 @@ protected final RewriteResult rewrite(Plan plan, List rules, RewriteJobCon return new RewriteResult(false, plan); } - protected final Plan linkChildrenAndParent(Plan plan, RewriteJobContext rewriteJobContext) { - Plan newPlan = linkChildren(plan, rewriteJobContext.childrenContext); - rewriteJobContext.setResult(newPlan); - return newPlan; - } - - protected final Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) { - boolean changed = false; - Plan[] newChildren = new Plan[childrenContext.length]; - for (int i = 0; i < childrenContext.length; ++i) { - Plan result = childrenContext[i].result; - Plan oldChild = plan.child(i); - if (result != null && result != oldChild) { - newChildren[i] = result; - changed = true; - } else { - newChildren[i] = oldChild; + protected static Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) { + List children = plan.children(); + // loop unrolling + switch (children.size()) { + case 0: { + return plan; + } + case 1: { + RewriteJobContext child = childrenContext[0]; + Plan firstResult = child == null ? plan.child(0) : child.result; + return firstResult == null || firstResult == children.get(0) + ? plan : plan.withChildren(ImmutableList.of(firstResult)); + } + case 2: { + RewriteJobContext left = childrenContext[0]; + Plan firstResult = left == null ? plan.child(0) : left.result; + RewriteJobContext right = childrenContext[1]; + Plan secondResult = right == null ? plan.child(1) : right.result; + Plan firstOrigin = children.get(0); + Plan secondOrigin = children.get(1); + boolean firstChanged = firstResult != null && firstResult != firstOrigin; + boolean secondChanged = secondResult != null && secondResult != secondOrigin; + if (firstChanged || secondChanged) { + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(2); + newChildren.add(firstChanged ? firstResult : firstOrigin); + newChildren.add(secondChanged ? secondResult : secondOrigin); + return plan.withChildren(newChildren.build()); + } else { + return plan; + } + } + default: { + boolean changed = false; + int i = 0; + Plan[] newChildren = new Plan[childrenContext.length]; + for (Plan oldChild : children) { + Plan result = childrenContext[i].result; + changed = result != null && result != oldChild; + newChildren[i] = changed ? result : oldChild; + i++; + } + return changed ? plan.withChildren(newChildren) : plan; } } - return changed ? plan.withChildren(newChildren) : plan; } private String getCurrentPlanTreeString() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java index d8dba41b3788bd..14019bc885e0d0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java @@ -56,21 +56,44 @@ public void execute() { RewriteJobContext newRewriteJobContext = rewriteJobContext.withChildrenVisited(true); pushJob(new PlanTreeRewriteTopDownJob(newRewriteJobContext, context, rules)); - List children = newRewriteJobContext.plan.children(); - for (int i = children.size() - 1; i >= 0; i--) { - RewriteJobContext childRewriteJobContext = new RewriteJobContext( - children.get(i), newRewriteJobContext, i, false); - // NOTICE: this relay on pull up cte anchor - if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { - pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules)); - } + // NOTICE: this relay on pull up cte anchor + if (!(this.rewriteJobContext.plan instanceof LogicalCTEAnchor)) { + pushChildrenJobs(newRewriteJobContext); } } else { // All the children part are already visited. Just link the children plan to the current node. - Plan result = linkChildrenAndParent(rewriteJobContext.plan, rewriteJobContext); + Plan result = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext); + rewriteJobContext.setResult(result); if (rewriteJobContext.parentContext == null) { context.getCascadesContext().setRewritePlan(result); } } } + + private void pushChildrenJobs(RewriteJobContext rewriteJobContext) { + List children = rewriteJobContext.plan.children(); + switch (children.size()) { + case 0: return; + case 1: + RewriteJobContext childRewriteJobContext = new RewriteJobContext( + children.get(0), rewriteJobContext, 0, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules)); + return; + case 2: + RewriteJobContext rightRewriteJobContext = new RewriteJobContext( + children.get(1), rewriteJobContext, 1, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(rightRewriteJobContext, context, rules)); + + RewriteJobContext leftRewriteJobContext = new RewriteJobContext( + children.get(0), rewriteJobContext, 0, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(leftRewriteJobContext, context, rules)); + return; + default: + for (int i = children.size() - 1; i >= 0; i--) { + childRewriteJobContext = new RewriteJobContext( + children.get(i), rewriteJobContext, i, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules)); + } + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java index fb0475f7a61a3b..060bb8edd62838 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java @@ -25,6 +25,7 @@ public class RewriteJobContext { final boolean childrenVisited; + final int batchId; final RewriteJobContext parentContext; final int childIndexInParentContext; final Plan plan; @@ -33,7 +34,7 @@ public class RewriteJobContext { /** RewriteJobContext */ public RewriteJobContext(Plan plan, @Nullable RewriteJobContext parentContext, int childIndexInParentContext, - boolean childrenVisited) { + boolean childrenVisited, int batchId) { this.plan = plan; this.parentContext = parentContext; this.childIndexInParentContext = childIndexInParentContext; @@ -42,6 +43,7 @@ public RewriteJobContext(Plan plan, @Nullable RewriteJobContext parentContext, i if (parentContext != null) { parentContext.childrenContext[childIndexInParentContext] = this; } + this.batchId = batchId; } public void setResult(Plan result) { @@ -49,15 +51,15 @@ public void setResult(Plan result) { } public RewriteJobContext withChildrenVisited(boolean childrenVisited) { - return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited); + return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId); } public RewriteJobContext withPlan(Plan plan) { - return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited); + return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId); } public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) { - return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited); + return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId); } public boolean isRewriteRoot() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java index 6bc055a68aa976..d352dfee4a0b20 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java @@ -27,9 +27,11 @@ import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; /** RootPlanTreeRewriteJob */ public class RootPlanTreeRewriteJob implements RewriteJob { + private static final AtomicInteger BATCH_ID = new AtomicInteger(); private final List rules; private final RewriteJobBuilder rewriteJobBuilder; @@ -47,7 +49,9 @@ public void execute(JobContext context) { // get plan from the cascades context Plan root = cascadesContext.getRewritePlan(); // write rewritten root plan to cascades context by the RootRewriteJobContext - RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext(root, false, context); + int batchId = BATCH_ID.incrementAndGet(); + RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext( + root, false, context, batchId); Job rewriteJob = rewriteJobBuilder.build(rewriteJobContext, context, rules); context.getScheduleContext().pushJob(rewriteJob); @@ -71,8 +75,8 @@ public static class RootRewriteJobContext extends RewriteJobContext { private final JobContext jobContext; - RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext) { - super(plan, null, -1, childrenVisited); + RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext, int batchId) { + super(plan, null, -1, childrenVisited, batchId); this.jobContext = Objects.requireNonNull(jobContext, "jobContext cannot be null"); jobContext.getCascadesContext().setCurrentRootRewriteJobContext(this); } @@ -89,17 +93,17 @@ public void setResult(Plan result) { @Override public RewriteJobContext withChildrenVisited(boolean childrenVisited) { - return new RootRewriteJobContext(plan, childrenVisited, jobContext); + return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId); } @Override public RewriteJobContext withPlan(Plan plan) { - return new RootRewriteJobContext(plan, childrenVisited, jobContext); + return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId); } @Override public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) { - return new RootRewriteJobContext(plan, childrenVisited, jobContext); + return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId); } /** linkChildren */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java new file mode 100644 index 00000000000000..523540e6435d89 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatchRule; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Expression; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +/** ExpressionPatternMapping */ +public class ExpressionPatternRules extends TypeMappings { + private static final Logger LOG = LogManager.getLogger(ExpressionPatternRules.class); + + public ExpressionPatternRules(List typeMappings) { + super(typeMappings); + } + + @Override + protected Set> getChildrenClasses(Class clazz) { + return org.apache.doris.nereids.pattern.GeneratedExpressionRelations.CHILDREN_CLASS_MAP.get(clazz); + } + + /** matchesAndApply */ + public Optional matchesAndApply(Expression expr, ExpressionRewriteContext context, Expression parent) { + List rules = singleMappings.get(expr.getClass()); + ExpressionMatchingContext matchingContext + = new ExpressionMatchingContext<>(expr, parent, context); + switch (rules.size()) { + case 0: { + for (ExpressionPatternMatchRule multiMatchRule : multiMappings) { + if (multiMatchRule.matchesTypeAndPredicates(matchingContext)) { + Expression newExpr = multiMatchRule.apply(matchingContext); + if (!newExpr.equals(expr)) { + if (context.cascadesContext.isEnableExprTrace()) { + traceExprChanged(multiMatchRule, expr, newExpr); + } + return Optional.of(newExpr); + } + } + } + return Optional.empty(); + } + case 1: { + ExpressionPatternMatchRule rule = rules.get(0); + if (rule.matchesPredicates(matchingContext)) { + Expression newExpr = rule.apply(matchingContext); + if (!newExpr.equals(expr)) { + if (context.cascadesContext.isEnableExprTrace()) { + traceExprChanged(rule, expr, newExpr); + } + return Optional.of(newExpr); + } + } + return Optional.empty(); + } + default: { + for (ExpressionPatternMatchRule rule : rules) { + if (rule.matchesPredicates(matchingContext)) { + Expression newExpr = rule.apply(matchingContext); + if (!expr.equals(newExpr)) { + if (context.cascadesContext.isEnableExprTrace()) { + traceExprChanged(rule, expr, newExpr); + } + return Optional.of(newExpr); + } + } + } + return Optional.empty(); + } + } + } + + private static void traceExprChanged(ExpressionPatternMatchRule rule, Expression expr, Expression newExpr) { + try { + Field[] declaredFields = (rule.matchingAction).getClass().getDeclaredFields(); + Class ruleClass; + if (declaredFields.length == 0) { + ruleClass = rule.matchingAction.getClass(); + } else { + Field field = declaredFields[0]; + field.setAccessible(true); + ruleClass = field.get(rule.matchingAction).getClass(); + } + LOG.info("RULE: " + ruleClass + "\nbefore: " + expr + "\nafter: " + newExpr); + } catch (Throwable t) { + LOG.error(t.getMessage(), t); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java new file mode 100644 index 00000000000000..3f3640a43bf8b2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListener; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListenerMapping; +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; + +/** ExpressionPatternTraverseListeners */ +public class ExpressionPatternTraverseListeners + extends TypeMappings { + public ExpressionPatternTraverseListeners( + List typeMappings) { + super(typeMappings); + } + + @Override + protected Set> getChildrenClasses(Class clazz) { + return org.apache.doris.nereids.pattern.GeneratedExpressionRelations.CHILDREN_CLASS_MAP.get(clazz); + } + + /** matchesAndCombineListener */ + public @Nullable CombinedListener matchesAndCombineListeners( + Expression expr, ExpressionRewriteContext context, Expression parent) { + List listenerSingleMappings = singleMappings.get(expr.getClass()); + ExpressionMatchingContext matchingContext + = new ExpressionMatchingContext<>(expr, parent, context); + switch (listenerSingleMappings.size()) { + case 0: { + ImmutableList.Builder> matchedListeners + = ImmutableList.builder(); + for (ExpressionTraverseListenerMapping multiMapping : multiMappings) { + if (multiMapping.matchesTypeAndPredicates(matchingContext)) { + matchedListeners.add(multiMapping.listener); + } + } + return CombinedListener.tryCombine(matchedListeners.build(), matchingContext); + } + case 1: { + ExpressionTraverseListenerMapping listenerMapping = listenerSingleMappings.get(0); + if (listenerMapping.matchesPredicates(matchingContext)) { + return CombinedListener.tryCombine(ImmutableList.of(listenerMapping.listener), matchingContext); + } + return null; + } + default: { + ImmutableList.Builder> matchedListeners + = ImmutableList.builder(); + for (ExpressionTraverseListenerMapping singleMapping : listenerSingleMappings) { + if (singleMapping.matchesPredicates(matchingContext)) { + matchedListeners.add(singleMapping.listener); + } + } + return CombinedListener.tryCombine(matchedListeners.build(), matchingContext); + } + } + } + + /** CombinedListener */ + public static class CombinedListener { + private final ExpressionMatchingContext context; + private final List> listeners; + + /** CombinedListener */ + public CombinedListener(ExpressionMatchingContext context, + List> listeners) { + this.context = context; + this.listeners = listeners; + } + + public static @Nullable CombinedListener tryCombine( + List> listenerMappings, + ExpressionMatchingContext context) { + return listenerMappings.isEmpty() ? null : new CombinedListener(context, listenerMappings); + } + + public void onEnter() { + for (ExpressionTraverseListener listener : listeners) { + listener.onEnter(context); + } + } + + public void onExit(Expression rewritten) { + for (ExpressionTraverseListener listener : listeners) { + listener.onExit(context, rewritten); + } + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java new file mode 100644 index 00000000000000..b4623e105238b7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.trees.expressions.LessThanEqual; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** ParentTypeIdMapping */ +public class ParentTypeIdMapping { + + private final AtomicInteger idGenerator = new AtomicInteger(); + private final Map, Integer> classId = new ConcurrentHashMap<>(8192); + + /** getId */ + public int getId(Class clazz) { + Integer id = classId.get(clazz); + if (id != null) { + return id; + } + return ensureClassHasId(clazz); + } + + private int ensureClassHasId(Class clazz) { + Class superClass = clazz.getSuperclass(); + if (superClass != null) { + ensureClassHasId(superClass); + } + + for (Class interfaceClass : clazz.getInterfaces()) { + ensureClassHasId(interfaceClass); + } + + return classId.computeIfAbsent(clazz, c -> idGenerator.incrementAndGet()); + } + + public static void main(String[] args) { + ParentTypeIdMapping mapping = new ParentTypeIdMapping(); + int id = mapping.getId(LessThanEqual.class); + System.out.println(id); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java index c47dcd6a725be1..91dd87ba457837 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java @@ -152,6 +152,10 @@ public boolean matchPlanTree(Plan plan) { if (this instanceof SubTreePattern) { return matchPredicates((TYPE) plan); } + return matchChildrenAndSelfPredicates(plan, childPatternNum); + } + + private boolean matchChildrenAndSelfPredicates(Plan plan, int childPatternNum) { List childrenPlan = plan.children(); for (int i = 0; i < childrenPlan.size(); i++) { Plan child = childrenPlan.get(i); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java new file mode 100644 index 00000000000000..4eb5ffc76d22a2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping; +import org.apache.doris.nereids.util.Utils; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Lists; + +import java.lang.reflect.Modifier; +import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; + +/** ExpressionPatternMappings */ +public abstract class TypeMappings> { + protected final ListMultimap, T> singleMappings; + protected final List multiMappings; + + /** ExpressionPatternMappings */ + public TypeMappings(List typeMappings) { + this.singleMappings = ArrayListMultimap.create(); + this.multiMappings = Lists.newArrayList(); + + for (T mapping : typeMappings) { + Set> childrenClasses = getChildrenClasses(mapping.getType()); + if (childrenClasses == null || childrenClasses.isEmpty()) { + // add some expressions which no child class + // e.g. LessThanEqual + addSimpleMapping(mapping); + } else if (childrenClasses.size() <= 100) { + // add some expressions which have children classes + // e.g. ComparisonPredicate will be expanded to + // ruleMappings.put(LessThanEqual.class, rule); + // ruleMappings.put(LessThan.class, rule); + // ruleMappings.put(GreaterThan.class, rule); + // ruleMappings.put(GreaterThanEquals.class, rule); + // ... + addThisAndChildrenMapping(mapping, childrenClasses); + } else { + // some expressions have lots of children classes, e.g. Expression, ExpressionTrait, BinaryExpression, + // we will not expand this types to child class, but also add this rules to other type matching. + // for example, if we have three rules to matches this types: LessThanEqual, Abs and Expression, + // then the ruleMappings would be: + // { + // LessThanEqual.class: [rule_of_LessThanEqual, rule_of_Expression], + // Abs.class: [rule_of_Abs, rule_of_Expression] + // } + // + // and the multiMatchRules would be: [rule_of_Expression] + // + // if we matches `a <= 1`, there have two rules would be applied because + // ruleMappings.get(LessThanEqual.class) return two rules; + // if we matches `a = 1`, ruleMappings.get(EqualTo.class) will return empty rules, so we use + // all the rules in multiMatchRules to matches and apply, the rule_of_Expression will be applied. + addMultiMapping(mapping); + } + } + } + + public @Nullable List get(Class clazz) { + return singleMappings.get(clazz); + } + + private void addSimpleMapping(T typeMapping) { + Class clazz = typeMapping.getType(); + int modifiers = clazz.getModifiers(); + if (!Modifier.isAbstract(modifiers)) { + addSingleMapping(clazz, typeMapping); + } + } + + private void addThisAndChildrenMapping( + T typeMapping, Set> childrenClasses) { + Class clazz = typeMapping.getType(); + if (!Modifier.isAbstract(clazz.getModifiers())) { + addSingleMapping(clazz, typeMapping); + } + + for (Class childrenClass : childrenClasses) { + if (!Modifier.isAbstract(childrenClass.getModifiers())) { + addSingleMapping(childrenClass, typeMapping); + } + } + } + + private void addMultiMapping(T multiMapping) { + multiMappings.add(multiMapping); + + Set> existSingleMappingTypes = Utils.fastToImmutableSet(singleMappings.keySet()); + for (Class existSingleType : existSingleMappingTypes) { + Class type = multiMapping.getType(); + if (type.isAssignableFrom(existSingleType)) { + singleMappings.put(existSingleType, multiMapping); + } + } + } + + private void addSingleMapping(Class clazz, T singleMapping) { + if (!singleMappings.containsKey(clazz) && !multiMappings.isEmpty()) { + for (T multiMapping : multiMappings) { + if (multiMapping.getType().isAssignableFrom(clazz)) { + singleMappings.put(clazz, multiMapping); + } + } + } + singleMappings.put(clazz, singleMapping); + } + + protected abstract Set> getChildrenClasses(Class clazz); + + /** TypeMapping */ + public interface TypeMapping { + Class getType(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java new file mode 100644 index 00000000000000..c5a923153dfeea --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern.generator; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.processing.ProcessingEnvironment; +import javax.tools.StandardLocation; + +/** ExpressionTypeMappingGenerator */ +public class ExpressionTypeMappingGenerator { + private final JavaAstAnalyzer analyzer; + + public ExpressionTypeMappingGenerator(JavaAstAnalyzer javaAstAnalyzer) { + this.analyzer = javaAstAnalyzer; + } + + public JavaAstAnalyzer getAnalyzer() { + return analyzer; + } + + /** generate */ + public void generate(ProcessingEnvironment processingEnv) throws IOException { + Set superExpressions = findSuperExpression(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Map> parentNameMap = analyzer.getParentNameMap(); + String code = generateCode(childrenNameMap, parentNameMap, superExpressions); + generateFile(processingEnv, code); + } + + private void generateFile(ProcessingEnvironment processingEnv, String code) throws IOException { + File generatePatternFile = new File(processingEnv.getFiler() + .getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern", + "GeneratedExpressionRelations.java").toUri()); + if (generatePatternFile.exists()) { + generatePatternFile.delete(); + } + if (!generatePatternFile.getParentFile().exists()) { + generatePatternFile.getParentFile().mkdirs(); + } + + // bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term + try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) { + bufferedWriter.write(code); + } + } + + private Set findSuperExpression() { + Map> parentNameMap = analyzer.getParentNameMap(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Set superExpressions = Sets.newLinkedHashSet(); + for (Entry> entry : childrenNameMap.entrySet()) { + String parentName = entry.getKey(); + Set childrenNames = entry.getValue(); + + if (parentName.startsWith("org.apache.doris.nereids.trees.expressions.")) { + for (String childrenName : childrenNames) { + Set parentNames = parentNameMap.get(childrenName); + if (parentNames != null + && parentNames.contains("org.apache.doris.nereids.trees.expressions.Expression")) { + superExpressions.add(parentName); + break; + } + } + } + } + return superExpressions; + } + + private String generateCode(Map> childrenNameMap, + Map> parentNameMap, Set superExpressions) { + String generateCode + = "// Licensed to the Apache Software Foundation (ASF) under one\n" + + "// or more contributor license agreements. See the NOTICE file\n" + + "// distributed with this work for additional information\n" + + "// regarding copyright ownership. The ASF licenses this file\n" + + "// to you under the Apache License, Version 2.0 (the\n" + + "// \"License\"); you may not use this file except in compliance\n" + + "// with the License. You may obtain a copy of the License at\n" + + "//\n" + + "// http://www.apache.org/licenses/LICENSE-2.0\n" + + "//\n" + + "// Unless required by applicable law or agreed to in writing,\n" + + "// software distributed under the License is distributed on an\n" + + "// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n" + + "// KIND, either express or implied. See the License for the\n" + + "// specific language governing permissions and limitations\n" + + "// under the License.\n" + + "\n" + + "package org.apache.doris.nereids.pattern;\n" + + "\n" + + "import org.apache.doris.nereids.trees.expressions.Expression;\n" + + "\n" + + "import com.google.common.collect.ImmutableMap;\n" + + "import com.google.common.collect.ImmutableSet;\n" + + "\n" + + "import java.util.Map;\n" + + "import java.util.Set;\n" + + "\n"; + generateCode += "/** GeneratedExpressionRelations */\npublic class GeneratedExpressionRelations {\n"; + String childrenClassesGenericType = ", Set>>"; + generateCode += + " public static final Map" + childrenClassesGenericType + " CHILDREN_CLASS_MAP;\n\n"; + generateCode += + " static {\n" + + " ImmutableMap.Builder" + childrenClassesGenericType + " childrenClassesBuilder\n" + + " = ImmutableMap.builderWithExpectedSize(" + childrenNameMap.size() + ");\n"; + + for (String superExpression : superExpressions) { + Set childrenClasseSet = childrenNameMap.get(superExpression) + .stream() + .filter(childClass -> parentNameMap.get(childClass) + .contains("org.apache.doris.nereids.trees.expressions.Expression") + ) + .collect(Collectors.toSet()); + + List childrenClasses = Lists.newArrayList(childrenClasseSet); + Collections.sort(childrenClasses, Comparator.naturalOrder()); + + String childClassesString = childrenClasses.stream() + .map(childClass -> " " + childClass + ".class") + .collect(Collectors.joining(",\n")); + generateCode += " childrenClassesBuilder.put(\n " + superExpression + + ".class,\n ImmutableSet.>of(\n" + childClassesString + + "\n )\n );\n\n"; + } + + generateCode += " CHILDREN_CLASS_MAP = childrenClassesBuilder.build();\n"; + + return generateCode + " }\n}\n"; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGeneratorAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/JavaAstAnalyzer.java similarity index 75% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGeneratorAnalyzer.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/JavaAstAnalyzer.java index f4a9d128087ae8..cce69151ca2ab7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGeneratorAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/JavaAstAnalyzer.java @@ -29,25 +29,24 @@ import com.google.common.base.Joiner; -import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.IdentityHashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; -/** - * used to analyze plan class extends hierarchy and then generated pattern builder methods. - */ -public class PatternGeneratorAnalyzer { - private final Map name2Ast = new LinkedHashMap<>(); - private final IdentityHashMap ast2Name = new IdentityHashMap<>(); - private final IdentityHashMap> ast2Import = new IdentityHashMap<>(); - private final IdentityHashMap> parentClassMap = new IdentityHashMap<>(); +/** JavaAstAnalyzer */ +public class JavaAstAnalyzer { + protected final Map name2Ast = new LinkedHashMap<>(); + protected final IdentityHashMap ast2Name = new IdentityHashMap<>(); + protected final IdentityHashMap> ast2Import = new IdentityHashMap<>(); + protected final IdentityHashMap> parentClassMap = new IdentityHashMap<>(); + protected final Map> parentNameMap = new LinkedHashMap<>(); + protected final Map> childrenNameMap = new LinkedHashMap<>(); /** add java AST. */ public void addAsts(List typeDeclarations) { @@ -56,14 +55,20 @@ public void addAsts(List typeDeclarations) { } } - /** generate pattern methods. */ - public String generatePatterns(String className, String parentClassName, boolean isMemoPattern) { - analyzeImport(); - analyzeParentClass(); - return doGenerate(className, parentClassName, isMemoPattern); + public IdentityHashMap> getParentClassMap() { + return parentClassMap; + } + + public Map> getParentNameMap() { + return parentNameMap; } - Optional getType(TypeDeclaration typeDeclaration, TypeType type) { + public Map> getChildrenNameMap() { + return childrenNameMap; + } + + /** getType */ + public Optional getType(TypeDeclaration typeDeclaration, TypeType type) { String typeName = analyzeClass(new LinkedHashSet<>(), typeDeclaration, type); if (typeName != null) { TypeDeclaration ast = name2Ast.get(typeName); @@ -73,34 +78,11 @@ Optional getType(TypeDeclaration typeDeclaration, TypeType type return Optional.empty(); } - private String doGenerate(String className, String parentClassName, boolean isMemoPattern) { - Map> planClassMap = parentClassMap.entrySet().stream() - .filter(kv -> kv.getValue().contains("org.apache.doris.nereids.trees.plans.Plan")) - .filter(kv -> !kv.getKey().name.equals("GroupPlan")) - .filter(kv -> !Modifier.isAbstract(kv.getKey().modifiers.mod) - && kv.getKey() instanceof ClassDeclaration) - .collect(Collectors.toMap(kv -> (ClassDeclaration) kv.getKey(), kv -> kv.getValue())); - - List generators = planClassMap.entrySet() - .stream() - .map(kv -> PatternGenerator.create(this, kv.getKey(), kv.getValue(), isMemoPattern)) - .filter(Optional::isPresent) - .map(Optional::get) - .sorted((g1, g2) -> { - // logical first - if (g1.isLogical() != g2.isLogical()) { - return g1.isLogical() ? -1 : 1; - } - // leaf first - if (g1.childrenNum() != g2.childrenNum()) { - return g1.childrenNum() - g2.childrenNum(); - } - // string dict sort - return g1.opType.name.compareTo(g2.opType.name); - }) - .collect(Collectors.toList()); - - return PatternGenerator.generateCode(className, parentClassName, generators, this, isMemoPattern); + protected void analyze() { + analyzeImport(); + analyzeParentClass(); + analyzeParentName(); + analyzeChildrenName(); } private void analyzeImport() { @@ -148,7 +130,28 @@ private void analyzeParentClass(Set parentClasses, TypeDeclaration typeD parentClasses.addAll(currentParentClasses); } - String analyzeClass(Set parentClasses, TypeDeclaration typeDeclaration, TypeType type) { + private void analyzeParentName() { + for (Entry> entry : parentClassMap.entrySet()) { + String parentName = entry.getKey().getFullQualifiedName(); + parentNameMap.put(parentName, entry.getValue()); + } + } + + private void analyzeChildrenName() { + for (Entry entry : name2Ast.entrySet()) { + Set parentNames = parentClassMap.get(entry.getValue()); + for (String parentName : parentNames) { + Set childrenNames = childrenNameMap.get(parentName); + if (childrenNames == null) { + childrenNames = new LinkedHashSet<>(); + childrenNameMap.put(parentName, childrenNames); + } + childrenNames.add(entry.getKey()); + } + } + } + + private String analyzeClass(Set parentClasses, TypeDeclaration typeDeclaration, TypeType type) { if (type.classOrInterfaceType.isPresent()) { List identifiers = new ArrayList<>(); ClassOrInterfaceType classOrInterfaceType = type.classOrInterfaceType.get(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java index bec3efa270a7f7..8e05a87ad7dc0e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for LogicalBinary. */ -public class LogicalBinaryPatternGenerator extends PatternGenerator { +public class LogicalBinaryPatternGenerator extends PlanPatternGenerator { - public LogicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public LogicalBinaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java index fd7b30a8e6f112..b82ac81d42077a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for LogicalLeaf. */ -public class LogicalLeafPatternGenerator extends PatternGenerator { +public class LogicalLeafPatternGenerator extends PlanPatternGenerator { - public LogicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer, + public LogicalLeafPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java index 8ecb7c14e1005c..d2f2b61bf96d71 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for LogicalUnary. */ -public class LogicalUnaryPatternGenerator extends PatternGenerator { +public class LogicalUnaryPatternGenerator extends PlanPatternGenerator { - public LogicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public LogicalUnaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java index 42cf82e3c01414..5ba81bbb96bc93 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java @@ -60,12 +60,12 @@ @SupportedSourceVersion(SourceVersion.RELEASE_8) @SupportedAnnotationTypes("org.apache.doris.nereids.pattern.generator.PatternDescribable") public class PatternDescribableProcessor extends AbstractProcessor { - private List planPaths; + private List paths; @Override public synchronized void init(ProcessingEnvironment processingEnv) { super.init(processingEnv); - this.planPaths = Arrays.stream(processingEnv.getOptions().get("planPath").split(",")) + this.paths = Arrays.stream(processingEnv.getOptions().get("path").split(",")) .map(path -> path.trim()) .filter(path -> !path.isEmpty()) .collect(Collectors.toSet()) @@ -80,15 +80,25 @@ public boolean process(Set annotations, RoundEnvironment return false; } try { - List planFiles = findJavaFiles(planPaths); - PatternGeneratorAnalyzer patternGeneratorAnalyzer = new PatternGeneratorAnalyzer(); - for (File file : planFiles) { + List javaFiles = findJavaFiles(paths); + JavaAstAnalyzer javaAstAnalyzer = new JavaAstAnalyzer(); + for (File file : javaFiles) { List asts = parseJavaFile(file); - patternGeneratorAnalyzer.addAsts(asts); + javaAstAnalyzer.addAsts(asts); } - doGenerate("GeneratedMemoPatterns", "MemoPatterns", true, patternGeneratorAnalyzer); - doGenerate("GeneratedPlanPatterns", "PlanPatterns", false, patternGeneratorAnalyzer); + javaAstAnalyzer.analyze(); + + ExpressionTypeMappingGenerator expressionTypeMappingGenerator + = new ExpressionTypeMappingGenerator(javaAstAnalyzer); + expressionTypeMappingGenerator.generate(processingEnv); + + PlanTypeMappingGenerator planTypeMappingGenerator = new PlanTypeMappingGenerator(javaAstAnalyzer); + planTypeMappingGenerator.generate(processingEnv); + + PlanPatternGeneratorAnalyzer patternGeneratorAnalyzer = new PlanPatternGeneratorAnalyzer(javaAstAnalyzer); + generatePlanPatterns("GeneratedMemoPatterns", "MemoPatterns", true, patternGeneratorAnalyzer); + generatePlanPatterns("GeneratedPlanPatterns", "PlanPatterns", false, patternGeneratorAnalyzer); } catch (Throwable t) { String exceptionMsg = Throwables.getStackTraceAsString(t); processingEnv.getMessager().printMessage(Kind.ERROR, @@ -97,8 +107,12 @@ public boolean process(Set annotations, RoundEnvironment return false; } - private void doGenerate(String className, String parentClassName, boolean isMemoPattern, - PatternGeneratorAnalyzer patternGeneratorAnalyzer) throws IOException { + private void generateExpressionTypeMapping() { + + } + + private void generatePlanPatterns(String className, String parentClassName, boolean isMemoPattern, + PlanPatternGeneratorAnalyzer patternGeneratorAnalyzer) throws IOException { String generatePatternCode = patternGeneratorAnalyzer.generatePatterns( className, parentClassName, isMemoPattern); File generatePatternFile = new File(processingEnv.getFiler() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java index 72a315574952ac..08e639a924dad3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for PhysicalBinary. */ -public class PhysicalBinaryPatternGenerator extends PatternGenerator { +public class PhysicalBinaryPatternGenerator extends PlanPatternGenerator { - public PhysicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public PhysicalBinaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java index f75746b5142f20..27a94edacad2b9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for PhysicalLeaf. */ -public class PhysicalLeafPatternGenerator extends PatternGenerator { +public class PhysicalLeafPatternGenerator extends PlanPatternGenerator { - public PhysicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer, + public PhysicalLeafPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java index 4254e28ee4371f..f69de7e9d6a123 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for PhysicalUnary. */ -public class PhysicalUnaryPatternGenerator extends PatternGenerator { +public class PhysicalUnaryPatternGenerator extends PlanPatternGenerator { - public PhysicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public PhysicalUnaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGenerator.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGenerator.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGenerator.java index 75c950f8c82bd4..b94c9f489e628c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGenerator.java @@ -43,8 +43,8 @@ import java.util.stream.Collectors; /** used to generate pattern by plan. */ -public abstract class PatternGenerator { - protected final PatternGeneratorAnalyzer analyzer; +public abstract class PlanPatternGenerator { + protected final JavaAstAnalyzer analyzer; protected final ClassDeclaration opType; protected final Set parentClass; protected final List enumFieldPatternInfos; @@ -52,9 +52,9 @@ public abstract class PatternGenerator { protected final boolean isMemoPattern; /** constructor. */ - public PatternGenerator(PatternGeneratorAnalyzer analyzer, ClassDeclaration opType, + public PlanPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { - this.analyzer = analyzer; + this.analyzer = analyzer.getAnalyzer(); this.opType = opType; this.parentClass = parentClass; this.enumFieldPatternInfos = getEnumFieldPatternInfos(); @@ -76,8 +76,8 @@ public String getPatternMethodName() { } /** generate code by generators and analyzer. */ - public static String generateCode(String className, String parentClassName, List generators, - PatternGeneratorAnalyzer analyzer, boolean isMemoPattern) { + public static String generateCode(String className, String parentClassName, List generators, + PlanPatternGeneratorAnalyzer analyzer, boolean isMemoPattern) { String generateCode = "// Licensed to the Apache Software Foundation (ASF) under one\n" + "// or more contributor license agreements. See the NOTICE file\n" @@ -206,7 +206,7 @@ protected String childType() { } /** create generator by plan's type. */ - public static Optional create(PatternGeneratorAnalyzer analyzer, + public static Optional create(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalLeaf")) { return Optional.of(new LogicalLeafPatternGenerator(analyzer, opType, parentClass, isMemoPattern)); @@ -225,9 +225,9 @@ public static Optional create(PatternGeneratorAnalyzer analyze } } - private static String generateImports(List generators) { + private static String generateImports(List generators) { Set imports = new HashSet<>(); - for (PatternGenerator generator : generators) { + for (PlanPatternGenerator generator : generators) { imports.addAll(generator.getImports()); } List sortedImports = new ArrayList<>(imports); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java new file mode 100644 index 00000000000000..99d7c308dacf0d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern.generator; + +import org.apache.doris.nereids.pattern.generator.javaast.ClassDeclaration; + +import java.lang.reflect.Modifier; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * used to analyze plan class extends hierarchy and then generated pattern builder methods. + */ +public class PlanPatternGeneratorAnalyzer { + private final JavaAstAnalyzer analyzer; + + public PlanPatternGeneratorAnalyzer(JavaAstAnalyzer analyzer) { + this.analyzer = analyzer; + } + + public JavaAstAnalyzer getAnalyzer() { + return analyzer; + } + + /** generate pattern methods. */ + public String generatePatterns(String className, String parentClassName, boolean isMemoPattern) { + Map> planClassMap = analyzer.getParentClassMap().entrySet().stream() + .filter(kv -> kv.getValue().contains("org.apache.doris.nereids.trees.plans.Plan")) + .filter(kv -> !kv.getKey().name.equals("GroupPlan")) + .filter(kv -> !Modifier.isAbstract(kv.getKey().modifiers.mod) + && kv.getKey() instanceof ClassDeclaration) + .collect(Collectors.toMap(kv -> (ClassDeclaration) kv.getKey(), kv -> kv.getValue())); + + List generators = planClassMap.entrySet() + .stream() + .map(kv -> PlanPatternGenerator.create(this, kv.getKey(), kv.getValue(), isMemoPattern)) + .filter(Optional::isPresent) + .map(Optional::get) + .sorted((g1, g2) -> { + // logical first + if (g1.isLogical() != g2.isLogical()) { + return g1.isLogical() ? -1 : 1; + } + // leaf first + if (g1.childrenNum() != g2.childrenNum()) { + return g1.childrenNum() - g2.childrenNum(); + } + // string dict sort + return g1.opType.name.compareTo(g2.opType.name); + }) + .collect(Collectors.toList()); + + return PlanPatternGenerator.generateCode(className, parentClassName, generators, this, isMemoPattern); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java new file mode 100644 index 00000000000000..c3b6c765d49383 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern.generator; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.processing.ProcessingEnvironment; +import javax.tools.StandardLocation; + +/** PlanTypeMappingGenerator */ +public class PlanTypeMappingGenerator { + private final JavaAstAnalyzer analyzer; + + public PlanTypeMappingGenerator(JavaAstAnalyzer javaAstAnalyzer) { + this.analyzer = javaAstAnalyzer; + } + + public JavaAstAnalyzer getAnalyzer() { + return analyzer; + } + + /** generate */ + public void generate(ProcessingEnvironment processingEnv) throws IOException { + Set superPlans = findSuperPlan(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Map> parentNameMap = analyzer.getParentNameMap(); + String code = generateCode(childrenNameMap, parentNameMap, superPlans); + generateFile(processingEnv, code); + } + + private void generateFile(ProcessingEnvironment processingEnv, String code) throws IOException { + File generatePatternFile = new File(processingEnv.getFiler() + .getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern", + "GeneratedPlanRelations.java").toUri()); + if (generatePatternFile.exists()) { + generatePatternFile.delete(); + } + if (!generatePatternFile.getParentFile().exists()) { + generatePatternFile.getParentFile().mkdirs(); + } + + // bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term + try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) { + bufferedWriter.write(code); + } + } + + private Set findSuperPlan() { + Map> parentNameMap = analyzer.getParentNameMap(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Set superPlans = Sets.newLinkedHashSet(); + for (Entry> entry : childrenNameMap.entrySet()) { + String parentName = entry.getKey(); + Set childrenNames = entry.getValue(); + + if (parentName.startsWith("org.apache.doris.nereids.trees.plans.")) { + for (String childrenName : childrenNames) { + Set parentNames = parentNameMap.get(childrenName); + if (parentNames != null + && parentNames.contains("org.apache.doris.nereids.trees.plans.Plan")) { + superPlans.add(parentName); + break; + } + } + } + } + return superPlans; + } + + private String generateCode(Map> childrenNameMap, + Map> parentNameMap, Set superPlans) { + String generateCode + = "// Licensed to the Apache Software Foundation (ASF) under one\n" + + "// or more contributor license agreements. See the NOTICE file\n" + + "// distributed with this work for additional information\n" + + "// regarding copyright ownership. The ASF licenses this file\n" + + "// to you under the Apache License, Version 2.0 (the\n" + + "// \"License\"); you may not use this file except in compliance\n" + + "// with the License. You may obtain a copy of the License at\n" + + "//\n" + + "// http://www.apache.org/licenses/LICENSE-2.0\n" + + "//\n" + + "// Unless required by applicable law or agreed to in writing,\n" + + "// software distributed under the License is distributed on an\n" + + "// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n" + + "// KIND, either express or implied. See the License for the\n" + + "// specific language governing permissions and limitations\n" + + "// under the License.\n" + + "\n" + + "package org.apache.doris.nereids.pattern;\n" + + "\n" + + "import org.apache.doris.nereids.trees.plans.Plan;\n" + + "\n" + + "import com.google.common.collect.ImmutableMap;\n" + + "import com.google.common.collect.ImmutableSet;\n" + + "\n" + + "import java.util.Map;\n" + + "import java.util.Set;\n" + + "\n"; + generateCode += "/** GeneratedPlanRelations */\npublic class GeneratedPlanRelations {\n"; + String childrenClassesGenericType = ", Set>>"; + generateCode += + " public static final Map" + childrenClassesGenericType + " CHILDREN_CLASS_MAP;\n\n"; + generateCode += + " static {\n" + + " ImmutableMap.Builder" + childrenClassesGenericType + " childrenClassesBuilder\n" + + " = ImmutableMap.builderWithExpectedSize(" + childrenNameMap.size() + ");\n"; + + for (String superPlan : superPlans) { + Set childrenClasseSet = childrenNameMap.get(superPlan) + .stream() + .filter(childClass -> parentNameMap.get(childClass) + .contains("org.apache.doris.nereids.trees.plans.Plan") + ) + .collect(Collectors.toSet()); + + List childrenClasses = Lists.newArrayList(childrenClasseSet); + Collections.sort(childrenClasses, Comparator.naturalOrder()); + + String childClassesString = childrenClasses.stream() + .map(childClass -> " " + childClass + ".class") + .collect(Collectors.joining(",\n")); + generateCode += " childrenClassesBuilder.put(\n " + superPlan + + ".class,\n ImmutableSet.>of(\n" + childClassesString + + "\n )\n );\n\n"; + } + + generateCode += " CHILDREN_CLASS_MAP = childrenClassesBuilder.build();\n"; + + return generateCode + " }\n}\n"; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java index 4efafe3af90f50..fb6e54e38a8545 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java @@ -195,9 +195,20 @@ private boolean isVisibleColumn(Slot slot) { @Override public PhysicalFilter visitPhysicalFilter(PhysicalFilter filter, CascadesContext context) { filter.child().accept(this, context); - boolean visibleFilter = filter.getExpressions().stream() - .flatMap(expression -> expression.getInputSlots().stream()) - .anyMatch(slot -> isVisibleColumn(slot)); + + boolean visibleFilter = false; + + for (Expression expr : filter.getExpressions()) { + for (Slot inputSlot : expr.getInputSlots()) { + if (isVisibleColumn(inputSlot)) { + visibleFilter = true; + break; + } + } + if (visibleFilter) { + break; + } + } if (visibleFilter) { // skip filters like: __DORIS_DELETE_SIGN__ = 0 context.getRuntimeFilterContext().addEffectiveSrcNode(filter, RuntimeFilterContext.EffectiveSrcType.NATIVE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java index e73039e9237980..561e09ed404ad2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java @@ -26,6 +26,8 @@ import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; +import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; @@ -69,7 +71,10 @@ public Plan visitPhysicalFilter(PhysicalFilter filter, CascadesC @Override public Plan visit(Plan plan, CascadesContext context) { - plan.children().forEach(child -> child.accept(this, context)); + for (Plan child : plan.children()) { + child.accept(this, context); + } + Optional opt = checkAllSlotFromChildren(plan); if (opt.isPresent()) { List childrenOutput = plan.children().stream().flatMap(p -> p.getOutput().stream()).collect( @@ -93,8 +98,7 @@ public static Optional checkAllSlotFromChildren(Plan plan) { if (plan instanceof Aggregate) { return Optional.empty(); } - Set childOutputSet = plan.children().stream().flatMap(child -> child.getOutputSet().stream()) - .collect(Collectors.toSet()); + Set childOutputSet = Utils.fastToImmutableSet(PlanUtils.fastGetChildrenOutputs(plan.children())); Set inputSlots = plan.getInputSlots(); for (Slot slot : inputSlots) { if (slot.getName().startsWith("mv") || slot instanceof SlotNotFromChildren) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java index d7b4b3b1c9f34d..c7e6030e13794c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import java.util.HashSet; import java.util.Map; @@ -196,12 +197,23 @@ public boolean containsAnySub(Set slotSet) { } public void removeNotContain(Set slotSet) { - slots = slots.stream() - .filter(slotSet::contains) - .collect(Collectors.toSet()); - slotSets = slotSets.stream() - .filter(slotSet::containsAll) - .collect(Collectors.toSet()); + if (!slotSet.isEmpty()) { + Set newSlots = Sets.newLinkedHashSetWithExpectedSize(slots.size()); + for (Slot slot : slots) { + if (slotSet.contains(slot)) { + newSlots.add(slot); + } + } + this.slots = newSlots; + + Set> newSlotSets = Sets.newLinkedHashSetWithExpectedSize(slots.size()); + for (ImmutableSet set : slotSets) { + if (slotSet.containsAll(set)) { + newSlotSets.add(set); + } + } + this.slotSets = newSlotSets; + } } public void add(Slot slot) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java index 07d2882894288c..ea8a6e29dde1f0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java @@ -19,7 +19,6 @@ import org.apache.doris.common.Id; import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import com.google.common.base.Supplier; @@ -62,21 +61,40 @@ public LogicalProperties(Supplier> outputSupplier, this.outputSupplier = Suppliers.memoize( Objects.requireNonNull(outputSupplier, "outputSupplier can not be null") ); - this.outputExprIdsSupplier = Suppliers.memoize( - () -> this.outputSupplier.get().stream().map(NamedExpression::getExprId).map(Id.class::cast) - .collect(ImmutableList.toImmutableList()) - ); - this.outputSetSupplier = Suppliers.memoize( - () -> ImmutableSet.copyOf(this.outputSupplier.get()) - ); - this.outputMapSupplier = Suppliers.memoize( - () -> this.outputSetSupplier.get().stream().collect(ImmutableMap.toImmutableMap(s -> s, s -> s)) - ); - this.outputExprIdSetSupplier = Suppliers.memoize( - () -> this.outputSupplier.get().stream() - .map(NamedExpression::getExprId) - .collect(ImmutableSet.toImmutableSet()) - ); + this.outputExprIdsSupplier = Suppliers.memoize(() -> { + List output = this.outputSupplier.get(); + ImmutableList.Builder exprIdSet + = ImmutableList.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIdSet.add(slot.getExprId()); + } + return exprIdSet.build(); + }); + this.outputSetSupplier = Suppliers.memoize(() -> { + List output = outputSupplier.get(); + ImmutableSet.Builder slots = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + slots.add(slot); + } + return slots.build(); + }); + this.outputMapSupplier = Suppliers.memoize(() -> { + Set slots = outputSetSupplier.get(); + ImmutableMap.Builder map = ImmutableMap.builderWithExpectedSize(slots.size()); + for (Slot slot : slots) { + map.put(slot, slot); + } + return map.build(); + }); + this.outputExprIdSetSupplier = Suppliers.memoize(() -> { + List output = this.outputSupplier.get(); + ImmutableSet.Builder exprIdSet + = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIdSet.add(slot.getExprId()); + } + return exprIdSet.build(); + }); this.fdSupplier = Suppliers.memoize( Objects.requireNonNull(fdSupplier, "FunctionalDependencies can not be null") ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java index a9b4591ad4a0a8..207dd6458c9202 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java @@ -24,8 +24,8 @@ import org.apache.doris.nereids.rules.RuleType.RuleTypeClass; import org.apache.doris.nereids.trees.plans.Plan; +import java.util.BitSet; import java.util.List; -import java.util.Set; /** * Abstract class for all rules. @@ -79,8 +79,8 @@ public void acceptPlan(Plan plan) { /** * Filter out already applied rules and rules that are not matched on root node. */ - public boolean isInvalid(Set disableRules, GroupExpression groupExpression) { - return disableRules.contains(this.getRuleType().type()) + public boolean isInvalid(BitSet disableRules, GroupExpression groupExpression) { + return disableRules.get(this.getRuleType().type()) || !groupExpression.notApplied(this) || !this.getPattern().matchRoot(groupExpression.getPlan()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 408b0d7355e762..d317b1e8738521 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -49,6 +49,7 @@ import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterAggregateRule; import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterJoinRule; import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule; +import org.apache.doris.nereids.rules.expression.ExpressionOptimization; import org.apache.doris.nereids.rules.implementation.AggregateStrategies; import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows; import org.apache.doris.nereids.rules.implementation.LogicalCTEAnchorToPhysicalCTEAnchor; @@ -153,7 +154,8 @@ public class RuleSet { new MergeLimits(), new PushDownAliasThroughJoin(), new PushDownFilterThroughWindow(), - new PushDownFilterThroughPartitionTopN() + new PushDownFilterThroughPartitionTopN(), + new ExpressionOptimization() ); public static final List IMPLEMENTATION_RULES = planRuleFactories() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java index 86a70d35ccc087..5543341ae277d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java @@ -46,21 +46,30 @@ public List buildRules() { RuleType.ADJUST_NULLABLE_FOR_AGGREGATE_SLOT.build( logicalAggregate() .then(agg -> { - List output = agg.getOutputExpressions().stream() - .map(ne -> ((NamedExpression) FunctionReplacer.INSTANCE.replace(ne, - agg.getGroupByExpressions().isEmpty()))) - .collect(ImmutableList.toImmutableList()); - return agg.withAggOutput(output); + List outputExprs = agg.getOutputExpressions(); + boolean noGroupBy = agg.getGroupByExpressions().isEmpty(); + ImmutableList.Builder newOutput + = ImmutableList.builderWithExpectedSize(outputExprs.size()); + for (NamedExpression ne : outputExprs) { + NamedExpression newExpr = + ((NamedExpression) FunctionReplacer.INSTANCE.replace(ne, noGroupBy)); + newOutput.add(newExpr); + } + return agg.withAggOutput(newOutput.build()); }) ), RuleType.ADJUST_NULLABLE_FOR_HAVING_SLOT.build( logicalHaving(logicalAggregate()) .then(having -> { - Set newConjuncts = having.getConjuncts().stream() - .map(ne -> FunctionReplacer.INSTANCE.replace(ne, - having.child().getGroupByExpressions().isEmpty())) - .collect(ImmutableSet.toImmutableSet()); - return new LogicalHaving<>(newConjuncts, having.child()); + Set conjuncts = having.getConjuncts(); + boolean noGroupBy = having.child().getGroupByExpressions().isEmpty(); + ImmutableSet.Builder newConjuncts + = ImmutableSet.builderWithExpectedSize(conjuncts.size()); + for (Expression expr : conjuncts) { + Expression newExpr = FunctionReplacer.INSTANCE.replace(expr, noGroupBy); + newConjuncts.add(newExpr); + } + return new LogicalHaving<>(newConjuncts.build(), having.child()); }) ) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index c2c7f5815d9288..6211f493eaf4e4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -333,10 +333,11 @@ private LogicalPlan bindInlineTable(MatchingContext ctx) { List relations = Lists.newArrayListWithCapacity(logicalInlineTable.getConstantExprsList().size()); for (int i = 0; i < logicalInlineTable.getConstantExprsList().size(); i++) { - if (logicalInlineTable.getConstantExprsList().get(i).stream() - .anyMatch(DefaultValueSlot.class::isInstance)) { - throw new AnalysisException("Default expression" - + " can't exist in SELECT statement at row " + (i + 1)); + for (NamedExpression constantExpr : logicalInlineTable.getConstantExprsList().get(i)) { + if (constantExpr instanceof DefaultValueSlot) { + throw new AnalysisException("Default expression" + + " can't exist in SELECT statement at row " + (i + 1)); + } } relations.add(new UnboundOneRowRelation(StatementScopeIdGenerator.newRelationId(), logicalInlineTable.getConstantExprsList().get(i))); @@ -590,7 +591,7 @@ private Plan bindFilter(MatchingContext> ctx) { SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer( filter, cascadesContext, filter.children(), true, true); ImmutableSet.Builder boundConjuncts = ImmutableSet.builderWithExpectedSize( - filter.getConjuncts().size() * 2); + filter.getConjuncts().size()); for (Expression conjunct : filter.getConjuncts()) { Expression boundConjunct = analyzer.analyze(conjunct); boundConjunct = TypeCoercionUtils.castIfNotSameType(boundConjunct, BooleanType.INSTANCE); @@ -828,15 +829,22 @@ private void checkIfOutputAliasNameDuplicatedForGroupBy(Collection e if (output.stream().noneMatch(Alias.class::isInstance)) { return; } - List aliasList = output.stream().filter(Alias.class::isInstance) - .map(Alias.class::cast).collect(Collectors.toList()); + List aliasList = ExpressionUtils.filter(output, Alias.class); List exprAliasList = ExpressionUtils.collectAll(expressions, NamedExpression.class::isInstance); - boolean isGroupByContainAlias = exprAliasList.stream().anyMatch(ne -> - aliasList.stream().anyMatch(alias -> !alias.getExprId().equals(ne.getExprId()) - && alias.getName().equals(ne.getName()))); + boolean isGroupByContainAlias = false; + for (NamedExpression ne : exprAliasList) { + for (Alias alias : aliasList) { + if (!alias.getExprId().equals(ne.getExprId()) && alias.getName().equalsIgnoreCase(ne.getName())) { + isGroupByContainAlias = true; + } + } + if (isGroupByContainAlias) { + break; + } + } if (isGroupByContainAlias && ConnectContext.get() != null diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java index 114e4c1d12051b..714e6e48794ba0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -34,7 +33,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** * Rule to bind slot with path in query plan. @@ -60,21 +58,18 @@ public List buildRules() { Set pathsSlots = ctx.statementContext.getAllPathsSlots(); // With new logical properties that contains new slots with paths StatementContext stmtCtx = ConnectContext.get().getStatementContext(); - List olapScanPathSlots = pathsSlots.stream().filter( - slot -> { - Preconditions.checkNotNull(stmtCtx.getRelationBySlot(slot), - "[Not implemented] Slot not found in relation map, slot ", slot); - return stmtCtx.getRelationBySlot(slot).getRelationId() - == logicalOlapScan.getRelationId(); - }).collect( - Collectors.toList()); - List newExprs = olapScanPathSlots.stream() - .map(SlotReference.class::cast) - .map(slotReference -> - new Alias(slotReference.getExprId(), - stmtCtx.getOriginalExpr(slotReference), slotReference.getName())) - .collect( - Collectors.toList()); + ImmutableList.Builder newExprsBuilder + = ImmutableList.builderWithExpectedSize(pathsSlots.size()); + for (SlotReference slot : pathsSlots) { + Preconditions.checkNotNull(stmtCtx.getRelationBySlot(slot), + "[Not implemented] Slot not found in relation map, slot ", slot); + if (stmtCtx.getRelationBySlot(slot).getRelationId() + == logicalOlapScan.getRelationId()) { + newExprsBuilder.add(new Alias(slot.getExprId(), + stmtCtx.getOriginalExpr(slot), slot.getName())); + } + } + ImmutableList newExprs = newExprsBuilder.build(); if (newExprs.isEmpty()) { return ctx.root; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java index 754d3efa583fa5..92052bc85ed100 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java @@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; +import com.google.common.collect.ImmutableSet; import org.apache.commons.lang3.StringUtils; import java.util.List; @@ -69,42 +70,43 @@ public Rule build() { } private void checkUnexpectedExpression(Plan plan) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(SubqueryExpr.class::isInstance))) { - throw new AnalysisException("Subquery is not allowed in " + plan.getType()); - } - if (!(plan instanceof Generate)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(TableGeneratingFunction.class::isInstance))) { - throw new AnalysisException("table generating function is not allowed in " + plan.getType()); - } + boolean isGenerate = plan instanceof Generate; + boolean isAgg = plan instanceof LogicalAggregate; + boolean isWindow = plan instanceof LogicalWindow; + boolean notAggAndWindow = !isAgg && !isWindow; + + for (Expression expression : plan.getExpressions()) { + expression.foreach(expr -> { + if (expr instanceof SubqueryExpr) { + throw new AnalysisException("Subquery is not allowed in " + plan.getType()); + } else if (!isGenerate && expr instanceof TableGeneratingFunction) { + throw new AnalysisException("table generating function is not allowed in " + plan.getType()); + } else if (notAggAndWindow && expr instanceof AggregateFunction) { + throw new AnalysisException("aggregate function is not allowed in " + plan.getType()); + } else if (!isAgg && expr instanceof GroupingScalarFunction) { + throw new AnalysisException("grouping scalar function is not allowed in " + plan.getType()); + } else if (!isWindow && expr instanceof WindowExpression) { + throw new AnalysisException("analytic function is not allowed in " + plan.getType()); + } + }); } - if (!(plan instanceof LogicalAggregate || plan instanceof LogicalWindow)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(AggregateFunction.class::isInstance))) { - throw new AnalysisException("aggregate function is not allowed in " + plan.getType()); + } + + private void checkAllSlotReferenceFromChildren(Plan plan) { + Set inputSlots = plan.getInputSlots(); + Set childrenOutput = plan.getChildrenOutputExprIdSet(); + + ImmutableSet.Builder notFromChildrenBuilder = ImmutableSet.builderWithExpectedSize(inputSlots.size()); + for (Slot inputSlot : inputSlots) { + if (!childrenOutput.contains(inputSlot.getExprId())) { + notFromChildrenBuilder.add(inputSlot); } } - if (!(plan instanceof LogicalAggregate)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(GroupingScalarFunction.class::isInstance))) { - throw new AnalysisException("grouping scalar function is not allowed in " + plan.getType()); - } + Set notFromChildren = notFromChildrenBuilder.build(); + if (notFromChildren.isEmpty()) { + return; } - if (!(plan instanceof LogicalWindow)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(WindowExpression.class::isInstance))) { - throw new AnalysisException("analytic function is not allowed in " + plan.getType()); - } - } - } - private void checkAllSlotReferenceFromChildren(Plan plan) { - Set notFromChildren = plan.getExpressions().stream() - .flatMap(expr -> expr.getInputSlots().stream()) - .collect(Collectors.toSet()); - Set childrenOutput = plan.children().stream() - .flatMap(child -> child.getOutput().stream()) - .map(NamedExpression::getExprId) - .collect(Collectors.toSet()); - notFromChildren = notFromChildren.stream() - .filter(s -> !childrenOutput.contains(s.getExprId())) - .collect(Collectors.toSet()); notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, childrenOutput); if (!notFromChildren.isEmpty()) { if (plan.arity() != 0 && plan.child(0) instanceof LogicalAggregate) { @@ -181,17 +183,18 @@ private void checkMetricTypeIsUsedCorrectly(Plan plan) { } private void checkMatchIsUsedCorrectly(Plan plan) { - if (plan.getExpressions().stream().anyMatch( - expression -> expression instanceof Match)) { - if (plan instanceof LogicalFilter && (plan.child(0) instanceof LogicalOlapScan - || plan.child(0) instanceof LogicalDeferMaterializeOlapScan - || plan.child(0) instanceof LogicalProject + for (Expression expression : plan.getExpressions()) { + if (expression instanceof Match) { + if (plan instanceof LogicalFilter && (plan.child(0) instanceof LogicalOlapScan + || plan.child(0) instanceof LogicalDeferMaterializeOlapScan + || plan.child(0) instanceof LogicalProject && ((LogicalProject) plan.child(0)).hasPushedDownToProjectionFunctions())) { - return; - } else { - throw new AnalysisException(String.format( - "Not support match in %s in plan: %s, only support in olapScan filter", - plan.child(0), plan)); + return; + } else { + throw new AnalysisException(String.format( + "Not support match in %s in plan: %s, only support in olapScan filter", + plan.child(0), plan)); + } } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index 5a310d697ac798..64fd14019bbbc9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -45,7 +45,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; /** @@ -117,14 +116,16 @@ private void checkUnexpectedExpressions(Plan plan) { if (unexpectedExpressionTypes.isEmpty()) { return; } - plan.getExpressions().forEach(c -> c.foreachUp(e -> { - for (Class type : unexpectedExpressionTypes) { - if (type.isInstance(e)) { - throw new AnalysisException(plan.getType() + " can not contains " - + type.getSimpleName() + " expression: " + ((Expression) e).toSql()); + for (Expression expr : plan.getExpressions()) { + expr.foreachUp(e -> { + for (Class type : unexpectedExpressionTypes) { + if (type.isInstance(e)) { + throw new AnalysisException(plan.getType() + " can not contains " + + type.getSimpleName() + " expression: " + ((Expression) e).toSql()); + } } - } - })); + }); + } } private void checkExpressionInputTypes(Plan plan) { @@ -157,20 +158,21 @@ private void checkAggregate(LogicalAggregate aggregate) { break; } } - long distinctFunctionNum = aggregateFunctions.stream() - .filter(AggregateFunction::isDistinct) - .count(); + + long distinctFunctionNum = 0; + for (AggregateFunction aggregateFunction : aggregateFunctions) { + distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0; + } if (distinctMultiColumns && distinctFunctionNum > 1) { throw new AnalysisException( "The query contains multi count distinct or sum distinct, each can't have multi columns"); } - Optional expr = aggregate.getGroupByExpressions().stream() - .filter(expression -> expression.containsType(AggregateFunction.class)).findFirst(); - if (expr.isPresent()) { - throw new AnalysisException( - "GROUP BY expression must not contain aggregate functions: " - + expr.get().toSql()); + for (Expression expr : aggregate.getGroupByExpressions()) { + if (expr.anyMatch(AggregateFunction.class::isInstance)) { + throw new AnalysisException( + "GROUP BY expression must not contain aggregate functions: " + expr.toSql()); + } } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java index 4408e64487cfb2..7f7e229da319f4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java @@ -64,7 +64,7 @@ public Rule build() { // because we rely on expression matching to replace subtree that same as group by expr in output // if we do constant folding before normalize aggregate, the subtree will change and matching fail // such as: select a + 1 + 2 + 3, sum(b) from t group by a + 1 + 2 - Expression foldExpression = FoldConstantRule.INSTANCE.rewrite(expression, context); + Expression foldExpression = FoldConstantRule.evaluate(expression, context); if (!foldExpression.isConstant()) { slotGroupByExprs.add(expression); } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index f4c1b428d4147b..56ca1b3a8c4822 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -297,7 +297,7 @@ public Expression visitUnboundFunction(UnboundFunction unboundFunction, Expressi if (unboundFunction.isHighOrder()) { unboundFunction = bindHighOrderFunction(unboundFunction, context); } else { - unboundFunction = (UnboundFunction) rewriteChildren(this, unboundFunction, context); + unboundFunction = (UnboundFunction) super.visit(unboundFunction, context); } // bind function diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index d6c783bbe946d8..82468978a8069a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -316,13 +316,18 @@ private Plan createPlan(Resolver resolver, Aggregate aggregate, } private boolean checkSort(LogicalSort logicalSort) { - return logicalSort.getOrderKeys().stream() - .map(OrderKey::getExpr) - .map(Expression::getInputSlots) - .flatMap(Set::stream) - .anyMatch(s -> !logicalSort.child().getOutputSet().contains(s)) - || logicalSort.getOrderKeys().stream() - .map(OrderKey::getExpr) - .anyMatch(e -> e.containsType(AggregateFunction.class)); + Plan child = logicalSort.child(); + for (OrderKey orderKey : logicalSort.getOrderKeys()) { + Expression expr = orderKey.getExpr(); + if (expr.anyMatch(AggregateFunction.class::isInstance)) { + return true; + } + for (Slot inputSlot : expr.getInputSlots()) { + if (!child.getOutputSet().contains(inputSlot)) { + return true; + } + } + } + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 6f668611b9c4ba..7f6df51248e34a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -139,8 +140,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional groupingByExprs = - ImmutableSet.copyOf(aggregate.getGroupByExpressions()); + Set groupingByExprs = Utils.fastToImmutableSet(aggregate.getGroupByExpressions()); // collect all trivial-agg List aggregateOutput = aggregate.getOutputExpressions(); @@ -149,27 +149,31 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional> categorizedNoDistinctAggsChildren = aggFuncs.stream() + Map> categorizedNoDistinctAggsChildren = aggFuncs.stream() .filter(aggFunc -> !aggFunc.isDistinct()) .flatMap(agg -> agg.children().stream()) .collect(Collectors.groupingBy( child -> child.containsType(SubqueryExpr.class, WindowExpression.class), - Collectors.toSet())); + ImmutableSet.toImmutableSet())); // split distinct agg child as two parts // TRUE part 1: need push down itself, if it is NOT SlotReference or Literal // FALSE part 2: need push down its input slots, if it is SlotReference or Literal - Map> categorizedDistinctAggsChildren = aggFuncs.stream() + Map> categorizedDistinctAggsChildren = aggFuncs.stream() .filter(AggregateFunction::isDistinct) .flatMap(agg -> agg.children().stream()) - .collect(Collectors.groupingBy(child -> !(child instanceof SlotReference), Collectors.toSet())); + .collect( + Collectors.groupingBy( + child -> !(child instanceof SlotReference), + ImmutableSet.toImmutableSet()) + ); Set needPushSelf = Sets.union( - categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()), - categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>())); + categorizedNoDistinctAggsChildren.getOrDefault(true, ImmutableSet.of()), + categorizedDistinctAggsChildren.getOrDefault(true, ImmutableSet.of())); Set needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union( - categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()), - categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>()))); + categorizedNoDistinctAggsChildren.getOrDefault(false, ImmutableSet.of()), + categorizedDistinctAggsChildren.getOrDefault(false, ImmutableSet.of()))); Set existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); @@ -194,8 +198,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional(ImmutableList.copyOf(bottomProjects), - aggregate.child()); + bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child()); } else { bottomPlan = aggregate.child(); } @@ -230,13 +233,17 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional normalizedAggOutput = ImmutableList.builder() - .addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator()) - .addAll(normalizedAggFuncsToSlotContext - .pushDownToNamedExpression(normalizedAggFuncs)) - .build(); + ImmutableList.Builder normalizedAggOutputBuilder + = ImmutableList.builderWithExpectedSize(groupingByExprs.size() + normalizedAggFuncs.size()); + for (NamedExpression pushedGroupByExpr : pushedGroupByExprs) { + normalizedAggOutputBuilder.add(pushedGroupByExpr.toSlot()); + } + normalizedAggOutputBuilder.addAll( + normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs) + ); // create new agg node + ImmutableList normalizedAggOutput = normalizedAggOutputBuilder.build(); LogicalAggregate newAggregate = aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java index b52e2f0218d04e..cd53086f96625d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -35,7 +36,6 @@ import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; /** * replace. @@ -47,52 +47,50 @@ public List buildRules() { .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( logicalSort(logicalProject()).then(sort -> { LogicalProject project = sort.child(); - Map sMap = Maps.newHashMap(); - project.getProjects().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .forEach(p -> sMap.put(p.child(), p.toSlot())); + Map sMap = buildOutputAliasMap(project.getProjects()); return replaceSortExpression(sort, sMap); }) )) .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( logicalSort(logicalAggregate()).then(sort -> { LogicalAggregate aggregate = sort.child(); - Map sMap = Maps.newHashMap(); - aggregate.getOutputExpressions().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .forEach(p -> sMap.put(p.child(), p.toSlot())); + Map sMap = buildOutputAliasMap(aggregate.getOutputExpressions()); return replaceSortExpression(sort, sMap); }) )).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( logicalSort(logicalHaving(logicalAggregate())).then(sort -> { LogicalAggregate aggregate = sort.child().child(); - Map sMap = Maps.newHashMap(); - aggregate.getOutputExpressions().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .forEach(p -> sMap.put(p.child(), p.toSlot())); + Map sMap = buildOutputAliasMap(aggregate.getOutputExpressions()); return replaceSortExpression(sort, sMap); }) )) .build(); } + private Map buildOutputAliasMap(List output) { + Map sMap = Maps.newHashMapWithExpectedSize(output.size()); + for (NamedExpression expr : output) { + if (expr instanceof Alias) { + Alias alias = (Alias) expr; + sMap.put(alias.child(), alias.toSlot()); + } + } + return sMap; + } + private LogicalPlan replaceSortExpression(LogicalSort sort, Map sMap) { List orderKeys = sort.getOrderKeys(); - AtomicBoolean changed = new AtomicBoolean(false); - List newKeys = orderKeys.stream().map(k -> { + + boolean changed = false; + ImmutableList.Builder newKeys = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey k : orderKeys) { Expression newExpr = ExpressionUtils.replace(k.getExpr(), sMap); if (newExpr != k.getExpr()) { - changed.set(true); + changed = true; } - return new OrderKey(newExpr, k.isAsc(), k.isNullFirst()); - }).collect(ImmutableList.toImmutableList()); - if (changed.get()) { - return new LogicalSort<>(newKeys, sort.child()); - } else { - return sort; + newKeys.add(new OrderKey(newExpr, k.isAsc(), k.isNullFirst())); } + + return changed ? new LogicalSort<>(newKeys.build(), sort.child()) : sort; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index b0f78be54a24aa..cfc5b2ba24a11b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot; import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.Alias; @@ -51,6 +52,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -77,24 +79,21 @@ public List buildRules() { RuleType.FILTER_SUBQUERY_TO_APPLY.build( logicalFilter().thenApply(ctx -> { LogicalFilter filter = ctx.root; - ImmutableList> subqueryExprsList = filter.getConjuncts().stream() - .>map(e -> e.collect(SubqueryToApply::canConvertToSupply)) - .collect(ImmutableList.toImmutableList()); - if (subqueryExprsList.stream() - .flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) { + + Set conjuncts = filter.getConjuncts(); + CollectSubquerys collectSubquerys = collectSubquerys(conjuncts); + if (!collectSubquerys.hasSubquery) { return filter; } - ImmutableList shouldOutputMarkJoinSlot = - filter.getConjuncts().stream() - .map(expr -> !(expr instanceof SubqueryExpr) - && expr.containsType(SubqueryExpr.class)) - .collect(ImmutableList.toImmutableList()); - List oldConjuncts = ImmutableList.copyOf(filter.getConjuncts()); - ImmutableList.Builder newConjuncts = new ImmutableList.Builder<>(); + List shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot(conjuncts); + + List oldConjuncts = Utils.fastToImmutableList(conjuncts); + ImmutableSet.Builder newConjuncts = new ImmutableSet.Builder<>(); LogicalPlan applyPlan = null; LogicalPlan tmpPlan = (LogicalPlan) filter.child(); + List> subqueryExprsList = collectSubquerys.subqueies; // Subquery traversal with the conjunct of and as the granularity. for (int i = 0; i < subqueryExprsList.size(); ++i) { Set subqueryExprs = subqueryExprsList.get(i); @@ -119,9 +118,11 @@ public List buildRules() { * if it's semi join with non-null mark slot * we can safely change the mark conjunct to hash conjunct */ + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class) ? ExpressionUtils.canInferNotNullForMarkSlot( - TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null)) + TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, + rewriteContext), rewriteContext) : false; applyPlan = subqueryToApply(subqueryExprs.stream() @@ -132,21 +133,22 @@ public List buildRules() { tmpPlan = applyPlan; newConjuncts.add(conjunct); } - Set conjuncts = ImmutableSet.copyOf(newConjuncts.build()); - Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan); + Plan newFilter = new LogicalFilter<>(newConjuncts.build(), applyPlan); return new LogicalProject<>(filter.getOutput().stream().collect(ImmutableList.toImmutableList()), newFilter); }) ), RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> { LogicalProject project = ctx.root; - ImmutableList> subqueryExprsList = project.getProjects().stream() - .>map(e -> e.collect(SubqueryToApply::canConvertToSupply)) - .collect(ImmutableList.toImmutableList()); - if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) { + + List projects = project.getProjects(); + CollectSubquerys collectSubquerys = collectSubquerys(projects); + if (!collectSubquerys.hasSubquery) { return project; } - List oldProjects = ImmutableList.copyOf(project.getProjects()); + + List> subqueryExprsList = collectSubquerys.subqueies; + List oldProjects = ImmutableList.copyOf(projects); ImmutableList.Builder newProjects = new ImmutableList.Builder<>(); LogicalPlan childPlan = (LogicalPlan) project.child(); LogicalPlan applyPlan; @@ -166,7 +168,7 @@ public List buildRules() { replaceSubquery.replace(oldProjects.get(i), context); applyPlan = subqueryToApply( - subqueryExprs.stream().collect(ImmutableList.toImmutableList()), + Utils.fastToImmutableList(subqueryExprs), childPlan, context.getSubqueryToMarkJoinSlot(), ctx.cascadesContext, Optional.of(newProject), true, false); @@ -240,9 +242,11 @@ public List buildRules() { * if it's semi join with non-null mark slot * we can safely change the mark conjunct to hash conjunct */ + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class) ? ExpressionUtils.canInferNotNullForMarkSlot( - TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null)) + TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, rewriteContext), + rewriteContext) : false; applyPlan = subqueryToApply( subqueryExprs.stream().collect(ImmutableList.toImmutableList()), @@ -566,4 +570,33 @@ private boolean shouldOutputMarkJoinSlot(Expression expr, SearchState searchStat } return false; } + + private List shouldOutputMarkJoinSlot(Collection conjuncts) { + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(conjuncts.size()); + for (Expression expr : conjuncts) { + result.add(!(expr instanceof SubqueryExpr) && expr.containsType(SubqueryExpr.class)); + } + return result.build(); + } + + private CollectSubquerys collectSubquerys(Collection exprs) { + boolean hasSubqueryExpr = false; + ImmutableList.Builder> subqueryExprsListBuilder = ImmutableList.builder(); + for (Expression expression : exprs) { + Set subqueries = expression.collect(SubqueryToApply::canConvertToSupply); + hasSubqueryExpr |= !subqueries.isEmpty(); + subqueryExprsListBuilder.add(subqueries); + } + return new CollectSubquerys(subqueryExprsListBuilder.build(), hasSubqueryExpr); + } + + private static class CollectSubquerys { + final List> subqueies; + final boolean hasSubquery; + + public CollectSubquerys(List> subqueies, boolean hasSubquery) { + this.subqueies = subqueies; + this.hasSubquery = hasSubquery; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java index 66852444cc0e2f..90e0f8ed1db8bb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java @@ -122,10 +122,11 @@ protected List checkQuery(Plan queryPlan, CascadesContext cascadesCo // TODO Just Check query queryPlan firstly, support multi later. StructInfo queryStructInfo = queryStructInfos.get(0); if (!checkPattern(queryStructInfo)) { - cascadesContext.getMaterializationContexts().forEach(ctx -> - ctx.recordFailReason(queryStructInfo, "Query struct info is invalid", - () -> String.format("queryPlan is %s", queryPlan.treeString()) - )); + for (MaterializationContext ctx : cascadesContext.getMaterializationContexts()) { + ctx.recordFailReason(queryStructInfo, "Query struct info is invalid", + () -> String.format("queryPlan is %s", queryPlan.treeString()) + ); + } return validQueryStructInfos; } validQueryStructInfos.add(queryStructInfo); @@ -228,7 +229,7 @@ protected List doRewrite(StructInfo queryStructInfo, CascadesContext casca viewToQuerySlotMapping)); continue; } - rewrittenPlan = new LogicalFilter<>(Sets.newHashSet(rewriteCompensatePredicates), mvScan); + rewrittenPlan = new LogicalFilter<>(Sets.newLinkedHashSet(rewriteCompensatePredicates), mvScan); } // Rewrite query by view rewrittenPlan = rewriteQueryByView(matchMode, queryStructInfo, viewStructInfo, viewToQuerySlotMapping, @@ -293,7 +294,7 @@ protected Plan rewriteByRules(CascadesContext cascadesContext, Plan rewrittenPla if (originOutputs.size() != rewrittenPlan.getOutput().size()) { return null; } - Map originSlotToRewrittenExprId = Maps.newHashMap(); + Map originSlotToRewrittenExprId = Maps.newLinkedHashMap(); for (int i = 0; i < originOutputs.size(); i++) { originSlotToRewrittenExprId.put(originOutputs.get(i), rewrittenPlan.getOutput().get(i).getExprId()); } @@ -305,7 +306,7 @@ protected Plan rewriteByRules(CascadesContext cascadesContext, Plan rewrittenPla rewrittenPlan = rewrittenPlanContext.getRewritePlan(); // for get right nullable after rewritten, we need this map - Map exprIdToNewRewrittenSlot = Maps.newHashMap(); + Map exprIdToNewRewrittenSlot = Maps.newLinkedHashMap(); for (Slot slot : rewrittenPlan.getOutput()) { exprIdToNewRewrittenSlot.put(slot.getExprId(), slot); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/InitMaterializationContextHook.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/InitMaterializationContextHook.java index 914b213361348a..2d5c9bf377b723 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/InitMaterializationContextHook.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/InitMaterializationContextHook.java @@ -79,7 +79,7 @@ public void initMaterializationContext(CascadesContext cascadesContext) { if (availableMTMVs.isEmpty()) { return; } - availableMTMVs.forEach(materializedView -> { + for (MTMV materializedView : availableMTMVs) { // generate outside, maybe add partition filter in the future LogicalOlapScan mvScan = new LogicalOlapScan( cascadesContext.getStatementContext().getNextRelationId(), @@ -96,6 +96,6 @@ public void initMaterializationContext(CascadesContext cascadesContext) { Plan projectScan = new LogicalProject(mvProjects, mvScan); cascadesContext.addMaterializationContext( MaterializationContext.fromMaterializedView(materializedView, projectScan, cascadesContext)); - }); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java index b3de74436524e7..db9f58ae070acc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializationContext.java @@ -33,11 +33,12 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -67,7 +68,7 @@ public class MaterializationContext { private boolean success = false; // if rewrite by mv fail, record the reason, if success the failReason should be empty. // The key is the query belonged group expression objectId, the value is the fail reason - private final Map> failReason = new HashMap<>(); + private final Map> failReason = new LinkedHashMap<>(); private boolean enableRecordFailureDetail = false; /** @@ -163,7 +164,6 @@ public void recordFailReason(StructInfo structInfo, String summary, Supplier materializatio for (MaterializationContext ctx : materializationContexts) { if (!ctx.isSuccess()) { Set failReasonSet = - ctx.getFailReason().values().stream().map(Pair::key).collect(Collectors.toSet()); + ctx.getFailReason().values().stream().map(Pair::key).collect(ImmutableSet.toImmutableSet()); builder.append("\n") .append(" Name: ").append(ctx.getMTMV().getName()) .append("\n") diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java index 674d4935594782..0d280bb8340a37 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java @@ -53,8 +53,8 @@ import com.google.common.collect.Sets; import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -116,7 +116,7 @@ public StructInfo(Plan originalPlan, ObjectId originalPlanId, HyperGraph hyperGr this.predicates = predicates; if (predicates == null) { // collect predicate from top plan which not in hyper graph - Set topPlanPredicates = new HashSet<>(); + Set topPlanPredicates = new LinkedHashSet<>(); topPlan.accept(PREDICATE_COLLECTOR, topPlanPredicates); this.predicates = Predicates.of(topPlanPredicates); } @@ -241,7 +241,9 @@ private Pair predicatesDerive(Predicates predi public static List of(Plan originalPlan) { // TODO only consider the inner join currently, Should support outer join // Split plan by the boundary which contains multi child - PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class)); + LinkedHashSet> set = Sets.newLinkedHashSet(); + set.add(LogicalJoin.class); + PlanSplitContext planSplitContext = new PlanSplitContext(set); // if single table without join, the bottom is originalPlan.accept(PLAN_SPLITTER, planSplitContext); @@ -261,16 +263,18 @@ public static StructInfo of(Plan originalPlan, @Nullable Plan topPlan, @Nullable .map(GroupExpression::getId).orElseGet(() -> new ObjectId(-1)); // if any of topPlan or bottomPlan is null, split the top plan to two parts by join node if (topPlan == null || bottomPlan == null) { - PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class)); + Set> set = Sets.newLinkedHashSet(); + set.add(LogicalJoin.class); + PlanSplitContext planSplitContext = new PlanSplitContext(set); originalPlan.accept(PLAN_SPLITTER, planSplitContext); bottomPlan = planSplitContext.getBottomPlan(); topPlan = planSplitContext.getTopPlan(); } // collect struct info fromGraph ImmutableList.Builder relationBuilder = ImmutableList.builder(); - Map relationIdStructInfoNodeMap = new HashMap<>(); - Map shuttledHashConjunctsToConjunctsMap = new HashMap<>(); - Map namedExprIdAndExprMapping = new HashMap<>(); + Map relationIdStructInfoNodeMap = new LinkedHashMap<>(); + Map shuttledHashConjunctsToConjunctsMap = new LinkedHashMap<>(); + Map namedExprIdAndExprMapping = new LinkedHashMap<>(); boolean valid = collectStructInfoFromGraph(hyperGraph, topPlan, shuttledHashConjunctsToConjunctsMap, namedExprIdAndExprMapping, relationBuilder, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java new file mode 100644 index 00000000000000..932446ce48b16d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.pattern.ExpressionPatternRules; +import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners; +import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners.CombinedListener; +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; + +/** ExpressionBottomUpRewriter */ +public class ExpressionBottomUpRewriter implements ExpressionRewriteRule { + public static final String BATCH_ID_KEY = "batch_id"; + private static final Logger LOG = LogManager.getLogger(ExpressionBottomUpRewriter.class); + private static final AtomicInteger rewriteBatchId = new AtomicInteger(); + private final ExpressionPatternRules rules; + private final ExpressionPatternTraverseListeners listeners; + + public ExpressionBottomUpRewriter(ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) { + this.rules = rules; + this.listeners = listeners; + } + + // entrance + @Override + public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + int currentBatch = rewriteBatchId.incrementAndGet(); + return rewriteBottomUp(expr, ctx, currentBatch, null, rules, listeners); + } + + private static Expression rewriteBottomUp( + Expression expression, ExpressionRewriteContext context, int currentBatch, @Nullable Expression parent, + ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) { + + Optional rewriteState = expression.getMutableState(BATCH_ID_KEY); + if (!rewriteState.isPresent() || rewriteState.get() != currentBatch) { + CombinedListener listener = null; + boolean hasChildren = expression.arity() > 0; + if (hasChildren) { + listener = listeners.matchesAndCombineListeners(expression, context, parent); + if (listener != null) { + listener.onEnter(); + } + } + + Expression afterRewrite = expression; + try { + Expression beforeRewrite; + afterRewrite = rewriteChildren(expression, context, currentBatch, rules, listeners); + // use rewriteTimes to avoid dead loop + int rewriteTimes = 0; + boolean changed; + do { + beforeRewrite = afterRewrite; + + // rewrite this + Optional applied = rules.matchesAndApply(beforeRewrite, context, parent); + + changed = applied.isPresent(); + if (changed) { + afterRewrite = applied.get(); + // ensure children are rewritten + afterRewrite = rewriteChildren(afterRewrite, context, currentBatch, rules, listeners); + } + rewriteTimes++; + } while (changed && rewriteTimes < 100); + + // set rewritten + afterRewrite.setMutableState(BATCH_ID_KEY, currentBatch); + } finally { + if (hasChildren && listener != null) { + listener.onExit(afterRewrite); + } + } + + return afterRewrite; + } + + // already rewritten + return expression; + } + + private static Expression rewriteChildren(Expression parent, ExpressionRewriteContext context, int currentBatch, + ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) { + boolean changed = false; + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(parent.arity()); + for (Expression child : parent.children()) { + Expression newChild = rewriteBottomUp(child, context, currentBatch, parent, rules, listeners); + changed |= !child.equals(newChild); + newChildren.add(newChild); + } + + Expression result = parent; + if (changed) { + result = parent.withChildren(newChildren.build()); + } + if (changed && context.cascadesContext.isEnableExprTrace()) { + LOG.info("WithChildren: \nbefore: " + parent + "\nafter: " + result); + } + return result; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java new file mode 100644 index 00000000000000..ea67d14e8fe6ae --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Predicate; + +/** ExpressionListenerMatcher */ +public class ExpressionListenerMatcher { + public final Class typePattern; + public final List>> predicates; + public final ExpressionTraverseListener listener; + + public ExpressionListenerMatcher(Class typePattern, + List>> predicates, + ExpressionTraverseListener listener) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = predicates == null ? ImmutableList.of() : predicates; + this.listener = Objects.requireNonNull(listener, "listener can not be null"); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java new file mode 100644 index 00000000000000..a28b96079b51a8 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +/** ExpressionMatchAction */ +public interface ExpressionMatchingAction { + Expression apply(ExpressionMatchingContext context); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java new file mode 100644 index 00000000000000..953815ad87c5c2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Expression; + +import java.util.Optional; + +/** ExpressionMatchingContext */ +public class ExpressionMatchingContext { + public final E expr; + public final Optional parent; + public final ExpressionRewriteContext rewriteContext; + public final CascadesContext cascadesContext; + + public ExpressionMatchingContext(E expr, Expression parent, ExpressionRewriteContext context) { + this.expr = expr; + this.parent = Optional.ofNullable(parent); + this.rewriteContext = context; + this.cascadesContext = context.cascadesContext; + } + + public boolean isRoot() { + return !parent.isPresent(); + } + + public Expression parentOr(Expression defaultParent) { + return parent.orElse(defaultParent); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java index 9886cb1787e9ed..adf0cb90a958c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java @@ -42,20 +42,21 @@ public class ExpressionNormalization extends ExpressionRewrite { // we should run supportJavaDateFormatter before foldConstantRule or be will fold // from_unixtime(timestamp, 'yyyyMMdd') to 'yyyyMMdd' public static final List NORMALIZE_REWRITE_RULES = ImmutableList.of( - SupportJavaDateFormatter.INSTANCE, - ReplaceVariableByLiteral.INSTANCE, - NormalizeBinaryPredicatesRule.INSTANCE, - InPredicateDedup.INSTANCE, - InPredicateToEqualToRule.INSTANCE, - SimplifyNotExprRule.INSTANCE, - SimplifyArithmeticRule.INSTANCE, - FoldConstantRule.INSTANCE, - SimplifyCastRule.INSTANCE, - DigitalMaskingConvert.INSTANCE, - SimplifyArithmeticComparisonRule.INSTANCE, - SupportJavaDateFormatter.INSTANCE, - ConvertAggStateCast.INSTANCE, - CheckCast.INSTANCE + bottomUp( + ReplaceVariableByLiteral.INSTANCE, + SupportJavaDateFormatter.INSTANCE, + NormalizeBinaryPredicatesRule.INSTANCE, + InPredicateDedup.INSTANCE, + InPredicateToEqualToRule.INSTANCE, + SimplifyNotExprRule.INSTANCE, + SimplifyArithmeticRule.INSTANCE, + FoldConstantRule.INSTANCE, + SimplifyCastRule.INSTANCE, + DigitalMaskingConvert.INSTANCE, + SimplifyArithmeticComparisonRule.INSTANCE, + ConvertAggStateCast.INSTANCE, + CheckCast.INSTANCE + ) ); public ExpressionNormalization() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java new file mode 100644 index 00000000000000..d694062ef1f049 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import com.google.common.collect.ImmutableList; + +/** ExpressionNormalizationAndOptimization */ +public class ExpressionNormalizationAndOptimization extends ExpressionRewrite { + /** ExpressionNormalizationAndOptimization */ + public ExpressionNormalizationAndOptimization() { + super(new ExpressionRuleExecutor( + ImmutableList.builder() + .addAll(ExpressionNormalization.NORMALIZE_REWRITE_RULES) + .addAll(ExpressionOptimization.OPTIMIZE_REWRITE_RULES) + .build() + )); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index fdf9820c582f56..b3bb18163ea2eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -39,18 +39,20 @@ */ public class ExpressionOptimization extends ExpressionRewrite { public static final List OPTIMIZE_REWRITE_RULES = ImmutableList.of( - ExtractCommonFactorRule.INSTANCE, - DistinctPredicatesRule.INSTANCE, - SimplifyComparisonPredicate.INSTANCE, - SimplifyInPredicate.INSTANCE, - SimplifyDecimalV3Comparison.INSTANCE, - SimplifyRange.INSTANCE, - DateFunctionRewrite.INSTANCE, - OrToIn.INSTANCE, - ArrayContainToArrayOverlap.INSTANCE, - CaseWhenToIf.INSTANCE, - TopnToMax.INSTANCE, - NullSafeEqualToEqual.INSTANCE + bottomUp( + ExtractCommonFactorRule.INSTANCE, + DistinctPredicatesRule.INSTANCE, + SimplifyComparisonPredicate.INSTANCE, + SimplifyInPredicate.INSTANCE, + SimplifyDecimalV3Comparison.INSTANCE, + OrToIn.INSTANCE, + SimplifyRange.INSTANCE, + DateFunctionRewrite.INSTANCE, + ArrayContainToArrayOverlap.INSTANCE, + CaseWhenToIf.INSTANCE, + TopnToMax.INSTANCE, + NullSafeEqualToEqual.INSTANCE + ) ); private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java new file mode 100644 index 00000000000000..dbf5c79c96d754 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping; +import org.apache.doris.nereids.trees.expressions.Expression; + +import java.util.List; +import java.util.function.Predicate; + +/** ExpressionPatternMatcherRule */ +public class ExpressionPatternMatchRule implements TypeMapping { + public final Class typePattern; + public final List>> predicates; + public final ExpressionMatchingAction matchingAction; + + public ExpressionPatternMatchRule(ExpressionPatternMatcher patternMatcher) { + this.typePattern = patternMatcher.typePattern; + this.predicates = patternMatcher.predicates; + this.matchingAction = patternMatcher.matchingAction; + } + + /** matches */ + public boolean matchesTypeAndPredicates(ExpressionMatchingContext context) { + return typePattern.isInstance(context.expr) && matchesPredicates(context); + } + + /** matchesPredicates */ + public boolean matchesPredicates(ExpressionMatchingContext context) { + if (!predicates.isEmpty()) { + for (Predicate> predicate : predicates) { + if (!predicate.test(context)) { + return false; + } + } + } + return true; + } + + public Expression apply(ExpressionMatchingContext context) { + Expression newResult = matchingAction.apply(context); + return newResult == null ? context.expr : newResult; + } + + @Override + public Class getType() { + return typePattern; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java new file mode 100644 index 00000000000000..058b1d60b1d013 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Predicate; + +/** ExpressionPattern */ +public class ExpressionPatternMatcher { + public final Class typePattern; + public final List>> predicates; + public final ExpressionMatchingAction matchingAction; + + public ExpressionPatternMatcher(Class typePattern, + List>> predicates, + ExpressionMatchingAction matchingAction) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = predicates == null ? ImmutableList.of() : predicates; + this.matchingAction = Objects.requireNonNull(matchingAction, "matchingAction can not be null"); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java new file mode 100644 index 00000000000000..7fb18735ba5e46 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Predicate; + +/** ExpressionPatternRuleFactory */ +public interface ExpressionPatternRuleFactory { + List> buildRules(); + + default ExpressionPatternDescriptor matchesType(Class clazz) { + return new ExpressionPatternDescriptor<>(clazz); + } + + default ExpressionPatternDescriptor root(Class clazz) { + return new ExpressionPatternDescriptor<>(clazz) + .whenCtx(ctx -> ctx.isRoot()); + } + + default ExpressionPatternDescriptor matchesTopType(Class clazz) { + return new ExpressionPatternDescriptor<>(clazz) + .whenCtx(ctx -> ctx.isRoot() || !clazz.isInstance(ctx.parent.get())); + } + + /** ExpressionPatternDescriptor */ + class ExpressionPatternDescriptor { + private final Class typePattern; + private final ImmutableList>> predicates; + + public ExpressionPatternDescriptor(Class typePattern) { + this(typePattern, ImmutableList.of()); + } + + public ExpressionPatternDescriptor( + Class typePattern, ImmutableList>> predicates) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = Objects.requireNonNull(predicates, "predicates can not be null"); + } + + public ExpressionPatternDescriptor when(Predicate predicate) { + return whenCtx(ctx -> predicate.test(ctx.expr)); + } + + public ExpressionPatternDescriptor whenCtx(Predicate> predicate) { + ImmutableList.Builder>> newPredicates + = ImmutableList.builderWithExpectedSize(predicates.size() + 1); + newPredicates.addAll(predicates); + newPredicates.add(predicate); + return new ExpressionPatternDescriptor<>(typePattern, newPredicates.build()); + } + + /** then */ + public ExpressionPatternMatcher then(Function rewriter) { + return new ExpressionPatternMatcher<>( + typePattern, predicates, (context) -> rewriter.apply(context.expr)); + } + + public ExpressionPatternMatcher thenApply(ExpressionMatchingAction action) { + return new ExpressionPatternMatcher<>(typePattern, predicates, action); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java index 912793e61d1b2e..b547f693a7cdb3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java @@ -18,6 +18,8 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.pattern.ExpressionPatternRules; +import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners; import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; @@ -41,7 +43,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Objects; import java.util.Set; @@ -123,9 +125,7 @@ public Rule build() { LogicalProject project = ctx.root; ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); List projects = project.getProjects(); - List newProjects = projects.stream() - .map(expr -> (NamedExpression) rewriter.rewrite(expr, context)) - .collect(ImmutableList.toImmutableList()); + List newProjects = rewriteAll(projects, rewriter, context); if (projects.equals(newProjects)) { return project; } @@ -160,9 +160,7 @@ public Rule build() { List newGroupByExprs = rewriter.rewrite(groupByExprs, context); List outputExpressions = agg.getOutputExpressions(); - List newOutputExpressions = outputExpressions.stream() - .map(expr -> (NamedExpression) rewriter.rewrite(expr, context)) - .collect(ImmutableList.toImmutableList()); + List newOutputExpressions = rewriteAll(outputExpressions, rewriter, context); if (outputExpressions.equals(newOutputExpressions)) { return agg; } @@ -222,13 +220,16 @@ public Rule build() { return logicalSort().thenApply(ctx -> { LogicalSort sort = ctx.root; List orderKeys = sort.getOrderKeys(); - List rewrittenOrderKeys = new ArrayList<>(); + ImmutableList.Builder rewrittenOrderKeys + = ImmutableList.builderWithExpectedSize(orderKeys.size()); ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + boolean changed = false; for (OrderKey k : orderKeys) { Expression expression = rewriter.rewrite(k.getExpr(), context); + changed |= expression != k.getExpr(); rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst())); } - return sort.withOrderKeys(rewrittenOrderKeys); + return changed ? sort.withOrderKeys(rewrittenOrderKeys.build()) : sort; }).toRule(RuleType.REWRITE_SORT_EXPRESSION); } } @@ -270,4 +271,36 @@ public Rule build() { }).toRule(RuleType.REWRITE_REPEAT_EXPRESSION); } } + + /** bottomUp */ + public static ExpressionBottomUpRewriter bottomUp(ExpressionPatternRuleFactory... ruleFactories) { + ImmutableList.Builder rules = ImmutableList.builder(); + ImmutableList.Builder listeners = ImmutableList.builder(); + for (ExpressionPatternRuleFactory ruleFactory : ruleFactories) { + if (ruleFactory instanceof ExpressionTraverseListenerFactory) { + List> listenersMatcher + = ((ExpressionTraverseListenerFactory) ruleFactory).buildListeners(); + for (ExpressionListenerMatcher listenerMatcher : listenersMatcher) { + listeners.add(new ExpressionTraverseListenerMapping(listenerMatcher)); + } + } + for (ExpressionPatternMatcher patternMatcher : ruleFactory.buildRules()) { + rules.add(new ExpressionPatternMatchRule(patternMatcher)); + } + } + + return new ExpressionBottomUpRewriter( + new ExpressionPatternRules(rules.build()), + new ExpressionPatternTraverseListeners(listeners.build()) + ); + } + + public static List rewriteAll( + Collection exprs, ExpressionRuleExecutor rewriter, ExpressionRewriteContext context) { + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(exprs.size()); + for (E expr : exprs) { + result.add((E) rewriter.rewrite(expr, context)); + } + return result.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java index cb50e0d2871e3b..35633e7594f717 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java @@ -19,6 +19,8 @@ import org.apache.doris.nereids.CascadesContext; +import java.util.Objects; + /** * expression rewrite context. */ @@ -27,7 +29,7 @@ public class ExpressionRewriteContext { public final CascadesContext cascadesContext; public ExpressionRewriteContext(CascadesContext cascadesContext) { - this.cascadesContext = cascadesContext; + this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not be null"); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java index ac7e6dae6b282d..0f951448dd2582 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import com.google.common.collect.ImmutableList; @@ -36,7 +37,11 @@ public ExpressionRuleExecutor(List rules) { } public List rewrite(List exprs, ExpressionRewriteContext ctx) { - return exprs.stream().map(expr -> rewrite(expr, ctx)).collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(rewrite(expr, ctx)); + } + return result.build(); } /** @@ -61,8 +66,15 @@ private Expression applyRule(Expression expr, ExpressionRewriteRule rule, Expres return rule.rewrite(expr, ctx); } + /** normalize */ public static Expression normalize(Expression expression) { - return NormalizeBinaryPredicatesRule.INSTANCE.rewrite(expression, null); + return expression.rewriteUp(expr -> { + if (expr instanceof ComparisonPredicate) { + return NormalizeBinaryPredicatesRule.normalize((ComparisonPredicate) expression); + } else { + return expr; + } + }); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java new file mode 100644 index 00000000000000..5df5a6d68185dd --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +/** ExpressionTraverseListener */ +public interface ExpressionTraverseListener { + default void onEnter(ExpressionMatchingContext context) {} + + default void onExit(ExpressionMatchingContext context, Expression rewritten) {} + + default ExpressionTraverseListener as() { + return (ExpressionTraverseListener) this; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java new file mode 100644 index 00000000000000..201362fed781b6 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Predicate; + +/** ExpressionTraverseListenerFactory */ +public interface ExpressionTraverseListenerFactory { + List> buildListeners(); + + default ListenerDescriptor listenerType(Class clazz) { + return new ListenerDescriptor<>(clazz); + } + + /** listenerTypes */ + default List> listenerTypes(Class... classes) { + ImmutableList.Builder> listeners + = ImmutableList.builderWithExpectedSize(classes.length); + for (Class clazz : classes) { + listeners.add((ListenerDescriptor) listenerType(clazz)); + } + return listeners.build(); + } + + /** ListenerDescriptor */ + class ListenerDescriptor { + + private final Class typePattern; + private final ImmutableList>> predicates; + + public ListenerDescriptor(Class typePattern) { + this(typePattern, ImmutableList.of()); + } + + public ListenerDescriptor( + Class typePattern, ImmutableList>> predicates) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = Objects.requireNonNull(predicates, "predicates can not be null"); + } + + public ListenerDescriptor when(Predicate predicate) { + return whenCtx(ctx -> predicate.test(ctx.expr)); + } + + public ListenerDescriptor whenCtx(Predicate> predicate) { + ImmutableList.Builder>> newPredicates + = ImmutableList.builderWithExpectedSize(predicates.size() + 1); + newPredicates.addAll(predicates); + newPredicates.add(predicate); + return new ListenerDescriptor<>(typePattern, newPredicates.build()); + } + + /** then */ + public ExpressionListenerMatcher then(ExpressionTraverseListener listener) { + return new ExpressionListenerMatcher<>(typePattern, predicates, listener); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java new file mode 100644 index 00000000000000..d99c231110f175 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping; +import org.apache.doris.nereids.trees.expressions.Expression; + +import java.util.List; +import java.util.function.Predicate; + +/** ExpressionTraverseListener */ +public class ExpressionTraverseListenerMapping implements TypeMapping { + public final Class typePattern; + public final List>> predicates; + public final ExpressionTraverseListener listener; + + public ExpressionTraverseListenerMapping(ExpressionListenerMatcher listenerMatcher) { + this.typePattern = listenerMatcher.typePattern; + this.predicates = listenerMatcher.predicates; + this.listener = listenerMatcher.listener; + } + + @Override + public Class getType() { + return typePattern; + } + + /** matches */ + public boolean matchesTypeAndPredicates(ExpressionMatchingContext context) { + return typePattern.isInstance(context.expr) && matchesPredicates(context); + } + + /** matchesPredicates */ + public boolean matchesPredicates(ExpressionMatchingContext context) { + if (!predicates.isEmpty()) { + for (Predicate> predicate : predicates) { + if (!predicate.test(context)) { + return false; + } + } + } + return true; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java index d7a6085dcab550..69a9105d653d81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java @@ -18,8 +18,8 @@ package org.apache.doris.nereids.rules.expression.check; import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.types.ArrayType; @@ -31,18 +31,24 @@ import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.PrimitiveType; +import com.google.common.collect.ImmutableList; + import java.util.List; /** * check cast valid */ -public class CheckCast extends AbstractExpressionRewriteRule { - - public static final CheckCast INSTANCE = new CheckCast(); +public class CheckCast implements ExpressionPatternRuleFactory { + public static CheckCast INSTANCE = new CheckCast(); @Override - public Expression visitCast(Cast cast, ExpressionRewriteContext context) { - rewrite(cast.child(), context); + public List> buildRules() { + return ImmutableList.of( + matchesType(Cast.class).then(CheckCast::check) + ); + } + + private static Expression check(Cast cast) { DataType originalType = cast.child().getDataType(); DataType targetType = cast.getDataType(); if (!check(originalType, targetType)) { @@ -51,7 +57,7 @@ public Expression visitCast(Cast cast, ExpressionRewriteContext context) { return cast; } - private boolean check(DataType originalType, DataType targetType) { + private static boolean check(DataType originalType, DataType targetType) { if (originalType.isVariantType() && (targetType instanceof PrimitiveType || targetType.isArrayType())) { // variant could cast to primitive types and array return true; @@ -99,7 +105,7 @@ private boolean check(DataType originalType, DataType targetType) { * 3. original type is same with target type * 4. target type is null type */ - private boolean checkPrimitiveType(DataType originalType, DataType targetType) { + private static boolean checkPrimitiveType(DataType originalType, DataType targetType) { if (!originalType.isPrimitive() || !targetType.isPrimitive()) { return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java index 7309ef111c925d..f32d76062aaf7c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java @@ -17,26 +17,29 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap; import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimaps; +import com.google.common.collect.SetMultimap; -import java.util.HashMap; -import java.util.HashSet; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; +import java.util.Map.Entry; import java.util.Set; -import java.util.stream.Collectors; /** * array_contains ( c_array, '1' ) @@ -44,56 +47,73 @@ * =========================================> * array_overlap(c_array, ['1', '2']) */ -public class ArrayContainToArrayOverlap extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class ArrayContainToArrayOverlap implements ExpressionPatternRuleFactory { public static final ArrayContainToArrayOverlap INSTANCE = new ArrayContainToArrayOverlap(); private static final int REWRITE_PREDICATE_THRESHOLD = 2; @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, ctx); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(Or.class).then(ArrayContainToArrayOverlap::rewrite) + ); } - @Override - public Expression visitOr(Or or, ExpressionRewriteContext ctx) { + private static Expression rewrite(Or or) { List disjuncts = ExpressionUtils.extractDisjunction(or); - Map> containFuncAndOtherFunc = disjuncts.stream() - .collect(Collectors.partitioningBy(this::isValidArrayContains)); - Map> containLiteralSet = new HashMap<>(); - List contains = containFuncAndOtherFunc.get(true); - List others = containFuncAndOtherFunc.get(false); - contains.forEach(containFunc -> - containLiteralSet.computeIfAbsent(containFunc.child(0), k -> new HashSet<>()) - .add((Literal) containFunc.child(1))); + List contains = Lists.newArrayList(); + List others = Lists.newArrayList(); + for (Expression expr : disjuncts) { + if (ArrayContainToArrayOverlap.isValidArrayContains(expr)) { + contains.add(expr); + } else { + others.add(expr); + } + } + + if (contains.size() <= 1) { + return or; + } + + SetMultimap containLiteralSet = Multimaps.newSetMultimap( + new LinkedHashMap<>(), LinkedHashSet::new + ); + for (Expression contain : contains) { + containLiteralSet.put(contain.child(0), (Literal) contain.child(1)); + } Builder newDisjunctsBuilder = new ImmutableList.Builder<>(); - containLiteralSet.forEach((left, literalSet) -> { + for (Entry> kv : containLiteralSet.asMap().entrySet()) { + Expression left = kv.getKey(); + Collection literalSet = kv.getValue(); if (literalSet.size() > REWRITE_PREDICATE_THRESHOLD) { newDisjunctsBuilder.add( - new ArraysOverlap(left, - new ArrayLiteral(ImmutableList.copyOf(literalSet)))); + new ArraysOverlap(left, new ArrayLiteral(Utils.fastToImmutableList(literalSet))) + ); + } + } + + for (Expression contain : contains) { + if (!canCovertToArrayOverlap(contain, containLiteralSet)) { + newDisjunctsBuilder.add(contain); } - }); - - contains.stream() - .filter(e -> !canCovertToArrayOverlap(e, containLiteralSet)) - .forEach(newDisjunctsBuilder::add); - others.stream() - .map(e -> e.accept(this, null)) - .forEach(newDisjunctsBuilder::add); + } + newDisjunctsBuilder.addAll(others); return ExpressionUtils.or(newDisjunctsBuilder.build()); } - private boolean isValidArrayContains(Expression expression) { + private static boolean isValidArrayContains(Expression expression) { return expression instanceof ArrayContains && expression.child(1) instanceof Literal; } - private boolean canCovertToArrayOverlap(Expression expression, Map> containLiteralSet) { - return expression instanceof ArrayContains - && containLiteralSet.getOrDefault(expression.child(0), - new HashSet<>()).size() > REWRITE_PREDICATE_THRESHOLD; + private static boolean canCovertToArrayOverlap( + Expression expression, SetMultimap containLiteralSet) { + if (!(expression instanceof ArrayContains)) { + return false; + } + Set containLiteral = containLiteralSet.get(expression.child(0)); + return containLiteral.size() > REWRITE_PREDICATE_THRESHOLD; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java index 6372338406dd1d..cafb0ecd068ddd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java @@ -17,25 +17,35 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Rewrite rule to convert CASE WHEN to IF. * For example: * CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0) */ -public class CaseWhenToIf extends AbstractExpressionRewriteRule { +public class CaseWhenToIf implements ExpressionPatternRuleFactory { public static CaseWhenToIf INSTANCE = new CaseWhenToIf(); @Override - public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CaseWhen.class).then(CaseWhenToIf::rewrite) + ); + } + + private static Expression rewrite(CaseWhen caseWhen) { Expression expr = caseWhen; if (caseWhen.getWhenClauses().size() == 1) { WhenClause whenClause = caseWhen.getWhenClauses().get(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java index e5748eb1d59e2c..239007015531eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; @@ -30,29 +30,30 @@ import com.google.common.collect.ImmutableList; +import java.util.List; + /** * Follow legacy planner cast agg_state combinator's children if we need cast it to another agg_state type when insert */ -public class ConvertAggStateCast extends AbstractExpressionRewriteRule { +public class ConvertAggStateCast implements ExpressionPatternRuleFactory { public static ConvertAggStateCast INSTANCE = new ConvertAggStateCast(); @Override - public Expression visitCast(Cast cast, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(Cast.class).then(ConvertAggStateCast::convert) + ); + } + + private static Expression convert(Cast cast) { Expression child = cast.child(); DataType originalType = child.getDataType(); DataType targetType = cast.getDataType(); if (originalType instanceof AggStateType && targetType instanceof AggStateType && child instanceof StateCombinator) { - AggStateType original = (AggStateType) originalType; AggStateType target = (AggStateType) targetType; - if (original.getSubTypes().size() != target.getSubTypes().size()) { - return processCastChild(cast, context); - } - if (!original.getFunctionName().equalsIgnoreCase(target.getFunctionName())) { - return processCastChild(cast, context); - } ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(child.arity()); for (int i = 0; i < child.arity(); i++) { Expression newChild = TypeCoercionUtils.castIfNotSameType(child.child(i), target.getSubTypes().get(i)); @@ -66,15 +67,7 @@ public Expression visitCast(Cast cast, ExpressionRewriteContext context) { newChildren.add(newChild); } child = child.withChildren(newChildren.build()); - return processCastChild(cast.withChildren(ImmutableList.of(child)), context); - } - return processCastChild(cast, context); - } - - private Expression processCastChild(Cast cast, ExpressionRewriteContext context) { - Expression child = visit(cast.child(), context); - if (child != cast.child()) { - cast = cast.withChildren(ImmutableList.of(child)); + return cast.withChildren(ImmutableList.of(child)); } return cast; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java index e78eeecff0d105..07ec0c3de71d24 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; @@ -34,17 +34,31 @@ import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * F: a DateTime or DateTimeV2 column * Date(F) > 2020-01-01 => F > 2020-01-02 00:00:00 * Date(F) >= 2020-01-01 => F > 2020-01-01 00:00:00 * */ -public class DateFunctionRewrite extends AbstractExpressionRewriteRule { +public class DateFunctionRewrite implements ExpressionPatternRuleFactory { public static DateFunctionRewrite INSTANCE = new DateFunctionRewrite(); @Override - public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(EqualTo.class).then(DateFunctionRewrite::rewriteEqualTo), + matchesType(GreaterThan.class).then(DateFunctionRewrite::rewriteGreaterThan), + matchesType(GreaterThanEqual.class).then(DateFunctionRewrite::rewriteGreaterThanEqual), + matchesType(LessThan.class).then(DateFunctionRewrite::rewriteLessThan), + matchesType(LessThanEqual.class).then(DateFunctionRewrite::rewriteLessThanEqual) + ); + } + + private static Expression rewriteEqualTo(EqualTo equalTo) { if (equalTo.left() instanceof Date) { // V1 if (equalTo.left().child(0).getDataType() instanceof DateTimeType @@ -70,8 +84,7 @@ public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context return equalTo; } - @Override - public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) { + private static Expression rewriteGreaterThan(GreaterThan greaterThan) { if (greaterThan.left() instanceof Date) { // V1 if (greaterThan.left().child(0).getDataType() instanceof DateTimeType @@ -91,8 +104,7 @@ public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteCon return greaterThan; } - @Override - public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) { + private static Expression rewriteGreaterThanEqual(GreaterThanEqual greaterThanEqual) { if (greaterThanEqual.left() instanceof Date) { // V1 if (greaterThanEqual.left().child(0).getDataType() instanceof DateTimeType @@ -111,8 +123,7 @@ public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, Expre return greaterThanEqual; } - @Override - public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) { + private static Expression rewriteLessThan(LessThan lessThan) { if (lessThan.left() instanceof Date) { // V1 if (lessThan.left().child(0).getDataType() instanceof DateTimeType @@ -131,8 +142,7 @@ public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext cont return lessThan; } - @Override - public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) { + private static Expression rewriteLessThanEqual(LessThanEqual lessThanEqual) { if (lessThanEqual.left() instanceof Date) { // V1 if (lessThanEqual.left().child(0).getDataType() instanceof DateTimeType diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java index 5e38c0390b6c93..95d25e3c592454 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat; import org.apache.doris.nereids.trees.expressions.functions.scalar.DigitalMasking; @@ -26,16 +26,25 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Right; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Convert DigitalMasking to Concat */ -public class DigitalMaskingConvert extends AbstractExpressionRewriteRule { - +public class DigitalMaskingConvert implements ExpressionPatternRuleFactory { public static DigitalMaskingConvert INSTANCE = new DigitalMaskingConvert(); @Override - public Expression visitDigitalMasking(DigitalMasking digitalMasking, ExpressionRewriteContext context) { - return new Concat(new Left(digitalMasking.child(), Literal.of(3)), Literal.of("****"), - new Right(digitalMasking.child(), Literal.of(4))); + public List> buildRules() { + return ImmutableList.of( + matchesType(DigitalMasking.class).then(digitalMasking -> + new Concat( + new Left(digitalMasking.child(), Literal.of(3)), + Literal.of("****"), + new Right(digitalMasking.child(), Literal.of(4))) + ) + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java index a3466d395d56e0..cf18886cd85fc3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java @@ -17,12 +17,13 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.LinkedHashSet; @@ -35,16 +36,21 @@ * transform (a = 1) and (b > 2) and (a = 1) to (a = 1) and (b > 2) * transform (a = 1) or (a = 1) to (a = 1) */ -public class DistinctPredicatesRule extends AbstractExpressionRewriteRule { - +public class DistinctPredicatesRule implements ExpressionPatternRuleFactory { public static final DistinctPredicatesRule INSTANCE = new DistinctPredicatesRule(); @Override - public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CompoundPredicate.class).then(DistinctPredicatesRule::distinct) + ); + } + + private static Expression distinct(CompoundPredicate expr) { List extractExpressions = ExpressionUtils.extract(expr); Set distinctExpressions = new LinkedHashSet<>(extractExpressions); if (distinctExpressions.size() != extractExpressions.size()) { - return ExpressionUtils.combine(expr.getClass(), Lists.newArrayList(distinctExpressions)); + return ExpressionUtils.combineAsLeftDeepTree(expr.getClass(), Lists.newArrayList(distinctExpressions)); } return expr; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java index dd457e01d8d9cc..4032db4aadf550 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java @@ -18,21 +18,28 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.nereids.annotation.Developing; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Multimaps; +import com.google.common.collect.SetMultimap; import com.google.common.collect.Sets; -import java.util.Collections; -import java.util.HashSet; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map.Entry; import java.util.Set; -import java.util.stream.Collectors; /** * Extract common expr for `CompoundPredicate`. @@ -41,42 +48,197 @@ * transform (a and b) or (a and c) to a and (b or c) */ @Developing -public class ExtractCommonFactorRule extends AbstractExpressionRewriteRule { - +public class ExtractCommonFactorRule implements ExpressionPatternRuleFactory { public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule(); @Override - public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CompoundPredicate.class).then(ExtractCommonFactorRule::extractCommonFactor) + ); + } + + private static Expression extractCommonFactor(CompoundPredicate originExpr) { + // fast return + if (!(originExpr.left() instanceof CompoundPredicate || originExpr.left() instanceof BooleanLiteral) + && !(originExpr.right() instanceof CompoundPredicate || originExpr.right() instanceof BooleanLiteral)) { + return originExpr; + } - Expression rewrittenChildren = ExpressionUtils.combine(expr.getClass(), ExpressionUtils.extract(expr).stream() - .map(predicate -> rewrite(predicate, context)).collect(ImmutableList.toImmutableList())); - if (!(rewrittenChildren instanceof CompoundPredicate)) { - return rewrittenChildren; + // flatten same type to a list + // e.g. ((a and (b or c)) and c) -> [a, (b or c), c] + List flatten = ExpressionUtils.extract(originExpr); + + // combine and delete some boolean literal predicate + // e.g. (a and true) -> true + Expression simplified = ExpressionUtils.combineAsLeftDeepTree(originExpr.getClass(), flatten); + if (!(simplified instanceof CompoundPredicate)) { + return simplified; } - CompoundPredicate compoundPredicate = (CompoundPredicate) rewrittenChildren; + // separate two levels CompoundPredicate to partitions + // e.g. ((a and (b or c)) and c) -> [[a], [b, c], c] + CompoundPredicate leftDeapTree = (CompoundPredicate) simplified; + ImmutableSet.Builder> partitionsBuilder + = ImmutableSet.builderWithExpectedSize(flatten.size()); + for (Expression onPartition : ExpressionUtils.extract(leftDeapTree)) { + if (onPartition instanceof CompoundPredicate) { + partitionsBuilder.add(ExpressionUtils.extract((CompoundPredicate) onPartition)); + } else { + partitionsBuilder.add(ImmutableList.of(onPartition)); + } + } + Set> partitions = partitionsBuilder.build(); - List> partitions = ExpressionUtils.extract(compoundPredicate).stream() - .map(predicate -> predicate instanceof CompoundPredicate ? ExpressionUtils.extract( - (CompoundPredicate) predicate) : Lists.newArrayList(predicate)).collect(Collectors.toList()); + Expression result = extractCommonFactors(originExpr, leftDeapTree, Utils.fastToImmutableList(partitions)); + return result; + } - Set commons = partitions.stream() - .>map(HashSet::new) - .reduce(Sets::intersection) - .orElse(Collections.emptySet()); + private static Expression extractCommonFactors(CompoundPredicate originPredicate, + CompoundPredicate leftDeapTreePredicate, List> initPartitions) { + // extract factor and fill into commonFactorToPartIds + // e.g. + // originPredicate: (a and (b and c)) and (b or c) + // leftDeapTreePredicate: ((a and b) and c) and (b or c) + // initPartitions: [[a], [b], [c], [b, c]] + // + // -> commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}. + // so we can know `b` and `c` is a common factors + SetMultimap commonFactorToPartIds = Multimaps.newSetMultimap( + Maps.newLinkedHashMap(), LinkedHashSet::new + ); + int originExpressionNum = 0; + int partId = 0; + for (List partition : initPartitions) { + for (Expression expression : partition) { + commonFactorToPartIds.put(expression, partId); + originExpressionNum++; + } + partId++; + } - List> uncorrelated = partitions.stream() - .map(predicates -> predicates.stream().filter(p -> !commons.contains(p)).collect(Collectors.toList())) - .collect(Collectors.toList()); + // commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]} + // + // -> reverse key value of commonFactorToPartIds and remove intersecting partition + // + // -> 1. reverse: {[0]: [a], [1, 3]: [b], [2, 3]: [c]} + // -> 2. sort by key size desc: {[1, 3]: [b], [2, 3]: [c], [0]: [a]} + // -> 3. remove intersection partition: {[1, 3]: [b], [2]: [c], [0]: [a]}, + // because first part and second part intersect by partition 3 + LinkedHashMap, Set> commonFactorPartitions + = partitionByMostCommonFactors(commonFactorToPartIds); - Expression combineUncorrelated = ExpressionUtils.combine(compoundPredicate.getClass(), - uncorrelated.stream() - .map(predicates -> ExpressionUtils.combine(compoundPredicate.flipType(), predicates)) - .collect(Collectors.toList())); + int extractedExpressionNum = 0; + for (Set exprs : commonFactorPartitions.values()) { + extractedExpressionNum += exprs.size(); + } + + // no any common factor + if (commonFactorPartitions.entrySet().iterator().next().getKey().size() <= 1 + && !(originPredicate.getWidth() > leftDeapTreePredicate.getWidth()) + && originExpressionNum <= extractedExpressionNum) { + // this condition is important because it can avoid deap loop: + // origin originExpr: A = 1 and (B > 0 and B < 10) + // after ExtractCommonFactorRule: (A = 1 and B > 0) and (B < 10) (left deap tree) + // after SimplifyRange: A = 1 and (B > 0 and B < 10) (right deap tree) + return originPredicate; + } + + // now we can do extract common factors for each part: + // originPredicate: (a and (b and c)) and (b or c) + // leftDeapTreePredicate: ((a and b) and c) and (b or c) + // initPartitions: [[a], [b], [c], [b, c]] + // commonFactorPartitions: {[1, 3]: [b], [0]: [a]} + // + // -> extractedExprs: [ + // b or (false and c) = b, + // a, + // c + // ] + // + // -> result: (b or c) and a and c + ImmutableList.Builder extractedExprs + = ImmutableList.builderWithExpectedSize(commonFactorPartitions.size()); + for (Entry, Set> kv : commonFactorPartitions.entrySet()) { + Expression extracted = doExtractCommonFactors( + leftDeapTreePredicate, initPartitions, kv.getKey(), kv.getValue() + ); + extractedExprs.add(extracted); + } + + // combine and eliminate some boolean literal predicate + return ExpressionUtils.combineAsLeftDeepTree(leftDeapTreePredicate.getClass(), extractedExprs.build()); + } - List finalCompound = Lists.newArrayList(commons); - finalCompound.add(combineUncorrelated); + private static Expression doExtractCommonFactors( + CompoundPredicate originPredicate, + List> partitions, Set partitionIds, Set commonFactors) { + ImmutableList.Builder uncorrelatedExprPartitionsBuilder + = ImmutableList.builderWithExpectedSize(partitionIds.size()); + for (Integer partitionId : partitionIds) { + List partition = partitions.get(partitionId); + ImmutableSet.Builder uncorrelatedBuilder + = ImmutableSet.builderWithExpectedSize(partition.size()); + for (Expression exprOfPart : partition) { + if (!commonFactors.contains(exprOfPart)) { + uncorrelatedBuilder.add(exprOfPart); + } + } + + Set uncorrelated = uncorrelatedBuilder.build(); + Expression partitionWithoutCommonFactor + = ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), uncorrelated); + if (partitionWithoutCommonFactor instanceof CompoundPredicate) { + partitionWithoutCommonFactor = extractCommonFactor((CompoundPredicate) partitionWithoutCommonFactor); + } + uncorrelatedExprPartitionsBuilder.add(partitionWithoutCommonFactor); + } + + ImmutableList uncorrelatedExprPartitions = uncorrelatedExprPartitionsBuilder.build(); + ImmutableList.Builder allExprs = ImmutableList.builderWithExpectedSize(commonFactors.size() + 1); + allExprs.addAll(commonFactors); + + Expression combineUncorrelatedExpr = ExpressionUtils.combineAsLeftDeepTree( + originPredicate.getClass(), uncorrelatedExprPartitions); + allExprs.add(combineUncorrelatedExpr); + + Expression result = ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), allExprs.build()); + return result; + } + + private static LinkedHashMap, Set> partitionByMostCommonFactors( + SetMultimap commonFactorToPartIds) { + SetMultimap, Expression> partWithCommonFactors = Multimaps.newSetMultimap( + Maps.newLinkedHashMap(), LinkedHashSet::new + ); + + for (Entry> factorToId : commonFactorToPartIds.asMap().entrySet()) { + partWithCommonFactors.put((Set) factorToId.getValue(), factorToId.getKey()); + } + + List> sortedPartitionIdHasCommonFactor = Lists.newArrayList(partWithCommonFactors.keySet()); + // place the most common factor at the head of this list + sortedPartitionIdHasCommonFactor.sort((p1, p2) -> p2.size() - p1.size()); + + LinkedHashMap, Set> shouldExtractFactors = Maps.newLinkedHashMap(); + + Set allocatedPartitions = Sets.newLinkedHashSet(); + for (Set originMostCommonFactorPartitions : sortedPartitionIdHasCommonFactor) { + ImmutableSet.Builder notAllocatePartitions = ImmutableSet.builderWithExpectedSize( + originMostCommonFactorPartitions.size()); + for (Integer partId : originMostCommonFactorPartitions) { + if (allocatedPartitions.add(partId)) { + notAllocatePartitions.add(partId); + } + } + + Set mostCommonFactorPartitions = notAllocatePartitions.build(); + if (!mostCommonFactorPartitions.isEmpty()) { + Set commonFactors = partWithCommonFactors.get(originMostCommonFactorPartitions); + shouldExtractFactors.put(mostCommonFactorPartitions, commonFactors); + } + } - return ExpressionUtils.combine(compoundPredicate.flipType(), finalCompound); + return shouldExtractFactors; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java index c801f749ee0ef0..04acb91d9e2d39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java @@ -17,24 +17,46 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Expression; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Constant evaluation of an expression. */ -public class FoldConstantRule extends AbstractExpressionRewriteRule { +public class FoldConstantRule implements ExpressionPatternRuleFactory { public static final FoldConstantRule INSTANCE = new FoldConstantRule(); + private static final ExpressionBottomUpRewriter FULL_FOLD_REWRITER = ExpressionRewrite.bottomUp( + FoldConstantRuleOnFE.VISITOR_INSTANCE, + FoldConstantRuleOnBE.INSTANCE + ); + + /** evaluate by pattern match */ @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + public List> buildRules() { + return ImmutableList.>builder() + .addAll(FoldConstantRuleOnFE.PATTERN_MATCH_INSTANCE.buildRules()) + .addAll(FoldConstantRuleOnBE.INSTANCE.buildRules()) + .build(); + } + + /** evaluate by visitor */ + public static Expression evaluate(Expression expr, ExpressionRewriteContext ctx) { if (ctx.cascadesContext != null && ctx.cascadesContext.getConnectContext() != null && ctx.cascadesContext.getConnectContext().getSessionVariable().isEnableFoldConstantByBe()) { - return new FoldConstantRuleOnBE().rewrite(expr, ctx); + return FULL_FOLD_REWRITER.rewrite(expr, ctx); + } else { + return FoldConstantRuleOnFE.VISITOR_INSTANCE.rewrite(expr, ctx); } - return FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java index 38c6a483c9f777..09e9bbe0b91e37 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java @@ -27,8 +27,9 @@ import org.apache.doris.common.util.DebugUtil; import org.apache.doris.common.util.TimeUtils; import org.apache.doris.nereids.glue.translator.ExpressionTranslator; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; @@ -55,6 +56,7 @@ import org.apache.doris.thrift.TQueryOptions; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -73,24 +75,38 @@ /** * Constant evaluation of an expression. */ -public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule { +public class FoldConstantRuleOnBE implements ExpressionPatternRuleFactory { + + public static final FoldConstantRuleOnBE INSTANCE = new FoldConstantRuleOnBE(); private static final Logger LOG = LogManager.getLogger(FoldConstantRuleOnBE.class); - private final IdGenerator idGenerator = ExprId.createGenerator(); @Override - public Expression rewrite(Expression expression, ExpressionRewriteContext ctx) { - expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expression, ctx); - return foldByBE(expression, ctx); + public List> buildRules() { + return ImmutableList.of( + root(Expression.class) + .whenCtx(FoldConstantRuleOnBE::isEnableFoldByBe) + .thenApply(FoldConstantRuleOnBE::foldByBE) + ); + } + + public static boolean isEnableFoldByBe(ExpressionMatchingContext ctx) { + return ctx.cascadesContext != null + && ctx.cascadesContext.getConnectContext() != null + && ctx.cascadesContext.getConnectContext().getSessionVariable().isEnableFoldConstantByBe(); } - private Expression foldByBE(Expression root, ExpressionRewriteContext context) { + /** foldByBE */ + public static Expression foldByBE(ExpressionMatchingContext context) { + IdGenerator idGenerator = ExprId.createGenerator(); + + Expression root = context.expr; Map constMap = Maps.newHashMap(); Map staleConstTExprMap = Maps.newHashMap(); Expression rootWithoutAlias = root; if (root instanceof Alias) { rootWithoutAlias = ((Alias) root).child(); } - collectConst(rootWithoutAlias, constMap, staleConstTExprMap); + collectConst(rootWithoutAlias, constMap, staleConstTExprMap, idGenerator); if (constMap.isEmpty()) { return root; } @@ -103,7 +119,8 @@ private Expression foldByBE(Expression root, ExpressionRewriteContext context) { return root; } - private Expression replace(Expression root, Map constMap, Map resultMap) { + private static Expression replace( + Expression root, Map constMap, Map resultMap) { for (Entry entry : constMap.entrySet()) { if (entry.getValue().equals(root)) { return resultMap.get(entry.getKey()); @@ -121,7 +138,8 @@ private Expression replace(Expression root, Map constMap, Ma return hasNewChildren ? root.withChildren(newChildren) : root; } - private void collectConst(Expression expr, Map constMap, Map tExprMap) { + private static void collectConst(Expression expr, Map constMap, + Map tExprMap, IdGenerator idGenerator) { if (expr.isConstant()) { // Do not constant fold cast(null as dataType) because we cannot preserve the // cast-to-types and that can lead to query failures, e.g., CTAS @@ -157,13 +175,13 @@ private void collectConst(Expression expr, Map constMap, Map } else { for (int i = 0; i < expr.children().size(); i++) { final Expression child = expr.children().get(i); - collectConst(child, constMap, tExprMap); + collectConst(child, constMap, tExprMap, idGenerator); } } } // if sleep(5) will cause rpc timeout - private boolean skipSleepFunction(Expression expr) { + private static boolean skipSleepFunction(Expression expr) { if (expr instanceof Sleep) { Expression param = expr.child(0); if (param instanceof Cast) { @@ -176,7 +194,7 @@ private boolean skipSleepFunction(Expression expr) { return false; } - private Map evalOnBE(Map> paramMap, + private static Map evalOnBE(Map> paramMap, Map constMap, ConnectContext context) { Map resultMap = new HashMap<>(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index 05165f6c312c56..cf3d1a88d8cff1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -22,7 +22,13 @@ import org.apache.doris.cluster.ClusterNamespace; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionListenerMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListener; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListenerFactory; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; @@ -80,6 +86,8 @@ import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import org.apache.commons.codec.digest.DigestUtils; @@ -87,13 +95,78 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Predicate; /** * evaluate an expression on fe. */ -public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { +public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule + implements ExpressionPatternRuleFactory, ExpressionTraverseListenerFactory { - public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE(); + public static final FoldConstantRuleOnFE VISITOR_INSTANCE = new FoldConstantRuleOnFE(true); + public static final FoldConstantRuleOnFE PATTERN_MATCH_INSTANCE = new FoldConstantRuleOnFE(false); + + // record whether current expression is in an aggregate function with distinct, + // if is, we will skip to fold constant + private static final ListenAggDistinct LISTEN_AGG_DISTINCT = new ListenAggDistinct(); + private static final CheckWhetherUnderAggDistinct NOT_UNDER_AGG_DISTINCT = new CheckWhetherUnderAggDistinct(); + + private final boolean deepRewrite; + + public FoldConstantRuleOnFE(boolean deepRewrite) { + this.deepRewrite = deepRewrite; + } + + public static Expression evaluate(Expression expression, ExpressionRewriteContext expressionRewriteContext) { + return VISITOR_INSTANCE.rewrite(expression, expressionRewriteContext); + } + + @Override + public List> buildListeners() { + return ImmutableList.of( + listenerType(AggregateFunction.class) + .when(AggregateFunction::isDistinct) + .then(LISTEN_AGG_DISTINCT.as()), + + listenerType(AggregateExpression.class) + .when(AggregateExpression::isDistinct) + .then(LISTEN_AGG_DISTINCT.as()) + ); + } + + @Override + public List> buildRules() { + return ImmutableList.of( + matches(EncryptKeyRef.class, this::visitEncryptKeyRef), + matches(EqualTo.class, this::visitEqualTo), + matches(GreaterThan.class, this::visitGreaterThan), + matches(GreaterThanEqual.class, this::visitGreaterThanEqual), + matches(LessThan.class, this::visitLessThan), + matches(LessThanEqual.class, this::visitLessThanEqual), + matches(NullSafeEqual.class, this::visitNullSafeEqual), + matches(Not.class, this::visitNot), + matches(Database.class, this::visitDatabase), + matches(CurrentUser.class, this::visitCurrentUser), + matches(CurrentCatalog.class, this::visitCurrentCatalog), + matches(User.class, this::visitUser), + matches(ConnectionId.class, this::visitConnectionId), + matches(And.class, this::visitAnd), + matches(Or.class, this::visitOr), + matches(Cast.class, this::visitCast), + matches(BoundFunction.class, this::visitBoundFunction), + matches(BinaryArithmetic.class, this::visitBinaryArithmetic), + matches(CaseWhen.class, this::visitCaseWhen), + matches(If.class, this::visitIf), + matches(InPredicate.class, this::visitInPredicate), + matches(IsNull.class, this::visitIsNull), + matches(TimestampArithmetic.class, this::visitTimestampArithmetic), + matches(Password.class, this::visitPassword), + matches(Array.class, this::visitArray), + matches(Date.class, this::visitDate), + matches(Version.class, this::visitVersion) + ); + } @Override public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { @@ -253,7 +326,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) { List nonTrueLiteral = Lists.newArrayList(); int nullCount = 0; for (Expression e : and.children()) { - e = e.accept(this, context); + e = deepRewrite ? e.accept(this, context) : e; if (BooleanLiteral.FALSE.equals(e)) { return BooleanLiteral.FALSE; } else if (e instanceof NullLiteral) { @@ -294,7 +367,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { List nonFalseLiteral = Lists.newArrayList(); int nullCount = 0; for (Expression e : or.children()) { - e = e.accept(this, context); + e = deepRewrite ? e.accept(this, context) : e; if (BooleanLiteral.TRUE.equals(e)) { return BooleanLiteral.TRUE; } else if (e instanceof NullLiteral) { @@ -412,9 +485,13 @@ public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext cont } } - Expression defaultResult = caseWhen.getDefaultValue().isPresent() - ? rewrite(caseWhen.getDefaultValue().get(), context) - : null; + Expression defaultResult = null; + if (caseWhen.getDefaultValue().isPresent()) { + defaultResult = caseWhen.getDefaultValue().get(); + if (deepRewrite) { + defaultResult = rewrite(defaultResult, context); + } + } if (foundNewDefault) { defaultResult = newDefault; } @@ -537,28 +614,83 @@ public Expression visitVersion(Version version, ExpressionRewriteContext context return new StringLiteral(GlobalVariable.version); } - private E rewriteChildren(Expression expr, ExpressionRewriteContext ctx) { - return (E) super.visit(expr, ctx); - } - - private boolean allArgsIsAllLiteral(Expression expression) { - return ExpressionUtils.isAllLiteral(expression.getArguments()); - } - - private boolean argsHasNullLiteral(Expression expression) { - return ExpressionUtils.hasNullLiteral(expression.getArguments()); + private E rewriteChildren(E expr, ExpressionRewriteContext context) { + if (!deepRewrite) { + return expr; + } + switch (expr.arity()) { + case 1: { + Expression originChild = expr.child(0); + Expression newChild = originChild.accept(this, context); + return (originChild != newChild) ? (E) expr.withChildren(ImmutableList.of(newChild)) : expr; + } + case 2: { + Expression originLeft = expr.child(0); + Expression newLeft = originLeft.accept(this, context); + Expression originRight = expr.child(1); + Expression newRight = originRight.accept(this, context); + return (originLeft != newLeft || originRight != newRight) + ? (E) expr.withChildren(ImmutableList.of(newLeft, newRight)) + : expr; + } + case 0: { + return expr; + } + default: { + boolean hasNewChildren = false; + Builder newChildren = ImmutableList.builderWithExpectedSize(expr.arity()); + for (Expression child : expr.children()) { + Expression newChild = child.accept(this, context); + if (newChild != child) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + return hasNewChildren ? (E) expr.withChildren(newChildren.build()) : expr; + } + } } private Optional preProcess(Expression expression) { if (expression instanceof AggregateFunction || expression instanceof TableGeneratingFunction) { return Optional.of(expression); } - if (expression instanceof PropagateNullable && argsHasNullLiteral(expression)) { + if (expression instanceof PropagateNullable && ExpressionUtils.hasNullLiteral(expression.getArguments())) { return Optional.of(new NullLiteral(expression.getDataType())); } - if (!allArgsIsAllLiteral(expression)) { + if (!ExpressionUtils.isAllLiteral(expression.getArguments())) { return Optional.of(expression); } return Optional.empty(); } + + private static class ListenAggDistinct implements ExpressionTraverseListener { + @Override + public void onEnter(ExpressionMatchingContext context) { + context.cascadesContext.incrementDistinctAggLevel(); + } + + @Override + public void onExit(ExpressionMatchingContext context, Expression rewritten) { + context.cascadesContext.decrementDistinctAggLevel(); + } + } + + private static class CheckWhetherUnderAggDistinct implements Predicate> { + @Override + public boolean test(ExpressionMatchingContext context) { + return context.cascadesContext.getDistinctAggLevel() == 0; + } + + public Predicate> as() { + return (Predicate) this; + } + } + + private ExpressionPatternMatcher matches( + Class clazz, BiFunction visitMethod) { + return matchesType(clazz) + .whenCtx(NOT_UNDER_AGG_DISTINCT.as()) + .thenApply(ctx -> visitMethod.apply(ctx.expr, ctx.rewriteContext)); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java index 32f8e46da7553f..3760dcf0e72420 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java @@ -17,13 +17,14 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; -import java.util.ArrayList; -import java.util.HashSet; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + import java.util.List; import java.util.Set; @@ -31,25 +32,32 @@ * Deduplicate InPredicate, For example: * where A in (x, x) ==> where A in (x) */ -public class InPredicateDedup extends AbstractExpressionRewriteRule { - - public static InPredicateDedup INSTANCE = new InPredicateDedup(); +public class InPredicateDedup implements ExpressionPatternRuleFactory { + public static final InPredicateDedup INSTANCE = new InPredicateDedup(); @Override - public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class).then(InPredicateDedup::dedup) + ); + } + + /** dedup */ + public static Expression dedup(InPredicate inPredicate) { // In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options. // It takes a long time to apply this rule. So set a threshold for the max number. - if (inPredicate.getOptions().size() > 200) { + int optionSize = inPredicate.getOptions().size(); + if (optionSize > 200) { return inPredicate; } - Set dedupExpr = new HashSet<>(); - List newOptions = new ArrayList<>(); + ImmutableSet.Builder newOptionsBuilder = ImmutableSet.builderWithExpectedSize(inPredicate.arity()); for (Expression option : inPredicate.getOptions()) { - if (dedupExpr.contains(option)) { - continue; - } - dedupExpr.add(option); - newOptions.add(option); + newOptionsBuilder.add(option); + } + + Set newOptions = newOptionsBuilder.build(); + if (newOptions.size() == optionSize) { + return inPredicate; } return new InPredicate(inPredicate.getCompareExpr(), newOptions); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java index b076cadd53358d..353de7f41f62a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java @@ -17,12 +17,14 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; +import com.google.common.collect.ImmutableList; + import java.util.List; /** @@ -36,17 +38,16 @@ * NOTICE: it's related with `SimplifyRange`. * They are same processes, so must change synchronously. */ -public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule { - - public static InPredicateToEqualToRule INSTANCE = new InPredicateToEqualToRule(); +public class InPredicateToEqualToRule implements ExpressionPatternRuleFactory { + public static final InPredicateToEqualToRule INSTANCE = new InPredicateToEqualToRule(); @Override - public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { - Expression left = inPredicate.getCompareExpr(); - List right = inPredicate.getOptions(); - if (right.size() != 1) { - return new InPredicate(left.accept(this, context), right); - } - return new EqualTo(left.accept(this, context), right.get(0).accept(this, context)); + public List> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class) + .when(in -> in.getOptions().size() == 1) + .then(in -> new EqualTo(in.getCompareExpr(), in.getOptions().get(0)) + ) + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java index 9b1c88b930ba22..e73104793cd916 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java @@ -17,22 +17,31 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Normalizes binary predicates of the form 'expr' op 'slot' so that the slot is on the left-hand side. * For example: * 5 > id -> id < 5 */ -public class NormalizeBinaryPredicatesRule extends AbstractExpressionRewriteRule { - - public static NormalizeBinaryPredicatesRule INSTANCE = new NormalizeBinaryPredicatesRule(); +public class NormalizeBinaryPredicatesRule implements ExpressionPatternRuleFactory { + public static final NormalizeBinaryPredicatesRule INSTANCE = new NormalizeBinaryPredicatesRule(); @Override - public Expression visitComparisonPredicate(ComparisonPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class).then(NormalizeBinaryPredicatesRule::normalize) + ); + } + + public static Expression normalize(ComparisonPredicate expr) { return expr.left().isConstant() && !expr.right().isConstant() ? expr.commute() : expr; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java index 6507f49825c7c5..e8eedb1e1980ff 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java @@ -17,31 +17,34 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.collect.ImmutableList; + +import java.util.List; /** * convert "<=>" to "=", if any side is not nullable * convert "A <=> null" to "A is null" */ -public class NullSafeEqualToEqual extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class NullSafeEqualToEqual implements ExpressionPatternRuleFactory { public static final NullSafeEqualToEqual INSTANCE = new NullSafeEqualToEqual(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, null); + public List> buildRules() { + return ImmutableList.of( + matchesType(NullSafeEqual.class).then(NullSafeEqualToEqual::rewrite) + ); } - @Override - public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext ctx) { + private static Expression rewrite(NullSafeEqual nullSafeEqual) { if (nullSafeEqual.left() instanceof NullLiteral) { if (nullSafeEqual.right().nullable()) { return new IsNull(nullSafeEqual.right()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java index dd71ed8e99ff76..b9bdf520e3d6d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java @@ -82,7 +82,7 @@ public Expression visit(Expression expr, Map context) expr = super.visit(expr, context); if (!(expr instanceof Literal)) { // just forward to fold constant rule - return expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + return FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext); } return expr; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java index 4e2ba5be909f1d..2c0f8c13939aa1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java @@ -92,7 +92,7 @@ public class OneRangePartitionEvaluator /** OneRangePartitionEvaluator */ public OneRangePartitionEvaluator(long partitionId, List partitionSlots, - RangePartitionItem partitionItem, CascadesContext cascadesContext) { + RangePartitionItem partitionItem, CascadesContext cascadesContext, int expandThreshold) { this.partitionId = partitionId; this.partitionSlots = Objects.requireNonNull(partitionSlots, "partitionSlots cannot be null"); this.partitionItem = Objects.requireNonNull(partitionItem, "partitionItem cannot be null"); @@ -103,41 +103,46 @@ public OneRangePartitionEvaluator(long partitionId, List partitionSlots, this.lowers = toNereidsLiterals(range.lowerEndpoint()); this.uppers = toNereidsLiterals(range.upperEndpoint()); - PartitionRangeExpander expander = new PartitionRangeExpander(); - this.partitionSlotTypes = expander.computePartitionSlotTypes(lowers, uppers); - this.slotToType = Maps.newHashMapWithExpectedSize(16); - for (int i = 0; i < partitionSlots.size(); i++) { - slotToType.put(partitionSlots.get(i), partitionSlotTypes.get(i)); - } + this.partitionSlotTypes = PartitionRangeExpander.computePartitionSlotTypes(lowers, uppers); - this.partitionSlotContainsNull = Maps.newHashMapWithExpectedSize(16); - for (int i = 0; i < partitionSlots.size(); i++) { - Slot slot = partitionSlots.get(i); - if (!slot.nullable()) { - partitionSlotContainsNull.put(slot, false); - continue; + if (partitionSlots.size() == 1) { + // fast path + Slot partSlot = partitionSlots.get(0); + this.slotToType = ImmutableMap.of(partSlot, partitionSlotTypes.get(0)); + this.partitionSlotContainsNull + = ImmutableMap.of(partSlot, range.lowerEndpoint().getKeys().get(0).isMinValue()); + } else { + // slow path + this.slotToType = Maps.newHashMap(); + for (int i = 0; i < partitionSlots.size(); i++) { + slotToType.put(partitionSlots.get(i), partitionSlotTypes.get(i)); } - PartitionSlotType partitionSlotType = partitionSlotTypes.get(i); - boolean maybeNull = false; - switch (partitionSlotType) { - case CONST: - case RANGE: - maybeNull = range.lowerEndpoint().getKeys().get(i).isMinValue(); - break; - case OTHER: - maybeNull = true; - break; - default: - throw new AnalysisException("Unknown partition slot type: " + partitionSlotType); + + this.partitionSlotContainsNull = Maps.newHashMap(); + for (int i = 0; i < partitionSlots.size(); i++) { + Slot slot = partitionSlots.get(i); + if (!slot.nullable()) { + partitionSlotContainsNull.put(slot, false); + continue; + } + PartitionSlotType partitionSlotType = partitionSlotTypes.get(i); + boolean maybeNull; + switch (partitionSlotType) { + case CONST: + case RANGE: + maybeNull = range.lowerEndpoint().getKeys().get(i).isMinValue(); + break; + case OTHER: + maybeNull = true; + break; + default: + throw new AnalysisException("Unknown partition slot type: " + partitionSlotType); + } + partitionSlotContainsNull.put(slot, maybeNull); } - partitionSlotContainsNull.put(slot, maybeNull); } - int expandThreshold = cascadesContext.getAndCacheSessionVariable( - "partitionPruningExpandThreshold", - 10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold); - - List> expandInputs = expander.tryExpandRange( + List> expandInputs = PartitionRangeExpander.tryExpandRange( partitionSlots, lowers, uppers, partitionSlotTypes, expandThreshold); // after expand range, we will get 2 dimension list like list: // part_col1: [1], part_col2:[4, 5, 6], we should combine it to @@ -451,10 +456,13 @@ public EvaluateRangeResult visitNot(Not not, EvaluateRangeInput context) { private EvaluateRangeResult evaluateChildrenThenThis(Expression expr, EvaluateRangeInput context) { // evaluate children - List newChildren = new ArrayList<>(); - List childrenResults = new ArrayList<>(); + List children = expr.children(); + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(children.size()); + List childrenResults = new ArrayList<>(children.size()); boolean hasNewChildren = false; - for (Expression child : expr.children()) { + + for (int i = 0; i < children.size(); i++) { + Expression child = children.get(i); EvaluateRangeResult childResult = child.accept(this, context); if (childResult.result != child) { hasNewChildren = true; @@ -463,11 +471,11 @@ private EvaluateRangeResult evaluateChildrenThenThis(Expression expr, EvaluateRa newChildren.add(childResult.result); } if (hasNewChildren) { - expr = expr.withChildren(newChildren); + expr = expr.withChildren(newChildren.build()); } // evaluate this - expr = expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + expr = FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext); return new EvaluateRangeResult(expr, context.defaultColumnRanges, childrenResults); } @@ -575,9 +583,28 @@ private EvaluateRangeResult mergeRanges( } private List toNereidsLiterals(PartitionKey partitionKey) { - List literals = Lists.newArrayListWithCapacity(partitionKey.getKeys().size()); - for (int i = 0; i < partitionKey.getKeys().size(); i++) { - LiteralExpr literalExpr = partitionKey.getKeys().get(i); + if (partitionKey.getKeys().size() == 1) { + // fast path + return toSingleNereidsLiteral(partitionKey); + } + + // slow path + return toMultiNereidsLiterals(partitionKey); + } + + private List toSingleNereidsLiteral(PartitionKey partitionKey) { + List keys = partitionKey.getKeys(); + LiteralExpr literalExpr = keys.get(0); + PrimitiveType primitiveType = partitionKey.getTypes().get(0); + Type type = Type.fromPrimitiveType(primitiveType); + return ImmutableList.of(Literal.fromLegacyLiteral(literalExpr, type)); + } + + private List toMultiNereidsLiterals(PartitionKey partitionKey) { + List keys = partitionKey.getKeys(); + List literals = Lists.newArrayListWithCapacity(keys.size()); + for (int i = 0; i < keys.size(); i++) { + LiteralExpr literalExpr = keys.get(i); PrimitiveType primitiveType = partitionKey.getTypes().get(i); Type type = Type.fromPrimitiveType(primitiveType); literals.add(Literal.fromLegacyLiteral(literalExpr, type)); @@ -613,8 +640,8 @@ public EvaluateRangeResult visitDate(Date date, EvaluateRangeInput context) { Literal lower = span.lowerEndpoint().getValue(); Literal upper = span.upperEndpoint().getValue(); - Expression lowerDate = new Date(lower).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); - Expression upperDate = new Date(upper).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + Expression lowerDate = FoldConstantRuleOnFE.evaluate(new Date(lower), expressionRewriteContext); + Expression upperDate = FoldConstantRuleOnFE.evaluate(new Date(upper), expressionRewriteContext); if (lowerDate instanceof Literal && upperDate instanceof Literal && lowerDate.equals(upperDate)) { return new EvaluateRangeResult(lowerDate, result.columnRanges, result.childrenResult); @@ -696,7 +723,7 @@ public EvaluateRangeResult(Expression result, Map columnRange public EvaluateRangeResult(Expression result, Map columnRanges, List childrenResult) { - this(result, columnRanges, childrenResult, childrenResult.stream().allMatch(r -> r.isRejectNot())); + this(result, columnRanges, childrenResult, allIsRejectNot(childrenResult)); } public EvaluateRangeResult withRejectNot(boolean rejectNot) { @@ -706,6 +733,15 @@ public EvaluateRangeResult withRejectNot(boolean rejectNot) { public boolean isRejectNot() { return rejectNot; } + + private static boolean allIsRejectNot(List childrenResult) { + for (EvaluateRangeResult evaluateRangeResult : childrenResult) { + if (!evaluateRangeResult.isRejectNot()) { + return false; + } + } + return true; + } } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java index b085f70da6e9c9..83da8055037242 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java @@ -17,15 +17,17 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; @@ -54,20 +56,25 @@ * adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression * rewrite framework always passes an ExpressionRewriteContext of type context to all rules. */ -public class OrToIn extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class OrToIn implements ExpressionPatternRuleFactory { public static final OrToIn INSTANCE = new OrToIn(); public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2; @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, null); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(Or.class).then(OrToIn::rewrite) + ); } - @Override - public Expression visitOr(Or or, ExpressionRewriteContext ctx) { + public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) { + ExpressionBottomUpRewriter bottomUpRewriter = ExpressionRewrite.bottomUp(this); + return bottomUpRewriter.rewrite(expr, context); + } + + private static Expression rewrite(Or or) { // NOTICE: use linked hash map to avoid unstable order or entry. // unstable order entry lead to dead loop since return expression always un-equals to original one. Map> slotNameToLiteral = Maps.newLinkedHashMap(); @@ -80,6 +87,10 @@ public Expression visitOr(Or or, ExpressionRewriteContext ctx) { handleInPredicate((InPredicate) expression, slotNameToLiteral, disConjunctToSlot); } } + if (disConjunctToSlot.isEmpty()) { + return or; + } + List rewrittenOr = new ArrayList<>(); for (Map.Entry> entry : slotNameToLiteral.entrySet()) { Set literals = entry.getValue(); @@ -90,7 +101,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext ctx) { } for (Expression expression : expressions) { if (disConjunctToSlot.get(expression) == null) { - rewrittenOr.add(expression.accept(this, null)); + rewrittenOr.add(expression); } else { Set literals = slotNameToLiteral.get(disConjunctToSlot.get(expression)); if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { @@ -102,7 +113,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext ctx) { return ExpressionUtils.or(rewrittenOr); } - private void handleEqualTo(EqualTo equal, Map> slotNameToLiteral, + private static void handleEqualTo(EqualTo equal, Map> slotNameToLiteral, Map disConjunctToSlot) { Expression left = equal.left(); Expression right = equal.right(); @@ -115,7 +126,7 @@ private void handleEqualTo(EqualTo equal, Map> slo } } - private void handleInPredicate(InPredicate inPredicate, Map> slotNameToLiteral, + private static void handleInPredicate(InPredicate inPredicate, Map> slotNameToLiteral, Map disConjunctToSlot) { // TODO a+b in (1,2,3...) is not supported now if (inPredicate.getCompareExpr() instanceof NamedExpression @@ -127,10 +138,9 @@ private void handleInPredicate(InPredicate inPredicate, Map> slotNameToLiteral) { Set literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>()); literals.add(literal); } - } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java index b84407778729e6..04cef999fda27c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.PartitionItem; import org.apache.doris.catalog.RangePartitionItem; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -81,14 +82,19 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) && ((Cast) right).child().getDataType().isDateType()) { DateTimeLiteral dt = (DateTimeLiteral) left; Cast cast = (Cast) right; - return cp.withChildren(new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()), cast.child()); + return cp.withChildren( + ImmutableList.of(new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()), cast.child()) + ); } else if (right instanceof DateTimeLiteral && ((DateTimeLiteral) right).isMidnight() && left instanceof Cast && ((Cast) left).child() instanceof SlotReference && ((Cast) left).child().getDataType().isDateType()) { DateTimeLiteral dt = (DateTimeLiteral) right; Cast cast = (Cast) left; - return cp.withChildren(cast.child(), new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay())); + return cp.withChildren(ImmutableList.of( + cast.child(), + new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay())) + ); } else { return cp; } @@ -115,13 +121,18 @@ public static List prune(List partitionSlots, Expression partitionPr partitionPredicate, ImmutableSet.copyOf(partitionSlots), cascadesContext); partitionPredicate = PredicateRewriteForPartitionPrune.rewrite(partitionPredicate, cascadesContext); + int expandThreshold = cascadesContext.getAndCacheSessionVariable( + "partitionPruningExpandThreshold", + 10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold); + List evaluators = Lists.newArrayListWithCapacity(idToPartitions.size()); for (Entry kv : idToPartitions.entrySet()) { evaluators.add(toPartitionEvaluator( - kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, partitionTableType)); + kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, expandThreshold)); } - partitionPredicate = OrToIn.INSTANCE.rewrite(partitionPredicate, null); + partitionPredicate = OrToIn.INSTANCE.rewriteTree( + partitionPredicate, new ExpressionRewriteContext(cascadesContext)); PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate); //TODO: we keep default partition because it's too hard to prune it, we return false in canPrune(). return partitionPruner.prune(); @@ -131,13 +142,13 @@ public static List prune(List partitionSlots, Expression partitionPr * convert partition item to partition evaluator */ public static final OnePartitionEvaluator toPartitionEvaluator(long id, PartitionItem partitionItem, - List partitionSlots, CascadesContext cascadesContext, PartitionTableType partitionTableType) { + List partitionSlots, CascadesContext cascadesContext, int expandThreshold) { if (partitionItem instanceof ListPartitionItem) { return new OneListPartitionEvaluator( id, partitionSlots, (ListPartitionItem) partitionItem, cascadesContext); } else if (partitionItem instanceof RangePartitionItem) { return new OneRangePartitionEvaluator( - id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext); + id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext, expandThreshold); } else { return new UnknownPartitionEvaluator(id, partitionItem); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java index 071ab8f11572c6..01a674488da50a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java @@ -41,7 +41,6 @@ import java.time.ZoneOffset; import java.util.Iterator; import java.util.List; -import java.util.NoSuchElementException; import java.util.function.Function; /** @@ -74,10 +73,44 @@ public enum PartitionSlotType { } /** expandRangeLiterals */ - public final List> tryExpandRange( + public static final List> tryExpandRange( List partitionSlots, List lowers, List uppers, List partitionSlotTypes, int expandThreshold) { + if (partitionSlots.size() == 1) { + return tryExpandSingleColumnRange(partitionSlots.get(0), lowers.get(0), + uppers.get(0), expandThreshold); + } else { + // slow path + return commonTryExpandRange(partitionSlots, lowers, uppers, partitionSlotTypes, expandThreshold); + } + } + + private static List> tryExpandSingleColumnRange(Slot partitionSlot, Literal lower, + Literal upper, int expandThreshold) { + // must be range slot + try { + if (canExpandRange(partitionSlot, lower, upper, 1, expandThreshold)) { + Iterator iterator = enumerableIterator( + partitionSlot, lower, upper, true); + if (iterator instanceof SingletonIterator) { + return ImmutableList.of(ImmutableList.of(iterator.next())); + } else { + return ImmutableList.of( + ImmutableList.copyOf(iterator) + ); + } + } else { + return ImmutableList.of(ImmutableList.of(partitionSlot)); + } + } catch (Throwable t) { + // catch for safety, should not invoke here + return ImmutableList.of(ImmutableList.of(partitionSlot)); + } + } + private static List> commonTryExpandRange( + List partitionSlots, List lowers, List uppers, + List partitionSlotTypes, int expandThreshold) { long expandedCount = 1; List> expandedLists = Lists.newArrayListWithCapacity(lowers.size()); for (int i = 0; i < partitionSlotTypes.size(); i++) { @@ -126,7 +159,7 @@ public final List> tryExpandRange( return expandedLists; } - private boolean canExpandRange(Slot slot, Literal lower, Literal upper, + private static boolean canExpandRange(Slot slot, Literal lower, Literal upper, long expandedCount, int expandThreshold) { DataType type = slot.getDataType(); if (!type.isIntegerLikeType() && !type.isDateType() && !type.isDateV2Type()) { @@ -139,7 +172,7 @@ private boolean canExpandRange(Slot slot, Literal lower, Literal upper, } // too much expanded will consuming resources of frontend, // e.g. [1, 100000000), we should skip expand it - return (expandedCount * count) <= expandThreshold; + return count == 1 || (expandedCount * count) <= expandThreshold; } catch (Throwable t) { // e.g. max_value can not expand return false; @@ -147,7 +180,7 @@ private boolean canExpandRange(Slot slot, Literal lower, Literal upper, } /** the types will like this: [CONST, CONST, ..., RANGE, OTHER, OTHER, ...] */ - public List computePartitionSlotTypes(List lowers, List uppers) { + public static List computePartitionSlotTypes(List lowers, List uppers) { PartitionSlotType previousType = PartitionSlotType.CONST; List types = Lists.newArrayListWithCapacity(lowers.size()); for (int i = 0; i < lowers.size(); ++i) { @@ -167,7 +200,7 @@ public List computePartitionSlotTypes(List lowers, L return types; } - private long enumerableCount(DataType dataType, Literal startInclusive, Literal endExclusive) { + private static long enumerableCount(DataType dataType, Literal startInclusive, Literal endExclusive) { if (dataType.isIntegerLikeType()) { BigInteger start = new BigInteger(startInclusive.getStringValue()); BigInteger end = new BigInteger(endExclusive.getStringValue()); @@ -175,6 +208,12 @@ private long enumerableCount(DataType dataType, Literal startInclusive, Literal } else if (dataType.isDateType()) { DateLiteral startInclusiveDate = (DateLiteral) startInclusive; DateLiteral endExclusiveDate = (DateLiteral) endExclusive; + + if (startInclusiveDate.getYear() == endExclusiveDate.getYear() + && startInclusiveDate.getMonth() == endExclusiveDate.getMonth()) { + return endExclusiveDate.getDay() - startInclusiveDate.getDay(); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -192,6 +231,12 @@ private long enumerableCount(DataType dataType, Literal startInclusive, Literal } else if (dataType.isDateV2Type()) { DateV2Literal startInclusiveDate = (DateV2Literal) startInclusive; DateV2Literal endExclusiveDate = (DateV2Literal) endExclusive; + + if (startInclusiveDate.getYear() == endExclusiveDate.getYear() + && startInclusiveDate.getMonth() == endExclusiveDate.getMonth()) { + return endExclusiveDate.getDay() - startInclusiveDate.getDay(); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -212,7 +257,7 @@ private long enumerableCount(DataType dataType, Literal startInclusive, Literal return -1; } - private Iterator enumerableIterator( + private static Iterator enumerableIterator( Slot slot, Literal startInclusive, Literal endLiteral, boolean endExclusive) { DataType dataType = slot.getDataType(); if (dataType.isIntegerLikeType()) { @@ -237,6 +282,12 @@ private Iterator enumerableIterator( } else if (dataType.isDateType()) { DateLiteral startInclusiveDate = (DateLiteral) startInclusive; DateLiteral endLiteralDate = (DateLiteral) endLiteral; + if (endExclusive && startInclusiveDate.getYear() == endLiteralDate.getYear() + && startInclusiveDate.getMonth() == endLiteralDate.getMonth() + && startInclusiveDate.getDay() + 1 == endLiteralDate.getDay()) { + return new SingletonIterator(startInclusive); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -258,6 +309,13 @@ private Iterator enumerableIterator( } else if (dataType.isDateV2Type()) { DateV2Literal startInclusiveDate = (DateV2Literal) startInclusive; DateV2Literal endLiteralDate = (DateV2Literal) endLiteral; + + if (endExclusive && startInclusiveDate.getYear() == endLiteralDate.getYear() + && startInclusiveDate.getMonth() == endLiteralDate.getMonth() + && startInclusiveDate.getDay() + 1 == endLiteralDate.getDay()) { + return new SingletonIterator(startInclusive); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -282,7 +340,7 @@ private Iterator enumerableIterator( return Iterators.singletonIterator(slot); } - private class IntegerLikeRangePartitionValueIterator + private static class IntegerLikeRangePartitionValueIterator extends RangePartitionValueIterator { public IntegerLikeRangePartitionValueIterator(BigInteger startInclusive, BigInteger end, @@ -296,7 +354,7 @@ protected BigInteger doGetNext(BigInteger current) { } } - private class DateLikeRangePartitionValueIterator + private static class DateLikeRangePartitionValueIterator extends RangePartitionValueIterator { public DateLikeRangePartitionValueIterator( @@ -309,43 +367,4 @@ protected LocalDate doGetNext(LocalDate current) { return current.plusDays(1); } } - - private abstract class RangePartitionValueIterator - implements Iterator { - private final C startInclusive; - private final C end; - private final boolean endExclusive; - private C current; - - private final Function toLiteral; - - public RangePartitionValueIterator(C startInclusive, C end, boolean endExclusive, Function toLiteral) { - this.startInclusive = startInclusive; - this.end = end; - this.endExclusive = endExclusive; - this.current = this.startInclusive; - this.toLiteral = toLiteral; - } - - @Override - public boolean hasNext() { - if (endExclusive) { - return current.compareTo(end) < 0; - } else { - return current.compareTo(end) <= 0; - } - } - - @Override - public L next() { - if (hasNext()) { - C value = current; - current = doGetNext(current); - return toLiteral.apply(value); - } - throw new NoSuchElementException(); - } - - protected abstract C doGetNext(C current); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java index c227c89b939188..87646fbd582d3c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java @@ -70,7 +70,7 @@ public Expression visitInPredicate(InPredicate in, CascadesContext context) { } } if (convertable) { - Expression or = ExpressionUtils.combine(Or.class, splitIn); + Expression or = ExpressionUtils.combineAsLeftDeepTree(Or.class, splitIn); return or; } } else if (dateChild.getDataType() instanceof DateTimeV2Type) { @@ -87,7 +87,7 @@ public Expression visitInPredicate(InPredicate in, CascadesContext context) { } } if (convertable) { - Expression or = ExpressionUtils.combine(Or.class, splitIn); + Expression or = ExpressionUtils.combineAsLeftDeepTree(Or.class, splitIn); return or; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java new file mode 100644 index 00000000000000..79ee33d1ebb815 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.trees.expressions.literal.Literal; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.function.Function; + +/** RangePartitionValueIterator */ +public abstract class RangePartitionValueIterator + implements Iterator { + private final C startInclusive; + private final C end; + private final boolean endExclusive; + private C current; + + private final Function toLiteral; + + public RangePartitionValueIterator(C startInclusive, C end, boolean endExclusive, Function toLiteral) { + this.startInclusive = startInclusive; + this.end = end; + this.endExclusive = endExclusive; + this.current = this.startInclusive; + this.toLiteral = toLiteral; + } + + @Override + public boolean hasNext() { + if (endExclusive) { + return current.compareTo(end) < 0; + } else { + return current.compareTo(end) <= 0; + } + } + + @Override + public L next() { + if (hasNext()) { + C value = current; + current = doGetNext(current); + return toLiteral.apply(value); + } + throw new NoSuchElementException(); + } + + protected abstract C doGetNext(C current); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java index 3fd5330395e7fc..b4c5552706c589 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java @@ -17,20 +17,25 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Variable; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * replace varaible to real expression */ -public class ReplaceVariableByLiteral extends AbstractExpressionRewriteRule { - +public class ReplaceVariableByLiteral implements ExpressionPatternRuleFactory { public static ReplaceVariableByLiteral INSTANCE = new ReplaceVariableByLiteral(); @Override - public Expression visitVariable(Variable variable, ExpressionRewriteContext context) { - return variable.getRealExpression(); + public List> buildRules() { + return ImmutableList.of( + matchesType(Variable.class).then(Variable::getRealExpression) + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java index 7606d082479227..6d18bc7b3807a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java @@ -17,7 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; @@ -43,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.TypeCoercionUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; @@ -55,11 +57,11 @@ * a + 1 > 1 => a > 0 * a / -2 > 1 => a < -2 */ -public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteRule { - public static final SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule(); +public class SimplifyArithmeticComparisonRule implements ExpressionPatternRuleFactory { + public static SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule(); // don't rearrange multiplication because divide may loss precision - final Map, Class> rearrangementMap = ImmutableMap + private static final Map, Class> REARRANGEMENT_MAP = ImmutableMap ., Class>builder() .put(Add.class, Subtract.class) .put(Subtract.class, Add.class) @@ -81,41 +83,54 @@ public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteR .build(); @Override - public Expression visitComparisonPredicate(ComparisonPredicate comparison, ExpressionRewriteContext context) { - if (couldRearrange(comparison)) { - ComparisonPredicate newComparison = normalize(comparison); - if (newComparison == null) { - return comparison; - } - try { - List children = - tryRearrangeChildren(newComparison.left(), newComparison.right(), context); - newComparison = (ComparisonPredicate) visitComparisonPredicate( - (ComparisonPredicate) newComparison.withChildren(children), context); - } catch (Exception e) { - return comparison; - } + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class) + .thenApply(ctx -> simplify(ctx.expr, new ExpressionRewriteContext(ctx.cascadesContext))) + ); + } + + /** simplify */ + public static Expression simplify(ComparisonPredicate comparison, ExpressionRewriteContext context) { + if (!couldRearrange(comparison)) { + return comparison; + } + ComparisonPredicate newComparison = normalize(comparison); + if (newComparison == null) { + return comparison; + } + try { + List children = tryRearrangeChildren(newComparison.left(), newComparison.right(), context); + newComparison = (ComparisonPredicate) simplify( + (ComparisonPredicate) newComparison.withChildren(children), context); return TypeCoercionUtils.processComparisonPredicate(newComparison); - } else { + } catch (Exception e) { return comparison; } } - private boolean couldRearrange(ComparisonPredicate cmp) { - return rearrangementMap.containsKey(cmp.left().getClass()) - && !cmp.left().isConstant() - && cmp.left().children().stream().anyMatch(Expression::isConstant); + private static boolean couldRearrange(ComparisonPredicate cmp) { + if (!REARRANGEMENT_MAP.containsKey(cmp.left().getClass()) || cmp.left().isConstant()) { + return false; + } + + for (Expression child : cmp.left().children()) { + if (child.isConstant()) { + return true; + } + } + return false; } - private List tryRearrangeChildren(Expression left, Expression right, + private static List tryRearrangeChildren(Expression left, Expression right, ExpressionRewriteContext context) throws Exception { if (!left.child(1).isConstant()) { throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left)); } - Literal leftLiteral = (Literal) FoldConstantRule.INSTANCE.rewrite(left.child(1), context); + Literal leftLiteral = (Literal) FoldConstantRule.evaluate(left.child(1), context); Expression leftExpr = left.child(0); - Class oppositeOperator = rearrangementMap.get(left.getClass()); + Class oppositeOperator = REARRANGEMENT_MAP.get(left.getClass()); Expression newChild = oppositeOperator.getConstructor(Expression.class, Expression.class) .newInstance(right, leftLiteral); @@ -127,25 +142,25 @@ private List tryRearrangeChildren(Expression left, Expression right, } // Ensure that the second child must be Literal, such as - private @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) { - if (!(comparison.left().child(1) instanceof Literal)) { - Expression left = comparison.left(); - if (comparison.left() instanceof Add) { - // 1 + a > 1 => a + 1 > 1 - Expression newLeft = left.withChildren(left.child(1), left.child(0)); - comparison = (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right()); - } else if (comparison.left() instanceof Subtract) { - // 1 - a > 1 => a + 1 < 1 - Expression newLeft = left.child(0); - Expression newRight = new Add(left.child(1), comparison.right()); - comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight); - comparison = comparison.commute(); - } else { - // Don't normalize division/multiplication because the slot sign is undecided. - return null; - } + private static @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) { + Expression left = comparison.left(); + Expression leftRight = left.child(1); + if (leftRight instanceof Literal) { + return comparison; + } + if (left instanceof Add) { + // 1 + a > 1 => a + 1 > 1 + Expression newLeft = left.withChildren(leftRight, left.child(0)); + return (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right()); + } else if (left instanceof Subtract) { + // 1 - a > 1 => a + 1 < 1 + Expression newLeft = left.child(0); + Expression newRight = new Add(leftRight, comparison.right()); + comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight); + return comparison.commute(); + } else { + // Don't normalize division/multiplication because the slot sign is undecided. + return null; } - return comparison; } - } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java index fc7431a9994d98..b9fd91f64387ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.Divide; @@ -27,7 +27,9 @@ import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.nereids.util.TypeUtils; +import org.apache.doris.nereids.util.Utils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; @@ -43,27 +45,24 @@ * * TODO: handle cases like: '1 - IA < 1' to 'IA > 0' */ -public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule { +public class SimplifyArithmeticRule implements ExpressionPatternRuleFactory { public static final SimplifyArithmeticRule INSTANCE = new SimplifyArithmeticRule(); @Override - public Expression visitAdd(Add add, ExpressionRewriteContext context) { - return process(add, true); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(BinaryArithmetic.class).then(SimplifyArithmeticRule::simplify) + ); } - @Override - public Expression visitSubtract(Subtract subtract, ExpressionRewriteContext context) { - return process(subtract, true); - } - - @Override - public Expression visitDivide(Divide divide, ExpressionRewriteContext context) { - return process(divide, false); - } - - @Override - public Expression visitMultiply(Multiply multiply, ExpressionRewriteContext context) { - return process(multiply, false); + /** simplify */ + public static Expression simplify(BinaryArithmetic binaryArithmetic) { + if (binaryArithmetic instanceof Add || binaryArithmetic instanceof Subtract) { + return process(binaryArithmetic, true); + } else if (binaryArithmetic instanceof Multiply || binaryArithmetic instanceof Divide) { + return process(binaryArithmetic, false); + } + return binaryArithmetic; } /** @@ -75,7 +74,7 @@ public Expression visitMultiply(Multiply multiply, ExpressionRewriteContext cont * 3.build new arithmetic expression. * (a + b - c + d) + (1 - 2 - 1) */ - private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { + private static Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { // 1. flatten the arithmetic expression. List flattedExpressions = flatten(arithmetic, isAddOrSub); @@ -83,22 +82,24 @@ private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { List constants = Lists.newArrayList(); // TODO currently we don't process decimal for simplicity. - if (flattedExpressions.stream().anyMatch(operand -> operand.expression.getDataType().isDecimalLikeType())) { - return arithmetic; + for (Operand operand : flattedExpressions) { + if (operand.expression.getDataType().isDecimalLikeType()) { + return arithmetic; + } } // 2. move variables to left side and move constants to right sid. - flattedExpressions.forEach(operand -> { + for (Operand operand : flattedExpressions) { if (operand.expression.isConstant()) { constants.add(operand); } else { variables.add(operand); } - }); + } // 3. build new arithmetic expression. if (!constants.isEmpty()) { boolean isOpposite = !constants.get(0).flag; - Optional c = constants.stream().reduce((x, y) -> { + Optional c = Utils.fastReduce(constants, (x, y) -> { Expression expr; if (isOpposite && y.flag || !isOpposite && !y.flag) { expr = getSubOrDivide(isAddOrSub, x, y); @@ -115,10 +116,10 @@ private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { } } - Optional result = variables.stream().reduce((x, y) -> !y.flag + Optional result = Utils.fastReduce(variables, (x, y) -> !y.flag ? Operand.of(true, getSubOrDivide(isAddOrSub, x, y)) - : Operand.of(true, getAddOrMultiply(isAddOrSub, x, y))); - + : Operand.of(true, getAddOrMultiply(isAddOrSub, x, y)) + ); if (result.isPresent()) { return TypeCoercionUtils.castIfNotSameType(result.get().expression, arithmetic.getDataType()); } else { @@ -126,7 +127,7 @@ private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { } } - private List flatten(Expression expr, boolean isAddOrSub) { + private static List flatten(Expression expr, boolean isAddOrSub) { List result = Lists.newArrayList(); if (isAddOrSub) { flattenAddSubtract(true, expr, result); @@ -136,7 +137,7 @@ private List flatten(Expression expr, boolean isAddOrSub) { return result; } - private void flattenAddSubtract(boolean flag, Expression expr, List result) { + private static void flattenAddSubtract(boolean flag, Expression expr, List result) { if (TypeUtils.isAddOrSubtract(expr)) { BinaryArithmetic arithmetic = (BinaryArithmetic) expr; flattenAddSubtract(flag, arithmetic.left(), result); @@ -152,7 +153,7 @@ private void flattenAddSubtract(boolean flag, Expression expr, List res } } - private void flattenMultiplyDivide(boolean flag, Expression expr, List result) { + private static void flattenMultiplyDivide(boolean flag, Expression expr, List result) { if (TypeUtils.isMultiplyOrDivide(expr)) { BinaryArithmetic arithmetic = (BinaryArithmetic) expr; flattenMultiplyDivide(flag, arithmetic.left(), result); @@ -168,13 +169,13 @@ private void flattenMultiplyDivide(boolean flag, Expression expr, List } } - private Expression getSubOrDivide(boolean isAddOrSub, Operand x, Operand y) { - return isAddOrSub ? new Subtract(x.expression, y.expression) + private static Expression getSubOrDivide(boolean isSubOrDivide, Operand x, Operand y) { + return isSubOrDivide ? new Subtract(x.expression, y.expression) : new Divide(x.expression, y.expression); } - private Expression getAddOrMultiply(boolean isAddOrSub, Operand x, Operand y) { - return isAddOrSub ? new Add(x.expression, y.expression) + private static Expression getAddOrMultiply(boolean isAddOrMultiply, Operand x, Operand y) { + return isAddOrMultiply ? new Add(x.expression, y.expression) : new Multiply(x.expression, y.expression); } @@ -204,3 +205,4 @@ public String toString() { } } } + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java index 34143043a07022..ded0a2f558f8d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; @@ -37,7 +37,10 @@ import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; +import com.google.common.collect.ImmutableList; + import java.math.BigDecimal; +import java.util.List; /** * Rewrite rule of simplify CAST expression. @@ -46,17 +49,19 @@ * Merge cast like * - cast(cast(1 as bigint) as string) -> cast(1 as string). */ -public class SimplifyCastRule extends AbstractExpressionRewriteRule { - +public class SimplifyCastRule implements ExpressionPatternRuleFactory { public static SimplifyCastRule INSTANCE = new SimplifyCastRule(); @Override - public Expression visitCast(Cast origin, ExpressionRewriteContext context) { - return simplify(origin, context); + public List> buildRules() { + return ImmutableList.of( + matchesType(Cast.class).then(SimplifyCastRule::simplifyCast) + ); } - private Expression simplify(Cast cast, ExpressionRewriteContext context) { - Expression child = rewrite(cast.child(), context); + /** simplifyCast */ + public static Expression simplifyCast(Cast cast) { + Expression child = cast.child(); // remove redundant cast // CAST(value as type), value is type diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 03958f3d55f6f8..d26b5a53036897 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -18,6 +18,8 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; @@ -55,17 +57,18 @@ import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.List; /** * simplify comparison * such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral * cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type) */ -public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule { - +public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); enum AdjustType { @@ -74,10 +77,20 @@ enum AdjustType { NONE } + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class).then(SimplifyComparisonPredicate::simplify) + ); + } + @Override public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) { - cp = (ComparisonPredicate) visit(cp, context); + return simplify(cp); + } + /** simplify */ + public static Expression simplify(ComparisonPredicate cp) { if (cp.left() instanceof Literal && !(cp.right() instanceof Literal)) { cp = cp.commute(); } @@ -146,7 +159,7 @@ private static Expression processComparisonPredicateDateTimeV2Literal( return comparisonPredicate; } - private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) { + private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) { if (left instanceof Cast && right instanceof DateLiteral) { Cast cast = (Cast) left; if (cast.child().getDataType() instanceof DateTimeType) { @@ -196,7 +209,7 @@ private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expressio } } - private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate, + private static Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate, Expression left, Expression right) { if (left instanceof Cast && left.child(0).getDataType().isIntegerLikeType() && (right instanceof DoubleLiteral || right instanceof FloatLiteral)) { @@ -209,7 +222,7 @@ private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPr } } - private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate, + private static Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate, Expression left, Expression right) { if (left instanceof Cast && right instanceof DecimalV3Literal) { Cast cast = (Cast) left; @@ -264,7 +277,7 @@ private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPr return comparisonPredicate; } - private Expression processIntegerDecimalLiteralComparison( + private static Expression processIntegerDecimalLiteralComparison( ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { @@ -306,7 +319,7 @@ private Expression processIntegerDecimalLiteralComparison( return comparisonPredicate; } - private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { + private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { Preconditions.checkArgument( decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, "decimal literal must have 0 scale and smaller than Long.MAX_VALUE"); @@ -322,15 +335,15 @@ private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal } } - private Expression migrateToDateTime(DateTimeV2Literal l) { + private static Expression migrateToDateTime(DateTimeV2Literal l) { return new DateTimeLiteral(l.getYear(), l.getMonth(), l.getDay(), l.getHour(), l.getMinute(), l.getSecond()); } - private boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) { + private static boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) { return cp instanceof EqualTo && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0); } - private Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) { + private static Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) { DateV2Literal d = new DateV2Literal(l.getYear(), l.getMonth(), l.getDay()); if (type == AdjustType.UPPER && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0)) { d = ((DateV2Literal) d.plusDays(1)); @@ -338,7 +351,7 @@ private Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) { return d; } - private Expression migrateToDate(DateV2Literal l) { + private static Expression migrateToDate(DateV2Literal l) { return new DateLiteral(l.getYear(), l.getMonth(), l.getDay()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java index 6b0426adaad5e9..c3c3c17dd55f42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -26,8 +26,10 @@ import org.apache.doris.nereids.types.DecimalV3Type; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.math.BigDecimal; +import java.util.List; /** * if we have a column with decimalv3 type and set enable_decimal_conversion = false. @@ -37,14 +39,20 @@ * and the col1 need to convert to decimalv3(27, 9) to match the precision of right hand * this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6 */ -public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule { - +public class SimplifyDecimalV3Comparison implements ExpressionPatternRuleFactory { public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison(); @Override - public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) { - Expression left = rewrite(cp.left(), context); - Expression right = rewrite(cp.right(), context); + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class).then(SimplifyDecimalV3Comparison::simplify) + ); + } + + /** simplify */ + public static Expression simplify(ComparisonPredicate cp) { + Expression left = cp.left(); + Expression right = cp.right(); if (left.getDataType() instanceof DecimalV3Type && left instanceof Cast @@ -60,7 +68,7 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew } } - private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { + private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros(); int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue); int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java index 3e194a4edde398..bf1b194a6ac7f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; @@ -33,12 +33,18 @@ /** * SimplifyInPredicate */ -public class SimplifyInPredicate extends AbstractExpressionRewriteRule { - +public class SimplifyInPredicate implements ExpressionPatternRuleFactory { public static final SimplifyInPredicate INSTANCE = new SimplifyInPredicate(); @Override - public Expression visitInPredicate(InPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class).then(SimplifyInPredicate::simplify) + ); + } + + /** simplify */ + public static Expression simplify(InPredicate expr) { if (expr.children().size() > 1) { if (expr.getCompareExpr() instanceof Cast) { Cast cast = (Cast) expr.getCompareExpr(); @@ -58,7 +64,7 @@ && canLosslessConvertToDateV2Literal((DateTimeV2Literal) literal))) { DateTimeV2Type compareType = (DateTimeV2Type) cast.child().getDataType(); if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal && canLosslessConvertToLowScaleLiteral( - (DateTimeV2Literal) literal, compareType.getScale()))) { + (DateTimeV2Literal) literal, compareType.getScale()))) { ImmutableList.Builder children = ImmutableList.builder(); children.add(cast.child()); literals.forEach(l -> children.add(new DateTimeV2Literal(compareType, @@ -86,7 +92,7 @@ private static boolean canLosslessConvertToDateV2Literal(DateTimeV2Literal liter | literal.getMicroSecond()) == 0L; } - private DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) { + private static DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) { return new DateV2Literal(literal.getYear(), literal.getMonth(), literal.getDay()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java index 7268d6e8328a9c..484d68f0d7317d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -28,6 +28,10 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Not; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Rewrite rule of NOT expression. * For example: @@ -42,12 +46,19 @@ * not and(a >= b, a <= c) -> or(a < b, a > c) * not or(a >= b, a <= c) -> and(a < b, a > c) */ -public class SimplifyNotExprRule extends AbstractExpressionRewriteRule { +public class SimplifyNotExprRule implements ExpressionPatternRuleFactory { public static SimplifyNotExprRule INSTANCE = new SimplifyNotExprRule(); @Override - public Expression visitNot(Not not, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(Not.class).then(SimplifyNotExprRule::simplify) + ); + } + + /** simplifyNot */ + public static Expression simplify(Not not) { Expression child = not.child(); if (child instanceof ComparisonPredicate) { ComparisonPredicate cp = (ComparisonPredicate) not.child(); @@ -55,23 +66,22 @@ public Expression visitNot(Not not, ExpressionRewriteContext context) { Expression right = cp.right(); if (child instanceof GreaterThan) { - return new LessThanEqual(left, right).accept(this, context); + return new LessThanEqual(left, right); } else if (child instanceof GreaterThanEqual) { - return new LessThan(left, right).accept(this, context); + return new LessThan(left, right); } else if (child instanceof LessThan) { - return new GreaterThanEqual(left, right).accept(this, context); + return new GreaterThanEqual(left, right); } else if (child instanceof LessThanEqual) { - return new GreaterThan(left, right).accept(this, context); + return new GreaterThan(left, right); } } else if (child instanceof CompoundPredicate) { CompoundPredicate cp = (CompoundPredicate) child; Not left = new Not(cp.left()); Not right = new Not(cp.right()); - return cp.flip(left, right).accept(this, context); + return cp.flip(left, right); } else if (child instanceof Not) { - return child.child(0).accept(this, context); + return child.child(0); } - - return super.visitNot(not, context); + return not; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index 4dbfdb2f35a249..98d752facb464a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -17,9 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; @@ -41,15 +40,17 @@ import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; import com.google.common.collect.Range; import com.google.common.collect.Sets; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.function.BinaryOperator; @@ -74,18 +75,21 @@ * 2. for `Or` expression (similar to `And`). * todo: support a > 10 and (a < 10 or a > 20 ) => a > 20 */ -public class SimplifyRange extends AbstractExpressionRewriteRule { - +public class SimplifyRange implements ExpressionPatternRuleFactory { public static final SimplifyRange INSTANCE = new SimplifyRange(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - if (expr instanceof CompoundPredicate) { - ValueDesc valueDesc = expr.accept(new RangeInference(), null); - Expression simplifiedExpr = valueDesc.toExpression(); - return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr; - } - return expr; + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CompoundPredicate.class).then(SimplifyRange::rewrite) + ); + } + + /** rewrite */ + public static Expression rewrite(CompoundPredicate expr) { + ValueDesc valueDesc = expr.accept(new RangeInference(), null); + Expression simplifiedExpr = valueDesc.toExpression(); + return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr; } private static class RangeInference extends ExpressionVisitor { @@ -96,21 +100,20 @@ public ValueDesc visit(Expression expr, Void context) { } private ValueDesc buildRange(ComparisonPredicate predicate) { - Expression rewrite = ExpressionRuleExecutor.normalize(predicate); - Expression right = rewrite.child(1); + Expression right = predicate.child(1); if (right.isNullLiteral()) { // it's safe to return empty value if >, >=, <, <= and = with null if ((predicate instanceof GreaterThan || predicate instanceof GreaterThanEqual || predicate instanceof LessThan || predicate instanceof LessThanEqual || predicate instanceof EqualTo)) { - return new EmptyValue(rewrite.child(0), rewrite); + return new EmptyValue(predicate.child(0), predicate); } else { return new UnknownValue(predicate); } } // only handle `NumericType` and `DateLikeType` if (right.isLiteral() && (right.getDataType().isNumericType() || right.getDataType().isDateLikeType())) { - return ValueDesc.range((ComparisonPredicate) rewrite); + return ValueDesc.range(predicate); } return new UnknownValue(predicate); } @@ -164,18 +167,23 @@ public ValueDesc visitOr(Or or, Void context) { private ValueDesc simplify(Expression originExpr, List predicates, BinaryOperator op, BinaryOperator exprOp) { - Map> groupByReference = predicates.stream() - .map(predicate -> predicate.accept(this, null)) - .collect(Collectors.groupingBy(p -> p.reference, LinkedHashMap::new, Collectors.toList())); + Multimap groupByReference + = Multimaps.newListMultimap(new LinkedHashMap<>(), ArrayList::new); + for (Expression predicate : predicates) { + ValueDesc valueDesc = predicate.accept(this, null); + List valueDescs = (List) groupByReference.get(valueDesc.reference); + valueDescs.add(valueDesc); + } List valuePerRefs = Lists.newArrayList(); - for (Entry> referenceValues : groupByReference.entrySet()) { - List valuePerReference = referenceValues.getValue(); + for (Entry> referenceValues : groupByReference.asMap().entrySet()) { + List valuePerReference = (List) referenceValues.getValue(); // merge per reference - ValueDesc simplifiedValue = valuePerReference.stream() - .reduce(op) - .get(); + ValueDesc simplifiedValue = valuePerReference.get(0); + for (int i = 1; i < valuePerReference.size(); i++) { + simplifiedValue = op.apply(simplifiedValue, valuePerReference.get(i)); + } valuePerRefs.add(simplifiedValue); } @@ -245,6 +253,7 @@ public static ValueDesc range(ComparisonPredicate predicate) { } public static ValueDesc discrete(InPredicate in) { + // Set literals = (Set) Utils.fastToImmutableSet(in.getOptions()); Set literals = in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet()); return new DiscreteValue(in.getCompareExpr(), in, literals); } @@ -427,7 +436,9 @@ public Expression toExpression() { // They are same processes, so must change synchronously. if (values.size() == 1) { return new EqualTo(reference, values.iterator().next()); - } else if (values.size() <= OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { + + // this condition should as same as OrToIn, or else meet dead loop + } else if (values.size() < OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { Iterator iterator = values.iterator(); return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next())); } else { @@ -478,10 +489,12 @@ public Expression toExpression() { if (sourceValues.isEmpty()) { return expr; } - return sourceValues.stream() - .map(ValueDesc::toExpression) - .reduce(mergeExprOp) - .get(); + + Expression result = sourceValues.get(0).toExpression(); + for (int i = 1; i < sourceValues.size(); i++) { + result = mergeExprOp.apply(result, sourceValues.get(i).toExpression()); + } + return result; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java index 17f4b7d239a237..27b929a2b9f865 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.scalar.DateFormat; import org.apache.doris.nereids.trees.expressions.functions.scalar.FromUnixtime; @@ -26,54 +26,46 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; -/** SupportJavaDateFormatter */ -public class SupportJavaDateFormatter extends AbstractExpressionRewriteRule { +/** SupportJavaDateFormatter2 */ +public class SupportJavaDateFormatter implements ExpressionPatternRuleFactory { public static final SupportJavaDateFormatter INSTANCE = new SupportJavaDateFormatter(); @Override - public Expression visitDateFormat(DateFormat dateFormat, ExpressionRewriteContext context) { - Expression expr = super.visitDateFormat(dateFormat, context); - if (!(expr instanceof DateFormat)) { - return expr; - } - dateFormat = (DateFormat) expr; + public List> buildRules() { + return ImmutableList.of( + matchesType(DateFormat.class).then(SupportJavaDateFormatter::rewriteDateFormat), + matchesType(FromUnixtime.class).then(SupportJavaDateFormatter::rewriteFromUnixtime), + matchesType(UnixTimestamp.class).then(SupportJavaDateFormatter::rewriteUnixTimestamp) + ); + } + + public static Expression rewriteDateFormat(DateFormat dateFormat) { if (dateFormat.arity() > 1) { return translateJavaFormatter(dateFormat, 1); } return dateFormat; } - @Override - public Expression visitFromUnixtime(FromUnixtime fromUnixtime, ExpressionRewriteContext context) { - Expression expr = super.visitFromUnixtime(fromUnixtime, context); - if (!(expr instanceof FromUnixtime)) { - return expr; - } - fromUnixtime = (FromUnixtime) expr; + public static Expression rewriteFromUnixtime(FromUnixtime fromUnixtime) { if (fromUnixtime.arity() > 1) { return translateJavaFormatter(fromUnixtime, 1); } return fromUnixtime; } - @Override - public Expression visitUnixTimestamp(UnixTimestamp unixTimestamp, ExpressionRewriteContext context) { - Expression expr = super.visitUnixTimestamp(unixTimestamp, context); - if (!(expr instanceof UnixTimestamp)) { - return expr; - } - unixTimestamp = (UnixTimestamp) expr; + public static Expression rewriteUnixTimestamp(UnixTimestamp unixTimestamp) { if (unixTimestamp.arity() > 1) { return translateJavaFormatter(unixTimestamp, 1); } return unixTimestamp; } - private Expression translateJavaFormatter(Expression function, int formatterIndex) { + private static Expression translateJavaFormatter(Expression function, int formatterIndex) { Expression formatterExpr = function.getArgument(formatterIndex); Expression newFormatterExpr = translateJavaFormatter(formatterExpr); if (newFormatterExpr != formatterExpr) { @@ -84,7 +76,7 @@ private Expression translateJavaFormatter(Expression function, int formatterInde return function; } - private Expression translateJavaFormatter(Expression formatterExpr) { + private static Expression translateJavaFormatter(Expression formatterExpr) { if (formatterExpr.isLiteral() && formatterExpr.getDataType().isStringLikeType()) { Literal literal = (Literal) formatterExpr; String originFormatter = literal.getStringValue(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java index 30e76cfe226f5b..318cb6ec6031af 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java @@ -17,39 +17,38 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.TopN; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.collect.ImmutableList; + +import java.util.List; /** * Convert topn(x, 1) to max(x) */ -public class TopnToMax extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class TopnToMax implements ExpressionPatternRuleFactory { public static final TopnToMax INSTANCE = new TopnToMax(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, null); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(TopN.class).then(TopnToMax::rewrite) + ); } - @Override - public Expression visitAggregateFunction(AggregateFunction aggregateFunction, ExpressionRewriteContext context) { - if (!(aggregateFunction instanceof TopN)) { - return aggregateFunction; - } - TopN topN = (TopN) aggregateFunction; + /** rewrite */ + public static Expression rewrite(TopN topN) { if (topN.arity() == 2 && topN.child(1) instanceof IntegerLikeLiteral && ((IntegerLikeLiteral) topN.child(1)).getIntValue() == 1) { return new Max(topN.child(0)); } else { - return aggregateFunction; + return topN; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java index 3faf56f0f3829e..ce23219bcc93e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rules.TryEliminateUninterestedPredicates.Context; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; @@ -51,10 +52,17 @@ private TryEliminateUninterestedPredicates(Set interestedSlots, CascadesCo this.expressionRewriteContext = new ExpressionRewriteContext(cascadesContext); } + /** rewrite */ public static Expression rewrite(Expression expression, Set interestedSlots, CascadesContext cascadesContext) { // before eliminate uninterested predicate, we must push down `Not` under CompoundPredicate - expression = expression.accept(new SimplifyNotExprRule(), null); + expression = expression.rewriteUp(expr -> { + if (expr instanceof Not) { + return SimplifyNotExprRule.simplify((Not) expr); + } else { + return expr; + } + }); TryEliminateUninterestedPredicates rewriter = new TryEliminateUninterestedPredicates( interestedSlots, cascadesContext); return expression.accept(rewriter, new Context()); @@ -89,7 +97,7 @@ public Expression visit(Expression originExpr, Context parentContext) { // -> ((interested slot a) and true) or true // -> (interested slot a) or true // -> true - expr = expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + expr = FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext); } } else { // ((uninterested slot b > 0) + 1) > 1 @@ -122,7 +130,7 @@ public Expression visitAnd(And and, Context parentContext) { if (rightContext.childrenContainsNonInterestedSlots) { newRight = BooleanLiteral.TRUE; } - Expression expr = new And(newLeft, newRight).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + Expression expr = FoldConstantRuleOnFE.evaluate(new And(newLeft, newRight), expressionRewriteContext); parentContext.childrenContainsInterestedSlots = rightContext.childrenContainsInterestedSlots || leftContext.childrenContainsInterestedSlots; return expr; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 10b21d0b979ae0..61aac4d2407462 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -1611,7 +1611,7 @@ private Pair, List> countDistinctMultiEx } private boolean containsCountDistinctMultiExpr(LogicalAggregate aggregate) { - return ExpressionUtils.anyMatch(aggregate.getOutputExpressions(), expr -> + return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(), expr -> expr instanceof Count && ((Count) expr).isDistinct() && expr.arity() > 1); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java index fcf3e82737bf9b..ebc67e7c515930 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java @@ -27,6 +27,8 @@ import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.util.TypeCoercionUtils; +import com.google.common.collect.ImmutableSet; + import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -51,7 +53,7 @@ public Plan visitLogicalFilter(LogicalFilter filter, Void contex filter = (LogicalFilter) super.visit(filter, context); Set conjuncts = filter.getConjuncts().stream() .map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE)) - .collect(Collectors.toSet()); + .collect(ImmutableSet.toImmutableSet()); return filter.withConjuncts(conjuncts); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java index f30d55ad0fc294..a608448e023f07 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java @@ -268,11 +268,19 @@ private T updateExpression(T input, Map rep } private List updateExpressions(List inputs, Map replaceMap) { - return inputs.stream().map(i -> updateExpression(i, replaceMap)).collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(inputs.size()); + for (T input : inputs) { + result.add(updateExpression(input, replaceMap)); + } + return result.build(); } private Set updateExpressions(Set inputs, Map replaceMap) { - return inputs.stream().map(i -> updateExpression(i, replaceMap)).collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder result = ImmutableSet.builderWithExpectedSize(inputs.size()); + for (T input : inputs) { + result.add(updateExpression(input, replaceMap)); + } + return result.build(); } private Map collectChildrenOutputMap(LogicalPlan plan) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java index 907d34c07c0a12..8c73991f3638aa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.util.ExpressionUtils; import java.util.List; @@ -38,7 +39,7 @@ public class CheckMatchExpression extends OneRewriteRuleFactory { @Override public Rule build() { return logicalFilter(logicalOlapScan()) - .when(filter -> containsMatchExpression(filter.getExpressions())) + .when(filter -> ExpressionUtils.containsType(filter.getExpressions(), Match.class)) .then(this::checkChildren) .toRule(RuleType.CHECK_MATCH_EXPRESSION); } @@ -60,8 +61,4 @@ private Plan checkChildren(LogicalFilter filter) { } return filter; } - - private boolean containsMatchExpression(List expressions) { - return expressions.stream().anyMatch(expr -> expr.anyMatch(Match.class::isInstance)); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java index b4d7b8005132ff..70a5c593ee3dc8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.rules.analysis.UserAuthentication; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; @@ -30,9 +29,12 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalView; import org.apache.doris.qe.ConnectContext; +import com.google.common.collect.Sets; + +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; /** CheckPrivileges */ public class CheckPrivileges extends ColumnPruning { @@ -65,15 +67,20 @@ public Plan visitLogicalRelation(LogicalRelation relation, PruneContext context) } private Set computeUsedColumns(Plan plan, Set requiredSlots) { - Map idToSlot = plan.getOutputSet() - .stream() - .collect(Collectors.toMap(slot -> slot.getExprId().asInt(), slot -> slot)); - return requiredSlots - .stream() - .map(slot -> idToSlot.get(slot.getExprId().asInt())) - .filter(slot -> slot != null) - .map(NamedExpression::getName) - .collect(Collectors.toSet()); + List outputs = plan.getOutput(); + Map idToSlot = new LinkedHashMap<>(outputs.size()); + for (Slot output : outputs) { + idToSlot.putIfAbsent(output.getExprId().asInt(), output); + } + + Set usedColumns = Sets.newLinkedHashSetWithExpectedSize(requiredSlots.size()); + for (Slot requiredSlot : requiredSlots) { + Slot slot = idToSlot.get(requiredSlot.getExprId().asInt()); + if (slot != null) { + usedColumns.add(slot.getName()); + } + } + return usedColumns; } private void checkColumnPrivileges(TableIf table, Set usedColumns) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java index f33f1658c32e29..e36c0f5172ad70 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; @@ -39,18 +40,17 @@ import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -97,13 +97,11 @@ public Plan visit(Plan plan, JobContext jobContext) { for (Plan child : plan.children()) { child.accept(this, jobContext); } - plan.getExpressions().stream().filter( - expression -> !(expression instanceof SlotReference) - ).forEach( - expression -> { - keys.addAll(expression.getInputSlots()); - } - ); + for (Expression expression : plan.getExpressions()) { + if (!(expression instanceof SlotReference)) { + keys.addAll(expression.getInputSlots()); + } + } return plan; } } @@ -212,39 +210,42 @@ private Plan pruneAggregate(Aggregate agg, PruneContext context) { } private Plan skipPruneThisAndFirstLevelChildren(Plan plan) { - Set requireAllOutputOfChildren = plan.children() - .stream() - .flatMap(child -> child.getOutputSet().stream()) - .collect(Collectors.toSet()); - return pruneChildren(plan, requireAllOutputOfChildren); + ImmutableSet.Builder requireAllOutputOfChildren = ImmutableSet.builder(); + for (Plan child : plan.children()) { + requireAllOutputOfChildren.addAll(child.getOutput()); + } + return pruneChildren(plan, requireAllOutputOfChildren.build()); } - private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) { + private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) { List groupBy = prunedOutputAgg.getGroupByExpressions(); List output = prunedOutputAgg.getOutputExpressions(); if (!(prunedOutputAgg instanceof LogicalAggregate)) { return prunedOutputAgg; } - // add back group by keys which eliminated by rule ELIMINATE_GROUP_BY_KEY - // if related output expressions are not in pruned output list. - List remainedOutputExprs = Lists.newArrayList(output); - remainedOutputExprs.removeAll(groupBy); - List newOutputList = Lists.newArrayList(); - newOutputList.addAll((List) groupBy); - newOutputList.addAll(remainedOutputExprs); + ImmutableList.Builder newOutputListBuilder + = ImmutableList.builderWithExpectedSize(output.size()); + newOutputListBuilder.addAll((List) groupBy); + for (NamedExpression ne : output) { + if (!groupBy.contains(ne)) { + newOutputListBuilder.add(ne); + } + } - if (!(prunedOutputAgg instanceof LogicalAggregate)) { - return prunedOutputAgg.withAggOutput(newOutputList); - } else { - List newGroupByExprList = newOutputList.stream().filter(e -> - !(prunedOutputAgg.getAggregateFunctions().contains(e) - || e instanceof Alias && prunedOutputAgg.getAggregateFunctions() - .contains(((Alias) e).child())) - ).collect(Collectors.toList()); - return ((LogicalAggregate) prunedOutputAgg).withGroupByAndOutput(newGroupByExprList, newOutputList); + List newOutputList = newOutputListBuilder.build(); + Set aggregateFunctions = prunedOutputAgg.getAggregateFunctions(); + ImmutableList.Builder newGroupByExprList + = ImmutableList.builderWithExpectedSize(newOutputList.size()); + for (NamedExpression e : newOutputList) { + if (!(aggregateFunctions.contains(e) + || (e instanceof Alias && aggregateFunctions.contains(e.child(0))))) { + newGroupByExprList.add(e); + } } + return ((LogicalAggregate) prunedOutputAgg).withGroupByAndOutput( + newGroupByExprList.build(), newOutputList); } /** prune output */ @@ -253,9 +254,8 @@ public

P pruneOutput(P plan, List originOutput if (originOutput.isEmpty()) { return plan; } - List prunedOutputs = originOutput.stream() - .filter(output -> context.requiredSlots.contains(output.toSlot())) - .collect(ImmutableList.toImmutableList()); + List prunedOutputs = + Utils.filterImmutableList(originOutput, output -> context.requiredSlots.contains(output.toSlot())); if (prunedOutputs.isEmpty()) { List candidates = Lists.newArrayList(originOutput); @@ -281,7 +281,6 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) } List prunedOutputs = Lists.newArrayList(); List> constantExprsList = union.getConstantExprsList(); - List> prunedConstantExprsList = Lists.newArrayList(); List extractColumnIndex = Lists.newArrayList(); for (int i = 0; i < originOutput.size(); i++) { NamedExpression output = originOutput.get(i); @@ -291,12 +290,14 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) } } int len = extractColumnIndex.size(); + ImmutableList.Builder> prunedConstantExprsList + = ImmutableList.builderWithExpectedSize(constantExprsList.size()); for (List row : constantExprsList) { - ArrayList newRow = new ArrayList<>(len); + ImmutableList.Builder newRow = ImmutableList.builderWithExpectedSize(len); for (int idx : extractColumnIndex) { newRow.add(row.get(idx)); } - prunedConstantExprsList.add(newRow); + prunedConstantExprsList.add(newRow.build()); } if (prunedOutputs.isEmpty()) { @@ -312,7 +313,7 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) if (prunedOutputs.equals(originOutput)) { return union; } else { - return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList); + return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList.build()); } } @@ -329,24 +330,31 @@ private

P pruneChildren(P plan, Set parentRequiredSlots) Set currentUsedSlots = plan.getInputSlots(); Set childrenRequiredSlots = parentRequiredSlots.isEmpty() ? currentUsedSlots - : ImmutableSet.builder() + : ImmutableSet.builderWithExpectedSize(parentRequiredSlots.size() + currentUsedSlots.size()) .addAll(parentRequiredSlots) .addAll(currentUsedSlots) .build(); - List newChildren = new ArrayList<>(); + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(plan.arity()); boolean hasNewChildren = false; for (Plan child : plan.children()) { - Set childOutputSet = child.getOutputSet(); - Set childRequiredSlots = childOutputSet.stream() - .filter(childrenRequiredSlots::contains).collect(Collectors.toSet()); + Set childRequiredSlots; + List childOutputs = child.getOutput(); + ImmutableSet.Builder childRequiredSlotBuilder + = ImmutableSet.builderWithExpectedSize(childOutputs.size()); + for (Slot childOutput : childOutputs) { + if (childrenRequiredSlots.contains(childOutput)) { + childRequiredSlotBuilder.add(childOutput); + } + } + childRequiredSlots = childRequiredSlotBuilder.build(); Plan prunedChild = doPruneChild(plan, child, childRequiredSlots); if (prunedChild != child) { hasNewChildren = true; } newChildren.add(prunedChild); } - return hasNewChildren ? (P) plan.withChildren(newChildren) : plan; + return hasNewChildren ? (P) plan.withChildren(newChildren.build()) : plan; } private Plan doPruneChild(Plan plan, Plan child, Set childRequiredSlots) { @@ -358,7 +366,7 @@ private Plan doPruneChild(Plan plan, Plan child, Set childRequiredSlots) { // the case 2 in the class comment, prune child's output failed if (!isProject && !Sets.difference(prunedChild.getOutputSet(), childRequiredSlots).isEmpty()) { - prunedChild = new LogicalProject<>(ImmutableList.copyOf(childRequiredSlots), prunedChild); + prunedChild = new LogicalProject<>(Utils.fastToImmutableList(childRequiredSlots), prunedChild); } return prunedChild; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java index f2ccf55ac50731..3d106078a03319 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java @@ -24,9 +24,12 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.types.DataType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import java.util.List; @@ -38,35 +41,42 @@ public class CountDistinctRewrite extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalAggregate().then(agg -> { - List output = agg.getOutputExpressions() - .stream() - .map(CountDistinctRewriter::rewrite) - .map(NamedExpression.class::cast) - .collect(ImmutableList.toImmutableList()); - return agg.withAggOutput(output); + return logicalAggregate().when(CountDistinctRewrite::containsCountObject).then(agg -> { + List outputExpressions = agg.getOutputExpressions(); + Builder newOutputs + = ImmutableList.builderWithExpectedSize(outputExpressions.size()); + for (NamedExpression outputExpression : outputExpressions) { + NamedExpression newOutput = (NamedExpression) outputExpression.rewriteUp(expr -> { + if (expr instanceof Count && ((Count) expr).isDistinct() && expr.arity() == 1) { + Expression child = expr.child(0); + if (child.getDataType().isBitmapType()) { + return new BitmapUnionCount(child); + } + if (child.getDataType().isHllType()) { + return new HllUnionAgg(child); + } + } + return expr; + }); + newOutputs.add(newOutput); + } + return agg.withAggOutput(newOutputs.build()); }).toRule(RuleType.COUNT_DISTINCT_REWRITE); } - private static class CountDistinctRewriter extends DefaultExpressionRewriter { - private static final CountDistinctRewriter INSTANCE = new CountDistinctRewriter(); - - public static Expression rewrite(Expression expr) { - return expr.accept(INSTANCE, null); - } - - @Override - public Expression visitCount(Count count, Void context) { - if (count.isDistinct() && count.arity() == 1) { - Expression child = count.child(0); - if (child.getDataType().isBitmapType()) { - return new BitmapUnionCount(child); - } - if (child.getDataType().isHllType()) { - return new HllUnionAgg(child); + private static boolean containsCountObject(LogicalAggregate agg) { + for (NamedExpression ne : agg.getOutputExpressions()) { + boolean needRewrite = ne.anyMatch(expr -> { + if (expr instanceof Count && ((Count) expr).isDistinct() && expr.arity() == 1) { + DataType dataType = expr.child(0).getDataType(); + return dataType.isBitmapType() || dataType.isHllType(); } + return false; + }); + if (needRewrite) { + return true; } - return count; } + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java index dfe13b388f5b56..bfbd6599cf8acf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java @@ -27,13 +27,14 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; /** * count(1) ==> count(*) @@ -50,21 +51,31 @@ public Rule build() { return agg; } - Map> projectsAndAggFunc = newExprs.stream() - .collect(Collectors.partitioningBy(Expression::isConstant)); + List projectFuncs = Lists.newArrayListWithCapacity(newExprs.size()); + Builder aggFuncsBuilder + = ImmutableList.builderWithExpectedSize(newExprs.size()); + for (NamedExpression newExpr : newExprs) { + if (newExpr.isConstant()) { + projectFuncs.add(newExpr); + } else { + aggFuncsBuilder.add(newExpr); + } + } - if (projectsAndAggFunc.get(false).isEmpty()) { + List aggFuncs = aggFuncsBuilder.build(); + if (aggFuncs.isEmpty()) { // if there is no group by keys and other agg func, don't rewrite return null; } else { // if there is group by keys, put count(null) in projects, such as // project(0 as count(null)) // --Aggregate(k1, group by k1) - Plan plan = agg.withAggOutput(projectsAndAggFunc.get(false)); - if (!projectsAndAggFunc.get(true).isEmpty()) { - projectsAndAggFunc.get(false).stream().map(NamedExpression::toSlot) - .forEach(projectsAndAggFunc.get(true)::add); - plan = new LogicalProject<>(projectsAndAggFunc.get(true), plan); + Plan plan = agg.withAggOutput(aggFuncs); + if (!projectFuncs.isEmpty()) { + for (NamedExpression aggFunc : aggFuncs) { + projectFuncs.add(aggFunc.toSlot()); + } + plan = new LogicalProject<>(projectFuncs, plan); } return plan; } @@ -77,9 +88,11 @@ private boolean rewriteCountLiteral(List oldExprs, List replaced = new HashMap<>(); Set oldAggFuncSet = expr.collect(AggregateFunction.class::isInstance); - oldAggFuncSet.stream() - .filter(this::isCountLiteral) - .forEach(c -> replaced.put(c, rewrite((Count) c))); + for (AggregateFunction aggFun : oldAggFuncSet) { + if (isCountLiteral(aggFun)) { + replaced.put(aggFun, rewrite((Count) aggFun)); + } + } expr = expr.rewriteUp(s -> replaced.getOrDefault(s, s)); changed |= !replaced.isEmpty(); newExprs.add((NamedExpression) expr); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java index a3de71a770e30f..ef9e418f58d185 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java @@ -43,10 +43,10 @@ public class EliminateFilter implements RewriteRuleFactory { @Override public List buildRules() { return ImmutableList.of(logicalFilter().when( - filter -> filter.getConjuncts().stream().anyMatch(BooleanLiteral.class::isInstance)) + filter -> ExpressionUtils.containsType(filter.getConjuncts(), BooleanLiteral.class)) .thenApply(ctx -> { LogicalFilter filter = ctx.root; - ImmutableSet.Builder newConjuncts = ImmutableSet.builder(); + ImmutableSet.Builder newConjuncts = ImmutableSet.builder(); for (Expression expression : filter.getConjuncts()) { if (expression == BooleanLiteral.FALSE) { return new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), @@ -73,8 +73,7 @@ public List buildRules() { new ExpressionRewriteContext(ctx.cascadesContext); for (Expression expression : filter.getConjuncts()) { Expression newExpr = ExpressionUtils.replace(expression, replaceMap); - Expression foldExpression = - FoldConstantRule.INSTANCE.rewrite(newExpr, context); + Expression foldExpression = FoldConstantRule.evaluate(newExpr, context); if (foldExpression == BooleanLiteral.FALSE) { return new LogicalEmptyRelation( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java index 3b95e9b44e06f0..109cff192f22f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -31,11 +32,14 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSet.Builder; + import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; /** * Eliminate GroupBy. @@ -45,39 +49,53 @@ public class EliminateGroupBy extends OneRewriteRuleFactory { @Override public Rule build() { return logicalAggregate() - .when(agg -> agg.getGroupByExpressions().stream().allMatch(expr -> expr instanceof Slot)) + .when(agg -> ExpressionUtils.allMatch(agg.getGroupByExpressions(), Slot.class::isInstance)) .then(agg -> { - Set groupby = agg.getGroupByExpressions().stream().map(e -> (Slot) e) - .collect(Collectors.toSet()); + List groupByExpressions = agg.getGroupByExpressions(); + Builder groupBySlots + = ImmutableSet.builderWithExpectedSize(groupByExpressions.size()); + for (Expression groupByExpression : groupByExpressions) { + groupBySlots.add((Slot) groupByExpression); + } Plan child = agg.child(); - boolean unique = child.getLogicalProperties().getFunctionalDependencies() - .isUniqueAndNotNull(groupby); + boolean unique = child.getLogicalProperties() + .getFunctionalDependencies() + .isUniqueAndNotNull(groupBySlots.build()); if (!unique) { return null; } - Set aggregateFunctions = agg.getAggregateFunctions(); - if (!aggregateFunctions.stream().allMatch( - f -> (f instanceof Sum || f instanceof Count || f instanceof Min || f instanceof Max) - && (f.arity() == 1 && f.child(0) instanceof Slot))) { - return null; + for (AggregateFunction f : agg.getAggregateFunctions()) { + if (!((f instanceof Sum || f instanceof Count || f instanceof Min || f instanceof Max) + && (f.arity() == 1 && f.child(0) instanceof Slot))) { + return null; + } } + List outputExpressions = agg.getOutputExpressions(); + + ImmutableList.Builder newOutput + = ImmutableList.builderWithExpectedSize(outputExpressions.size()); - List newOutput = agg.getOutputExpressions().stream().map(ne -> { + for (NamedExpression ne : outputExpressions) { if (ne instanceof Alias && ne.child(0) instanceof AggregateFunction) { AggregateFunction f = (AggregateFunction) ne.child(0); if (f instanceof Sum || f instanceof Min || f instanceof Max) { - return new Alias(ne.getExprId(), f.child(0), ne.getName()); + newOutput.add(new Alias(ne.getExprId(), f.child(0), ne.getName())); } else if (f instanceof Count) { - return (NamedExpression) ne.withChildren( - new If(new IsNull(f.child(0)), Literal.of(0), Literal.of(1))); + newOutput.add((NamedExpression) ne.withChildren( + new If( + new IsNull(f.child(0)), + Literal.of(0), + Literal.of(1) + ) + )); } else { throw new IllegalStateException("Unexpected aggregate function: " + f); } } else { - return ne; + newOutput.add(ne); } - }).collect(Collectors.toList()); - return PlanUtils.projectOrSelf(newOutput, child); + } + return PlanUtils.projectOrSelf(newOutput.build(), child); }).toRule(RuleType.ELIMINATE_GROUP_BY); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java index 2c5a4bbdd14e61..2e426beae46537 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java @@ -19,9 +19,11 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; @@ -38,15 +40,22 @@ public class EliminateMarkJoin extends OneRewriteRuleFactory { public Rule build() { return logicalFilter(logicalJoin().when( join -> join.getJoinType().isSemiJoin() && !join.getMarkJoinConjuncts().isEmpty())) - .when(filter -> canSimplifyMarkJoin(filter.getConjuncts())) - .then(filter -> filter.withChildren(eliminateMarkJoin(filter.child()))) + .when(filter -> canSimplifyMarkJoin(filter.getConjuncts(), null)) + .thenApply(ctx -> { + LogicalFilter> filter = ctx.root; + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); + if (canSimplifyMarkJoin(filter.getConjuncts(), rewriteContext)) { + return filter.withChildren(eliminateMarkJoin(filter.child())); + } + return filter; + }) .toRule(RuleType.ELIMINATE_MARK_JOIN); } - private boolean canSimplifyMarkJoin(Set predicates) { + private boolean canSimplifyMarkJoin(Set predicates, ExpressionRewriteContext rewriteContext) { return ExpressionUtils .canInferNotNullForMarkSlot(TrySimplifyPredicateWithMarkJoinSlot.INSTANCE - .rewrite(ExpressionUtils.and(predicates), null)); + .rewrite(ExpressionUtils.and(predicates), rewriteContext), rewriteContext); } private LogicalJoin eliminateMarkJoin(LogicalJoin join) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java index db95d1fefa03be..22393cb55f8335 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java @@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; @@ -41,7 +40,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Eliminate Predicate `is not null`, like @@ -85,29 +83,34 @@ private List removeGeneratedNotNull(Collection exprs, Ca // remove `name` (it's generated), remove `id` (because `id > 0` already contains it) Set predicatesNotContainIsNotNull = Sets.newHashSet(); List slotsFromIsNotNull = Lists.newArrayList(); - exprs.stream() - .filter(expr -> !(expr instanceof Not) - || !((Not) expr).isGeneratedIsNotNull()) // remove generated `is not null` - .forEach(expr -> { - Optional notNullSlot = TypeUtils.isNotNull(expr); - if (notNullSlot.isPresent()) { - slotsFromIsNotNull.add(notNullSlot.get()); - } else { - predicatesNotContainIsNotNull.add(expr); - } - }); + + for (Expression expr : exprs) { + // remove generated `is not null` + if (!(expr instanceof Not) || !((Not) expr).isGeneratedIsNotNull()) { + Optional notNullSlot = TypeUtils.isNotNull(expr); + if (notNullSlot.isPresent()) { + slotsFromIsNotNull.add(notNullSlot.get()); + } else { + predicatesNotContainIsNotNull.add(expr); + } + } + } + Set inferNonNotSlots = ExpressionUtils.inferNotNullSlots( predicatesNotContainIsNotNull, ctx); - Set keepIsNotNull = slotsFromIsNotNull.stream() - .filter(ExpressionTrait::nullable) - .filter(slot -> !inferNonNotSlots.contains(slot)) - .map(slot -> new Not(new IsNull(slot))).collect(Collectors.toSet()); + ImmutableSet.Builder keepIsNotNull + = ImmutableSet.builderWithExpectedSize(slotsFromIsNotNull.size()); + for (Slot slot : slotsFromIsNotNull) { + if (slot.nullable() && !inferNonNotSlots.contains(slot)) { + keepIsNotNull.add(new Not(new IsNull(slot))); + } + } // merge predicatesNotContainIsNotNull and keepIsNotNull into a new List return ImmutableList.builder() .addAll(predicatesNotContainIsNotNull) - .addAll(keepIsNotNull) + .addAll(keepIsNotNull.build()) .build(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java index 969d6e6b045b9b..021cae2d6533f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java @@ -33,13 +33,19 @@ public class EliminateOrderByConstant extends OneRewriteRuleFactory { @Override public Rule build() { return logicalSort().then(sort -> { - List orderKeysWithoutConst = sort - .getOrderKeys() - .stream() - .filter(k -> !(k.getExpr().isConstant())) - .collect(ImmutableList.toImmutableList()); + List orderKeys = sort.getOrderKeys(); + ImmutableList.Builder orderKeysWithoutConstBuilder + = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey orderKey : orderKeys) { + if (!orderKey.getExpr().isConstant()) { + orderKeysWithoutConstBuilder.add(orderKey); + } + } + List orderKeysWithoutConst = orderKeysWithoutConstBuilder.build(); if (orderKeysWithoutConst.isEmpty()) { return sort.child(); + } else if (orderKeysWithoutConst.size() == orderKeys.size()) { + return sort; } return sort.withOrderKeys(orderKeysWithoutConst); }).toRule(RuleType.ELIMINATE_ORDER_BY_CONSTANT); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java index 697eb8fa5a3fa9..5ec0f0cd698d5e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; @@ -30,6 +31,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; import org.apache.doris.nereids.util.ExpressionUtils; @@ -50,75 +52,94 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i @Override public Rule build() { - return logicalProject().when(project -> containsWindowExpression(project.getProjects())).then(project -> { - List outputs = - ExpressionUtils.rewriteDownShortCircuit(project.getProjects(), output -> { - if (output instanceof WindowExpression) { - WindowExpression windowExpression = (WindowExpression) output; - Expression expression = ((WindowExpression) output).getFunction(); - if (expression instanceof Sum || expression instanceof Max - || expression instanceof Min || expression instanceof Avg) { - // sum, max, min and avg in window function should be always nullable - windowExpression = ((WindowExpression) output) - .withFunction(((NullableAggregateFunction) expression) - .withAlwaysNullable(true)); + return logicalProject() + .when(project -> ExpressionUtils.containsWindowExpression(project.getProjects())) + .then(this::normalize) + .toRule(RuleType.EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS); + } + + private Plan normalize(LogicalProject project) { + List outputs = + ExpressionUtils.rewriteDownShortCircuit(project.getProjects(), output -> { + if (output instanceof WindowExpression) { + WindowExpression windowExpression = (WindowExpression) output; + Expression expression = ((WindowExpression) output).getFunction(); + if (expression instanceof Sum || expression instanceof Max + || expression instanceof Min || expression instanceof Avg) { + // sum, max, min and avg in window function should be always nullable + windowExpression = ((WindowExpression) output) + .withFunction( + ((NullableAggregateFunction) expression).withAlwaysNullable(true) + ); + } + + ImmutableList.Builder nonLiteralPartitionKeys = + ImmutableList.builderWithExpectedSize(windowExpression.getPartitionKeys().size()); + for (Expression partitionKey : windowExpression.getPartitionKeys()) { + if (!partitionKey.isConstant()) { + nonLiteralPartitionKeys.add(partitionKey); } - // remove literal partition by and order by keys - return windowExpression.withPartitionKeysOrderKeys( - windowExpression.getPartitionKeys().stream() - .filter(partitionExpr -> !partitionExpr.isConstant()) - .collect(Collectors.toList()), - windowExpression.getOrderKeys().stream() - .filter(orderExpression -> !orderExpression - .getOrderKey().getExpr().isConstant()) - .collect(Collectors.toList())); } - return output; - }); - - // 1. handle bottom projects - Set existedAlias = ExpressionUtils.collect(outputs, Alias.class::isInstance); - Set toBePushedDown = collectExpressionsToBePushedDown(outputs); - NormalizeToSlotContext context = NormalizeToSlotContext.buildContext(existedAlias, toBePushedDown); - // set toBePushedDown exprs as NamedExpression, e.g. (a+1) -> Alias(a+1) - Set bottomProjects = context.pushDownToNamedExpression(toBePushedDown); - Plan normalizedChild; - if (bottomProjects.isEmpty()) { - normalizedChild = project.child(); - } else { - normalizedChild = project.withProjectsAndChild( - ImmutableList.copyOf(bottomProjects), project.child()); - } - - // 2. handle window's outputs and windowExprs - // need to replace exprs with SlotReference in WindowSpec, due to LogicalWindow.getExpressions() - - // because alias is pushed down to bottom project - // we need replace alias's child expr with corresponding alias's slot in output - // so create a customNormalizeMap alias's child -> alias.toSlot to do it - Map customNormalizeMap = toBePushedDown.stream() - .filter(expr -> expr instanceof Alias) - .collect(Collectors.toMap(expr -> ((Alias) expr).child(), expr -> ((Alias) expr).toSlot(), - (oldExpr, newExpr) -> oldExpr)); - - List normalizedOutputs = context.normalizeToUseSlotRef(outputs, - (ctx, expr) -> customNormalizeMap.getOrDefault(expr, null)); - Set normalizedWindows = - ExpressionUtils.collect(normalizedOutputs, WindowExpression.class::isInstance); - - existedAlias = ExpressionUtils.collect(normalizedOutputs, Alias.class::isInstance); - NormalizeToSlotContext ctxForWindows = NormalizeToSlotContext.buildContext( - existedAlias, Sets.newHashSet(normalizedWindows)); - - Set normalizedWindowWithAlias = ctxForWindows.pushDownToNamedExpression(normalizedWindows); - // only need normalized windowExpressions - LogicalWindow normalizedLogicalWindow = - new LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), normalizedChild); - - // 3. handle top projects - List topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs); - return project.withProjectsAndChild(topProjects, normalizedLogicalWindow); - }).toRule(RuleType.EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS); + + ImmutableList.Builder nonLiteralOrderExpressions = + ImmutableList.builderWithExpectedSize(windowExpression.getOrderKeys().size()); + for (OrderExpression orderExpr : windowExpression.getOrderKeys()) { + if (!orderExpr.getOrderKey().getExpr().isConstant()) { + nonLiteralOrderExpressions.add(orderExpr); + } + } + + // remove literal partition by and order by keys + return windowExpression.withPartitionKeysOrderKeys( + nonLiteralPartitionKeys.build(), + nonLiteralOrderExpressions.build() + ); + } + return output; + }); + + // 1. handle bottom projects + Set existedAlias = ExpressionUtils.collect(outputs, Alias.class::isInstance); + Set toBePushedDown = collectExpressionsToBePushedDown(outputs); + NormalizeToSlotContext context = NormalizeToSlotContext.buildContext(existedAlias, toBePushedDown); + // set toBePushedDown exprs as NamedExpression, e.g. (a+1) -> Alias(a+1) + Set bottomProjects = context.pushDownToNamedExpression(toBePushedDown); + Plan normalizedChild; + if (bottomProjects.isEmpty()) { + normalizedChild = project.child(); + } else { + normalizedChild = project.withProjectsAndChild( + ImmutableList.copyOf(bottomProjects), project.child()); + } + + // 2. handle window's outputs and windowExprs + // need to replace exprs with SlotReference in WindowSpec, due to LogicalWindow.getExpressions() + + // because alias is pushed down to bottom project + // we need replace alias's child expr with corresponding alias's slot in output + // so create a customNormalizeMap alias's child -> alias.toSlot to do it + Map customNormalizeMap = toBePushedDown.stream() + .filter(expr -> expr instanceof Alias) + .collect(Collectors.toMap(expr -> ((Alias) expr).child(), expr -> ((Alias) expr).toSlot(), + (oldExpr, newExpr) -> oldExpr)); + + List normalizedOutputs = context.normalizeToUseSlotRef(outputs, + (ctx, expr) -> customNormalizeMap.getOrDefault(expr, null)); + Set normalizedWindows = + ExpressionUtils.collect(normalizedOutputs, WindowExpression.class::isInstance); + + existedAlias = ExpressionUtils.collect(normalizedOutputs, Alias.class::isInstance); + NormalizeToSlotContext ctxForWindows = NormalizeToSlotContext.buildContext( + existedAlias, Sets.newHashSet(normalizedWindows)); + + Set normalizedWindowWithAlias = ctxForWindows.pushDownToNamedExpression(normalizedWindows); + // only need normalized windowExpressions + LogicalWindow normalizedLogicalWindow = + new LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), normalizedChild); + + // 3. handle top projects + List topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs); + return project.withProjectsAndChild(topProjects, normalizedLogicalWindow); } private Set collectExpressionsToBePushedDown(List expressions) { @@ -161,10 +182,4 @@ private Set collectExpressionsToBePushedDown(List e }) .collect(ImmutableSet.toImmutableSet()); } - - private boolean containsWindowExpression(List expressions) { - // WindowExpression in top LogicalProject will be normalized as Alias(SlotReference) after this rule, - // so it will not be normalized infinitely - return expressions.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance)); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java index 4ecc79ae94e7b0..2f8e1404b7199e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Paper: Quantifying TPC-H Choke Points and Their Optimizations @@ -84,13 +83,9 @@ private List extractDependentConjuncts(Set conjuncts) { } // only check table in first disjunct. // In our example, qualifiers = { n1, n2 } - Expression first = disjuncts.get(0); - Set qualifiers = first.getInputSlots() - .stream() - .map(slot -> String.join(".", slot.getQualifier())) - .collect(Collectors.toSet()); // try to extract - for (String qualifier : qualifiers) { + for (Slot inputSlot : disjuncts.get(0).getInputSlots()) { + String qualifier = String.join(".", inputSlot.getQualifier()); List extractForAll = Lists.newArrayList(); boolean success = true; for (Expression expr : ExpressionUtils.extractDisjunction(conjunct)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java index e7168ca0e99148..39a0e63ff21deb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java @@ -26,7 +26,7 @@ import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; /** @@ -46,7 +46,7 @@ public Rule build() { .whenNot(LogicalJoin::isMarkJoin) .thenApply(ctx -> { LogicalJoin join = ctx.root; - Set conjuncts = new HashSet<>(); + Set conjuncts = new LinkedHashSet<>(); conjuncts.addAll(join.getHashJoinConjuncts()); conjuncts.addAll(join.getOtherJoinConjuncts()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 3bdfbc582acc99..9a0b9f8b5e0353 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -201,7 +201,7 @@ private boolean canMergeAggregateWithProject(LogicalAggregate !(expr instanceof SlotReference) && !(expr instanceof Alias))) { return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java index d152178b5238de..3ea903f8565928 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java @@ -20,9 +20,9 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; import java.util.List; @@ -43,8 +43,8 @@ public Rule build() { // TODO modify ExtractAndNormalizeWindowExpression to handle nested window functions // here we just don't merge two projects if there is any window function return logicalProject(logicalProject()) - .whenNot(project -> containsWindowExpression(project.getProjects()) - && containsWindowExpression(project.child().getProjects())) + .whenNot(project -> ExpressionUtils.containsWindowExpression(project.getProjects()) + && ExpressionUtils.containsWindowExpression(project.child().getProjects())) .then(MergeProjects::mergeProjects).toRule(RuleType.MERGE_PROJECTS); } @@ -54,8 +54,4 @@ public static Plan mergeProjects(LogicalProject project) { LogicalProject newProject = childProject.canEliminate() ? project : childProject; return newProject.withProjectsAndChild(projectExpressions, childProject.child(0)); } - - private boolean containsWindowExpression(List expressions) { - return expressions.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance)); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java index b36d0e63b85423..b7554582885c0c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java @@ -24,13 +24,15 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import java.util.List; -import java.util.stream.Stream; /** * SortNode on BE always output order keys because BE needs them to do merge sort. So we normalize LogicalSort as BE @@ -40,29 +42,44 @@ public class NormalizeSort extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalSort().whenNot(sort -> sort.getOrderKeys().stream() - .map(OrderKey::getExpr).allMatch(Slot.class::isInstance)) + return logicalSort().whenNot(this::allOrderKeyIsSlot) .then(sort -> { List newProjects = Lists.newArrayList(); - List newOrderKeys = sort.getOrderKeys().stream() - .map(orderKey -> { - Expression expr = orderKey.getExpr(); - if (!(expr instanceof Slot)) { - Alias alias = new Alias(expr); - newProjects.add(alias); - expr = alias.toSlot(); - } - return orderKey.withExpression(expr); - }).collect(ImmutableList.toImmutableList()); - List bottomProjections = Stream.concat( - sort.child().getOutput().stream(), - newProjects.stream() - ).collect(ImmutableList.toImmutableList()); - List topProjections = sort.getOutput().stream() - .map(NamedExpression.class::cast) - .collect(ImmutableList.toImmutableList()); - return new LogicalProject<>(topProjections, sort.withOrderKeysAndChild(newOrderKeys, + + Builder newOrderKeys = + ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); + for (OrderKey orderKey : sort.getOrderKeys()) { + Expression expr = orderKey.getExpr(); + if (!(expr instanceof Slot)) { + Alias alias = new Alias(expr); + newProjects.add(alias); + expr = alias.toSlot(); + newOrderKeys.add(orderKey.withExpression(expr)); + } else { + newOrderKeys.add(orderKey); + } + } + + List childOutput = sort.child().getOutput(); + List bottomProjections = ImmutableList.builderWithExpectedSize( + childOutput.size() + newProjects.size()) + .addAll(childOutput) + .addAll(newProjects) + .build(); + + List topProjections = (List) sort.getOutput(); + return new LogicalProject<>(topProjections, sort.withOrderKeysAndChild( + newOrderKeys.build(), new LogicalProject<>(bottomProjections, sort.child()))); }).toRule(RuleType.NORMALIZE_SORT); } + + private boolean allOrderKeyIsSlot(LogicalSort sort) { + for (OrderKey orderKey : sort.getOrderKeys()) { + if (!(orderKey.getExpr() instanceof Slot)) { + return false; + } + } + return true; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java index 683841a5f8fb2c..ea2fb8f4beb538 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java @@ -104,15 +104,19 @@ public List normalizeToUseSlotRef(Collection expres */ public List normalizeToUseSlotRef(Collection expressions, BiFunction customNormalize) { - return expressions.stream() - .map(expr -> (E) expr.rewriteDownShortCircuit(child -> { - Expression newChild = customNormalize.apply(this, child); - if (newChild != null && newChild != child) { - return newChild; - } - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); - return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; - })).collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(expressions.size()); + for (E expr : expressions) { + Expression rewriteExpr = expr.rewriteDownShortCircuit(child -> { + Expression newChild = customNormalize.apply(this, child); + if (newChild != null && newChild != child) { + return newChild; + } + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; + }); + result.add((E) rewriteExpr); + } + return result.build(); } public List normalizeToUseSlotRefWithoutWindowFunction( @@ -131,13 +135,20 @@ public List normalizeToUseSlotRefWithoutWindowFunction * bottom: k1#0, (k2#1 + 1) AS (k2 + 1)#2; */ public Set pushDownToNamedExpression(Collection needToPushExpressions) { - return needToPushExpressions.stream() - .map(expr -> { - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); - return normalizeToSlotTriplet == null - ? (NamedExpression) expr - : normalizeToSlotTriplet.pushedExpr; - }).collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder result + = ImmutableSet.builderWithExpectedSize(needToPushExpressions.size()); + for (Expression expr : needToPushExpressions) { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); + result.add(normalizeToSlotTriplet == null + ? (NamedExpression) expr + : normalizeToSlotTriplet.pushedExpr); + } + return result.build(); + } + + public NamedExpression pushDownToNamedExpression(Expression expr) { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); + return normalizeToSlotTriplet == null ? (NamedExpression) expr : normalizeToSlotTriplet.pushedExpr; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java index 60df874f2a1004..0aacde1cc1984c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java @@ -36,7 +36,6 @@ import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Maps; import java.util.ArrayList; import java.util.List; @@ -50,6 +49,23 @@ * MergeConsecutiveProjects and all predicate push down related rules. */ public class PruneOlapScanPartition implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of( + logicalFilter(logicalOlapScan()) + .when(p -> !p.child().isPartitionPruned()) + .thenApply(ctx -> prunePartitions(ctx.cascadesContext, ctx.root.child(), ctx.root)) + .toRule(RuleType.OLAP_SCAN_PARTITION_PRUNE), + + logicalFilter(logicalProject(logicalOlapScan())) + .when(p -> !p.child().child().isPartitionPruned()) + .when(p -> p.child().hasPushedDownToProjectionFunctions()) + .thenApply(ctx -> prunePartitions(ctx.cascadesContext, ctx.root.child().child(), ctx.root)) + .toRule(RuleType.OLAP_SCAN_WITH_PROJECT_PARTITION_PRUNE) + ); + } + private Plan prunePartitions(CascadesContext ctx, LogicalOlapScan scan, LogicalFilter originalFilter) { OlapTable table = scan.getTable(); @@ -59,20 +75,22 @@ private Plan prunePartitions(CascadesContext ctx, } List output = scan.getOutput(); - Map scanOutput = Maps.newHashMapWithExpectedSize(output.size() * 2); - for (Slot slot : output) { - scanOutput.put(slot.getName().toLowerCase(), slot); - } - PartitionInfo partitionInfo = table.getPartitionInfo(); List partitionColumns = partitionInfo.getPartitionColumns(); List partitionSlots = new ArrayList<>(partitionColumns.size()); for (Column column : partitionColumns) { - Slot slot = scanOutput.get(column.getName().toLowerCase()); - if (slot == null) { + Slot partitionSlot = null; + // loop search is faster than build a map + for (Slot slot : output) { + if (slot.getName().equalsIgnoreCase(column.getName())) { + partitionSlot = slot; + break; + } + } + if (partitionSlot == null) { return originalFilter; } else { - partitionSlots.add(slot); + partitionSlots.add(partitionSlot); } } @@ -105,19 +123,4 @@ private Plan prunePartitions(CascadesContext ctx, } return originalFilter.withChildren(ImmutableList.of(rewrittenScan)); } - - @Override - public List buildRules() { - return ImmutableList.of( - logicalFilter(logicalOlapScan()).when(p -> !p.child().isPartitionPruned()).thenApply(ctx -> { - return prunePartitions(ctx.cascadesContext, ctx.root.child(), ctx.root); - }).toRule(RuleType.OLAP_SCAN_PARTITION_PRUNE), - - logicalFilter(logicalProject(logicalOlapScan())) - .when(p -> !p.child().child().isPartitionPruned()) - .when(p -> p.child().hasPushedDownToProjectionFunctions()).thenApply(ctx -> { - return prunePartitions(ctx.cascadesContext, ctx.root.child().child(), ctx.root); - }).toRule(RuleType.OLAP_SCAN_WITH_PROJECT_PARTITION_PRUNE) - ); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 26e1358c2e5e11..b02c51b1fe906e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -31,16 +31,17 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSet.Builder; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.google.common.collect.Sets; -import java.util.Collection; import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; import java.util.function.Supplier; -import java.util.stream.Collectors; /** * poll up effective predicates from operator's children. @@ -60,7 +61,7 @@ public ImmutableSet visit(Plan plan, Void context) { @Override public ImmutableSet visitLogicalFilter(LogicalFilter filter, Void context) { return cacheOrElse(filter, () -> { - List predicates = Lists.newArrayList(filter.getConjuncts()); + Set predicates = Sets.newLinkedHashSet(filter.getConjuncts()); predicates.addAll(filter.child().accept(this, context)); return getAvailableExpressions(predicates, filter); }); @@ -82,14 +83,14 @@ public ImmutableSet visitLogicalJoin(LogicalJoin visitLogicalProject(LogicalProject project, Void context) { return cacheOrElse(project, () -> { ImmutableSet childPredicates = project.child().accept(this, context); - - Set allPredicates = Sets.newHashSet(childPredicates); - project.getAliasToProducer().forEach((k, v) -> { - Set expressions = childPredicates.stream() - .map(e -> e.rewriteDownShortCircuit(c -> c.equals(v) ? k : c)).collect(Collectors.toSet()); - allPredicates.addAll(expressions); - }); - + Set allPredicates = Sets.newLinkedHashSet(childPredicates); + for (Entry kv : project.getAliasToProducer().entrySet()) { + Slot k = kv.getKey(); + Expression v = kv.getValue(); + for (Expression childPredicate : childPredicates) { + allPredicates.add(childPredicate.rewriteDownShortCircuit(c -> c.equals(v) ? k : c)); + } + } return getAvailableExpressions(allPredicates, project); }); } @@ -99,21 +100,22 @@ public ImmutableSet visitLogicalAggregate(LogicalAggregate { ImmutableSet childPredicates = aggregate.child().accept(this, context); // TODO - Map expressionSlotMap = aggregate.getOutputExpressions() - .stream() - .filter(this::hasAgg) - .collect(Collectors.toMap( - namedExpr -> { - if (namedExpr instanceof Alias) { - return ((Alias) namedExpr).child(); - } else { - return namedExpr; - } - }, NamedExpression::toSlot) + List outputExpressions = aggregate.getOutputExpressions(); + + Map expressionSlotMap + = Maps.newLinkedHashMapWithExpectedSize(outputExpressions.size()); + for (NamedExpression output : outputExpressions) { + if (hasAgg(output)) { + expressionSlotMap.putIfAbsent( + output instanceof Alias ? output.child(0) : output, output.toSlot() ); - Expression expression = ExpressionUtils.replace(ExpressionUtils.and(Lists.newArrayList(childPredicates)), - expressionSlotMap); - List predicates = ExpressionUtils.extractConjunction(expression); + } + } + Expression expression = ExpressionUtils.replace( + ExpressionUtils.and(Lists.newArrayList(childPredicates)), + expressionSlotMap + ); + Set predicates = Sets.newLinkedHashSet(ExpressionUtils.extractConjunction(expression)); return getAvailableExpressions(predicates, aggregate); }); } @@ -128,12 +130,23 @@ private ImmutableSet cacheOrElse(Plan plan, Supplier getAvailableExpressions(Collection predicates, Plan plan) { - Set expressions = Sets.newHashSet(predicates); - expressions.addAll(PredicatePropagation.infer(expressions)); - return expressions.stream() - .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) - .collect(ImmutableSet.toImmutableSet()); + private ImmutableSet getAvailableExpressions(Set predicates, Plan plan) { + Set inferPredicates = PredicatePropagation.infer(predicates); + Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size() + 10); + Set outputSet = plan.getOutputSet(); + + for (Expression predicate : predicates) { + if (outputSet.containsAll(predicate.getInputSlots())) { + newPredicates.add(predicate); + } + } + + for (Expression inferPredicate : inferPredicates) { + if (outputSet.containsAll(inferPredicate.getInputSlots())) { + newPredicates.add(inferPredicate); + } + } + return newPredicates.build(); } private boolean hasAgg(Expression expression) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java index f3a54fd8eeaa8b..798a41b37643dc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java @@ -29,7 +29,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; /** @@ -60,9 +60,9 @@ public Rule build() { LogicalAggregate aggregate = filter.child(); Set canPushDownSlots = getCanPushDownSlots(aggregate); - Set pushDownPredicates = Sets.newHashSet(); - Set filterPredicates = Sets.newHashSet(); - filter.getConjuncts().forEach(conjunct -> { + Set pushDownPredicates = Sets.newLinkedHashSet(); + Set filterPredicates = Sets.newLinkedHashSet(); + for (Expression conjunct : filter.getConjuncts()) { Set conjunctSlots = conjunct.getInputSlots(); // NOTICE: filter not contain slot should not be pushed. e.g. 'a' = 'b' if (!conjunctSlots.isEmpty() && canPushDownSlots.containsAll(conjunctSlots)) { @@ -70,7 +70,7 @@ public Rule build() { } else { filterPredicates.add(conjunct); } - }); + } if (pushDownPredicates.isEmpty()) { return null; } @@ -84,7 +84,7 @@ public Rule build() { * get the slots that can be pushed down */ public static Set getCanPushDownSlots(LogicalAggregate aggregate) { - Set canPushDownSlots = new HashSet<>(); + Set canPushDownSlots = new LinkedHashSet<>(); if (aggregate.getSourceRepeat().isPresent()) { // When there is a repeat, the push-down condition is consistent with the repeat aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions().stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java index 77c90820a258c4..71834a66b19a2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; @@ -49,16 +48,16 @@ public class PushDownFilterThroughProject implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of( logicalFilter(logicalProject()) - .whenNot(filter -> filter.child().getProjects().stream().anyMatch( - expr -> expr.anyMatch(WindowExpression.class::isInstance))) + .whenNot(filter -> ExpressionUtils.containsWindowExpression(filter.child().getProjects())) .whenNot(filter -> filter.child().hasPushedDownToProjectionFunctions()) .then(PushDownFilterThroughProject::pushdownFilterThroughProject) .toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_PROJECT), // filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit, // then we should change filter(limit(project)) to project(filter(limit)) logicalFilter(logicalLimit(logicalProject())) - .whenNot(filter -> filter.child().child().getProjects().stream() - .anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) + .whenNot(filter -> + ExpressionUtils.containsWindowExpression(filter.child().child().getProjects()) + ) .whenNot(filter -> filter.child().child().hasPushedDownToProjectionFunctions()) .then(PushDownFilterThroughProject::pushdownFilterThroughLimitProject) .toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT) @@ -111,14 +110,14 @@ private static Pair, Set> splitConjunctsByChildOutpu Set conjuncts, Set childOutputs) { Set pushDownPredicates = Sets.newLinkedHashSet(); Set remainPredicates = Sets.newLinkedHashSet(); - conjuncts.forEach(conjunct -> { + for (Expression conjunct : conjuncts) { Set conjunctSlots = conjunct.getInputSlots(); if (childOutputs.containsAll(conjunctSlots)) { pushDownPredicates.add(conjunct); } else { remainPredicates.add(conjunct); } - }); + } return Pair.of(remainPredicates, pushDownPredicates); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReorderJoin.java index 97af548ce02b8c..238db403be632b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReorderJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReorderJoin.java @@ -42,7 +42,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -293,10 +293,10 @@ public Plan multiJoinToJoin(MultiJoin multiJoin, Map p } // following this multiJoin just contain INNER/CROSS. - Set joinFilter = new HashSet<>(multiJoinHandleChildren.getJoinFilter()); + Set joinFilter = new LinkedHashSet<>(multiJoinHandleChildren.getJoinFilter()); Plan left = multiJoinHandleChildren.child(0); - Set usedPlansIndex = new HashSet<>(); + Set usedPlansIndex = new LinkedHashSet<>(); usedPlansIndex.add(0); while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java index cc0c7f12f33cbc..6dc446d88ca882 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java @@ -19,16 +19,18 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** * Simplify Aggregate group by Multiple to One. For example @@ -41,20 +43,25 @@ public class SimplifyAggGroupBy extends OneRewriteRuleFactory { @Override public Rule build() { return logicalAggregate() - .when(agg -> agg.getGroupByExpressions().size() > 1) - .when(agg -> agg.getGroupByExpressions().stream().allMatch(this::isBinaryArithmeticSlot)) + .when(agg -> agg.getGroupByExpressions().size() > 1 + && ExpressionUtils.allMatch(agg.getGroupByExpressions(), this::isBinaryArithmeticSlot)) .then(agg -> { - Set slots = agg.getGroupByExpressions().stream() - .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet()); + List groupByExpressions = agg.getGroupByExpressions(); + ImmutableSet.Builder inputSlots + = ImmutableSet.builderWithExpectedSize(groupByExpressions.size()); + for (Expression groupByExpression : groupByExpressions) { + inputSlots.addAll(groupByExpression.getInputSlots()); + } + Set slots = inputSlots.build(); if (slots.size() != 1) { return null; } - return agg.withGroupByAndOutput(ImmutableList.copyOf(slots), agg.getOutputExpressions()); + return agg.withGroupByAndOutput(Utils.fastToImmutableList(slots), agg.getOutputExpressions()); }) .toRule(RuleType.SIMPLIFY_AGG_GROUP_BY); } - private boolean isBinaryArithmeticSlot(Expression expr) { + private boolean isBinaryArithmeticSlot(TreeNode expr) { if (expr instanceof Slot) { return true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java index 7ca697726ee447..e69bffb301b31e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java @@ -434,7 +434,12 @@ private static int indexKeyPrefixMatchCount( } protected static boolean preAggEnabledByHint(LogicalOlapScan olapScan) { - return olapScan.getHints().stream().anyMatch("PREAGGOPEN"::equalsIgnoreCase); + for (String hint : olapScan.getHints()) { + if ("PREAGGOPEN".equalsIgnoreCase(hint)) { + return true; + } + } + return false; } public static String normalizeName(String name) { @@ -447,11 +452,11 @@ public static Expression slotToCaseWhen(Expression expression) { } protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan) { - Map baseSlotToMvSlot = new HashMap<>(); - Map mvNameToMvSlot = new HashMap<>(); if (mvPlan.getSelectedIndexId() == mvPlan.getTable().getBaseIndexId()) { - return new SlotContext(baseSlotToMvSlot, mvNameToMvSlot, new TreeSet()); + return SlotContext.EMPTY; } + Map baseSlotToMvSlot = new HashMap<>(); + Map mvNameToMvSlot = new HashMap<>(); for (Slot mvSlot : mvPlan.getOutputByIndex(mvPlan.getSelectedIndexId())) { boolean isPushed = false; for (Slot baseSlot : mvPlan.getOutput()) { @@ -505,6 +510,8 @@ protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan, Set baseSlotToMvSlot; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index bd3494378ae9fe..b1a06e3875a466 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -74,8 +74,10 @@ import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import org.apache.doris.planner.PlanNode; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -256,9 +258,10 @@ public List buildRules() { if (result.indexId == scan.getTable().getBaseIndexId()) { LogicalOlapScan mvPlanWithoutAgg = SelectMaterializedIndexWithoutAggregate.select(scan, project::getInputSlots, filter::getConjuncts, - Stream.concat(filter.getExpressions().stream(), - project.getExpressions().stream()) - .collect(ImmutableSet.toImmutableSet())); + Suppliers.memoize(() -> Utils.concatToSet( + filter.getExpressions(), project.getExpressions() + )) + ); SlotContext slotContextWithoutAgg = generateBaseScanExprToMvExpr(mvPlanWithoutAgg); return agg.withChildren(new LogicalProject( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java index 7960dd73df9a8c..e05a1eda3e63fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java @@ -32,7 +32,9 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -42,7 +44,6 @@ import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; -import java.util.stream.Stream; /** * Select materialized index, i.e., both for rollup and materialized view when aggregate is not present. @@ -70,11 +71,13 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, project::getInputSlots, filter::getConjuncts, - Stream.concat(filter.getExpressions().stream(), - project.getExpressions().stream()).collect(ImmutableSet.toImmutableSet())); + Suppliers.memoize(() -> + Utils.concatToSet(filter.getExpressions(), project.getExpressions()) + ) + ); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); - return new LogicalProject( + return new LogicalProject<>( generateProjectsAlias(project.getOutput(), slotContext), new ReplaceExpressions(slotContext).replace( project.withChildren(filter.withChildren(mvPlan)), mvPlan)); @@ -90,7 +93,7 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, project::getInputSlots, ImmutableSet::of, - new HashSet<>(project.getExpressions())); + () -> Utils.fastToImmutableSet(project.getExpressions())); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -107,8 +110,10 @@ public List buildRules() { LogicalOlapScan scan = filter.child(); LogicalOlapScan mvPlan = select( scan, filter::getOutputSet, filter::getConjuncts, - Stream.concat(filter.getExpressions().stream(), - filter.getOutputSet().stream()).collect(ImmutableSet.toImmutableSet())); + Suppliers.memoize(() -> + Utils.concatToSet(filter.getExpressions(), filter.getOutputSet()) + ) + ); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -127,7 +132,8 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, project::getInputSlots, ImmutableSet::of, - new HashSet<>(project.getExpressions())); + Suppliers.memoize(() -> Utils.fastToImmutableSet(project.getExpressions())) + ); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -145,7 +151,7 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, scan::getOutputSet, ImmutableSet::of, - scan.getOutputSet()); + () -> (Set) scan.getOutputSet()); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -169,7 +175,7 @@ public static LogicalOlapScan select( LogicalOlapScan scan, Supplier> requiredScanOutputSupplier, Supplier> predicatesSupplier, - Set requiredExpr) { + Supplier> requiredExpr) { OlapTable table = scan.getTable(); long baseIndexId = table.getBaseIndexId(); KeysType keysType = scan.getTable().getKeysType(); @@ -186,21 +192,24 @@ public static LogicalOlapScan select( throw new RuntimeException("Not supported keys type: " + keysType); } - Set requiredSlots = new HashSet<>(); - requiredSlots.addAll(requiredScanOutputSupplier.get()); - requiredSlots.addAll(ExpressionUtils.getInputSlotSet(requiredExpr)); - requiredSlots.addAll(ExpressionUtils.getInputSlotSet(predicatesSupplier.get())); + Supplier> requiredSlots = Suppliers.memoize(() -> { + Set set = new HashSet<>(); + set.addAll(requiredScanOutputSupplier.get()); + set.addAll(ExpressionUtils.getInputSlotSet(requiredExpr.get())); + set.addAll(ExpressionUtils.getInputSlotSet(predicatesSupplier.get())); + return set; + }); if (scan.getTable().isDupKeysOrMergeOnWrite()) { // Set pre-aggregation to `on` to keep consistency with legacy logic. List candidates = scan .getTable().getVisibleIndex().stream().filter(index -> index.getId() != baseIndexId) .filter(index -> !indexHasAggregate(index, scan)).filter(index -> containAllRequiredColumns(index, - scan, requiredScanOutputSupplier.get(), requiredExpr, predicatesSupplier.get())) + scan, requiredScanOutputSupplier.get(), requiredExpr.get(), predicatesSupplier.get())) .collect(Collectors.toList()); long bestIndex = selectBestIndex(candidates, scan, predicatesSupplier.get()); // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex - bestIndex = isSameDataType(scan, bestIndex, requiredSlots) ? bestIndex : baseIndexId; + bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; return scan.withMaterializedIndexSelected(PreAggStatus.on(), bestIndex); } else { final PreAggStatus preAggStatus; @@ -221,7 +230,7 @@ public static LogicalOlapScan select( List candidates = table.getVisibleIndex().stream() .filter(index -> table.getKeyColumnsByIndexId(index.getId()).size() == baseIndexKeySize) .filter(index -> containAllRequiredColumns(index, scan, requiredScanOutputSupplier.get(), - requiredExpr, predicatesSupplier.get())) + requiredExpr.get(), predicatesSupplier.get())) .collect(Collectors.toList()); if (candidates.size() == 1) { @@ -231,7 +240,7 @@ public static LogicalOlapScan select( long bestIndex = selectBestIndex(candidates, scan, predicatesSupplier.get()); // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex - bestIndex = isSameDataType(scan, bestIndex, requiredSlots) ? bestIndex : baseIndexId; + bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; return scan.withMaterializedIndexSelected(preAggStatus, bestIndex); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 2ed1afc56772c2..57a79037d801d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -132,6 +132,7 @@ import org.apache.doris.statistics.StatisticsBuilder; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.commons.collections.CollectionUtils; @@ -753,8 +754,14 @@ private ColumnStatistic getColumnStatistic(TableIf table, String colName, long i // 2. Consider the influence of runtime filter // 3. Get NDV and column data size from StatisticManger, StatisticManager doesn't support it now. private Statistics computeCatalogRelation(CatalogRelation catalogRelation) { - Set slotSet = catalogRelation.getOutput().stream().filter(SlotReference.class::isInstance) - .map(s -> (SlotReference) s).collect(Collectors.toSet()); + List output = catalogRelation.getOutput(); + ImmutableSet.Builder slotSetBuilder = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + if (slot instanceof SlotReference) { + slotSetBuilder.add((SlotReference) slot); + } + } + Set slotSet = slotSetBuilder.build(); Map columnStatisticMap = new HashMap<>(); TableIf table = catalogRelation.getTable(); double rowCount = catalogRelation.getTable().getRowCountForNereids(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java index 59d2acbe22bd6b..92bbcdb9b38fd0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java @@ -17,9 +17,12 @@ package org.apache.doris.nereids.trees; +import org.apache.doris.nereids.util.MutableState; +import org.apache.doris.nereids.util.MutableState.EmptyMutableState; import org.apache.doris.nereids.util.Utils; import java.util.List; +import java.util.Optional; /** * Abstract class for plan node in Nereids, include plan node and expression. @@ -30,8 +33,13 @@ public abstract class AbstractTreeNode> implements TreeNode { protected final List children; - // TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression. - // https://github.com/apache/doris/pull/9807#discussion_r884829067 + + // this field is special, because other fields in tree node is immutable, but in some scenes, mutable + // state is necessary. e.g. the rewrite framework need distinguish whether the plan is created by + // rules, the framework can set this field to a state variable to quickly judge without new big plan. + // we should avoid using it as much as possible, because mutable state is easy to cause bugs and + // difficult to locate. + private MutableState mutableState = EmptyMutableState.INSTANCE; protected AbstractTreeNode(NODE_TYPE... children) { // NOTE: ImmutableList.copyOf has additional clone of the list, so here we @@ -55,6 +63,16 @@ public List children() { return children; } + @Override + public Optional getMutableState(String key) { + return mutableState.get(key); + } + + @Override + public void setMutableState(String key, Object state) { + this.mutableState = this.mutableState.set(key, state); + } + public int arity() { return children.size(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index d37070865e22eb..6d1a298eb79fe2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -28,11 +28,13 @@ import java.util.Deque; import java.util.LinkedList; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; /** * interface for all node in Nereids, include plan node and expression. @@ -48,6 +50,21 @@ public interface TreeNode> { int arity(); + Optional getMutableState(String key); + + /** getOrInitMutableState */ + default T getOrInitMutableState(String key, Supplier initState) { + Optional mutableState = getMutableState(key); + if (!mutableState.isPresent()) { + T state = initState.get(); + setMutableState(key, state); + return state; + } + return mutableState.get(); + } + + void setMutableState(String key, Object value); + default NODE_TYPE withChildren(NODE_TYPE... children) { return withChildren(Utils.fastToImmutableList(children)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java index 01a61d576d25aa..750f3a77881430 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java @@ -24,7 +24,6 @@ import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.Objects; /** * Abstract for all binary operator, include binary arithmetic, compound predicate, comparison predicate. @@ -63,9 +62,4 @@ public String toString() { public String shapeInfo() { return "(" + left().shapeInfo() + " " + symbol + " " + right().shapeInfo() + ")"; } - - @Override - public int hashCode() { - return Objects.hash(symbol, left(), right()); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java index c9d10bde36d3c1..d343f6f93566cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java @@ -61,10 +61,10 @@ public DataType inputType() { @Override public void checkLegalityBeforeTypeCoercion() { - children().forEach(c -> { + for (Expression c : children) { if (c.getDataType().isComplexType() && !c.getDataType().isArrayType()) { throw new AnalysisException("comparison predicate could not contains complex type: " + this.toSql()); } - }); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index a7947c82a565bf..75cef0fc94677b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -39,6 +39,7 @@ import org.apache.doris.nereids.types.MapType; import org.apache.doris.nereids.types.StructField; import org.apache.doris.nereids.types.StructType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; @@ -70,20 +71,43 @@ public abstract class Expression extends AbstractTreeNode implements protected Expression(Expression... children) { super(children); - int maxChildDepth = 0; - int sumChildWidth = 0; + boolean hasUnbound = false; - boolean compareWidthAndDepth = true; - for (int i = 0; i < children.length; ++i) { - Expression child = children[i]; - maxChildDepth = Math.max(child.depth, maxChildDepth); - sumChildWidth += child.width; - hasUnbound |= child.hasUnbound; - compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); + switch (children.length) { + case 0: + this.depth = 1; + this.width = 1; + this.compareWidthAndDepth = supportCompareWidthAndDepth(); + break; + case 1: + Expression child = children[0]; + this.depth = child.depth + 1; + this.width = child.width; + this.compareWidthAndDepth = child.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + case 2: + Expression left = children[0]; + Expression right = children[1]; + this.depth = Math.max(left.depth, right.depth) + 1; + this.width = left.width + right.width; + this.compareWidthAndDepth = + left.compareWidthAndDepth && right.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + default: + int maxChildDepth = 0; + int sumChildWidth = 0; + boolean compareWidthAndDepth = true; + for (Expression expression : children) { + child = expression; + maxChildDepth = Math.max(child.depth, maxChildDepth); + sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= child.compareWidthAndDepth; + } + this.depth = maxChildDepth + 1; + this.width = sumChildWidth; + this.compareWidthAndDepth = compareWidthAndDepth; } - this.depth = maxChildDepth + 1; - this.width = sumChildWidth + ((children.length == 0) ? 1 : 0); - this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = false; @@ -96,20 +120,43 @@ protected Expression(List children) { protected Expression(List children, boolean inferred) { super(children); - int maxChildDepth = 0; - int sumChildWidth = 0; + boolean hasUnbound = false; - boolean compareWidthAndDepth = true; - for (int i = 0; i < children.size(); ++i) { - Expression child = children.get(i); - maxChildDepth = Math.max(child.depth, maxChildDepth); - sumChildWidth += child.width; - hasUnbound |= child.hasUnbound; - compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); + switch (children.size()) { + case 0: + this.depth = 1; + this.width = 1; + this.compareWidthAndDepth = supportCompareWidthAndDepth(); + break; + case 1: + Expression child = children.get(0); + this.depth = child.depth + 1; + this.width = child.width; + this.compareWidthAndDepth = child.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + case 2: + Expression left = children.get(0); + Expression right = children.get(1); + this.depth = Math.max(left.depth, right.depth) + 1; + this.width = left.width + right.width; + this.compareWidthAndDepth = + left.compareWidthAndDepth && right.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + default: + int maxChildDepth = 0; + int sumChildWidth = 0; + boolean compareWidthAndDepth = true; + for (Expression expression : children) { + child = expression; + maxChildDepth = Math.max(child.depth, maxChildDepth); + sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= child.compareWidthAndDepth; + } + this.depth = maxChildDepth + 1; + this.width = sumChildWidth; + this.compareWidthAndDepth = compareWidthAndDepth && supportCompareWidthAndDepth(); } - this.depth = maxChildDepth + 1; - this.width = sumChildWidth + ((children.isEmpty()) ? 1 : 0); - this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = inferred; @@ -284,7 +331,7 @@ public boolean isConstant() { if (this instanceof LeafExpression) { return this instanceof Literal; } else { - return !(this instanceof Nondeterministic) && children().stream().allMatch(Expression::isConstant); + return !(this instanceof Nondeterministic) && ExpressionUtils.allMatch(children(), Expression::isConstant); } } @@ -376,7 +423,7 @@ protected boolean extraEquals(Expression that) { @Override public int hashCode() { - return 0; + return getClass().hashCode(); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java index bcebdca4f5b8a5..53a753c4535dd1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import java.util.Collection; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -42,13 +43,13 @@ public class InPredicate extends Expression { private final Expression compareExpr; private final List options; - public InPredicate(Expression compareExpr, List options) { + public InPredicate(Expression compareExpr, Collection options) { super(new Builder().add(compareExpr).addAll(options).build()); this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null"); this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); } - public InPredicate(Expression compareExpr, List options, boolean inferred) { + public InPredicate(Expression compareExpr, Collection options, boolean inferred) { super(new Builder().add(compareExpr).addAll(options).build(), inferred); this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null"); this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 7cfaad72a2c546..760836455127ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -132,8 +132,8 @@ public static SlotReference of(String name, DataType type) { */ public static SlotReference fromColumn(TableIf table, Column column, List qualifier, Relation relation) { DataType dataType = DataType.fromCatalogType(column.getType()); - SlotReference slot = new SlotReference(StatementScopeIdGenerator.newExprId(), column.getName(), dataType, - column.isAllowNull(), qualifier, table, column, Optional.empty(), null); + SlotReference slot = new SlotReference(StatementScopeIdGenerator.newExprId(), () -> column.getName(), dataType, + column.isAllowNull(), qualifier, table, column, () -> Optional.of(column.getName()), null); if (relation != null && ConnectContext.get() != null && ConnectContext.get().getStatementContext() != null) { ConnectContext.get().getStatementContext().addSlotToRelation(slot, relation); @@ -260,6 +260,9 @@ public SlotReference withQualifier(List qualifier) { @Override public SlotReference withName(String name) { + if (this.name.get().equals(name)) { + return this; + } return new SlotReference( exprId, () -> name, dataType, nullable, qualifier, table, column, internalName, subColPath); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java index 6d0a5d85de5557..2cdbe43c12ecb5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java @@ -38,6 +38,8 @@ import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -413,9 +415,12 @@ private static FunctionSignature defaultDateTimeV2PrecisionPromotion( return signature; } DateTimeV2Type argType = finalType; - List newArgTypes = signature.argumentsTypes.stream() - .map(at -> TypeCoercionUtils.replaceDateTimeV2WithTarget(at, argType)) - .collect(Collectors.toList()); + + ImmutableList.Builder newArgTypesBuilder = ImmutableList.builderWithExpectedSize(signature.arity); + for (DataType at : signature.argumentsTypes) { + newArgTypesBuilder.add(TypeCoercionUtils.replaceDateTimeV2WithTarget(at, argType)); + } + List newArgTypes = newArgTypesBuilder.build(); signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes); signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes); if (signature.returnType instanceof DateTimeV2Type) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index 4f53b383d244eb..e45d3fb4da8b0b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -108,11 +108,18 @@ public boolean hasVarArguments() { @Override public String toSql() throws UnboundException { - String args = children() - .stream() - .map(Expression::toSql) - .collect(Collectors.joining(", ")); - return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")"; + StringBuilder sql = new StringBuilder(getName()).append("("); + if (distinct) { + sql.append("DISTINCT "); + } + int arity = arity(); + for (int i = 0; i < arity; i++) { + sql.append(child(i).toSql()); + if (i + 1 < arity) { + sql.append(", "); + } + } + return sql.append(")").toString(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java index 81678153cd6206..d8e3642c36ffdc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java @@ -43,10 +43,9 @@ public PushDownToProjectionFunction(String name, Expression... arguments) { */ public static boolean validToPushDown(Expression pushDownExpr) { // Currently only element at for variant type could be pushed down - return pushDownExpr != null && !pushDownExpr.collectToList( - PushDownToProjectionFunction.class::isInstance).stream().filter( - x -> ((Expression) x).getDataType().isVariantType()).collect( - Collectors.toList()).isEmpty(); + return pushDownExpr != null && pushDownExpr.anyMatch(expr -> + expr instanceof PushDownToProjectionFunction && ((Expression) expr).getDataType().isVariantType() + ); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java index a33cc32c16f2a6..38951ea9e453b4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java @@ -31,6 +31,7 @@ import com.google.common.collect.ImmutableSet; +import java.time.LocalDate; import java.time.LocalDateTime; import java.time.Year; import java.time.temporal.ChronoField; @@ -158,7 +159,9 @@ private static void replacePunctuation(String s, StringBuilder sb, char c, int i static String normalize(String s) { // merge consecutive space - s = s.replaceAll(" +", " "); + if (s.contains(" ")) { + s = s.replaceAll(" +", " "); + } StringBuilder sb = new StringBuilder(); @@ -261,6 +264,14 @@ static String normalize(String s) { } protected static TemporalAccessor parse(String s) { + // fast parse '2022-01-01' + if (s.length() == 10 && s.charAt(4) == '-' && s.charAt(7) == '-') { + TemporalAccessor date = fastParseDate(s); + if (date != null) { + return date; + } + } + String originalString = s; try { TemporalAccessor dateTime; @@ -477,4 +488,30 @@ public DateTimeLiteral toBeginOfTomorrow() { return toEndOfTheDay(); } } + + private static TemporalAccessor fastParseDate(String date) { + Integer year = readNextInt(date, 0, 4); + Integer month = readNextInt(date, 5, 2); + Integer day = readNextInt(date, 8, 2); + if (year != null && month != null && day != null) { + return LocalDate.of(year, month, day); + } else { + return null; + } + } + + private static Integer readNextInt(String str, int offset, int readLength) { + int value = 0; + int realReadLength = 0; + for (int i = offset; i < str.length(); i++) { + char c = str.charAt(i); + if ('0' <= c && c <= '9') { + realReadLength++; + value = value * 10 + (c - '0'); + } else { + break; + } + } + return readLength == realReadLength ? value : null; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java index 2248666dbca12f..fd25f9368ef0b5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java @@ -34,13 +34,13 @@ public Expression visit(Expression expr, C context) { } /** rewriteChildren */ - public static final Expression rewriteChildren( - ExpressionVisitor rewriter, Expression expr, C context) { + public static final E rewriteChildren( + ExpressionVisitor rewriter, E expr, C context) { switch (expr.arity()) { case 1: { Expression originChild = expr.child(0); Expression newChild = originChild.accept(rewriter, context); - return (originChild != newChild) ? expr.withChildren(ImmutableList.of(newChild)) : expr; + return (originChild != newChild) ? (E) expr.withChildren(ImmutableList.of(newChild)) : expr; } case 2: { Expression originLeft = expr.child(0); @@ -48,7 +48,7 @@ public static final Expression rewriteChildren( Expression originRight = expr.child(1); Expression newRight = originRight.accept(rewriter, context); return (originLeft != newLeft || originRight != newRight) - ? expr.withChildren(ImmutableList.of(newLeft, newRight)) + ? (E) expr.withChildren(ImmutableList.of(newLeft, newRight)) : expr; } case 0: { @@ -64,7 +64,7 @@ public static final Expression rewriteChildren( } newChildren.add(newChild); } - return hasNewChildren ? expr.withChildren(newChildren.build()) : expr; + return hasNewChildren ? (E) expr.withChildren(newChildren.build()) : expr; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java index 4be6d35dc94692..286a92aab768f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.MutableState; -import org.apache.doris.nereids.util.MutableState.EmptyMutableState; import org.apache.doris.nereids.util.TreeStringUtils; import org.apache.doris.statistics.Statistics; @@ -58,13 +57,6 @@ public abstract class AbstractPlan extends AbstractTreeNode implements Pla protected final Optional groupExpression; protected final Supplier logicalPropertiesSupplier; - // this field is special, because other fields in tree node is immutable, but in some scenes, mutable - // state is necessary. e.g. the rewrite framework need distinguish whether the plan is created by - // rules, the framework can set this field to a state variable to quickly judge without new big plan. - // we should avoid using it as much as possible, because mutable state is easy to cause bugs and - // difficult to locate. - private MutableState mutableState = EmptyMutableState.INSTANCE; - /** * all parameter constructor. */ @@ -108,7 +100,15 @@ public Statistics getStats() { @Override public boolean canBind() { - return !bound() && children().stream().allMatch(Plan::bound); + if (bound()) { + return false; + } + for (Plan child : children()) { + if (!child.bound()) { + return false; + } + } + return true; } /** @@ -185,16 +185,6 @@ public LogicalProperties computeLogicalProperties() { } } - @Override - public Optional getMutableState(String key) { - return mutableState.get(key); - } - - @Override - public void setMutableState(String key, Object state) { - this.mutableState = this.mutableState.set(key, state); - } - public int getId() { return id.asInt(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java index 1b237c72fdc207..d73b7390ce8d59 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java @@ -23,10 +23,10 @@ import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.MutableState; +import org.apache.doris.nereids.util.PlanUtils; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; @@ -36,8 +36,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.function.Supplier; -import java.util.stream.Collectors; /** * Abstract class for all plan node. @@ -104,12 +102,46 @@ default Set getOutputSet() { return ImmutableSet.copyOf(getOutput()); } + /** getOutputExprIds */ default List getOutputExprIds() { - return getOutput().stream().map(NamedExpression::getExprId).collect(Collectors.toList()); + List output = getOutput(); + ImmutableList.Builder exprIds = ImmutableList.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIds.add(slot.getExprId()); + } + return exprIds.build(); } + /** getOutputExprIdSet */ default Set getOutputExprIdSet() { - return getOutput().stream().map(NamedExpression::getExprId).collect(Collectors.toSet()); + List output = getOutput(); + ImmutableSet.Builder exprIds = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIds.add(slot.getExprId()); + } + return exprIds.build(); + } + + /** getChildrenOutputExprIdSet */ + default Set getChildrenOutputExprIdSet() { + switch (arity()) { + case 0: return ImmutableSet.of(); + case 1: return child(0).getOutputExprIdSet(); + default: { + int exprIdSize = 0; + for (Plan child : children()) { + exprIdSize += child.getOutput().size(); + } + + ImmutableSet.Builder exprIds = ImmutableSet.builderWithExpectedSize(exprIdSize); + for (Plan child : children()) { + for (Slot slot : child.getOutput()) { + exprIds.add(slot.getExprId()); + } + } + return exprIds.build(); + } + } } /** @@ -119,9 +151,7 @@ default Set getOutputExprIdSet() { * Note that the input slots of subquery's inner plan are not included. */ default Set getInputSlots() { - return getExpressions().stream() - .flatMap(expr -> expr.getInputSlots().stream()) - .collect(ImmutableSet.toImmutableSet()); + return PlanUtils.fastGetInputSlots(this.getExpressions()); } default List computeOutput() { @@ -147,21 +177,6 @@ default Set getInputRelations() { Plan withGroupExprLogicalPropChildren(Optional groupExpression, Optional logicalProperties, List children); - Optional getMutableState(String key); - - /** getOrInitMutableState */ - default T getOrInitMutableState(String key, Supplier initState) { - Optional mutableState = getMutableState(key); - if (!mutableState.isPresent()) { - T state = initState.get(); - setMutableState(key, state); - return state; - } - return mutableState.get(); - } - - void setMutableState(String key, Object value); - /** * a simple version of explain, used to verify plan shape * @param prefix " " diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java index 15fd5bec868eeb..e7d09b8cf8b9ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java @@ -53,10 +53,19 @@ default Set getAggregateFunctions() { return ExpressionUtils.collect(getOutputExpressions(), AggregateFunction.class::isInstance); } + /** getDistinctArguments */ default Set getDistinctArguments() { - return getAggregateFunctions().stream() - .filter(AggregateFunction::isDistinct) - .flatMap(aggregateFunction -> aggregateFunction.getDistinctArguments().stream()) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder distinctArguments = ImmutableSet.builder(); + for (NamedExpression outputExpression : getOutputExpressions()) { + outputExpression.foreach(expr -> { + if (expr instanceof AggregateFunction) { + AggregateFunction aggFun = (AggregateFunction) expr; + if (aggFun.isDistinct()) { + distinctArguments.addAll(aggFun.getDistinctArguments()); + } + } + }); + } + return distinctArguments.build(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java index b734bba576df26..7fa62f7628fc2d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java @@ -74,15 +74,24 @@ default List mergeProjections(Project childProject) { * And check if contains PushDownToProjectionFunction that can pushed down to project */ default boolean hasPushedDownToProjectionFunctions() { - return ConnectContext.get() != null - && ConnectContext.get().getSessionVariable() != null - && ConnectContext.get().getSessionVariable().isEnableRewriteElementAtToSlot() - && getProjects().stream().allMatch(namedExpr -> - namedExpr instanceof SlotReference - || (namedExpr instanceof Alias - && PushDownToProjectionFunction.validToPushDown(((Alias) namedExpr).child()))) - && getProjects().stream().anyMatch((namedExpr -> namedExpr instanceof Alias - && PushDownToProjectionFunction.validToPushDown(((Alias) namedExpr).child()))); + if ((ConnectContext.get() == null + || ConnectContext.get().getSessionVariable() == null + || !ConnectContext.get().getSessionVariable().isEnableRewriteElementAtToSlot())) { + return false; + } + + boolean hasValidAlias = false; + for (NamedExpression namedExpr : getProjects()) { + if (namedExpr instanceof Alias) { + if (!PushDownToProjectionFunction.validToPushDown(((Alias) namedExpr).child())) { + return false; + } + hasValidAlias = true; + } else if (!(namedExpr instanceof SlotReference)) { + return false; + } + } + return hasValidAlias; } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 20647a3808ebe7..fa4f891e7a20b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -200,9 +200,11 @@ public String toString() { @Override public List computeOutput() { - return outputExpressions.stream() - .map(NamedExpression::toSlot) - .collect(ImmutableList.toImmutableList()); + ImmutableList.Builder outputSlots = ImmutableList.builderWithExpectedSize(outputExpressions.size()); + for (NamedExpression outputExpression : outputExpressions) { + outputSlots.add(outputExpression.toSlot()); + } + return outputSlots.build(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java index 4076e8348e2208..b4dbc9444604da 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java @@ -22,6 +22,8 @@ import org.apache.doris.catalog.Env; import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.TableIf; +import org.apache.doris.catalog.constraint.PrimaryKeyConstraint; +import org.apache.doris.catalog.constraint.UniqueConstraint; import org.apache.doris.datasource.CatalogIf; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.memo.GroupExpression; @@ -31,7 +33,6 @@ import org.apache.doris.nereids.properties.FunctionalDependencies.Builder; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.TableFdItem; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.PlanType; @@ -128,85 +129,80 @@ public String qualifiedName() { @Override public FunctionalDependencies computeFuncDeps(Supplier> outputSupplier) { Builder fdBuilder = new Builder(); - Set output = ImmutableSet.copyOf(outputSupplier.get()); + Set outputSet = Utils.fastToImmutableSet(outputSupplier.get()); if (table instanceof OlapTable && ((OlapTable) table).getKeysType().isAggregationFamily()) { - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && s.getColumn().get().isKey()) - .collect(ImmutableSet.toImmutableSet()); - fdBuilder.addUniqueSlot(slotSet); + ImmutableSet.Builder uniqSlots = ImmutableSet.builderWithExpectedSize(outputSet.size()); + for (Slot slot : outputSet) { + if (!(slot instanceof SlotReference)) { + continue; + } + SlotReference slotRef = (SlotReference) slot; + if (slotRef.getColumn().isPresent() && slotRef.getColumn().get().isKey()) { + uniqSlots.add(slot); + } + } + fdBuilder.addUniqueSlot(uniqSlots.build()); } - table.getPrimaryKeyConstraints().forEach(c -> { - Set columns = c.getPrimaryKeys(this.getTable()); - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - fdBuilder.addUniqueSlot(slotSet); - }); - table.getUniqueConstraints().forEach(c -> { - Set columns = c.getUniqueKeys(this.getTable()); - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - fdBuilder.addUniqueSlot(slotSet); - }); - ImmutableSet fdItems = computeFdItems(outputSupplier); - fdBuilder.addFdItems(fdItems); + + for (PrimaryKeyConstraint c : table.getPrimaryKeyConstraints()) { + Set columns = c.getPrimaryKeys(table); + fdBuilder.addUniqueSlot((ImmutableSet) findSlotsByColumn(outputSet, columns)); + } + + for (UniqueConstraint c : table.getUniqueConstraints()) { + Set columns = c.getUniqueKeys(table); + fdBuilder.addUniqueSlot((ImmutableSet) findSlotsByColumn(outputSet, columns)); + } + fdBuilder.addFdItems(computeFdItems(outputSet)); return fdBuilder.build(); } @Override public ImmutableSet computeFdItems(Supplier> outputSupplier) { - Set output = ImmutableSet.copyOf(outputSupplier.get()); + return computeFdItems(Utils.fastToImmutableSet(outputSupplier.get())); + } + + private ImmutableSet computeFdItems(Set outputSet) { ImmutableSet.Builder builder = ImmutableSet.builder(); - table.getPrimaryKeyConstraints().forEach(c -> { + + for (PrimaryKeyConstraint c : table.getPrimaryKeyConstraints()) { Set columns = c.getPrimaryKeys(this.getTable()); - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem(slotSet, true, - false, ImmutableSet.of(table)); + ImmutableSet slotSet = findSlotsByColumn(outputSet, columns); + TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem( + slotSet, true, false, ImmutableSet.of(table)); builder.add(tableFdItem); - }); - table.getUniqueConstraints().forEach(c -> { + } + + for (UniqueConstraint c : table.getUniqueConstraints()) { Set columns = c.getUniqueKeys(this.getTable()); - boolean allNotNull = columns.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .allMatch(s -> !s.nullable()); - if (allNotNull) { - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem(slotSet, - true, false, ImmutableSet.of(table)); - builder.add(tableFdItem); - } else { - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem(slotSet, - true, true, ImmutableSet.of(table)); - builder.add(tableFdItem); + boolean allNotNull = true; + + for (Column column : columns) { + if (column.isAllowNull()) { + allNotNull = false; + break; + } } - }); + + ImmutableSet slotSet = findSlotsByColumn(outputSet, columns); + TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem( + slotSet, true, !allNotNull, ImmutableSet.of(table)); + builder.add(tableFdItem); + } return builder.build(); } + + private ImmutableSet findSlotsByColumn(Set outputSet, Set columns) { + ImmutableSet.Builder slotSet = ImmutableSet.builderWithExpectedSize(columns.size()); + for (Slot slot : outputSet) { + if (!(slot instanceof SlotReference)) { + continue; + } + SlotReference slotRef = (SlotReference) slot; + if (slotRef.getColumn().isPresent() && columns.contains(slotRef.getColumn().get())) { + slotSet.add(slotRef); + } + } + return slotSet.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java index cb5d1847fef2a5..d83a2f59f7fb79 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java @@ -38,16 +38,17 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.commons.lang3.tuple.Pair; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Logical OlapScan. @@ -174,9 +175,23 @@ public LogicalOlapScan(RelationId id, Table table, List qualifier, this.indexSelected = indexSelected; this.preAggStatus = preAggStatus; this.manuallySpecifiedPartitions = ImmutableList.copyOf(specifiedPartitions); - this.selectedPartitionIds = selectedPartitionIds.stream() - .filter(partitionId -> this.getTable().getPartition(partitionId) != null) - .collect(Collectors.toList()); + + switch (selectedPartitionIds.size()) { + case 0: { + this.selectedPartitionIds = ImmutableList.of(); + break; + } + default: { + ImmutableList.Builder existPartitions + = ImmutableList.builderWithExpectedSize(selectedPartitionIds.size()); + for (Long partitionId : selectedPartitionIds) { + if (((OlapTable) table).getPartition(partitionId) != null) { + existPartitions.add(partitionId); + } + } + this.selectedPartitionIds = existPartitions.build(); + } + } this.hints = Objects.requireNonNull(hints, "hints can not be null"); this.cacheSlotWithSlotName = Objects.requireNonNull(cacheSlotWithSlotName, "mvNameToSlot can not be null"); @@ -333,14 +348,17 @@ public List computeOutput() { return getOutputByIndex(selectedIndexId); } List baseSchema = table.getBaseSchema(true); + List slotFromColumn = createSlotsVectorized(baseSchema); + Builder slots = ImmutableList.builder(); - for (Column col : baseSchema) { + for (int i = 0; i < baseSchema.size(); i++) { + Column col = baseSchema.get(i); Pair key = Pair.of(selectedIndexId, col.getName()); Slot slot = cacheSlotWithSlotName.get(key); if (slot != null) { slots.add(slot); } else { - slot = SlotReference.fromColumn(table, col, qualified(), this); + slot = slotFromColumn.get(i); cacheSlotWithSlotName.put(key, slot); slots.add(slot); } @@ -363,27 +381,27 @@ public List getOutputByIndex(long indexId) { OlapTable olapTable = (OlapTable) table; // PhysicalStorageLayerAggregateTest has no visible index // when we have a partitioned table without any partition, visible index is empty - if (-1 == indexId || olapTable.getIndexMetaByIndexId(indexId) == null) { - return olapTable.getIndexMetaByIndexId(indexId).getSchema().stream() - .map(c -> generateUniqueSlot(olapTable, c, - indexId == ((OlapTable) table).getBaseIndexId(), indexId)) - .collect(Collectors.toList()); + List schema = olapTable.getIndexMetaByIndexId(indexId).getSchema(); + List slots = Lists.newArrayListWithCapacity(schema.size()); + for (Column c : schema) { + Slot slot = generateUniqueSlot( + olapTable, c, indexId == ((OlapTable) table).getBaseIndexId(), indexId); + slots.add(slot); } - return olapTable.getIndexMetaByIndexId(indexId).getSchema().stream() - .map(s -> generateUniqueSlot(olapTable, s, - indexId == ((OlapTable) table).getBaseIndexId(), indexId)) - .collect(ImmutableList.toImmutableList()); + return slots; } private Slot generateUniqueSlot(OlapTable table, Column column, boolean isBaseIndex, long indexId) { String name = isBaseIndex || directMvScan ? column.getName() : AbstractSelectMaterializedIndexRule.parseMvColumnToMvName(column.getName(), column.isAggregated() ? Optional.of(column.getAggregationType().toSql()) : Optional.empty()); - if (cacheSlotWithSlotName.containsKey(Pair.of(indexId, name))) { - return cacheSlotWithSlotName.get(Pair.of(indexId, name)); + Pair key = Pair.of(indexId, name); + Slot slot = cacheSlotWithSlotName.get(key); + if (slot != null) { + return slot; } - Slot slot = SlotReference.fromColumn(table, column, name, qualified()); - cacheSlotWithSlotName.put(Pair.of(indexId, name), slot); + slot = SlotReference.fromColumn(table, column, name, qualified()); + cacheSlotWithSlotName.put(key, slot); return slot; } @@ -402,4 +420,13 @@ public Optional getTableSample() { public boolean isDirectMvScan() { return directMvScan; } + + private List createSlotsVectorized(List columns) { + List qualified = qualified(); + Object[] slots = new Object[columns.size()]; + for (int i = 0; i < columns.size(); i++) { + slots[i] = SlotReference.fromColumn(table, columns.get(i), qualified, this); + } + return (List) Arrays.asList(slots); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java index d899d228fb66bc..89dd7d49677d3e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java @@ -114,8 +114,14 @@ public List getExcepts() { return excepts; } + /** isAllSlots */ public boolean isAllSlots() { - return projects.stream().allMatch(NamedExpression::isSlot); + for (NamedExpression project : projects) { + if (!project.isSlot()) { + return false; + } + } + return true; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java index 9d9d321e659636..607fcf25bca7fe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java @@ -30,11 +30,13 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Supplier; /** * Logical Sort plan. @@ -47,6 +49,7 @@ public class LogicalSort extends LogicalUnary orderKeys; + private final Supplier> expressions; public LogicalSort(List orderKeys, CHILD_TYPE child) { this(orderKeys, Optional.empty(), Optional.empty(), child); @@ -58,7 +61,17 @@ public LogicalSort(List orderKeys, CHILD_TYPE child) { public LogicalSort(List orderKeys, Optional groupExpression, Optional logicalProperties, CHILD_TYPE child) { super(PlanType.LOGICAL_SORT, groupExpression, logicalProperties, child); - this.orderKeys = ImmutableList.copyOf(Objects.requireNonNull(orderKeys, "orderKeys can not be null")); + this.orderKeys = Utils.fastToImmutableList( + Objects.requireNonNull(orderKeys, "orderKeys can not be null") + ); + this.expressions = Suppliers.memoize(() -> { + ImmutableList.Builder exprs + = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey orderKey : orderKeys) { + exprs.add(orderKey.getExpr()); + } + return exprs.build(); + }); } @Override @@ -100,9 +113,7 @@ public R accept(PlanVisitor visitor, C context) { @Override public List getExpressions() { - return orderKeys.stream() - .map(OrderKey::getExpr) - .collect(ImmutableList.toImmutableList()); + return expressions.get(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java index 63e3dc9c0b9717..1fb5dbaab7271c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java @@ -35,6 +35,7 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -51,6 +52,7 @@ public class LogicalTopN extends LogicalUnary orderKeys; private final long limit; private final long offset; + private final Supplier> expressions; public LogicalTopN(List orderKeys, long limit, long offset, CHILD_TYPE child) { this(orderKeys, limit, offset, Optional.empty(), Optional.empty(), child); @@ -65,6 +67,13 @@ public LogicalTopN(List orderKeys, long limit, long offset, Optional { + ImmutableList.Builder exprs = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey orderKey : orderKeys) { + exprs.add(orderKey.getExpr()); + } + return exprs.build(); + }); } @Override @@ -120,9 +129,7 @@ public R accept(PlanVisitor visitor, C context) { @Override public List getExpressions() { - return orderKeys.stream() - .map(OrderKey::getExpr) - .collect(ImmutableList.toImmutableList()); + return expressions.get(); } public LogicalTopN withOrderKeys(List orderKeys) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java index 9b342c8e618060..d1a554820f796d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java @@ -246,7 +246,8 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator probExprList = Sets.newHashSet(probeExpr); + Set probExprList = Sets.newLinkedHashSet(); + probExprList.add(probeExpr); Pair srcPair = ctx.getAliasTransferMap().get(srcExpr); PhysicalRelation srcNode = (srcPair == null) ? null : srcPair.first; Pair targetPair = ctx.getAliasTransferMap().get(probeExpr); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 1f44d128b23be7..64685bc55e3f1d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; @@ -59,7 +60,6 @@ import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -89,7 +89,7 @@ public static List extractConjunction(Expression expr) { } public static Set extractConjunctionToSet(Expression expr) { - Set exprSet = Sets.newHashSet(); + Set exprSet = Sets.newLinkedHashSet(); extract(And.class, expr, exprSet); return exprSet; } @@ -162,11 +162,11 @@ public static Optional optionalAnd(Collection collection } public static Expression and(Collection expressions) { - return combine(And.class, expressions); + return combineAsLeftDeepTree(And.class, expressions); } public static Expression and(Expression... expressions) { - return combine(And.class, Lists.newArrayList(expressions)); + return combineAsLeftDeepTree(And.class, Lists.newArrayList(expressions)); } public static Optional optionalOr(List expressions) { @@ -178,17 +178,18 @@ public static Optional optionalOr(List expressions) { } public static Expression or(Expression... expressions) { - return combine(Or.class, Lists.newArrayList(expressions)); + return combineAsLeftDeepTree(Or.class, Lists.newArrayList(expressions)); } public static Expression or(Collection expressions) { - return combine(Or.class, expressions); + return combineAsLeftDeepTree(Or.class, expressions); } /** * Use AND/OR to combine expressions together. */ - public static Expression combine(Class type, Collection expressions) { + public static Expression combineAsLeftDeepTree( + Class type, Collection expressions) { /* * (AB) (CD) E ((AB)(CD)) E (((AB)(CD))E) * â–² â–² â–² â–² â–² â–² @@ -209,9 +210,20 @@ public static Expression combine(Class type, Collection extractSlotOrCastOnSlot(Expression expr) { * Generate replaceMap Slot -> Expression from NamedExpression[Expression as name] */ public static Map generateReplaceMap(List namedExpressions) { - ImmutableMap.Builder replaceMap = ImmutableMap.builderWithExpectedSize( - namedExpressions.size() * 2); + Map replaceMap = Maps.newLinkedHashMapWithExpectedSize(namedExpressions.size()); for (NamedExpression namedExpression : namedExpressions) { if (namedExpression instanceof Alias) { // Avoid cast to alias, retrieving the first child expression. - replaceMap.put(namedExpression.toSlot(), namedExpression.child(0)); + Slot slot = namedExpression.toSlot(); + replaceMap.putIfAbsent(slot, namedExpression.child(0)); } } - return replaceMap.build(); + return replaceMap; } /** @@ -360,16 +372,20 @@ public static Expression replace(Expression expr, Map replace(List exprs, Map replaceMap) { - return exprs.stream() - .map(expr -> replace(expr, replaceMap)) - .collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(replace(expr, replaceMap)); + } + return result.build(); } public static Set replace(Set exprs, Map replaceMap) { - return exprs.stream() - .map(expr -> replace(expr, replaceMap)) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder result = ImmutableSet.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(replace(expr, replaceMap)); + } + return result.build(); } /** @@ -456,34 +472,60 @@ public static List mergeArguments(Object... arguments) { return builder.build(); } + /** isAllLiteral */ public static boolean isAllLiteral(List children) { - return children.stream().allMatch(c -> c instanceof Literal); + for (Expression child : children) { + if (!(child instanceof Literal)) { + return false; + } + } + return true; } + /** matchNumericType */ public static boolean matchNumericType(List children) { - return children.stream().allMatch(c -> c.getDataType().isNumericType()); + for (Expression child : children) { + if (!child.getDataType().isNumericType()) { + return false; + } + } + return true; } + /** matchDateLikeType */ public static boolean matchDateLikeType(List children) { - return children.stream().allMatch(c -> c.getDataType().isDateLikeType()); + for (Expression child : children) { + if (!child.getDataType().isDateLikeType()) { + return false; + } + } + return true; } + /** hasNullLiteral */ public static boolean hasNullLiteral(List children) { - return children.stream().anyMatch(c -> c instanceof NullLiteral); + for (Expression child : children) { + if (child instanceof NullLiteral) { + return true; + } + } + return false; } + /** hasOnlyMetricType */ public static boolean hasOnlyMetricType(List children) { - return children.stream().anyMatch(c -> c.getDataType().isOnlyMetricType()); - } - - public static boolean isAllNullLiteral(List children) { - return children.stream().allMatch(c -> c instanceof NullLiteral); + for (Expression child : children) { + if (child.getDataType().isOnlyMetricType()) { + return true; + } + } + return false; } /** * canInferNotNullForMarkSlot */ - public static boolean canInferNotNullForMarkSlot(Expression predicate) { + public static boolean canInferNotNullForMarkSlot(Expression predicate, ExpressionRewriteContext ctx) { /* * assume predicate is from LogicalFilter * the idea is replacing each mark join slot with null and false literal then run FoldConstant rule @@ -523,9 +565,10 @@ public static boolean canInferNotNullForMarkSlot(Expression predicate) { for (int j = 0; j < markSlotSize; ++j) { replaceMap.put(markJoinSlotReferenceList.get(j), literals.get((i >> j) & 1)); } - Expression evalResult = FoldConstantRule.INSTANCE.rewrite( + Expression evalResult = FoldConstantRule.evaluate( ExpressionUtils.replace(predicate, replaceMap), - new ExpressionRewriteContext(null)); + ctx + ); if (evalResult.equals(BooleanLiteral.TRUE)) { if (meetNullOrFalse) { @@ -553,30 +596,33 @@ private static boolean isNullOrFalse(Expression expression) { * infer notNulls slot from predicate */ public static Set inferNotNullSlots(Set predicates, CascadesContext cascadesContext) { - Set notNullSlots = Sets.newHashSet(); + ImmutableSet.Builder notNullSlots = ImmutableSet.builderWithExpectedSize(predicates.size()); for (Expression predicate : predicates) { for (Slot slot : predicate.getInputSlots()) { Map replaceMap = new HashMap<>(); Literal nullLiteral = new NullLiteral(slot.getDataType()); replaceMap.put(slot, nullLiteral); - Expression evalExpr = FoldConstantRule.INSTANCE.rewrite( + Expression evalExpr = FoldConstantRule.evaluate( ExpressionUtils.replace(predicate, replaceMap), - new ExpressionRewriteContext(cascadesContext)); + new ExpressionRewriteContext(cascadesContext) + ); if (evalExpr.isNullLiteral() || BooleanLiteral.FALSE.equals(evalExpr)) { notNullSlots.add(slot); } } } - return notNullSlots; + return notNullSlots.build(); } /** * infer notNulls slot from predicate */ public static Set inferNotNull(Set predicates, CascadesContext cascadesContext) { - return inferNotNullSlots(predicates, cascadesContext).stream() - .map(slot -> new Not(new IsNull(slot), false)) - .collect(Collectors.toSet()); + ImmutableSet.Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size()); + for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) { + newPredicates.add(new Not(new IsNull(slot), false)); + } + return newPredicates.build(); } /** @@ -584,37 +630,90 @@ public static Set inferNotNull(Set predicates, CascadesC */ public static Set inferNotNull(Set predicates, Set slots, CascadesContext cascadesContext) { - return inferNotNullSlots(predicates, cascadesContext).stream() - .filter(slots::contains) - .map(slot -> new Not(new IsNull(slot), true)) - .collect(Collectors.toSet()); + ImmutableSet.Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size()); + for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) { + if (slots.contains(slot)) { + newPredicates.add(new Not(new IsNull(slot), true)); + } + } + return newPredicates.build(); } - public static List flatExpressions(List> expressions) { - return expressions.stream() - .flatMap(List::stream) - .collect(ImmutableList.toImmutableList()); + /** flatExpressions */ + public static List flatExpressions(List> expressionLists) { + int num = 0; + for (List expressionList : expressionLists) { + num += expressionList.size(); + } + + ImmutableList.Builder flatten = ImmutableList.builderWithExpectedSize(num); + for (List expressionList : expressionLists) { + flatten.addAll(expressionList); + } + return flatten.build(); + } + + /** containsType */ + public static boolean containsType(Collection expressions, Class type) { + for (Expression expression : expressions) { + if (expression.anyMatch(expr -> expr.anyMatch(type::isInstance))) { + return true; + } + } + return false; } - public static boolean anyMatch(List expressions, Predicate> predicate) { - return expressions.stream() - .anyMatch(expr -> expr.anyMatch(predicate)); + /** allMatch */ + public static boolean allMatch( + Collection expressions, Predicate predicate) { + for (Expression expression : expressions) { + if (!predicate.test(expression)) { + return false; + } + } + return true; } - public static boolean noneMatch(List expressions, Predicate> predicate) { - return expressions.stream() - .noneMatch(expr -> expr.anyMatch(predicate)); + /** anyMatch */ + public static boolean anyMatch( + Collection expressions, Predicate predicate) { + for (Expression expression : expressions) { + if (predicate.test(expression)) { + return true; + } + } + return false; } - public static boolean containsType(List expressions, Class type) { - return anyMatch(expressions, type::isInstance); + /** deapAnyMatch */ + public static boolean deapAnyMatch( + Collection expressions, Predicate> predicate) { + for (Expression expression : expressions) { + if (expression.anyMatch(expr -> expr.anyMatch(predicate))) { + return true; + } + } + return false; + } + + /** deapNoneMatch */ + public static boolean deapNoneMatch( + Collection expressions, Predicate> predicate) { + for (Expression expression : expressions) { + if (expression.anyMatch(expr -> expr.anyMatch(predicate))) { + return false; + } + } + return true; } public static Set collect(Collection expressions, Predicate> predicate) { - return expressions.stream() - .flatMap(expr -> expr.>collect(predicate).stream()) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder set = ImmutableSet.builder(); + for (Expression expr : expressions) { + set.addAll(expr.collectToList(predicate)); + } + return set.build(); } /** @@ -652,11 +751,19 @@ public static Set mutableCollect(List expressions, return set; } + /** collectAll */ public static List collectAll(Collection expressions, Predicate> predicate) { - return expressions.stream() - .flatMap(expr -> expr.>collect(predicate).stream()) - .collect(ImmutableList.toImmutableList()); + switch (expressions.size()) { + case 0: return ImmutableList.of(); + default: { + ImmutableList.Builder result = ImmutableList.builder(); + for (Expression expr : expressions) { + result.addAll((Set) expr.collect(predicate)); + } + return result.build(); + } + } } public static List> rollupToGroupingSets(List rollupExpressions) { @@ -807,4 +914,25 @@ public static List distinctSlotByName(List slots) { } return distinctSlots.build(); } + + /** containsWindowExpression */ + public static boolean containsWindowExpression(List expressions) { + for (NamedExpression expression : expressions) { + if (expression.anyMatch(WindowExpression.class::isInstance)) { + return true; + } + } + return false; + } + + /** filter */ + public static List filter(List expressions, Class clazz) { + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(expressions.size()); + for (Expression expression : expressions) { + if (clazz.isInstance(expression)) { + result.add((E) expression); + } + } + return result.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java index e8dcbe084d6f80..f5f3dd75b51bfc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java @@ -21,7 +21,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -44,8 +44,8 @@ public static ImmutableEqualSet empty() { * Builder for ImmutableEqualSet. */ public static class Builder { - private final Map parent = new HashMap<>(); - private final Map size = new HashMap<>(); + private final Map parent = new LinkedHashMap<>(); + private final Map size = new LinkedHashMap<>(); /** * Add a equal pair diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index a73c183e4e75c1..5c01fd4df9a87a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -47,6 +47,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -98,8 +99,8 @@ public static final class JoinSlotCoverageChecker { Set rightExprIds; public JoinSlotCoverageChecker(List left, List right) { - leftExprIds = left.stream().map(Slot::getExprId).collect(Collectors.toSet()); - rightExprIds = right.stream().map(Slot::getExprId).collect(Collectors.toSet()); + leftExprIds = left.stream().map(Slot::getExprId).collect(ImmutableSet.toImmutableSet()); + rightExprIds = right.stream().map(Slot::getExprId).collect(ImmutableSet.toImmutableSet()); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index 759b96c5b73047..3955b2d0f0c6af 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -155,6 +155,30 @@ public static List fastGetChildrenOutputs(List children) { return output.build(); } + /** fastGetInputSlots */ + public static Set fastGetInputSlots(List expressions) { + switch (expressions.size()) { + case 1: return expressions.get(0).getInputSlots(); + case 0: return ImmutableSet.of(); + default: { + } + } + + int inputSlotsNum = 0; + // child.inputSlots is cached by Expression.inputSlots, + // we can compute output num without the overhead of re-compute output + for (Expression expr : expressions) { + Set output = expr.getInputSlots(); + inputSlotsNum += output.size(); + } + // generate output list only copy once and without resize the list + ImmutableSet.Builder inputSlots = ImmutableSet.builderWithExpectedSize(inputSlotsNum); + for (Expression expr : expressions) { + inputSlots.addAll(expr.getInputSlots()); + } + return inputSlots.build(); + } + /** * Check if slot is from the plan. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 4ec437055a3e67..afcdb30f2dcf03 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -1047,11 +1047,20 @@ public static Expression processCaseWhen(CaseWhen caseWhen) { public static Expression processCompoundPredicate(CompoundPredicate compoundPredicate) { // check compoundPredicate.checkLegalityBeforeTypeCoercion(); - List children = compoundPredicate.children().stream() - .map(e -> e.getDataType().isNullType() ? new NullLiteral(BooleanType.INSTANCE) - : castIfNotSameType(e, BooleanType.INSTANCE)) - .collect(Collectors.toList()); - return compoundPredicate.withChildren(children); + ImmutableList.Builder newChildren + = ImmutableList.builderWithExpectedSize(compoundPredicate.arity()); + boolean changed = false; + for (Expression child : compoundPredicate.children()) { + Expression newChild; + if (child.getDataType().isNullType()) { + newChild = new NullLiteral(BooleanType.INSTANCE); + } else { + newChild = castIfNotSameType(child, BooleanType.INSTANCE); + } + changed |= child != newChild; + newChildren.add(newChild); + } + return changed ? compoundPredicate.withChildren(newChildren.build()) : compoundPredicate; } private static boolean canCompareDate(DataType t1, DataType t2) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java index df9528cc49e2d6..c28b18e697d34c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java @@ -26,6 +26,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.apache.commons.lang3.StringUtils; @@ -34,7 +35,9 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -325,24 +328,50 @@ public static ImmutableList fastToImmutableList(E[] array) { } /** fastToImmutableList */ - public static ImmutableList fastToImmutableList(List originList) { - if (originList instanceof ImmutableList) { - return (ImmutableList) originList; + public static ImmutableList fastToImmutableList(Collection collection) { + if (collection instanceof ImmutableList) { + return (ImmutableList) collection; } - switch (originList.size()) { + switch (collection.size()) { case 0: return ImmutableList.of(); - case 1: return ImmutableList.of(originList.get(0)); + case 1: + return collection instanceof List + ? ImmutableList.of(((List) collection).get(0)) + : ImmutableList.of(collection.iterator().next()); default: { // NOTE: ImmutableList.copyOf(list) has additional clone of the list, so here we // direct generate a ImmutableList - Builder copyChildren = ImmutableList.builderWithExpectedSize(originList.size()); - copyChildren.addAll(originList); + Builder copyChildren = ImmutableList.builderWithExpectedSize(collection.size()); + copyChildren.addAll(collection); return copyChildren.build(); } } } + /** fastToImmutableSet */ + public static ImmutableSet fastToImmutableSet(Collection collection) { + if (collection instanceof ImmutableSet) { + return (ImmutableSet) collection; + } + switch (collection.size()) { + case 0: + return ImmutableSet.of(); + case 1: + return collection instanceof List + ? ImmutableSet.of(((List) collection).get(0)) + : ImmutableSet.of(collection.iterator().next()); + default: + // NOTE: ImmutableList.copyOf(array) has additional clone of the array, so here we + // direct generate a ImmutableList + ImmutableSet.Builder copyChildren = ImmutableSet.builderWithExpectedSize(collection.size()); + for (E child : collection) { + copyChildren.add(child); + } + return copyChildren.build(); + } + } + /** reverseImmutableList */ public static ImmutableList reverseImmutableList(List list) { Builder reverseList = ImmutableList.builderWithExpectedSize(list.size()); @@ -363,4 +392,26 @@ public static ImmutableList filterImmutableList(List list, P } return newList.build(); } + + /** concatToSet */ + public static Set concatToSet(Collection left, Collection right) { + ImmutableSet.Builder required = ImmutableSet.builderWithExpectedSize( + left.size() + right.size() + ); + required.addAll(left); + required.addAll(right); + return required.build(); + } + + /** fastReduce */ + public static Optional fastReduce(List list, BiFunction reduceOp) { + if (list.isEmpty()) { + return Optional.empty(); + } + M merge = list.get(0); + for (int i = 1; i < list.size(); i++) { + merge = reduceOp.apply(merge, list.get(i)); + } + return Optional.of(merge); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 07357d98521edf..7ae0b9c6301a69 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -58,6 +58,7 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.Arrays; +import java.util.BitSet; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -300,6 +301,7 @@ public class SessionVariable implements Serializable, Writable { public static final String NEREIDS_CBO_PENALTY_FACTOR = "nereids_cbo_penalty_factor"; public static final String ENABLE_NEREIDS_TRACE = "enable_nereids_trace"; + public static final String ENABLE_EXPR_TRACE = "enable_expr_trace"; public static final String ENABLE_DPHYP_TRACE = "enable_dphyp_trace"; @@ -1136,6 +1138,9 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) { @VariableMgr.VarAttr(name = ENABLE_NEREIDS_TRACE) private boolean enableNereidsTrace = false; + @VariableMgr.VarAttr(name = ENABLE_EXPR_TRACE) + private boolean enableExprTrace = false; + @VariableMgr.VarAttr(name = ENABLE_DPHYP_TRACE, needForward = true) public boolean enableDpHypTrace = false; @@ -2733,15 +2738,20 @@ public Set getDisableNereidsRuleNames() { .collect(ImmutableSet.toImmutableSet()); } - public Set getDisableNereidsRules() { - return Arrays.stream(disableNereidsRules.split(",[\\s]*")) - .filter(rule -> !rule.isEmpty()) - .map(rule -> rule.toUpperCase(Locale.ROOT)) - .map(rule -> RuleType.valueOf(rule)) - .filter(ruleType -> ruleType != RuleType.CHECK_PRIVILEGES - && ruleType != RuleType.CHECK_ROW_POLICY) - .map(RuleType::type) - .collect(ImmutableSet.toImmutableSet()); + public BitSet getDisableNereidsRules() { + BitSet bitSet = new BitSet(); + for (String ruleName : disableNereidsRules.split(",[\\s]*")) { + if (ruleName.isEmpty()) { + continue; + } + ruleName = ruleName.toUpperCase(Locale.ROOT); + RuleType ruleType = RuleType.valueOf(ruleName); + if (ruleType == RuleType.CHECK_PRIVILEGES || ruleType == RuleType.CHECK_ROW_POLICY) { + continue; + } + bitSet.set(ruleType.type()); + } + return bitSet; } public Set getEnableNereidsRules() { @@ -2776,6 +2786,16 @@ public boolean isEnableNereidsTrace() { return isEnableNereidsPlanner() && enableNereidsTrace; } + public void setEnableExprTrace(boolean enableExprTrace) { + this.enableExprTrace = enableExprTrace; + } + + public boolean isEnableExprTrace() { + return enableExprTrace; + } + + + public boolean isEnableSingleReplicaInsert() { return enableSingleReplicaInsert; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java index af9c9d7e3c1754..1ab96a6ed74ab7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java @@ -56,7 +56,7 @@ void testStarGraph() { + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN4 [label=\"1.00\",arrowhead=none]\n" + "}\n"; - Assertions.assertEquals(dottyGraph, target); + Assertions.assertEquals(target, dottyGraph); } @Test @@ -85,12 +85,12 @@ void testCircleGraph() { + " LOGICAL_OLAP_SCAN3 [label=\"LOGICAL_OLAP_SCAN3 \n" + " rowCount=40.00\"];\n" + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN1 [label=\"1.00\",arrowhead=none]\n" - + "LOGICAL_OLAP_SCAN1 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n" + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n" - + "LOGICAL_OLAP_SCAN2 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n" + + "LOGICAL_OLAP_SCAN1 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n" + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n" + + "LOGICAL_OLAP_SCAN2 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n" + "}\n"; - Assertions.assertEquals(dottyGraph, target); + Assertions.assertEquals(target, dottyGraph); } @Test @@ -101,8 +101,8 @@ void testRandomQuery() { for (int i = 0; i < 10; i++) { HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder(); HyperGraph hyperGraph = hyperGraphBuilder.randomBuildWith(tableNum, edgeNum); - Assertions.assertEquals(hyperGraph.getNodes().size(), tableNum); - Assertions.assertEquals(hyperGraph.getJoinEdges().size(), edgeNum); + Assertions.assertEquals(tableNum, hyperGraph.getNodes().size()); + Assertions.assertEquals(edgeNum, hyperGraph.getJoinEdges().size()); } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index de859058ecc7d4..2dd1d6aa459681 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -22,9 +22,11 @@ import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup; import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule; import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule; +import org.apache.doris.nereids.rules.expression.rules.OrToIn; import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison; import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule; +import org.apache.doris.nereids.rules.expression.rules.SimplifyRange; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; @@ -54,7 +56,9 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper { @Test void testNotRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + ExpressionRewrite.bottomUp(SimplifyNotExprRule.INSTANCE) + )); assertRewrite("not x", "not x"); assertRewrite("not not x", "x"); @@ -79,7 +83,9 @@ void testNotRewrite() { @Test void testNormalizeExpressionRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + ExpressionRewrite.bottomUp(NormalizeBinaryPredicatesRule.INSTANCE) + )); assertRewrite("1 = 1", "1 = 1"); assertRewrite("2 > x", "x < 2"); @@ -91,7 +97,9 @@ void testNormalizeExpressionRewrite() { @Test void testDistinctPredicatesRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(DistinctPredicatesRule.INSTANCE) + )); assertRewrite("a = 1", "a = 1"); assertRewrite("a = 1 and a = 1", "a = 1"); @@ -103,7 +111,9 @@ void testDistinctPredicatesRewrite() { @Test void testExtractCommonFactorRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(ExtractCommonFactorRule.INSTANCE) + )); assertRewrite("a", "a"); @@ -112,22 +122,24 @@ void testExtractCommonFactorRewrite() { assertRewrite("a = 1 and b > 2", "a = 1 and b > 2"); assertRewrite("(a and b) or (c and d)", "(a and b) or (c and d)"); - assertRewrite("(a and b) and (c and d)", "((a and b) and c) and d"); + assertRewrite("(a and b) and (c and d)", "((a and b) and (c and d))"); + assertRewrite("(a and (b and c)) and (b or c)", "((b and c) and a)"); assertRewrite("(a or b) and (a or c)", "a or (b and c)"); assertRewrite("(a and b) or (a and c)", "a and (b or c)"); assertRewrite("(a or b) and (a or c) and (a or d)", "a or (b and c and d)"); assertRewrite("(a and b) or (a and c) or (a and d)", "a and (b or c or d)"); - assertRewrite("(a and b) or (a or c) or (a and d)", "((((a and b) or a) or c) or (a and d))"); - assertRewrite("(a and b) or (a and c) or (a or d)", "(((a and b) or (a and c) or a) or d))"); - assertRewrite("(a or b) or (a and c) or (a and d)", "(a or b) or (a and c) or (a and d)"); - assertRewrite("(a or b) or (a and c) or (a or d)", "(((a or b) or (a and c)) or d)"); - assertRewrite("(a or b) or (a or c) or (a and d)", "((a or b) or c) or (a and d)"); + assertRewrite("(a or b) and (a or d)", "a or (b and d)"); + assertRewrite("(a and b) or (a or c) or (a and d)", "a or c"); + assertRewrite("(a and b) or (a and c) or (a or d)", "(a or d)"); + assertRewrite("(a or b) or (a and c) or (a and d)", "(a or b)"); + assertRewrite("(a or b) or (a and c) or (a or d)", "((a or b) or d)"); + assertRewrite("(a or b) or (a or c) or (a and d)", "((a or b) or c)"); assertRewrite("(a or b) or (a or c) or (a or d)", "(((a or b) or c) or d)"); - assertRewrite("(a and b) or (d and c) or (d and e)", "(a and b) or (d and c) or (d and e)"); - assertRewrite("(a or b) and (d or c) and (d or e)", "(a or b) and (d or c) and (d or e)"); + assertRewrite("(a and b) or (d and c) or (d and e)", "((d and (c or e)) or (a and b))"); + assertRewrite("(a or b) and (d or c) and (d or e)", "((d or (c and e)) and (a or b))"); assertRewrite("(a and b) or ((d and c) and (d and e))", "(a and b) or (d and c and e)"); assertRewrite("(a or b) and ((d or c) or (d or e))", "(a or b) and (d or c or e)"); @@ -152,11 +164,29 @@ void testExtractCommonFactorRewrite() { assertRewrite("(a or b) and (a or true)", "a or b"); + assertRewrite("a and (b or ((a and e) or (a and f))) and (b or d)", "(b or ((a and (e or f)) and d)) and a"); + + } + + @Test + void testTpcdsCase() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyRange.INSTANCE, + OrToIn.INSTANCE, + ExtractCommonFactorRule.INSTANCE + ) + )); + assertRewrite( + "(((((customer_address.ca_country = 'United States') AND ca_state IN ('DE', 'FL', 'TX')) OR ((customer_address.ca_country = 'United States') AND ca_state IN ('ID', 'IN', 'ND'))) OR ((customer_address.ca_country = 'United States') AND ca_state IN ('IL', 'MT', 'OH'))))", + "((customer_address.ca_country = 'United States') AND ca_state IN ('DE', 'FL', 'TX', 'ID', 'IN', 'ND', 'IL', 'MT', 'OH'))"); } @Test void testInPredicateToEqualToRule() { - executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(InPredicateToEqualToRule.INSTANCE) + )); assertRewrite("a in (1)", "a = 1"); assertRewrite("a not in (1)", "not a = 1"); @@ -172,14 +202,18 @@ void testInPredicateToEqualToRule() { @Test void testInPredicateDedup() { - executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateDedup.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(InPredicateDedup.INSTANCE) + )); assertRewrite("a in (1, 2, 1, 2)", "a in (1, 2)"); } @Test void testSimplifyCastRule() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyCastRule.INSTANCE) + )); // deduplicate assertRewrite("CAST(1 AS tinyint)", "1"); @@ -211,7 +245,9 @@ void testSimplifyCastRule() { @Test void testSimplifyDecimalV3Comparison() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyDecimalV3Comparison.INSTANCE) + )); // do rewrite Expression left = new DecimalV3Literal(new BigDecimal("12345.67")); @@ -226,4 +262,16 @@ void testSimplifyDecimalV3Comparison() { comparison = new EqualTo(new DecimalV3Literal(new BigDecimal("12345.67")), new DecimalV3Literal(new BigDecimal("76543.21"))); assertRewrite(comparison, comparison); } + + @Test + void testDeadLoop() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyRange.INSTANCE, + ExtractCommonFactorRule.INSTANCE + ) + )); + + assertRewrite("a and (b > 0 and b < 10)", "a and (b > 0 and b < 10)"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java index 60d4384207f90f..b252b4650f7315 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java @@ -46,7 +46,7 @@ import java.util.List; import java.util.Map; -public abstract class ExpressionRewriteTestHelper { +public abstract class ExpressionRewriteTestHelper extends ExpressionRewrite { protected static final NereidsParser PARSER = new NereidsParser(); protected ExpressionRuleExecutor executor; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java index 3b8fbc8526b356..747e72b0a9167c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java @@ -58,7 +58,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper { @Test void testCaseWhenFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); // assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' < 2 then 2 else 3 end", "2"); // assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' > 2 then 2 end", "null"); assertRewriteAfterTypeCoercion("case when (1 + 5) / 2 > 2 then 4 when '1' < 2 then 2 else 3 end", "4"); @@ -75,7 +77,9 @@ void testCaseWhenFold() { @Test void testInFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("1 in (1,2,3,4)", "true"); // Type Coercion trans all to string. assertRewriteAfterTypeCoercion("3 in ('1', 2 + 8 / 2, 3, 4)", "true"); @@ -88,7 +92,9 @@ void testInFold() { @Test void testLogicalFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 > 2", "false"); assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 < 2", "true"); assertRewriteAfterTypeCoercion("null + 1 > 1 and 1 < 2", "null"); @@ -126,7 +132,9 @@ void testLogicalFold() { @Test void testIsNullFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("100 is null", "false"); assertRewriteAfterTypeCoercion("null is null", "true"); assertRewriteAfterTypeCoercion("null is not null", "false"); @@ -137,7 +145,9 @@ void testIsNullFold() { @Test void testNotPredicateFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("not 1 > 2", "true"); assertRewriteAfterTypeCoercion("not null + 1 > 2", "null"); assertRewriteAfterTypeCoercion("not (1 + 5) / 2 + (10 - 1) * 3 > 3 * 5 + 1", "false"); @@ -145,7 +155,9 @@ void testNotPredicateFold() { @Test void testCastFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); // cast '1' as tinyint Cast c = new Cast(Literal.of("1"), TinyIntType.INSTANCE); @@ -156,7 +168,9 @@ void testCastFold() { @Test void testCompareFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("'1' = 2", "false"); assertRewriteAfterTypeCoercion("1 = 2", "false"); assertRewriteAfterTypeCoercion("1 != 2", "true"); @@ -173,7 +187,9 @@ void testCompareFold() { @Test void testArithmeticFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewrite("1 + 1", Literal.of((short) 2)); assertRewrite("1 - 1", Literal.of((short) 0)); assertRewrite("100 + 100", Literal.of((short) 200)); @@ -206,7 +222,9 @@ void testArithmeticFold() { @Test void testTimestampFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); String interval = "'1991-05-01' - interval 1 day"; Expression e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); Expression e8 = Config.enable_date_conversion diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java index cab2c2f5a64274..a83ac620164806 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java @@ -48,7 +48,7 @@ public void testSplitPredicates() { "c = d or a = 10"); assetEquals("a = b and c + d = e and a > 7 and 10 > d", "a = b", - "10 > d and a > 7", + "a > 7 and 10 > d", "c + d = e"); assetEquals("a = b and c + d = e or a > 7 and 10 > d", "", diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java index 174592270dd973..e7423fe12d9349 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java @@ -29,9 +29,11 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper { @Test void testSimplifyArithmetic() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE, + bottomUp(SimplifyArithmeticRule.INSTANCE), FunctionBinder.INSTANCE, - FoldConstantRule.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE + ) )); assertRewriteAfterTypeCoercion("IA", "IA"); assertRewriteAfterTypeCoercion("IA + 1", "IA + 1"); @@ -55,7 +57,7 @@ void testSimplifyArithmetic() { @Test void testSimplifyArithmeticRuleOnly() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE + bottomUp(SimplifyArithmeticRule.INSTANCE) )); // add and subtract @@ -67,39 +69,43 @@ void testSimplifyArithmeticRuleOnly() { assertRewriteAfterTypeCoercion("IA - 2 - ((-IB - 1) - (3 + (IC + 4)))", "(((IA + IB) + IC) - ((((2 + 0) - 1) - 3) - 4))"); // multiply and divide - assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((3 * IC) as DOUBLE))"); + assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((IC * 3) as DOUBLE))"); assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))", "(((cast(IA as DOUBLE) / cast((IB * 1) as DOUBLE)) / cast(IC as DOUBLE)) / ((cast(2 as DOUBLE) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))"); assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast((IC * 4) as DOUBLE)) / ((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)))"); - assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((3 * (IC * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((IC * (3 * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))"); // hybrid // root is subtract assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))", "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * (cast((1 - IB) as DOUBLE) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))"); - assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - IA) - 2) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))"); + assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - 2) - IA) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))"); // root is add - assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast((3 - (IC + 4)) as DOUBLE)))"); + assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast(((3 - 4) - IC) as DOUBLE)))"); assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 4)))", "(((((0 + 2) - 1) - IA) + IB) - (3 * (IC + 4)))"); // root is multiply - assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))"); - assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - IB) - 1) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - 1) - IB) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))"); // root is divide - assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) / cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))"); - assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) / (3 + (IC * 4)))", "(((cast((0 - IA) as DOUBLE) / cast(((0 - IB) - 1) as DOUBLE)) * cast((3 + (IC * 4)) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) / cast((((((0 - 1) - 3) - 4) - IB) - IC) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) / (3 + (IC * 4)))", "(((cast((0 - IA) as DOUBLE) / cast(((0 - 1) - IB) as DOUBLE)) * cast(((IC * 4) + 3) as DOUBLE)) / cast(2 as DOUBLE))"); // unsupported decimal - assertRewriteAfterTypeCoercion("-2 - MA - ((1 - IB) - (3 + IC))", "((cast(-2 as DECIMALV3(38, 9)) - MA) - cast(((1 - IB) - (3 + IC)) as DECIMALV3(38, 9)))"); - assertRewriteAfterTypeCoercion("-IA / 2.0 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DECIMALV3(25, 5)) / 2.0) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DECIMALV3(20, 0)))"); + assertRewriteAfterTypeCoercion("-2 - MA - ((1 - IB) - (3 + IC))", "((cast(-2 as DECIMALV3(38, 9)) - MA) - cast((((1 - 3) - IB) - IC) as DECIMALV3(38, 9)))"); + assertRewriteAfterTypeCoercion("-IA / 2.0 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DECIMALV3(25, 5)) / 2.0) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as DECIMALV3(20, 0)))"); } @Test void testSimplifyArithmeticComparison() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE, - FoldConstantRule.INSTANCE, - SimplifyArithmeticComparisonRule.INSTANCE, - SimplifyArithmeticRule.INSTANCE, + bottomUp( + SimplifyArithmeticRule.INSTANCE, + FoldConstantRule.INSTANCE, + SimplifyArithmeticComparisonRule.INSTANCE, + SimplifyArithmeticRule.INSTANCE + ), FunctionBinder.INSTANCE, - FoldConstantRule.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE + ) )); assertRewriteAfterTypeCoercion("IA", "IA"); assertRewriteAfterTypeCoercion("IA > IB", "IA > IB"); @@ -134,12 +140,16 @@ void testSimplifyArithmeticComparison() { @Test void testSimplifyDateTimeComparison() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE, - FoldConstantRule.INSTANCE, - SimplifyArithmeticComparisonRule.INSTANCE, - SimplifyArithmeticRule.INSTANCE, + bottomUp( + SimplifyArithmeticRule.INSTANCE, + FoldConstantRule.INSTANCE, + SimplifyArithmeticComparisonRule.INSTANCE, + SimplifyArithmeticRule.INSTANCE + ), FunctionBinder.INSTANCE, - FoldConstantRule.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE + ) )); assertRewriteAfterTypeCoercion("years_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-01-01 00:00:00')"); assertRewriteAfterTypeCoercion("years_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2022-01-01 00:00:00')"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java index 87c57889b2f6fc..09fc7346f56659 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java @@ -34,8 +34,10 @@ public class SimplifyInPredicateTest extends ExpressionRewriteTestHelper { @Test public void test() { executor = new ExpressionRuleExecutor(ImmutableList.of( - FoldConstantRule.INSTANCE, - SimplifyInPredicate.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE, + SimplifyInPredicate.INSTANCE + ) )); Map mem = Maps.newHashMap(); Expression rewrittenExpression = PARSER.parseExpression("cast(CA as DATETIME) in ('1992-01-31 00:00:00', '1992-02-01 00:00:00')"); @@ -48,7 +50,9 @@ public void test() { Expression expectedExpression = PARSER.parseExpression("CA in (cast('1992-01-31' as date), cast('1992-02-01' as date))"); expectedExpression = replaceUnboundSlot(expectedExpression, mem); executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( FoldConstantRule.INSTANCE + ) )); expectedExpression = executor.rewrite(expectedExpression, context); Assertions.assertEquals(expectedExpression, rewrittenExpression); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java index 62f3680e675375..f2c74251935bae 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java @@ -45,7 +45,7 @@ import java.util.List; import java.util.Map; -public class SimplifyRangeTest { +public class SimplifyRangeTest extends ExpressionRewrite { private static final NereidsParser PARSER = new NereidsParser(); private ExpressionRuleExecutor executor; @@ -59,7 +59,9 @@ public SimplifyRangeTest() { @Test public void testSimplify() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); assertRewrite("TA", "TA"); assertRewrite("TA > 3 or TA > null", "TA > 3"); assertRewrite("TA > 3 or TA < null", "TA > 3"); @@ -100,7 +102,7 @@ public void testSimplify() { assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))"); assertRewrite("TA in (1,2,3) and TA > 10", "FALSE"); assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)"); - assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))"); + assertRewrite("TA in (1,2,3) and TA > 1", "TA IN (2, 3)"); assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1"); assertRewrite("TA in (1)", "TA in (1)"); assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)"); @@ -147,7 +149,7 @@ public void testSimplify() { assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))"); assertRewrite("TA + TC in (1,2,3) and TA + TC > 10", "FALSE"); assertRewrite("TA + TC in (1,2,3) and TA + TC >= 1", "TA + TC in (1,2,3)"); - assertRewrite("TA + TC in (1,2,3) and TA + TC > 1", "((TA + TC = 2) OR (TA + TC = 3))"); + assertRewrite("TA + TC in (1,2,3) and TA + TC > 1", "(TA + TC) IN (2, 3)"); assertRewrite("TA + TC in (1,2,3) or TA + TC >= 1", "TA + TC >= 1"); assertRewrite("TA + TC in (1)", "TA + TC in (1)"); assertRewrite("TA + TC in (1,2,3) and TA + TC < 10", "TA + TC in (1,2,3)"); @@ -171,8 +173,10 @@ public void testSimplify() { @Test public void testSimplifyDate() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE)); - // assertRewrite("TA", "TA"); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); + assertRewrite("TA", "TA"); assertRewrite( "(TA >= date '2024-01-01' and TA <= date '2024-01-03') or (TA > date '2024-01-05' and TA < date '2024-01-07')", "(TA >= date '2024-01-01' and TA <= date '2024-01-03') or (TA > date '2024-01-05' and TA < date '2024-01-07')"); @@ -213,7 +217,7 @@ public void testSimplifyDate() { assertRewrite("TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') and TA >= date '2024-01-01'", "TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03')"); assertRewrite("TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') and TA > date '2024-01-01'", - "((TA = date '2024-01-02') OR (TA = date '2024-01-03'))"); + "TA IN (date '2024-01-02', date '2024-01-03')"); assertRewrite("TA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') or TA >= date '2024-01-01'", "TA >= date '2024-01-01'"); assertRewrite("TA in (date '2024-01-01')", "TA in (date '2024-01-01')"); @@ -237,8 +241,10 @@ public void testSimplifyDate() { @Test public void testSimplifyDateTime() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE)); - // assertRewrite("TA", "TA"); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); + assertRewrite("TA", "TA"); assertRewrite( "(TA >= timestamp '2024-01-01 00:00:00' and TA <= timestamp '2024-01-03 00:00:00') or (TA > timestamp '2024-01-05 00:00:00' and TA < timestamp '2024-01-07 00:00:00')", "(TA >= timestamp '2024-01-01 00:00:00' and TA <= timestamp '2024-01-03 00:00:00') or (TA > timestamp '2024-01-05 00:00:00' and TA < timestamp '2024-01-07 00:00:00')"); @@ -279,7 +285,7 @@ public void testSimplifyDateTime() { assertRewrite("TA in (timestamp '2024-01-01 01:00:00',timestamp '2024-01-02 01:50:00',timestamp '2024-01-03 02:00:00') and TA >= timestamp '2024-01-01'", "TA in (timestamp '2024-01-01 01:00:00',timestamp '2024-01-02 01:50:00',timestamp '2024-01-03 02:00:00')"); assertRewrite("TA in (timestamp '2024-01-01 02:00:00',timestamp '2024-01-02 02:00:00',timestamp '2024-01-03 02:00:00') and TA > timestamp '2024-01-01 02:10:00'", - "((TA = timestamp '2024-01-02 02:00:00') OR (TA = timestamp '2024-01-03 02:00:00'))"); + "TA IN (timestamp '2024-01-02 02:00:00', timestamp '2024-01-03 02:00:00')"); assertRewrite("TA in (timestamp '2024-01-01 02:00:00',timestamp '2024-01-02 02:00:00',timestamp '2024-01-03 02:00:00') or TA >= timestamp '2024-01-01 01:00:00'", "TA >= timestamp '2024-01-01 01:00:00'"); assertRewrite("TA in (timestamp '2024-01-01 02:00:00')", "TA in (timestamp '2024-01-01 02:00:00')"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java index 140f72c57f4a4c..db1186738da713 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java @@ -35,7 +35,9 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper { // "A<=> Null" to "A is null" @Test void testNullSafeEqualToIsNull() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference slot = new SlotReference("a", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot)); } @@ -43,7 +45,9 @@ void testNullSafeEqualToIsNull() { // "A<=> Null" to "False", when A is not nullable @Test void testNullSafeEqualToFalse() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference slot = new SlotReference("a", StringType.INSTANCE, false); assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), BooleanLiteral.FALSE); } @@ -51,7 +55,9 @@ void testNullSafeEqualToFalse() { // "A(nullable)<=>B" not changed @Test void testNullSafeEqualNotChangedLeft() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference a = new SlotReference("a", StringType.INSTANCE, true); SlotReference b = new SlotReference("b", StringType.INSTANCE, false); assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); @@ -60,7 +66,9 @@ void testNullSafeEqualNotChangedLeft() { // "A<=>B(nullable)" not changed @Test void testNullSafeEqualNotChangedRight() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference a = new SlotReference("a", StringType.INSTANCE, false); SlotReference b = new SlotReference("b", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); @@ -69,7 +77,9 @@ void testNullSafeEqualNotChangedRight() { // "A<=>B" changed @Test void testNullSafeEqualToEqual() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference a = new SlotReference("a", StringType.INSTANCE, false); SlotReference b = new SlotReference("b", StringType.INSTANCE, false); assertRewrite(new NullSafeEqual(a, b), new EqualTo(a, b)); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java index fc31daaa9414d7..4d932187611136 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; import org.apache.doris.nereids.trees.expressions.Expression; @@ -37,7 +38,9 @@ class SimplifyArithmeticComparisonRuleTest extends ExpressionRewriteTestHelper { public void testProcess() { Map nameToSlot = new HashMap<>(); nameToSlot.put("a", new SlotReference("a", IntegerType.INSTANCE)); - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyArithmeticComparisonRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + ExpressionRewrite.bottomUp(SimplifyArithmeticComparisonRule.INSTANCE) + )); assertRewriteAfterSimplify("a + 1 > 1", "a > cast((1 - 1) as INT)", nameToSlot); assertRewriteAfterSimplify("a - 1 > 1", "a > cast((1 + 1) as INT)", nameToSlot); assertRewriteAfterSimplify("a / -2 > 1", "cast((1 * -2) as INT) > a", nameToSlot); @@ -82,7 +85,7 @@ private void assertRewriteAfterSimplify(String expr, String expected, Map 2021-01-01 00:00:00.001) Expression expression = new GreaterThan(left, right); Expression rewrittenExpression = executor.rewrite(typeCoercion(expression), context); - Assertions.assertEquals(left.getDataType(), rewrittenExpression.child(0).getDataType()); + Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType()); // (cast(0001-01-01 01:01:01 as DATETIMEV2(0)) < 2021-01-01 00:00:00.001) expression = new GreaterThan(left, right); rewrittenExpression = executor.rewrite(typeCoercion(expression), context); - Assertions.assertEquals(left.getDataType(), rewrittenExpression.child(0).getDataType()); + Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType()); } @Test void testRound() { - executor = new ExpressionRuleExecutor( - ImmutableList.of(SimplifyCastRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyCastRule.INSTANCE, + SimplifyComparisonPredicate.INSTANCE + ) + )); Expression left = new Cast(new DateTimeLiteral("2021-01-02 00:00:00.00"), DateTimeV2Type.of(1)); Expression right = new DateTimeV2Literal("2021-01-01 23:59:59.99"); @@ -120,13 +132,14 @@ void testRound() { Expression rewrittenExpression = executor.rewrite(typeCoercion(expression), context); // right should round to be 2021-01-02 00:00:00.00 - Assertions.assertEquals(new DateTimeV2Literal("2021-01-02 00:00:00"), rewrittenExpression.child(1)); + Assertions.assertEquals(new DateTimeLiteral("2021-01-02 00:00:00"), rewrittenExpression.child(1)); } @Test void testDoubleLiteral() { - executor = new ExpressionRuleExecutor( - ImmutableList.of(SimplifyComparisonPredicate.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyComparisonPredicate.INSTANCE) + )); Expression leftChild = new BigIntLiteral(999); Expression left = new Cast(leftChild, DoubleType.INSTANCE); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java index ff424e4971145b..ee089a82f88079 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java @@ -39,7 +39,9 @@ public void testSimplifyDecimalV3Comparison() { Config.enable_decimal_conversion = false; Map nameToSlot = new HashMap<>(); nameToSlot.put("col1", new SlotReference("col1", DecimalV3Type.createDecimalV3Type(15, 2))); - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyDecimalV3Comparison.INSTANCE) + )); assertRewriteAfterSimplify("cast(col1 as decimalv3(27, 9)) > 0.6", "cast(col1 as decimalv3(27, 9)) > 0.6", nameToSlot); } @@ -48,7 +50,7 @@ private void assertRewriteAfterSimplify(String expr, String expected, Map inPredicates = rewritten.collect(e -> e instanceof InPredicate); Assertions.assertEquals(1, inPredicates.size()); InPredicate inPredicate = inPredicates.iterator().next(); @@ -62,7 +61,7 @@ void test1() { void test2() { String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("((((col1 = 1) AND (col1 = 3)) AND (col2 = 3)) OR (col2 = 4))", rewritten.toSql()); } @@ -71,7 +70,7 @@ void test2() { void test3() { String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); List inPredicates = rewritten.collectToList(e -> e instanceof InPredicate); Assertions.assertEquals(2, inPredicates.size()); InPredicate in1 = inPredicates.get(0); @@ -95,7 +94,7 @@ void test4() { String expr = "case when col = 1 or col = 2 or col = 3 then 1" + " when col = 4 or col = 5 or col = 6 then 1 else 0 end"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END", rewritten.toSql()); } @@ -104,7 +103,7 @@ void test4() { void test5() { String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))", rewritten.toSql()); } @@ -113,7 +112,7 @@ void test5() { void test6() { String expr = "col = 1 or col = 2 or col in (1, 2, 3)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("col IN (1, 2, 3)", rewritten.toSql()); } @@ -121,7 +120,7 @@ void test6() { void test7() { String expr = "A = 1 or A = 2 or abs(A)=5 or A in (1, 2, 3) or B = 1 or B = 2 or B in (1, 2, 3) or B+1 in (4, 5, 7)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("(((A IN (1, 2, 3) OR B IN (1, 2, 3)) OR (abs(A) = 5)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql()); } @@ -129,7 +128,7 @@ void test7() { void test8() { String expr = "col = 1 or (col = 2 and (col = 3 or col = '4' or col = 5.0))"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))", rewritten.toSql()); } @@ -139,7 +138,7 @@ void testEnsureOrder() { // ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2) String expr = "col1 IN (1, 2) OR col2 IN (1, 2)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))", rewritten.toSql()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java index 36cb8cee8d41ec..8bae1713fe1b51 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java @@ -47,8 +47,8 @@ import java.util.Optional; public class PushDownFilterThroughAggregationTest implements MemoPatternMatchSupported { - private final LogicalOlapScan scan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student, - ImmutableList.of("")); + private final LogicalOlapScan scan = new LogicalOlapScan( + StatementScopeIdGenerator.newRelationId(), PlanConstructor.student, ImmutableList.of("")); /*- * origin plan: diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java index 3fc00ee4bad2ba..ff518fb9d1fa87 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java @@ -49,6 +49,7 @@ import java.math.BigDecimal; import java.util.Collections; import java.util.List; +import java.util.Optional; public class ComputeSignatureHelperTest { @@ -419,6 +420,16 @@ public int arity() { return 0; } + @Override + public Optional getMutableState(String key) { + return Optional.empty(); + } + + @Override + public void setMutableState(String key, Object value) { + + } + @Override public Expression withChildren(List children) { return null; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 9471e18f6a5a51..d0bd735ae92f49 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -20,6 +20,7 @@ import org.apache.doris.analysis.ExplainOptions; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.NereidsPlanner; +import org.apache.doris.nereids.PlanProcess; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.glue.LogicalPlanAdapter; @@ -161,6 +162,25 @@ public PlanChecker customRewrite(CustomRewriter customRewriter) { return this; } + public PlanChecker printPlanProcess(String sql) { + List planProcesses = explainPlanProcess(sql); + for (PlanProcess row : planProcesses) { + System.out.println("RULE: " + row.ruleName + "\nBEFORE:\n" + + row.beforeShape + "\nafter:\n" + row.afterShape); + } + return this; + } + + public List explainPlanProcess(String sql) { + NereidsParser parser = new NereidsParser(); + LogicalPlan command = parser.parseSingle(sql); + NereidsPlanner planner = new NereidsPlanner( + new StatementContext(connectContext, new OriginStatement(sql, 0))); + planner.plan(command, PhysicalProperties.ANY, ExplainLevel.ALL_PLAN, true); + this.cascadesContext = planner.getCascadesContext(); + return cascadesContext.getPlanProcesses(); + } + public PlanChecker applyTopDown(RuleFactory ruleFactory) { return applyTopDown(ruleFactory.buildRules()); } diff --git a/regression-test/data/nereids_hint_tpcds_p0/shape/query24.out b/regression-test/data/nereids_hint_tpcds_p0/shape/query24.out index 92e60cfa81846b..ea66a97a2e4084 100644 --- a/regression-test/data/nereids_hint_tpcds_p0/shape/query24.out +++ b/regression-test/data/nereids_hint_tpcds_p0/shape/query24.out @@ -7,21 +7,21 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------PhysicalDistribute[DistributionSpecHash] ----------hashAgg[LOCAL] ------------PhysicalProject ---------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF5 sr_item_sk->[i_item_sk,ss_item_sk];RF6 sr_ticket_number->[ss_ticket_number] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF5 sr_ticket_number->[ss_ticket_number];RF6 sr_item_sk->[ss_item_sk,i_item_sk] ----------------PhysicalProject ------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF4 i_item_sk->[ss_item_sk] --------------------PhysicalDistribute[DistributionSpecHash] ----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store.s_zip = customer_address.ca_zip) and (store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF2 ca_zip->[s_zip];RF3 c_customer_sk->[ss_customer_sk] +------------------------hashJoin[INNER_JOIN] hashCondition=((store.s_zip = customer_address.ca_zip) and (store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF2 c_customer_sk->[ss_customer_sk];RF3 ca_zip->[s_zip] --------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=() build RFs:RF1 s_store_sk->[ss_store_sk] --------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF3 RF4 RF5 RF6 +----------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2 RF4 RF5 RF6 --------------------------------PhysicalDistribute[DistributionSpecReplicated] ----------------------------------PhysicalProject ------------------------------------filter((store.s_market_id = 5)) ---------------------------------------PhysicalOlapScan[store] apply RFs: RF2 +--------------------------------------PhysicalOlapScan[store] apply RFs: RF3 --------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((customer.c_current_addr_sk = customer_address.ca_address_sk)) otherCondition=(( not (c_birth_country = upper(ca_country)))) build RFs:RF0 ca_address_sk->[c_current_addr_sk] @@ -33,7 +33,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------------------------PhysicalOlapScan[customer_address] --------------------PhysicalDistribute[DistributionSpecHash] ----------------------PhysicalProject -------------------------PhysicalOlapScan[item] apply RFs: RF5 +------------------------PhysicalOlapScan[item] apply RFs: RF6 ----------------PhysicalDistribute[DistributionSpecHash] ------------------PhysicalProject --------------------PhysicalOlapScan[store_returns] diff --git a/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out b/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out index a4d47cb3a81d8f..e4932b2437eeb2 100644 --- a/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out +++ b/regression-test/data/nereids_hint_tpcds_p0/shape/query64.out @@ -112,7 +112,7 @@ PhysicalCteAnchor ( cteId=CTEId#1 ) ----------------------PhysicalDistribute[DistributionSpecHash] ------------------------hashAgg[LOCAL] --------------------------PhysicalProject -----------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF0 cr_order_number->[cs_order_number];RF1 cr_item_sk->[cs_item_sk] +----------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF0 cr_item_sk->[cs_item_sk];RF1 cr_order_number->[cs_order_number] ------------------------------PhysicalProject --------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1 ------------------------------PhysicalProject diff --git a/regression-test/data/nereids_rules_p0/filter_push_down/push_filter_through.out b/regression-test/data/nereids_rules_p0/filter_push_down/push_filter_through.out index 82e2624942eb04..f32919b7411bf2 100644 --- a/regression-test/data/nereids_rules_p0/filter_push_down/push_filter_through.out +++ b/regression-test/data/nereids_rules_p0/filter_push_down/push_filter_through.out @@ -66,13 +66,7 @@ PhysicalResultSink -- !filter_join_inner -- PhysicalResultSink ---PhysicalDistribute[DistributionSpecGather] -----hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -------filter((t1.id = 1) and (t1.id = 2)) ---------PhysicalOlapScan[t1] -------PhysicalDistribute[DistributionSpecHash] ---------filter((t2.id = 1) and (t2.id = 2)) -----------PhysicalOlapScan[t2] +--PhysicalEmptyRelation -- !filter_join_inner -- PhysicalResultSink @@ -94,13 +88,7 @@ PhysicalResultSink -- !filter_join_left -- PhysicalResultSink ---PhysicalDistribute[DistributionSpecGather] -----hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -------filter((t1.id = 1) and (t1.id = 2)) ---------PhysicalOlapScan[t1] -------PhysicalDistribute[DistributionSpecHash] ---------filter((t2.id = 1) and (t2.id = 2)) -----------PhysicalOlapScan[t2] +--PhysicalEmptyRelation -- !filter_join_left -- PhysicalResultSink @@ -133,13 +121,7 @@ PhysicalResultSink -- !filter_join_left -- PhysicalResultSink ---PhysicalDistribute[DistributionSpecGather] -----hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -------filter((t1.id = 1) and (t1.id = 2)) ---------PhysicalOlapScan[t1] -------PhysicalDistribute[DistributionSpecHash] ---------filter((t2.id = 1) and (t2.id = 2)) -----------PhysicalOlapScan[t2] +--PhysicalEmptyRelation -- !filter_join_left -- PhysicalResultSink @@ -412,7 +394,7 @@ PhysicalResultSink --PhysicalDistribute[DistributionSpecGather] ----PhysicalUnion ------PhysicalDistribute[DistributionSpecExecutionAny] ---------filter(((cast(random() as INT) = 2) OR (cast(random() as INT) = 3))) +--------filter(cast(random() as INT) IN (2, 3)) ----------PhysicalOneRowRelation ------PhysicalDistribute[DistributionSpecExecutionAny] --------filter(id IN (2, 3)) @@ -423,7 +405,7 @@ PhysicalResultSink --PhysicalDistribute[DistributionSpecGather] ----PhysicalIntersect ------PhysicalDistribute[DistributionSpecHash] ---------filter(((cast(random() as INT) = 2) OR (cast(random() as INT) = 3))) +--------filter(cast(random() as INT) IN (2, 3)) ----------PhysicalOneRowRelation ------PhysicalDistribute[DistributionSpecHash] --------filter(id IN (2, 3)) diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query13.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query13.out index 92c09059cd5078..0bb5e76d2be2d8 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query13.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query13.out @@ -10,7 +10,7 @@ PhysicalResultSink --------------PhysicalOlapScan[store] apply RFs: RF4 ------------PhysicalDistribute[DistributionSpecHash] --------------PhysicalProject -----------------hashJoin[INNER_JOIN] hashCondition=((customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)) otherCondition=(((((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)) OR ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Primary')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) AND (household_demographics.hd_dep_count = 1))) OR ((((customer_demographics.cd_marital_status = 'W') AND (customer_demographics.cd_education_status = '2 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))) AND (household_demographics.hd_dep_count = 1)))) build RFs:RF3 ss_cdemo_sk->[cd_demo_sk] +----------------hashJoin[INNER_JOIN] hashCondition=((customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)) otherCondition=((((household_demographics.hd_dep_count = 1) AND ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Primary')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) OR (((customer_demographics.cd_marital_status = 'W') AND (customer_demographics.cd_education_status = '2 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))))) OR ((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)))) build RFs:RF3 ss_cdemo_sk->[cd_demo_sk] ------------------PhysicalDistribute[DistributionSpecHash] --------------------PhysicalProject ----------------------filter(((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = 'College')) OR ((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Primary'))) OR ((customer_demographics.cd_marital_status = 'W') AND (customer_demographics.cd_education_status = '2 yr Degree')))) diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query14.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query14.out index e1d05ef57f443c..c468cad96ccfe2 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query14.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query14.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 class_id->[i_class_id];RF7 category_id->[i_category_id];RF8 brand_id->[i_brand_id] +------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 brand_id->[i_brand_id];RF7 class_id->[i_class_id];RF8 category_id->[i_category_id] --------PhysicalProject ----------PhysicalOlapScan[item] apply RFs: RF6 RF7 RF8 --------PhysicalDistribute[DistributionSpecReplicated] diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query24.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query24.out index 009769ead1cae7..ca81590aa05160 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query24.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query24.out @@ -7,7 +7,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------PhysicalDistribute[DistributionSpecHash] ----------hashAgg[LOCAL] ------------PhysicalProject ---------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF5 sr_item_sk->[i_item_sk,ss_item_sk];RF6 sr_ticket_number->[ss_ticket_number] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF5 sr_ticket_number->[ss_ticket_number];RF6 sr_item_sk->[ss_item_sk,i_item_sk] ----------------PhysicalProject ------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF4 i_item_sk->[ss_item_sk] --------------------PhysicalDistribute[DistributionSpecHash] @@ -31,7 +31,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------------PhysicalOlapScan[customer_address] --------------------PhysicalDistribute[DistributionSpecHash] ----------------------PhysicalProject -------------------------PhysicalOlapScan[item] apply RFs: RF5 +------------------------PhysicalOlapScan[item] apply RFs: RF6 ----------------PhysicalDistribute[DistributionSpecHash] ------------------PhysicalProject --------------------PhysicalOlapScan[store_returns] diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query41.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query41.out index 75ae21d6aa52ca..4bf92b396d61c5 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query41.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query41.out @@ -19,6 +19,6 @@ PhysicalResultSink --------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------hashAgg[LOCAL] ------------------------------PhysicalProject ---------------------------------filter((((((((((((item.i_category = 'Women') AND i_color IN ('forest', 'lime')) AND i_units IN ('Pallet', 'Pound')) AND i_size IN ('economy', 'small')) OR ((((item.i_category = 'Women') AND i_color IN ('navy', 'slate')) AND i_units IN ('Bunch', 'Gross')) AND i_size IN ('extra large', 'petite'))) OR ((((item.i_category = 'Women') AND i_color IN ('aquamarine', 'dark')) AND i_units IN ('Tbl', 'Ton')) AND i_size IN ('economy', 'small'))) OR ((((item.i_category = 'Women') AND i_color IN ('frosted', 'plum')) AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'petite'))) OR ((((item.i_category = 'Men') AND i_color IN ('powder', 'sky')) AND i_units IN ('Dozen', 'Lb')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('maroon', 'smoke')) AND i_units IN ('Case', 'Ounce')) AND i_size IN ('economy', 'small'))) OR ((((item.i_category = 'Men') AND i_color IN ('papaya', 'peach')) AND i_units IN ('Bundle', 'Carton')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('firebrick', 'sienna')) AND i_units IN ('Cup', 'Each')) AND i_size IN ('economy', 'small')))) +--------------------------------filter((((item.i_category = 'Men') AND (((((i_size IN ('economy', 'small') AND i_color IN ('maroon', 'smoke')) AND i_units IN ('Case', 'Ounce')) OR ((i_size IN ('economy', 'small') AND i_color IN ('firebrick', 'sienna')) AND i_units IN ('Cup', 'Each'))) OR ((i_color IN ('powder', 'sky') AND i_units IN ('Dozen', 'Lb')) AND i_size IN ('N/A', 'large'))) OR ((i_color IN ('papaya', 'peach') AND i_units IN ('Bundle', 'Carton')) AND i_size IN ('N/A', 'large')))) OR ((item.i_category = 'Women') AND (((((i_color IN ('forest', 'lime') AND i_units IN ('Pallet', 'Pound')) AND i_size IN ('economy', 'small')) OR ((i_color IN ('navy', 'slate') AND i_units IN ('Bunch', 'Gross')) AND i_size IN ('extra large', 'petite'))) OR ((i_color IN ('aquamarine', 'dark') AND i_units IN ('Tbl', 'Ton')) AND i_size IN ('economy', 'small'))) OR ((i_color IN ('frosted', 'plum') AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'petite')))))) ----------------------------------PhysicalOlapScan[item] diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query50.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query50.out index 40f241ec82e976..9a77a7d7d2f549 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query50.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query50.out @@ -12,7 +12,7 @@ PhysicalResultSink ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = d1.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] ----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF1 sr_customer_sk->[ss_customer_sk];RF2 sr_item_sk->[ss_item_sk];RF3 sr_ticket_number->[ss_ticket_number] +------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF1 sr_ticket_number->[ss_ticket_number];RF2 sr_item_sk->[ss_item_sk];RF3 sr_customer_sk->[ss_customer_sk] --------------------------PhysicalProject ----------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2 RF3 RF4 RF5 --------------------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_returned_date_sk = d2.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[sr_returned_date_sk] diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query64.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query64.out index 7c84cef7588bbc..4530a8697367e4 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query64.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query64.out @@ -64,7 +64,7 @@ PhysicalCteAnchor ( cteId=CTEId#1 ) ------------------------------------------------PhysicalDistribute[DistributionSpecHash] --------------------------------------------------hashAgg[LOCAL] ----------------------------------------------------PhysicalProject -------------------------------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF5 cr_order_number->[cs_order_number];RF6 cr_item_sk->[cs_item_sk] +------------------------------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF5 cr_item_sk->[cs_item_sk];RF6 cr_order_number->[cs_order_number] --------------------------------------------------------PhysicalProject ----------------------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF5 RF6 --------------------------------------------------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query85.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query85.out index 547042a9a33744..8e6087c95b31a0 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query85.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query85.out @@ -16,7 +16,7 @@ PhysicalResultSink ----------------------PhysicalProject ------------------------hashJoin[INNER_JOIN] hashCondition=((reason.r_reason_sk = web_returns.wr_reason_sk)) otherCondition=() build RFs:RF8 r_reason_sk->[wr_reason_sk] --------------------------PhysicalProject -----------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF5 cd_marital_status->[cd_marital_status];RF6 cd_education_status->[cd_education_status];RF7 wr_returning_cdemo_sk->[cd_demo_sk] +----------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF5 wr_returning_cdemo_sk->[cd_demo_sk];RF6 cd_marital_status->[cd_marital_status];RF7 cd_education_status->[cd_education_status] ------------------------------PhysicalProject --------------------------------PhysicalOlapScan[customer_demographics] apply RFs: RF5 RF6 RF7 ------------------------------PhysicalDistribute[DistributionSpecReplicated] diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query95.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query95.out index c9c64a64e35fe7..b0a0655caff31e 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query95.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query95.out @@ -19,7 +19,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------hashAgg[GLOBAL] ----------------hashAgg[LOCAL] ------------------PhysicalProject ---------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF6 ws_order_number->[ws_order_number,wr_order_number] +--------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF6 ws_order_number->[wr_order_number,ws_order_number] ----------------------PhysicalProject ------------------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_order_number = ws_wh.ws_order_number)) otherCondition=() build RFs:RF5 wr_order_number->[ws_order_number] --------------------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query13.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query13.out index 09eea0486ff63e..05a975ddfa2967 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query13.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query13.out @@ -7,7 +7,7 @@ PhysicalResultSink --------PhysicalProject ----------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] ------------PhysicalProject ---------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=(((((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)) OR ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) AND (household_demographics.hd_dep_count = 1))) OR ((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))) AND (household_demographics.hd_dep_count = 1)))) build RFs:RF3 hd_demo_sk->[ss_hdemo_sk] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=((((household_demographics.hd_dep_count = 1) AND ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) OR (((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))))) OR ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)))) build RFs:RF3 hd_demo_sk->[ss_hdemo_sk] ----------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_addr_sk = customer_address.ca_address_sk)) otherCondition=((((ca_state IN ('KS', 'MI', 'SD') AND ((store_sales.ss_net_profit >= 100.00) AND (store_sales.ss_net_profit <= 200.00))) OR (ca_state IN ('CO', 'MO', 'ND') AND ((store_sales.ss_net_profit >= 150.00) AND (store_sales.ss_net_profit <= 300.00)))) OR (ca_state IN ('NH', 'OH', 'TX') AND ((store_sales.ss_net_profit >= 50.00) AND (store_sales.ss_net_profit <= 250.00))))) build RFs:RF2 ca_address_sk->[ss_addr_sk] ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((store.s_store_sk = store_sales.ss_store_sk)) otherCondition=() diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query17.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query17.out index 323e0432d7e806..4c6356300e1753 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query17.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query17.out @@ -18,7 +18,7 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------hashJoin[INNER_JOIN] hashCondition=((d1.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=() build RFs:RF5 d_date_sk->[ss_sold_date_sk] ----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_customer_sk = catalog_sales.cs_bill_customer_sk) and (store_returns.sr_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF3 cs_bill_customer_sk->[ss_customer_sk,sr_customer_sk];RF4 cs_item_sk->[ss_item_sk,sr_item_sk] +------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_customer_sk = catalog_sales.cs_bill_customer_sk) and (store_returns.sr_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF3 cs_bill_customer_sk->[sr_customer_sk,ss_customer_sk];RF4 cs_item_sk->[sr_item_sk,ss_item_sk] --------------------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------------------PhysicalProject ------------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query41.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query41.out index f07731a85b9b06..b73a9538e32098 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query41.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query41.out @@ -19,6 +19,6 @@ PhysicalResultSink --------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------hashAgg[LOCAL] ------------------------------PhysicalProject ---------------------------------filter((((((((((((item.i_category = 'Women') AND i_color IN ('aquamarine', 'gainsboro')) AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((((item.i_category = 'Women') AND i_color IN ('chiffon', 'violet')) AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Women') AND i_color IN ('blanched', 'tomato')) AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Women') AND i_color IN ('almond', 'lime')) AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Men') AND i_color IN ('blue', 'chartreuse')) AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Men') AND i_color IN ('peru', 'saddle')) AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown')) AND i_size IN ('economy', 'medium')))) +--------------------------------filter((((item.i_category = 'Men') AND (((((i_size IN ('economy', 'medium') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) OR ((i_size IN ('economy', 'medium') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown'))) OR ((i_color IN ('blue', 'chartreuse') AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((i_color IN ('peru', 'saddle') AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large')))) OR ((item.i_category = 'Women') AND (((((i_color IN ('aquamarine', 'gainsboro') AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((i_color IN ('chiffon', 'violet') AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((i_color IN ('blanched', 'tomato') AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((i_color IN ('almond', 'lime') AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small')))))) ----------------------------------PhysicalOlapScan[item] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query47.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query47.out index 4e5fef7b83fb57..c3aef2e8e016c6 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query47.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query47.out @@ -37,6 +37,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------PhysicalTopN[LOCAL_SORT] ------------PhysicalProject --------------hashJoin[INNER_JOIN] hashCondition=((v1.i_brand = v1_lead.i_brand) and (v1.i_category = v1_lead.i_category) and (v1.rn = expr_(rn - 1)) and (v1.s_company_name = v1_lead.s_company_name) and (v1.s_store_name = v1_lead.s_store_name)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalProject +--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------hashJoin[INNER_JOIN] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() --------------------PhysicalDistribute[DistributionSpecHash] @@ -46,7 +49,3 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------------PhysicalProject ------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) --------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) -----------------PhysicalDistribute[DistributionSpecHash] -------------------PhysicalProject ---------------------PhysicalCteConsumer ( cteId=CTEId#0 ) - diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query50.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query50.out index d7b38695e2c2eb..0fbafee1c41fc5 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query50.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query50.out @@ -13,7 +13,7 @@ PhysicalResultSink --------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=() ----------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = d1.d_date_sk)) otherCondition=() ------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF0 sr_customer_sk->[ss_customer_sk];RF1 sr_item_sk->[ss_item_sk];RF2 sr_ticket_number->[ss_ticket_number] +--------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF0 sr_ticket_number->[ss_ticket_number];RF1 sr_item_sk->[ss_item_sk];RF2 sr_customer_sk->[ss_customer_sk] ----------------------------PhysicalProject ------------------------------PhysicalOlapScan[store_sales] apply RFs: RF0 RF1 RF2 ----------------------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query57.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query57.out index c78c8d0fecc73a..8ff4bb6350f34f 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query57.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query57.out @@ -37,6 +37,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------PhysicalTopN[LOCAL_SORT] ------------PhysicalProject --------------hashJoin[INNER_JOIN] hashCondition=((v1.cc_name = v1_lead.cc_name) and (v1.i_brand = v1_lead.i_brand) and (v1.i_category = v1_lead.i_category) and (v1.rn = expr_(rn - 1))) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalProject +--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------hashJoin[INNER_JOIN] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() --------------------PhysicalDistribute[DistributionSpecHash] @@ -46,7 +49,3 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------------PhysicalProject ------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) --------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) -----------------PhysicalDistribute[DistributionSpecHash] -------------------PhysicalProject ---------------------PhysicalCteConsumer ( cteId=CTEId#0 ) - diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query6.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query6.out index 6ef7f0f37aecee..377e295b173b4e 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query6.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query6.out @@ -11,37 +11,36 @@ PhysicalResultSink ----------------hashAgg[LOCAL] ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((d.d_month_seq = date_dim.d_month_seq)) otherCondition=() build RFs:RF5 d_month_seq->[d_month_seq] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((j.i_category = i.i_category)) otherCondition=((cast(i_current_price as DECIMALV3(38, 5)) > (1.2 * avg(cast(i_current_price as DECIMALV3(9, 4)))))) ---------------------------PhysicalProject -----------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_sold_date_sk = d.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ss_sold_date_sk] -------------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_item_sk = i.i_item_sk)) otherCondition=() ---------------------------------PhysicalProject -----------------------------------hashJoin[INNER_JOIN] hashCondition=((c.c_customer_sk = s.ss_customer_sk)) otherCondition=() -------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF3 -------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------PhysicalProject -----------------------------------------hashJoin[INNER_JOIN] hashCondition=((a.ca_address_sk = c.c_current_addr_sk)) otherCondition=() -------------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------------PhysicalProject -----------------------------------------------PhysicalOlapScan[customer] -------------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------------PhysicalProject -----------------------------------------------PhysicalOlapScan[customer_address] ---------------------------------PhysicalDistribute[DistributionSpecReplicated] -----------------------------------PhysicalProject -------------------------------------PhysicalOlapScan[item] +----------------------hashJoin[INNER_JOIN] hashCondition=((j.i_category = i.i_category)) otherCondition=((cast(i_current_price as DECIMALV3(38, 5)) > (1.2 * avg(cast(i_current_price as DECIMALV3(9, 4)))))) +------------------------PhysicalProject +--------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_sold_date_sk = d.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ss_sold_date_sk] +----------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_item_sk = i.i_item_sk)) otherCondition=() +------------------------------PhysicalProject +--------------------------------hashJoin[INNER_JOIN] hashCondition=((c.c_customer_sk = s.ss_customer_sk)) otherCondition=() +----------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------PhysicalProject +--------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF3 +----------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN] hashCondition=((a.ca_address_sk = c.c_current_addr_sk)) otherCondition=() +----------------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------------PhysicalProject +--------------------------------------------PhysicalOlapScan[customer] +----------------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------------PhysicalProject +--------------------------------------------PhysicalOlapScan[customer_address] ------------------------------PhysicalDistribute[DistributionSpecReplicated] --------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[date_dim] apply RFs: RF5 ---------------------------PhysicalDistribute[DistributionSpecReplicated] -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------PhysicalOlapScan[item] +----------------------------------PhysicalOlapScan[item] +----------------------------PhysicalDistribute[DistributionSpecReplicated] +------------------------------PhysicalProject +--------------------------------PhysicalOlapScan[date_dim] apply RFs: RF5 +------------------------PhysicalDistribute[DistributionSpecReplicated] +--------------------------hashAgg[GLOBAL] +----------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------hashAgg[LOCAL] +--------------------------------PhysicalProject +----------------------------------PhysicalOlapScan[item] ----------------------PhysicalDistribute[DistributionSpecReplicated] ------------------------PhysicalAssertNumRows --------------------------PhysicalDistribute[DistributionSpecGather] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query65.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query65.out index 0b84e050b6e8ac..f6dace38b89fc9 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query65.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query65.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------hashJoin[INNER_JOIN] hashCondition=((sb.ss_store_sk = sc.ss_store_sk)) otherCondition=((cast(revenue as DOUBLE) <= cast((0.1 * ave) as DOUBLE))) build RFs:RF4 ss_store_sk->[s_store_sk,ss_store_sk] +----------hashJoin[INNER_JOIN] hashCondition=((sb.ss_store_sk = sc.ss_store_sk)) otherCondition=((cast(revenue as DOUBLE) <= cast((0.1 * ave) as DOUBLE))) build RFs:RF4 ss_store_sk->[ss_store_sk,s_store_sk] ------------PhysicalProject --------------hashJoin[INNER_JOIN] hashCondition=((store.s_store_sk = sc.ss_store_sk)) otherCondition=() ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query13.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query13.out index 55f477930971cf..640a112c2b7e72 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query13.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query13.out @@ -7,7 +7,7 @@ PhysicalResultSink --------PhysicalProject ----------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] ------------PhysicalProject ---------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=(((((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)) OR ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) AND (household_demographics.hd_dep_count = 1))) OR ((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))) AND (household_demographics.hd_dep_count = 1)))) build RFs:RF3 hd_demo_sk->[ss_hdemo_sk] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=((((household_demographics.hd_dep_count = 1) AND ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) OR (((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))))) OR ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)))) build RFs:RF3 hd_demo_sk->[ss_hdemo_sk] ----------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_addr_sk = customer_address.ca_address_sk)) otherCondition=((((ca_state IN ('KS', 'MI', 'SD') AND ((store_sales.ss_net_profit >= 100.00) AND (store_sales.ss_net_profit <= 200.00))) OR (ca_state IN ('CO', 'MO', 'ND') AND ((store_sales.ss_net_profit >= 150.00) AND (store_sales.ss_net_profit <= 300.00)))) OR (ca_state IN ('NH', 'OH', 'TX') AND ((store_sales.ss_net_profit >= 50.00) AND (store_sales.ss_net_profit <= 250.00))))) build RFs:RF2 ca_address_sk->[ss_addr_sk] ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((store.s_store_sk = store_sales.ss_store_sk)) otherCondition=() build RFs:RF1 s_store_sk->[ss_store_sk] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query14.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query14.out index d3ae5a3b3045e1..21d1066c58e2ce 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query14.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query14.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 i_class_id->[i_class_id];RF7 i_category_id->[i_category_id];RF8 i_brand_id->[i_brand_id] +------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 i_brand_id->[i_brand_id];RF7 i_class_id->[i_class_id];RF8 i_category_id->[i_category_id] --------PhysicalIntersect ----------PhysicalDistribute[DistributionSpecHash] ------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query17.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query17.out index ef6485cbcb6e5e..747b928caaafe0 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query17.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query17.out @@ -18,7 +18,7 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------hashJoin[INNER_JOIN] hashCondition=((d1.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=() build RFs:RF5 d_date_sk->[ss_sold_date_sk] ----------------------------------PhysicalProject -------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_customer_sk = catalog_sales.cs_bill_customer_sk) and (store_returns.sr_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF3 cs_bill_customer_sk->[ss_customer_sk,sr_customer_sk];RF4 cs_item_sk->[ss_item_sk,sr_item_sk] +------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_customer_sk = catalog_sales.cs_bill_customer_sk) and (store_returns.sr_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF3 cs_bill_customer_sk->[sr_customer_sk,ss_customer_sk];RF4 cs_item_sk->[sr_item_sk,ss_item_sk] --------------------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------------------PhysicalProject ------------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF0 sr_customer_sk->[ss_customer_sk];RF1 sr_item_sk->[ss_item_sk];RF2 sr_ticket_number->[ss_ticket_number] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query24.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query24.out index 5f2ab88b0b8f36..489d26eb248480 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query24.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query24.out @@ -14,7 +14,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF2 c_customer_sk->[ss_customer_sk] ------------------------PhysicalDistribute[DistributionSpecHash] --------------------------PhysicalProject -----------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF0 sr_item_sk->[ss_item_sk];RF1 sr_ticket_number->[ss_ticket_number] +----------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF0 sr_ticket_number->[ss_ticket_number];RF1 sr_item_sk->[ss_item_sk] ------------------------------PhysicalProject --------------------------------PhysicalOlapScan[store_sales] apply RFs: RF0 RF1 RF2 RF4 RF6 ------------------------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query41.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query41.out index f07731a85b9b06..b73a9538e32098 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query41.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query41.out @@ -19,6 +19,6 @@ PhysicalResultSink --------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------hashAgg[LOCAL] ------------------------------PhysicalProject ---------------------------------filter((((((((((((item.i_category = 'Women') AND i_color IN ('aquamarine', 'gainsboro')) AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((((item.i_category = 'Women') AND i_color IN ('chiffon', 'violet')) AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Women') AND i_color IN ('blanched', 'tomato')) AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Women') AND i_color IN ('almond', 'lime')) AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Men') AND i_color IN ('blue', 'chartreuse')) AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Men') AND i_color IN ('peru', 'saddle')) AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown')) AND i_size IN ('economy', 'medium')))) +--------------------------------filter((((item.i_category = 'Men') AND (((((i_size IN ('economy', 'medium') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) OR ((i_size IN ('economy', 'medium') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown'))) OR ((i_color IN ('blue', 'chartreuse') AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((i_color IN ('peru', 'saddle') AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large')))) OR ((item.i_category = 'Women') AND (((((i_color IN ('aquamarine', 'gainsboro') AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((i_color IN ('chiffon', 'violet') AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((i_color IN ('blanched', 'tomato') AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((i_color IN ('almond', 'lime') AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small')))))) ----------------------------------PhysicalOlapScan[item] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query47.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query47.out index 371ab5bf0aaab9..dbc3759558c205 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query47.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query47.out @@ -37,6 +37,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------PhysicalTopN[LOCAL_SORT] ------------PhysicalProject --------------hashJoin[INNER_JOIN] hashCondition=((v1.i_brand = v1_lead.i_brand) and (v1.i_category = v1_lead.i_category) and (v1.rn = expr_(rn - 1)) and (v1.s_company_name = v1_lead.s_company_name) and (v1.s_store_name = v1_lead.s_store_name)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalProject +--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------hashJoin[INNER_JOIN] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() --------------------PhysicalDistribute[DistributionSpecHash] @@ -46,7 +49,3 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------------PhysicalProject ------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) --------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) -----------------PhysicalDistribute[DistributionSpecHash] -------------------PhysicalProject ---------------------PhysicalCteConsumer ( cteId=CTEId#0 ) - diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query50.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query50.out index 182e94c690f7df..ea53d954beff20 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query50.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query50.out @@ -13,7 +13,7 @@ PhysicalResultSink --------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=() build RFs:RF4 s_store_sk->[ss_store_sk] ----------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = d1.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ss_sold_date_sk] ------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF0 sr_customer_sk->[ss_customer_sk];RF1 sr_item_sk->[ss_item_sk];RF2 sr_ticket_number->[ss_ticket_number] +--------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF0 sr_ticket_number->[ss_ticket_number];RF1 sr_item_sk->[ss_item_sk];RF2 sr_customer_sk->[ss_customer_sk] ----------------------------PhysicalProject ------------------------------PhysicalOlapScan[store_sales] apply RFs: RF0 RF1 RF2 RF3 RF4 ----------------------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query57.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query57.out index 9ebb7ceb6cb32e..40ecf4c1541a61 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query57.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query57.out @@ -37,6 +37,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------PhysicalTopN[LOCAL_SORT] ------------PhysicalProject --------------hashJoin[INNER_JOIN] hashCondition=((v1.cc_name = v1_lead.cc_name) and (v1.i_brand = v1_lead.i_brand) and (v1.i_category = v1_lead.i_category) and (v1.rn = expr_(rn - 1))) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalProject +--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------hashJoin[INNER_JOIN] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() --------------------PhysicalDistribute[DistributionSpecHash] @@ -46,7 +49,3 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------------PhysicalProject ------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) --------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) -----------------PhysicalDistribute[DistributionSpecHash] -------------------PhysicalProject ---------------------PhysicalCteConsumer ( cteId=CTEId#0 ) - diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query6.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query6.out index c6f981470aadf8..7b7d383540c609 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query6.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query6.out @@ -11,37 +11,36 @@ PhysicalResultSink ----------------hashAgg[LOCAL] ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((d.d_month_seq = date_dim.d_month_seq)) otherCondition=() build RFs:RF5 d_month_seq->[d_month_seq] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((j.i_category = i.i_category)) otherCondition=((cast(i_current_price as DECIMALV3(38, 5)) > (1.2 * avg(cast(i_current_price as DECIMALV3(9, 4)))))) build RFs:RF4 i_category->[i_category] ---------------------------PhysicalProject -----------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_sold_date_sk = d.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ss_sold_date_sk] -------------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_item_sk = i.i_item_sk)) otherCondition=() build RFs:RF2 i_item_sk->[ss_item_sk] ---------------------------------PhysicalProject -----------------------------------hashJoin[INNER_JOIN] hashCondition=((c.c_customer_sk = s.ss_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] -------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------PhysicalProject -----------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2 RF3 -------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------PhysicalProject -----------------------------------------hashJoin[INNER_JOIN] hashCondition=((a.ca_address_sk = c.c_current_addr_sk)) otherCondition=() build RFs:RF0 ca_address_sk->[c_current_addr_sk] -------------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------------PhysicalProject -----------------------------------------------PhysicalOlapScan[customer] apply RFs: RF0 -------------------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------------------PhysicalProject -----------------------------------------------PhysicalOlapScan[customer_address] ---------------------------------PhysicalDistribute[DistributionSpecReplicated] -----------------------------------PhysicalProject -------------------------------------PhysicalOlapScan[item] apply RFs: RF4 +----------------------hashJoin[INNER_JOIN] hashCondition=((j.i_category = i.i_category)) otherCondition=((cast(i_current_price as DECIMALV3(38, 5)) > (1.2 * avg(cast(i_current_price as DECIMALV3(9, 4)))))) build RFs:RF4 i_category->[i_category] +------------------------PhysicalProject +--------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_sold_date_sk = d.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ss_sold_date_sk] +----------------------------hashJoin[INNER_JOIN] hashCondition=((s.ss_item_sk = i.i_item_sk)) otherCondition=() build RFs:RF2 i_item_sk->[ss_item_sk] +------------------------------PhysicalProject +--------------------------------hashJoin[INNER_JOIN] hashCondition=((c.c_customer_sk = s.ss_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] +----------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------PhysicalProject +--------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2 RF3 +----------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------PhysicalProject +--------------------------------------hashJoin[INNER_JOIN] hashCondition=((a.ca_address_sk = c.c_current_addr_sk)) otherCondition=() build RFs:RF0 ca_address_sk->[c_current_addr_sk] +----------------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------------PhysicalProject +--------------------------------------------PhysicalOlapScan[customer] apply RFs: RF0 +----------------------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------------------PhysicalProject +--------------------------------------------PhysicalOlapScan[customer_address] ------------------------------PhysicalDistribute[DistributionSpecReplicated] --------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[date_dim] apply RFs: RF5 ---------------------------PhysicalDistribute[DistributionSpecReplicated] -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute[DistributionSpecHash] ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------PhysicalOlapScan[item] +----------------------------------PhysicalOlapScan[item] apply RFs: RF4 +----------------------------PhysicalDistribute[DistributionSpecReplicated] +------------------------------PhysicalProject +--------------------------------PhysicalOlapScan[date_dim] apply RFs: RF5 +------------------------PhysicalDistribute[DistributionSpecReplicated] +--------------------------hashAgg[GLOBAL] +----------------------------PhysicalDistribute[DistributionSpecHash] +------------------------------hashAgg[LOCAL] +--------------------------------PhysicalProject +----------------------------------PhysicalOlapScan[item] ----------------------PhysicalDistribute[DistributionSpecReplicated] ------------------------PhysicalAssertNumRows --------------------------PhysicalDistribute[DistributionSpecGather] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query64.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query64.out index ab0640721b4d55..7232618ba59e83 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query64.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query64.out @@ -47,7 +47,7 @@ PhysicalCteAnchor ( cteId=CTEId#1 ) ----------------------------------------------------------------------------PhysicalDistribute[DistributionSpecHash] ------------------------------------------------------------------------------hashAgg[LOCAL] --------------------------------------------------------------------------------PhysicalProject -----------------------------------------------------------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF0 cr_order_number->[cs_order_number];RF1 cr_item_sk->[cs_item_sk] +----------------------------------------------------------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF0 cr_item_sk->[cs_item_sk];RF1 cr_order_number->[cs_order_number] ------------------------------------------------------------------------------------PhysicalProject --------------------------------------------------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1 RF19 ------------------------------------------------------------------------------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query65.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query65.out index 76324a3ac4cf61..308eb4b3a2ec29 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query65.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query65.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------hashJoin[INNER_JOIN] hashCondition=((sb.ss_store_sk = sc.ss_store_sk)) otherCondition=((cast(revenue as DOUBLE) <= cast((0.1 * ave) as DOUBLE))) build RFs:RF4 ss_store_sk->[s_store_sk,ss_store_sk] +----------hashJoin[INNER_JOIN] hashCondition=((sb.ss_store_sk = sc.ss_store_sk)) otherCondition=((cast(revenue as DOUBLE) <= cast((0.1 * ave) as DOUBLE))) build RFs:RF4 ss_store_sk->[ss_store_sk,s_store_sk] ------------PhysicalProject --------------hashJoin[INNER_JOIN] hashCondition=((store.s_store_sk = sc.ss_store_sk)) otherCondition=() build RFs:RF3 s_store_sk->[ss_store_sk] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query85.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query85.out index fdc9d7548628b1..d0e1696471e278 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query85.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query85.out @@ -15,7 +15,7 @@ PhysicalResultSink ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((customer_address.ca_address_sk = web_returns.wr_refunded_addr_sk)) otherCondition=((((ca_state IN ('DE', 'FL', 'TX') AND ((web_sales.ws_net_profit >= 100.00) AND (web_sales.ws_net_profit <= 200.00))) OR (ca_state IN ('ID', 'IN', 'ND') AND ((web_sales.ws_net_profit >= 150.00) AND (web_sales.ws_net_profit <= 300.00)))) OR (ca_state IN ('IL', 'MT', 'OH') AND ((web_sales.ws_net_profit >= 50.00) AND (web_sales.ws_net_profit <= 250.00))))) build RFs:RF7 ca_address_sk->[wr_refunded_addr_sk] ----------------------------PhysicalProject -------------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF4 cd_marital_status->[cd_marital_status];RF5 cd_education_status->[cd_education_status];RF6 cd_demo_sk->[wr_returning_cdemo_sk] +------------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF4 cd_demo_sk->[wr_returning_cdemo_sk];RF5 cd_marital_status->[cd_marital_status];RF6 cd_education_status->[cd_education_status] --------------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------------PhysicalProject ------------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_web_page_sk = web_page.wp_web_page_sk)) otherCondition=() build RFs:RF3 wp_web_page_sk->[ws_web_page_sk] @@ -24,14 +24,14 @@ PhysicalResultSink ------------------------------------------PhysicalProject --------------------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_item_sk = web_returns.wr_item_sk) and (web_sales.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF0 ws_item_sk->[wr_item_sk];RF1 ws_order_number->[wr_order_number] ----------------------------------------------PhysicalProject -------------------------------------------------PhysicalOlapScan[web_returns] apply RFs: RF0 RF1 RF2 RF6 RF7 RF9 +------------------------------------------------PhysicalOlapScan[web_returns] apply RFs: RF0 RF1 RF2 RF4 RF7 RF9 ----------------------------------------------PhysicalProject ------------------------------------------------filter((web_sales.ws_net_profit <= 300.00) and (web_sales.ws_net_profit >= 50.00) and (web_sales.ws_sales_price <= 200.00) and (web_sales.ws_sales_price >= 50.00)) --------------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF3 RF8 ----------------------------------------PhysicalDistribute[DistributionSpecHash] ------------------------------------------PhysicalProject --------------------------------------------filter(((((cd1.cd_marital_status = 'M') AND (cd1.cd_education_status = '4 yr Degree')) OR ((cd1.cd_marital_status = 'S') AND (cd1.cd_education_status = 'Secondary'))) OR ((cd1.cd_marital_status = 'W') AND (cd1.cd_education_status = 'Advanced Degree')))) -----------------------------------------------PhysicalOlapScan[customer_demographics] apply RFs: RF4 RF5 +----------------------------------------------PhysicalOlapScan[customer_demographics] apply RFs: RF5 RF6 --------------------------------------PhysicalDistribute[DistributionSpecReplicated] ----------------------------------------PhysicalProject ------------------------------------------PhysicalOlapScan[web_page] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query13.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query13.out index 441b3d76382391..accd0ebbb14132 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query13.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query13.out @@ -10,7 +10,7 @@ PhysicalResultSink --------------PhysicalOlapScan[store] apply RFs: RF4 ------------PhysicalDistribute[DistributionSpecHash] --------------PhysicalProject -----------------hashJoin[INNER_JOIN] hashCondition=((customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)) otherCondition=(((((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)) OR ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) AND (household_demographics.hd_dep_count = 1))) OR ((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))) AND (household_demographics.hd_dep_count = 1)))) build RFs:RF3 ss_cdemo_sk->[cd_demo_sk] +----------------hashJoin[INNER_JOIN] hashCondition=((customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)) otherCondition=((((household_demographics.hd_dep_count = 1) AND ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) OR (((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))))) OR ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)))) build RFs:RF3 ss_cdemo_sk->[cd_demo_sk] ------------------PhysicalProject --------------------filter(((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) OR ((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College'))) OR ((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')))) ----------------------PhysicalOlapScan[customer_demographics] apply RFs: RF3 diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query14.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query14.out index 281da092c6f533..9cd54a646a37c4 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query14.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query14.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 class_id->[i_class_id];RF7 category_id->[i_category_id];RF8 brand_id->[i_brand_id] +------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 brand_id->[i_brand_id];RF7 class_id->[i_class_id];RF8 category_id->[i_category_id] --------PhysicalProject ----------PhysicalOlapScan[item] apply RFs: RF6 RF7 RF8 --------PhysicalDistribute[DistributionSpecReplicated] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query41.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query41.out index f07731a85b9b06..b73a9538e32098 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query41.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query41.out @@ -19,6 +19,6 @@ PhysicalResultSink --------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------hashAgg[LOCAL] ------------------------------PhysicalProject ---------------------------------filter((((((((((((item.i_category = 'Women') AND i_color IN ('aquamarine', 'gainsboro')) AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((((item.i_category = 'Women') AND i_color IN ('chiffon', 'violet')) AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Women') AND i_color IN ('blanched', 'tomato')) AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Women') AND i_color IN ('almond', 'lime')) AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Men') AND i_color IN ('blue', 'chartreuse')) AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Men') AND i_color IN ('peru', 'saddle')) AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown')) AND i_size IN ('economy', 'medium')))) +--------------------------------filter((((item.i_category = 'Men') AND (((((i_size IN ('economy', 'medium') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) OR ((i_size IN ('economy', 'medium') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown'))) OR ((i_color IN ('blue', 'chartreuse') AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((i_color IN ('peru', 'saddle') AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large')))) OR ((item.i_category = 'Women') AND (((((i_color IN ('aquamarine', 'gainsboro') AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((i_color IN ('chiffon', 'violet') AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((i_color IN ('blanched', 'tomato') AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((i_color IN ('almond', 'lime') AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small')))))) ----------------------------------PhysicalOlapScan[item] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query50.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query50.out index f92a4b59d05a6e..eb871031f07f8c 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query50.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query50.out @@ -12,7 +12,7 @@ PhysicalResultSink ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = d1.d_date_sk)) otherCondition=() ----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF1 sr_customer_sk->[ss_customer_sk];RF2 sr_item_sk->[ss_item_sk];RF3 sr_ticket_number->[ss_ticket_number] +------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF1 sr_ticket_number->[ss_ticket_number];RF2 sr_item_sk->[ss_item_sk];RF3 sr_customer_sk->[ss_customer_sk] --------------------------PhysicalProject ----------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2 RF3 --------------------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_returned_date_sk = d2.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[sr_returned_date_sk] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query85.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query85.out index 2e35ec10486d9c..1e0d108c71a842 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query85.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query85.out @@ -22,7 +22,7 @@ PhysicalResultSink --------------------------------PhysicalOlapScan[web_page] apply RFs: RF7 ------------------------------PhysicalDistribute[DistributionSpecHash] --------------------------------PhysicalProject -----------------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF4 cd_marital_status->[cd_marital_status];RF5 cd_education_status->[cd_education_status];RF6 wr_returning_cdemo_sk->[cd_demo_sk] +----------------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF4 wr_returning_cdemo_sk->[cd_demo_sk];RF5 cd_marital_status->[cd_marital_status];RF6 cd_education_status->[cd_education_status] ------------------------------------PhysicalProject --------------------------------------PhysicalOlapScan[customer_demographics] apply RFs: RF4 RF5 RF6 ------------------------------------PhysicalDistribute[DistributionSpecReplicated] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query95.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query95.out index 976a8937349c99..a835868fd8c78e 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query95.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query95.out @@ -19,7 +19,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------hashAgg[GLOBAL] ----------------hashAgg[LOCAL] ------------------PhysicalProject ---------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF6 ws_order_number->[ws_order_number,wr_order_number] +--------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF6 ws_order_number->[wr_order_number,ws_order_number] ----------------------PhysicalProject ------------------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_order_number = ws_wh.ws_order_number)) otherCondition=() --------------------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query13.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query13.out index 441b3d76382391..accd0ebbb14132 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query13.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query13.out @@ -10,7 +10,7 @@ PhysicalResultSink --------------PhysicalOlapScan[store] apply RFs: RF4 ------------PhysicalDistribute[DistributionSpecHash] --------------PhysicalProject -----------------hashJoin[INNER_JOIN] hashCondition=((customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)) otherCondition=(((((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)) OR ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) AND (household_demographics.hd_dep_count = 1))) OR ((((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))) AND (household_demographics.hd_dep_count = 1)))) build RFs:RF3 ss_cdemo_sk->[cd_demo_sk] +----------------hashJoin[INNER_JOIN] hashCondition=((customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk)) otherCondition=((((household_demographics.hd_dep_count = 1) AND ((((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College')) AND ((store_sales.ss_sales_price >= 50.00) AND (store_sales.ss_sales_price <= 100.00))) OR (((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')) AND ((store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00))))) OR ((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) AND ((store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00))) AND (household_demographics.hd_dep_count = 3)))) build RFs:RF3 ss_cdemo_sk->[cd_demo_sk] ------------------PhysicalProject --------------------filter(((((customer_demographics.cd_marital_status = 'D') AND (customer_demographics.cd_education_status = 'Unknown')) OR ((customer_demographics.cd_marital_status = 'S') AND (customer_demographics.cd_education_status = 'College'))) OR ((customer_demographics.cd_marital_status = 'M') AND (customer_demographics.cd_education_status = '4 yr Degree')))) ----------------------PhysicalOlapScan[customer_demographics] apply RFs: RF3 diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query14.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query14.out index faa1a93b48f95e..af56a7fb3ac7da 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query14.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query14.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 class_id->[i_class_id];RF7 category_id->[i_category_id];RF8 brand_id->[i_brand_id] +------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 brand_id->[i_brand_id];RF7 class_id->[i_class_id];RF8 category_id->[i_category_id] --------PhysicalProject ----------PhysicalOlapScan[item] apply RFs: RF6 RF7 RF8 --------PhysicalDistribute[DistributionSpecReplicated] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out index 479424affd183f..152e46f1661deb 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query24.out @@ -7,7 +7,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------PhysicalDistribute[DistributionSpecHash] ----------hashAgg[LOCAL] ------------PhysicalProject ---------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF5 sr_item_sk->[i_item_sk,ss_item_sk];RF6 sr_ticket_number->[ss_ticket_number] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF5 sr_ticket_number->[ss_ticket_number];RF6 sr_item_sk->[ss_item_sk,i_item_sk] ----------------PhysicalProject ------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF4 i_item_sk->[ss_item_sk] --------------------PhysicalDistribute[DistributionSpecHash] @@ -31,7 +31,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------------------------PhysicalOlapScan[customer_address] --------------------PhysicalDistribute[DistributionSpecHash] ----------------------PhysicalProject -------------------------PhysicalOlapScan[item] apply RFs: RF5 +------------------------PhysicalOlapScan[item] apply RFs: RF6 ----------------PhysicalDistribute[DistributionSpecHash] ------------------PhysicalProject --------------------PhysicalOlapScan[store_returns] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out index f07731a85b9b06..b73a9538e32098 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out @@ -19,6 +19,6 @@ PhysicalResultSink --------------------------PhysicalDistribute[DistributionSpecHash] ----------------------------hashAgg[LOCAL] ------------------------------PhysicalProject ---------------------------------filter((((((((((((item.i_category = 'Women') AND i_color IN ('aquamarine', 'gainsboro')) AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((((item.i_category = 'Women') AND i_color IN ('chiffon', 'violet')) AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Women') AND i_color IN ('blanched', 'tomato')) AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Women') AND i_color IN ('almond', 'lime')) AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small'))) OR ((((item.i_category = 'Men') AND i_color IN ('blue', 'chartreuse')) AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) AND i_size IN ('economy', 'medium'))) OR ((((item.i_category = 'Men') AND i_color IN ('peru', 'saddle')) AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large'))) OR ((((item.i_category = 'Men') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown')) AND i_size IN ('economy', 'medium')))) +--------------------------------filter((((item.i_category = 'Men') AND (((((i_size IN ('economy', 'medium') AND i_color IN ('dodger', 'tan')) AND i_units IN ('Bunch', 'Tsp')) OR ((i_size IN ('economy', 'medium') AND i_color IN ('indian', 'spring')) AND i_units IN ('Carton', 'Unknown'))) OR ((i_color IN ('blue', 'chartreuse') AND i_units IN ('Each', 'Oz')) AND i_size IN ('N/A', 'large'))) OR ((i_color IN ('peru', 'saddle') AND i_units IN ('Gram', 'Pallet')) AND i_size IN ('N/A', 'large')))) OR ((item.i_category = 'Women') AND (((((i_color IN ('aquamarine', 'gainsboro') AND i_units IN ('Dozen', 'Ounce')) AND i_size IN ('economy', 'medium')) OR ((i_color IN ('chiffon', 'violet') AND i_units IN ('Pound', 'Ton')) AND i_size IN ('extra large', 'small'))) OR ((i_color IN ('blanched', 'tomato') AND i_units IN ('Case', 'Tbl')) AND i_size IN ('economy', 'medium'))) OR ((i_color IN ('almond', 'lime') AND i_units IN ('Box', 'Dram')) AND i_size IN ('extra large', 'small')))))) ----------------------------------PhysicalOlapScan[item] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query50.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query50.out index 40f241ec82e976..9a77a7d7d2f549 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query50.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query50.out @@ -12,7 +12,7 @@ PhysicalResultSink ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = d1.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] ----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF1 sr_customer_sk->[ss_customer_sk];RF2 sr_item_sk->[ss_item_sk];RF3 sr_ticket_number->[ss_ticket_number] +------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = store_returns.sr_customer_sk) and (store_sales.ss_item_sk = store_returns.sr_item_sk) and (store_sales.ss_ticket_number = store_returns.sr_ticket_number)) otherCondition=() build RFs:RF1 sr_ticket_number->[ss_ticket_number];RF2 sr_item_sk->[ss_item_sk];RF3 sr_customer_sk->[ss_customer_sk] --------------------------PhysicalProject ----------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2 RF3 RF4 RF5 --------------------------hashJoin[INNER_JOIN] hashCondition=((store_returns.sr_returned_date_sk = d2.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[sr_returned_date_sk] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query64.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query64.out index e1bb4def1c6972..fe890559af6f56 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query64.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query64.out @@ -51,7 +51,7 @@ PhysicalCteAnchor ( cteId=CTEId#1 ) ------------------------------------------------------------------------------PhysicalDistribute[DistributionSpecHash] --------------------------------------------------------------------------------hashAgg[LOCAL] ----------------------------------------------------------------------------------PhysicalProject -------------------------------------------------------------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF4 cr_order_number->[cs_order_number];RF5 cr_item_sk->[cs_item_sk] +------------------------------------------------------------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = catalog_returns.cr_item_sk) and (catalog_sales.cs_order_number = catalog_returns.cr_order_number)) otherCondition=() build RFs:RF4 cr_item_sk->[cs_item_sk];RF5 cr_order_number->[cs_order_number] --------------------------------------------------------------------------------------PhysicalProject ----------------------------------------------------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF4 RF5 RF11 --------------------------------------------------------------------------------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query85.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query85.out index 77aa3c3b857dd0..607b0f176f6686 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query85.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query85.out @@ -22,7 +22,7 @@ PhysicalResultSink --------------------------------PhysicalOlapScan[web_page] apply RFs: RF7 ------------------------------PhysicalDistribute[DistributionSpecHash] --------------------------------PhysicalProject -----------------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF4 cd_marital_status->[cd_marital_status];RF5 cd_education_status->[cd_education_status];RF6 wr_returning_cdemo_sk->[cd_demo_sk] +----------------------------------hashJoin[INNER_JOIN] hashCondition=((cd1.cd_education_status = cd2.cd_education_status) and (cd1.cd_marital_status = cd2.cd_marital_status) and (cd2.cd_demo_sk = web_returns.wr_returning_cdemo_sk)) otherCondition=() build RFs:RF4 wr_returning_cdemo_sk->[cd_demo_sk];RF5 cd_marital_status->[cd_marital_status];RF6 cd_education_status->[cd_education_status] ------------------------------------PhysicalProject --------------------------------------PhysicalOlapScan[customer_demographics] apply RFs: RF4 RF5 RF6 ------------------------------------PhysicalDistribute[DistributionSpecReplicated] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out index 7477231b372203..4763e6aa34cf49 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out @@ -19,7 +19,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --------------hashAgg[GLOBAL] ----------------hashAgg[LOCAL] ------------------PhysicalProject ---------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF12 ws_order_number->[ws_order_number,wr_order_number];RF13 ws_order_number->[ws_order_number,wr_order_number] +--------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((ws1.ws_order_number = web_returns.wr_order_number)) otherCondition=() build RFs:RF12 ws_order_number->[wr_order_number,ws_order_number];RF13 ws_order_number->[wr_order_number,ws_order_number] ----------------------PhysicalProject ------------------------hashJoin[INNER_JOIN] hashCondition=((web_returns.wr_order_number = ws_wh.ws_order_number)) otherCondition=() build RFs:RF10 wr_order_number->[ws_order_number];RF11 wr_order_number->[ws_order_number] --------------------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q9.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q9.out index f94b72a9700813..ef80c9fce5dfcc 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q9.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q9.out @@ -8,7 +8,7 @@ PhysicalResultSink ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] --------------PhysicalProject -----------------hashJoin[INNER_JOIN] hashCondition=((partsupp.ps_partkey = lineitem.l_partkey) and (partsupp.ps_suppkey = lineitem.l_suppkey)) otherCondition=() build RFs:RF4 ps_suppkey->[l_suppkey,s_suppkey];RF5 ps_partkey->[p_partkey,l_partkey] +----------------hashJoin[INNER_JOIN] hashCondition=((partsupp.ps_partkey = lineitem.l_partkey) and (partsupp.ps_suppkey = lineitem.l_suppkey)) otherCondition=() build RFs:RF4 ps_suppkey->[l_suppkey,s_suppkey];RF5 ps_partkey->[l_partkey,p_partkey] ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((supplier.s_suppkey = lineitem.l_suppkey)) otherCondition=() build RFs:RF3 s_suppkey->[l_suppkey] ----------------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q9.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q9.out index 0710be21e188d7..7ca43463734f03 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q9.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q9.out @@ -12,7 +12,7 @@ PhysicalResultSink ------------------PhysicalProject --------------------hashJoin[INNER_JOIN] hashCondition=((supplier.s_suppkey = lineitem.l_suppkey)) otherCondition=() build RFs:RF4 s_suppkey->[l_suppkey,ps_suppkey] ----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((partsupp.ps_partkey = lineitem.l_partkey) and (partsupp.ps_suppkey = lineitem.l_suppkey)) otherCondition=() build RFs:RF2 ps_suppkey->[l_suppkey];RF3 ps_partkey->[p_partkey,l_partkey] +------------------------hashJoin[INNER_JOIN] hashCondition=((partsupp.ps_partkey = lineitem.l_partkey) and (partsupp.ps_suppkey = lineitem.l_suppkey)) otherCondition=() build RFs:RF2 ps_suppkey->[l_suppkey];RF3 ps_partkey->[l_partkey,p_partkey] --------------------------PhysicalProject ----------------------------hashJoin[INNER_JOIN] hashCondition=((part.p_partkey = lineitem.l_partkey)) otherCondition=() build RFs:RF1 p_partkey->[l_partkey] ------------------------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy index 4c6f7b72e41125..51e97bb9d091ee 100644 --- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy +++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy @@ -867,7 +867,12 @@ class Suite implements GroovyInterceptable { throw new IllegalStateException("Check tag '${tag}' failed, sql:\n${arg}", t) } if (errorMsg != null) { - logger.warn("expect results: " + expectCsvResults + "\nrealResults: " + realResults) + String csvRealResult = realResults.stream() + .map {row -> OutputUtils.toCsvString(row)} + .collect(Collectors.joining("\n")) + def outputFilePath = context.outputFile.getCanonicalPath().substring(context.config.dataPath.length() + 1) + def line = expectCsvResults.currentLine() + logger.warn("expect results in file: ${outputFilePath}, line: ${line}\nrealResults:\n" + csvRealResult) throw new IllegalStateException("Check tag '${tag}' failed:\n${errorMsg}\n\nsql:\n${arg}") } } diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/OutputUtils.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/OutputUtils.groovy index 95f6f615c68381..d051a9dc654973 100644 --- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/OutputUtils.groovy +++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/OutputUtils.groovy @@ -92,7 +92,7 @@ class OutputUtils { return null } } - return "${info}, line ${line}, ${dataType} result mismatch.\nExpect cell is: ${expectCell}\nBut real is: ${realCell}\nrelative error is: ${realRelativeError}, bigger than ${expectRelativeError}" + return "${info}, line ${line}, ${dataType} result mismatch.\nExpect cell is: ${expectCell}\nBut real is : ${realCell}\nrelative error is: ${realRelativeError}, bigger than ${expectRelativeError}" } } } else if(dataType == "DATE" || dataType =="DATETIME") { @@ -100,11 +100,11 @@ class OutputUtils { realCell = realCell.replace("T", " ") if(!expectCell.equals(realCell)) { - return "${info}, line ${line}, ${dataType} result mismatch.\nExpect cell is: ${expectCell}\nBut real is: ${realCell}" + return "${info}, line ${line}, ${dataType} result mismatch.\nExpect cell is: ${expectCell}\nBut real is : ${realCell}" } } else { if(!expectCell.equals(realCell)) { - return "${info}, line ${line}, ${dataType} result mismatch.\nExpect cell is: ${expectCell}\nBut real is: ${realCell}" + return "${info}, line ${line}, ${dataType} result mismatch.\nExpect cell is: ${expectCell}\nBut real is : ${realCell}" } } @@ -141,7 +141,7 @@ class OutputUtils { def res = checkCell(info, line, expectCell, realCell, dataType) if(res != null) { - res += "\nline ${line} mismatch\nExpectRow: ${expectRaw}\nRealRow: ${realRaw}"; + res += "\nline ${line} mismatch\nExpectRow: ${expectRaw}\nRealRow : ${realRaw}"; return res } } @@ -149,7 +149,7 @@ class OutputUtils { def expectCsvString = transform1.apply(expectRaw) def realCsvString = transform2.apply(realRaw) if (!expectCsvString.equals(realCsvString)) { - return "${info}, line ${line} mismatch.\nExpect line is: ${expectCsvString}\nBut real is: ${realCsvString}" + return "${info}, line ${line} mismatch.\nExpect line is: ${expectCsvString}\nBut real is : ${realCsvString}" } } @@ -222,11 +222,15 @@ class OutputUtils { static class TagBlockIterator implements Iterator> { private final String tag + private final int startLine + private int currentLine private Iterator> it - TagBlockIterator(String tag, Iterator> it) { + TagBlockIterator(String tag, int startLine, Iterator> it) { this.tag = tag + this.startLine = startLine this.it = it + this.currentLine = startLine } String getTag() { @@ -240,7 +244,13 @@ class OutputUtils { @Override List next() { - return it.next() + def next = it.next() + currentLine++ + return next + } + + int currentLine() { + return currentLine } } @@ -284,7 +294,9 @@ class OutputUtils { return false } } - cache = new TagBlockIterator(tag, new CsvParserIterator(new SkipLastEmptyLineIterator(new OutputBlockIterator(lineIt)))) + + def csvIt = new CsvParserIterator(new SkipLastEmptyLineIterator(new OutputBlockIterator(lineIt))) + cache = new TagBlockIterator(tag, lineIt.getCurrentId(), csvIt) cached = true return true } else { diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/ReusableIterator.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/ReusableIterator.groovy index 68cf9e8cb91706..f9910ae7b8a10b 100644 --- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/ReusableIterator.groovy +++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/util/ReusableIterator.groovy @@ -24,9 +24,11 @@ class ReusableIterator implements CloseableIterator { private CloseableIterator it private T next private boolean cached + private int currentId ReusableIterator(CloseableIterator it) { this.it = it + this.currentId = 0 } @Override @@ -57,8 +59,13 @@ class ReusableIterator implements CloseableIterator { T next() { if (hasNext()) { cached = false + currentId++ return next } throw new NoSuchElementException() } + + int getCurrentId() { + return currentId + } } From 1149b2ae0e9c79ef5d3307bfa22cf672c7614b5d Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Tue, 2 Apr 2024 10:24:44 +0800 Subject: [PATCH 10/12] [fix](Nereids) fix link children failed (#33134) #32617 introduce a bug: rewrite may not working when plan's arity >= 3. this pr fix it (cherry picked from commit 8b070d1a9d43aa7d25225a79da81573c384ee825) --- .../jobs/rewrite/PlanTreeRewriteJob.java | 7 ++-- .../nereids/trees/plans/SetOperationTest.java | 38 +++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java index 5e5acc29f66edb..c2b136c40fad78 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java @@ -110,16 +110,17 @@ protected static Plan linkChildren(Plan plan, RewriteJobContext[] childrenContex } } default: { - boolean changed = false; + boolean anyChanged = false; int i = 0; Plan[] newChildren = new Plan[childrenContext.length]; for (Plan oldChild : children) { Plan result = childrenContext[i].result; - changed = result != null && result != oldChild; + boolean changed = result != null && result != oldChild; newChildren[i] = changed ? result : oldChild; + anyChanged |= changed; i++; } - return changed ? plan.withChildren(newChildren) : plan; + return anyChanged ? plan.withChildren(newChildren) : plan; } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/SetOperationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/SetOperationTest.java index fa7fcddc3f679a..b6932f846692f6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/SetOperationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/SetOperationTest.java @@ -17,9 +17,19 @@ package org.apache.doris.nereids.trees.plans; +import org.apache.doris.nereids.analyzer.UnboundAlias; +import org.apache.doris.nereids.analyzer.UnboundFunction; +import org.apache.doris.nereids.analyzer.UnboundOneRowRelation; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; +import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; public class SetOperationTest extends TestWithFeService { @@ -110,4 +120,32 @@ public void testUnion5() { PlanChecker.from(connectContext) .checkPlannerResult("select 1, 2 union all select 1, 2 union all select 10 e, 20 f;"); } + + @Test + public void testUnion6() { + LogicalOneRowRelation first = new LogicalOneRowRelation( + RelationId.createGenerator().getNextId(), ImmutableList.of( + new Alias(new Concat(new StringLiteral("1"), new StringLiteral("1"))) + )); + + UnboundOneRowRelation second = new UnboundOneRowRelation( + RelationId.createGenerator().getNextId(), ImmutableList.of( + new UnboundAlias(new UnboundFunction( + "concat", + ImmutableList.of(new StringLiteral("2"), new StringLiteral("2"))) + ) + )); + + LogicalOneRowRelation third = new LogicalOneRowRelation( + RelationId.createGenerator().getNextId(), ImmutableList.of( + new Alias(new Concat(new StringLiteral("3"), new StringLiteral("3"))) + )); + + LogicalUnion union = new LogicalUnion(Qualifier.ALL, ImmutableList.of( + first, second, third + )); + PlanChecker.from(connectContext, union) + .analyze() + .rewrite(); + } } From 021d39a87e73be3c58aa927239fe5dc0eb27735c Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Mon, 1 Apr 2024 21:28:39 +0800 Subject: [PATCH 11/12] [fix](Nereids) fix group concat (#33091) Fix failed in regression_test/suites/query_p0/group_concat/test_group_concat.groovy select group_concat( distinct b1, '?'), group_concat( distinct b3, '?') from table_group_concat group by b2 exception: lowestCostPlans with physicalProperties(GATHER) doesn't exist in root group The root cause is '?' is push down to slot by NormalizeAggregate, AggregateStrategies treat the slot as a distinct parameter and generate a invalid PhysicalHashAggregate, and then reject by ChildOutputPropertyDeriver. I fix this bug by avoid push down literal to slot in NormalizeAggregate, and forbidden generate stream aggregate node when group by slots is empty --- be/src/pipeline/pipeline_fragment_context.cpp | 9 ++++-- .../pipeline_x_fragment_context.cpp | 11 +++++-- be/src/runtime/descriptors.h | 5 +++ be/src/vec/exec/vaggregation_node.h | 1 + .../org/apache/doris/nereids/memo/Group.java | 22 +++++++++++-- .../doris/nereids/memo/GroupExpression.java | 5 +++ .../ChildrenPropertiesRegulator.java | 4 ++- .../properties/PhysicalProperties.java | 6 ++-- .../rules/analysis/NormalizeAggregate.java | 9 +++++- .../implementation/AggregateStrategies.java | 32 +++++++++++++++++++ 10 files changed, 93 insertions(+), 11 deletions(-) diff --git a/be/src/pipeline/pipeline_fragment_context.cpp b/be/src/pipeline/pipeline_fragment_context.cpp index c273a0c3807485..a32d777788d64a 100644 --- a/be/src/pipeline/pipeline_fragment_context.cpp +++ b/be/src/pipeline/pipeline_fragment_context.cpp @@ -559,7 +559,12 @@ Status PipelineFragmentContext::_build_pipelines(ExecNode* node, PipelinePtr cur auto* agg_node = dynamic_cast(node); auto new_pipe = add_pipeline(); RETURN_IF_ERROR(_build_pipelines(node->child(0), new_pipe)); - if (agg_node->is_aggregate_evaluators_empty()) { + if (agg_node->is_probe_expr_ctxs_empty() && node->row_desc().num_slots() == 0) { + return Status::InternalError("Illegal aggregate node " + + std::to_string(agg_node->id()) + + ": group by and output is empty"); + } + if (agg_node->is_aggregate_evaluators_empty() && !agg_node->is_probe_expr_ctxs_empty()) { auto data_queue = std::make_shared(1); OperatorBuilderPtr pre_agg_sink = std::make_shared(node->id(), agg_node, @@ -570,7 +575,7 @@ Status PipelineFragmentContext::_build_pipelines(ExecNode* node, PipelinePtr cur std::make_shared( node->id(), agg_node, data_queue); RETURN_IF_ERROR(cur_pipe->add_operator(pre_agg_source)); - } else if (agg_node->is_streaming_preagg()) { + } else if (agg_node->is_streaming_preagg() && !agg_node->is_probe_expr_ctxs_empty()) { auto data_queue = std::make_shared(1); OperatorBuilderPtr pre_agg_sink = std::make_shared( node->id(), agg_node, data_queue); diff --git a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp index 744ce754a5969e..5dac71e842057b 100644 --- a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp +++ b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp @@ -989,13 +989,20 @@ Status PipelineXFragmentContext::_create_operator(ObjectPool* pool, const TPlanN break; } case TPlanNodeType::AGGREGATION_NODE: { + if (tnode.agg_node.grouping_exprs.empty() && + descs.get_tuple_descriptor(tnode.agg_node.output_tuple_id)->slots().empty()) { + return Status::InternalError("Illegal aggregate node " + std::to_string(tnode.node_id) + + ": group by and output is empty"); + } if (tnode.agg_node.aggregate_functions.empty() && !_runtime_state->enable_agg_spill() && request.query_options.__isset.enable_distinct_streaming_aggregation && - request.query_options.enable_distinct_streaming_aggregation) { + request.query_options.enable_distinct_streaming_aggregation && + !tnode.agg_node.grouping_exprs.empty()) { op.reset(new DistinctStreamingAggOperatorX(pool, next_operator_id(), tnode, descs)); RETURN_IF_ERROR(cur_pipe->add_operator(op)); } else if (tnode.agg_node.__isset.use_streaming_preaggregation && - tnode.agg_node.use_streaming_preaggregation) { + tnode.agg_node.use_streaming_preaggregation && + !tnode.agg_node.grouping_exprs.empty()) { op.reset(new StreamingAggOperatorX(pool, next_operator_id(), tnode, descs)); RETURN_IF_ERROR(cur_pipe->add_operator(op)); } else { diff --git a/be/src/runtime/descriptors.h b/be/src/runtime/descriptors.h index fff1ed339d53ba..7cb7e9fe01540d 100644 --- a/be/src/runtime/descriptors.h +++ b/be/src/runtime/descriptors.h @@ -505,10 +505,12 @@ class RowDescriptor { _has_varlen_slots(desc._has_varlen_slots) { _num_materialized_slots = 0; _num_null_slots = 0; + _num_slots = 0; std::vector::const_iterator it = desc._tuple_desc_map.begin(); for (; it != desc._tuple_desc_map.end(); ++it) { _num_materialized_slots += (*it)->num_materialized_slots(); _num_null_slots += (*it)->num_null_slots(); + _num_slots += (*it)->slots().size(); } _num_null_bytes = (_num_null_slots + 7) / 8; } @@ -531,6 +533,8 @@ class RowDescriptor { int num_null_bytes() const { return _num_null_bytes; } + int num_slots() const { return _num_slots; } + static const int INVALID_IDX; // Returns INVALID_IDX if id not part of this row. @@ -585,6 +589,7 @@ class RowDescriptor { int _num_materialized_slots; int _num_null_slots; int _num_null_bytes; + int _num_slots; }; } // namespace doris diff --git a/be/src/vec/exec/vaggregation_node.h b/be/src/vec/exec/vaggregation_node.h index f09ebcbba83a60..f89bbb9d780d1d 100644 --- a/be/src/vec/exec/vaggregation_node.h +++ b/be/src/vec/exec/vaggregation_node.h @@ -416,6 +416,7 @@ class AggregationNode : public ::doris::ExecNode { Status pull(doris::RuntimeState* state, vectorized::Block* output_block, bool* eos) override; Status sink(doris::RuntimeState* state, vectorized::Block* input_block, bool eos) override; Status do_pre_agg(vectorized::Block* input_block, vectorized::Block* output_block); + bool is_probe_expr_ctxs_empty() const { return _probe_expr_ctxs.empty(); } bool is_streaming_preagg() const { return _is_streaming_preagg; } bool is_aggregate_evaluators_empty() const { return _aggregate_evaluators.empty(); } void _make_nullable_output_key(Block* block); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index 5a5abd56f95bd7..01968a03bef373 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -35,6 +35,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -65,7 +66,7 @@ public class Group { // Map of cost lower bounds // Map required plan props to cost lower bound of corresponding plan - private final Map> lowestCostPlans = Maps.newHashMap(); + private final Map> lowestCostPlans = Maps.newLinkedHashMap(); private boolean isExplored = false; @@ -228,6 +229,12 @@ public Optional> getLowestCostPlan(PhysicalPropertie return costAndGroupExpression; } + public Map getLowestCosts() { + return lowestCostPlans.entrySet() + .stream() + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> kv.getValue().first)); + } + public GroupExpression getBestPlan(PhysicalProperties properties) { if (lowestCostPlans.containsKey(properties)) { return lowestCostPlans.get(properties).second; @@ -489,9 +496,18 @@ public String toString() { public String treeString() { Function toString = obj -> { if (obj instanceof Group) { - return "Group[" + ((Group) obj).groupId + "]"; + Group group = (Group) obj; + Map lowestCosts = group.getLowestCosts(); + return "Group[" + group.groupId + ", lowestCosts: " + lowestCosts + "]"; } else if (obj instanceof GroupExpression) { - return ((GroupExpression) obj).getPlan().toString(); + GroupExpression groupExpression = (GroupExpression) obj; + Map>> lowestCostTable + = groupExpression.getLowestCostTable(); + Map requestPropertiesMap + = groupExpression.getRequestPropertiesMap(); + Cost cost = groupExpression.getCost(); + return groupExpression.getPlan().toString() + " [cost: " + cost + ", lowestCostTable: " + + lowestCostTable + ", requestPropertiesMap: " + requestPropertiesMap + "]"; } else if (obj instanceof Pair) { // print logicalExpressions or physicalExpressions // first is name, second is group expressions diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java index eda7f5c9c3575a..24bc9383b5264c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java @@ -35,6 +35,7 @@ import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -318,6 +319,10 @@ public void setEstOutputRowCount(double estOutputRowCount) { this.estOutputRowCount = estOutputRowCount; } + public Map getRequestPropertiesMap() { + return ImmutableMap.copyOf(requestPropertiesMap); + } + @Override public String toString() { DecimalFormat format = new DecimalFormat("#,###.##"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 7c5374ebd212ee..366730f7dc521e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -104,6 +104,9 @@ public Boolean visit(Plan plan, Void context) { @Override public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate agg, Void context) { + if (agg.getGroupByExpressions().isEmpty() && agg.getOutputExpressions().isEmpty()) { + return false; + } if (!agg.getAggregateParam().canBeBanned) { return true; } @@ -121,7 +124,6 @@ public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate return true; } return false; - } // forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java index 81e7190e163ecb..031f18ab918800 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/PhysicalProperties.java @@ -88,11 +88,13 @@ public static PhysicalProperties createHash( .map(SlotReference.class::cast) .map(SlotReference::getExprId) .collect(Collectors.toList()); - return createHash(partitionedSlots, shuffleType); + return partitionedSlots.isEmpty() ? PhysicalProperties.GATHER : createHash(partitionedSlots, shuffleType); } public static PhysicalProperties createHash(List orderedShuffledColumns, ShuffleType shuffleType) { - return new PhysicalProperties(new DistributionSpecHash(orderedShuffledColumns, shuffleType)); + return orderedShuffledColumns.isEmpty() + ? PhysicalProperties.GATHER + : new PhysicalProperties(new DistributionSpecHash(orderedShuffledColumns, shuffleType)); } public static PhysicalProperties createHash(DistributionSpecHash distributionSpecHash) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 7f6df51248e34a..e9b3d32da6e2a9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -31,6 +31,7 @@ import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; @@ -152,6 +153,9 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional> categorizedNoDistinctAggsChildren = aggFuncs.stream() .filter(aggFunc -> !aggFunc.isDistinct()) .flatMap(agg -> agg.children().stream()) + // should not push down literal under aggregate + // e.g. group_concat(distinct xxx, ','), the ',' literal show stay in aggregate + .filter(arg -> !(arg instanceof Literal)) .collect(Collectors.groupingBy( child -> child.containsType(SubqueryExpr.class, WindowExpression.class), ImmutableSet.toImmutableSet())); @@ -159,9 +163,12 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional> categorizedDistinctAggsChildren = aggFuncs.stream() + Map> categorizedDistinctAggsChildren = aggFuncs.stream() .filter(AggregateFunction::isDistinct) .flatMap(agg -> agg.children().stream()) + // should not push down literal under aggregate + // e.g. group_concat(distinct xxx, ','), the ',' literal show stay in aggregate + .filter(arg -> !(arg instanceof Literal)) .collect( Collectors.groupingBy( child -> !(child instanceof SlotReference), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 61aac4d2407462..edbd28677b4a00 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -70,6 +70,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp; +import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.qe.ConnectContext; @@ -1292,6 +1293,15 @@ private List> threePhaseAggregateWithDisti .build(); List localAggGroupBy = ImmutableList.copyOf(localAggGroupBySet); + boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty(); + + // be not recommend generate an aggregate node with empty group by and empty output, + // so add a null int slot to group by slot and output + if (isGroupByEmptySelectEmpty) { + localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE)); + localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE))); + } + boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy); List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY); @@ -1317,6 +1327,12 @@ private List> threePhaseAggregateWithDisti .addAll(nonDistinctAggFunctionToAliasPhase2.values()) .build(); + // be not recommend generate an aggregate node with empty group by and empty output, + // so add a null int slot to group by slot and output + if (isGroupByEmptySelectEmpty) { + globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE))); + } + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); PhysicalHashAggregate anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>( localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions), @@ -1680,6 +1696,16 @@ private List> fourPhaseAggregateWithDistin boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy); List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY); + + boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty(); + + // be not recommend generate an aggregate node with empty group by and empty output, + // so add a null int slot to group by slot and output + if (isGroupByEmptySelectEmpty) { + localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE)); + localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE))); + } + PhysicalHashAggregate anyLocalAgg = new PhysicalHashAggregate<>(localAggGroupBy, localAggOutput, Optional.of(partitionExpressions), inputToBufferParam, maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(), @@ -1702,6 +1728,12 @@ private List> fourPhaseAggregateWithDistin .addAll(nonDistinctAggFunctionToAliasPhase2.values()) .build(); + // be not recommend generate an aggregate node with empty group by and empty output, + // so add a null int slot to group by slot and output + if (isGroupByEmptySelectEmpty) { + globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE))); + } + RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); RequireProperties requireDistinctHash = RequireProperties.of( From 693690a0013f1aeef136791fae095277dcdbc7da Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Mon, 1 Apr 2024 21:28:53 +0800 Subject: [PATCH 12/12] [fix](Nereids) fix bind group by int literal (#33117) This sql will failed because 2 in the group by will bind to 1 as col2 in BindExpression ResolveOrdinalInOrderByAndGroupBy will replace 1 to MIN (LENGTH (cast(age as varchar))) CheckAnalysis will throw an exception because group by can not contains aggregate function select MIN (LENGTH (cast(age as varchar))), 1 AS col2 from test_bind_groupby_slots group by 2 we should move ResolveOrdinalInOrderByAndGroupBy into BindExpression (cherry picked from commit 3fab4496c3fefe95b4db01f300bf747080bfc3d8) --- .../doris/nereids/jobs/executor/Analyzer.java | 2 - .../rules/analysis/BindExpression.java | 33 +++++- .../ResolveOrdinalInOrderByAndGroupBy.java | 102 ------------------ .../data/nereids_syntax_p0/bind_priority.out | 6 ++ .../nereids_syntax_p0/bind_priority.groovy | 28 +++++ 5 files changed, 63 insertions(+), 108 deletions(-) delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ResolveOrdinalInOrderByAndGroupBy.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index 9ad10a30aa29bd..a0431e066beee8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -43,7 +43,6 @@ import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate; import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate; import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput; -import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy; import org.apache.doris.nereids.rules.analysis.SubqueryToApply; import org.apache.doris.nereids.rules.rewrite.MergeProjects; import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute; @@ -147,7 +146,6 @@ private static List buildAnalyzeJobs(Optional c // please see rule BindSlotReference or BindFunction for example new EliminateDistinctConstant(), new ProjectWithDistinctToAggregate(), - new ResolveOrdinalInOrderByAndGroupBy(), new ReplaceExpressionByChildOutput(), new OneRowRelationExtractAggregate() ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 6211f493eaf4e4..43f89d5b010c99 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -56,6 +56,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement; import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.plans.AbstractPlan; import org.apache.doris.nereids.trees.plans.JoinType; @@ -486,11 +487,12 @@ private LogicalSort bindSortWithSetOperation( LogicalSort sort = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; + List childOutput = sort.child().getOutput(); SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer( sort, cascadesContext, sort.children(), true, true); Builder boundKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); for (OrderKey orderKey : sort.getOrderKeys()) { - Expression boundKey = analyzer.analyze(orderKey.getExpr()); + Expression boundKey = bindWithOrdinal(orderKey.getExpr(), analyzer, childOutput); boundKeys.add(orderKey.withExpression(boundKey)); } return new LogicalSort<>(boundKeys.build(), sort.child()); @@ -699,7 +701,11 @@ private List bindGroupBy( return useOutputExpr.build(); }); - List boundGroupBy = analyzer.analyzeToList(groupBy); + ImmutableList.Builder boundGroupByBuilder = ImmutableList.builderWithExpectedSize(groupBy.size()); + for (Expression key : groupBy) { + boundGroupByBuilder.add(bindWithOrdinal(key, analyzer, boundAggOutput)); + } + List boundGroupBy = boundGroupByBuilder.build(); checkIfOutputAliasNameDuplicatedForGroupBy(boundGroupBy, boundAggOutput); return boundGroupBy; } @@ -723,6 +729,9 @@ private Supplier buildAggOutputScopeWithoutAggFun( private Plan bindSortWithoutSetOperation(MatchingContext> ctx) { LogicalSort sort = ctx.root; Plan input = sort.child(); + + List childOutput = input.getOutput(); + // we should skip LogicalHaving to bind slot in LogicalSort; if (input instanceof LogicalHaving) { input = input.child(0); @@ -744,7 +753,8 @@ private Plan bindSortWithoutSetOperation(MatchingContext> ctx) // group by col1 // order by col1; # order by order_col1 // bind order_col1 with alias_col1, then, bind it with inner_col1 - Scope inputScope = toScope(cascadesContext, input.getOutput()); + List inputSlots = input.getOutput(); + Scope inputScope = toScope(cascadesContext, inputSlots); final Plan finalInput = input; Supplier inputChildrenScope = Suppliers.memoize( @@ -766,7 +776,7 @@ private Plan bindSortWithoutSetOperation(MatchingContext> ctx) Builder boundOrderKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); for (OrderKey orderKey : sort.getOrderKeys()) { - Expression boundKey = analyzer.analyze(orderKey.getExpr()); + Expression boundKey = bindWithOrdinal(orderKey.getExpr(), analyzer, childOutput); boundOrderKeys.add(orderKey.withExpression(boundKey)); } return new LogicalSort<>(boundOrderKeys.build(), sort.child()); @@ -858,6 +868,21 @@ private boolean isAggregateFunction(UnboundFunction unboundFunction, FunctionReg unboundFunction.getDbName(), unboundFunction.getName()); } + private Expression bindWithOrdinal( + Expression unbound, SimpleExprAnalyzer analyzer, List boundSelectOutput) { + if (unbound instanceof IntegerLikeLiteral) { + int ordinal = ((IntegerLikeLiteral) unbound).getIntValue(); + if (ordinal >= 1 && ordinal <= boundSelectOutput.size()) { + Expression boundSelectItem = boundSelectOutput.get(ordinal - 1); + return boundSelectItem instanceof Alias ? boundSelectItem.child(0) : boundSelectItem; + } else { + return unbound; // bound literal + } + } else { + return analyzer.analyze(unbound); + } + } + private E checkBoundExceptLambda(E expression, Plan plan) { if (expression instanceof Lambda) { return expression; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ResolveOrdinalInOrderByAndGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ResolveOrdinalInOrderByAndGroupBy.java deleted file mode 100644 index 1cefd203ff74aa..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ResolveOrdinalInOrderByAndGroupBy.java +++ /dev/null @@ -1,102 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.analysis; - -import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalSort; - -import com.google.common.collect.ImmutableList; - -import java.util.ArrayList; -import java.util.List; - -/** - * SELECT col1, col2 FROM t1 ORDER BY 1 -> SELECT col1, col2 FROM t1 ORDER BY col1 - * SELECT col1, SUM(col2) FROM t1 GROUP BY 1 -> SELECT col1, SUM(col2) FROM t1 GROUP BY col1 - */ -public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory { - - @Override - public List buildRules() { - return ImmutableList.builder() - .add(RuleType.RESOLVE_ORDINAL_IN_ORDER_BY.build( - logicalSort().thenApply(ctx -> { - LogicalSort sort = ctx.root; - List childOutput = sort.child().getOutput(); - List orderKeys = sort.getOrderKeys(); - List orderKeysWithoutOrd = new ArrayList<>(); - for (OrderKey k : orderKeys) { - Expression expression = k.getExpr(); - if (expression instanceof IntegerLikeLiteral) { - IntegerLikeLiteral i = (IntegerLikeLiteral) expression; - int ord = i.getIntValue(); - checkOrd(ord, childOutput.size()); - orderKeysWithoutOrd - .add(new OrderKey(childOutput.get(ord - 1), k.isAsc(), k.isNullFirst())); - } else { - orderKeysWithoutOrd.add(k); - } - } - return sort.withOrderKeys(orderKeysWithoutOrd); - }) - )) - .add(RuleType.RESOLVE_ORDINAL_IN_GROUP_BY.build( - logicalAggregate().whenNot(LogicalAggregate::isOrdinalIsResolved).thenApply(ctx -> { - LogicalAggregate agg = ctx.root; - List aggOutput = agg.getOutputExpressions(); - List groupByWithoutOrd = new ArrayList<>(); - boolean ordExists = false; - for (Expression groupByExpr : agg.getGroupByExpressions()) { - if (groupByExpr instanceof IntegerLikeLiteral) { - IntegerLikeLiteral i = (IntegerLikeLiteral) groupByExpr; - int ord = i.getIntValue(); - checkOrd(ord, aggOutput.size()); - Expression aggExpr = aggOutput.get(ord - 1); - if (aggExpr instanceof Alias) { - aggExpr = ((Alias) aggExpr).child(); - } - groupByWithoutOrd.add(aggExpr); - ordExists = true; - } else { - groupByWithoutOrd.add(groupByExpr); - } - } - if (ordExists) { - return new LogicalAggregate<>(groupByWithoutOrd, agg.getOutputExpressions(), - true, agg.child()); - } else { - return agg; - } - }))).build(); - } - - private void checkOrd(int ord, int childOutputSize) { - if (ord < 1 || ord > childOutputSize) { - throw new IllegalStateException(String.format("ordinal exceeds number of items in select list: %s", ord)); - } - } -} diff --git a/regression-test/data/nereids_syntax_p0/bind_priority.out b/regression-test/data/nereids_syntax_p0/bind_priority.out index 53432880c2427a..56706546bab52b 100644 --- a/regression-test/data/nereids_syntax_p0/bind_priority.out +++ b/regression-test/data/nereids_syntax_p0/bind_priority.out @@ -85,3 +85,9 @@ all 2 -- !having_bind_group_by -- 2 1 +-- !sql -- +2 1 + +-- !sql -- +2 1 + diff --git a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy index 84bab14eba0980..769f1771982ab0 100644 --- a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy +++ b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy @@ -309,4 +309,32 @@ suite("bind_priority") { having pk = 2; """ }() + + def bindGroupBy = { + sql "drop table if exists test_bind_groupby_slots" + + sql """create table test_bind_groupby_slots + (id int, age int) + distributed by hash(id) + properties('replication_num'='1'); + """ + sql "insert into test_bind_groupby_slots values(1, 10), (2, 20), (3, 30);" + + order_qt_sql "select MIN (LENGTH (cast(age as varchar))), 1 AS col2 from test_bind_groupby_slots group by 2" + }() + + + + def bindOrderBy = { + sql "drop table if exists test_bind_orderby_slots" + + sql """create table test_bind_orderby_slots + (id int, age int) + distributed by hash(id) + properties('replication_num'='1'); + """ + sql "insert into test_bind_orderby_slots values(1, 10), (2, 20), (3, 30);" + + order_qt_sql "select MIN (LENGTH (cast(age as varchar))), 1 AS col2 from test_bind_orderby_slots order by 2" + }() }