From 1d3b076acc1c0f248a19c6149c0634e63a5b836e Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 13 May 2021 18:51:14 +0800 Subject: [PATCH 01/21] add window expr --- ballista/rust/core/proto/ballista.proto | 75 ++++++- .../core/src/serde/logical_plan/from_proto.rs | 190 ++++++++++++++++- .../core/src/serde/logical_plan/to_proto.rs | 114 +++++++++- .../src/serde/physical_plan/from_proto.rs | 80 ++++++- ballista/rust/scheduler/src/planner.rs | 8 + datafusion/src/logical_plan/builder.rs | 43 +++- datafusion/src/logical_plan/expr.rs | 33 ++- datafusion/src/logical_plan/plan.rs | 62 ++++-- datafusion/src/optimizer/constant_folding.rs | 1 + .../src/optimizer/hash_build_probe_order.rs | 5 + .../src/optimizer/projection_push_down.rs | 53 +++++ datafusion/src/optimizer/utils.rs | 21 ++ datafusion/src/physical_plan/aggregates.rs | 5 +- .../src/physical_plan/expressions/count.rs | 2 +- datafusion/src/physical_plan/mod.rs | 19 ++ datafusion/src/physical_plan/planner.rs | 68 +++++- datafusion/src/physical_plan/sort.rs | 1 + .../src/physical_plan/window_functions.rs | 165 +++++++++++++++ datafusion/src/physical_plan/windows.rs | 195 ++++++++++++++++++ datafusion/src/sql/planner.rs | 184 +++++++++++------ datafusion/src/sql/utils.rs | 15 ++ 21 files changed, 1236 insertions(+), 103 deletions(-) create mode 100644 datafusion/src/physical_plan/window_functions.rs create mode 100644 datafusion/src/physical_plan/windows.rs diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 3da0e85437d76..b5f3835f1f8c0 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -39,7 +39,6 @@ message LogicalExprNode { ScalarValue literal = 3; - // binary expressions BinaryExprNode binary_expr = 4; @@ -60,6 +59,9 @@ message LogicalExprNode { bool wildcard = 15; ScalarFunctionNode scalar_function = 16; TryCastNode try_cast = 17; + + // window expressions + WindowExprNode window_expr = 18; } } @@ -151,6 +153,25 @@ message AggregateExprNode { LogicalExprNode expr = 2; } +enum BuiltInWindowFunction { + ROW_NUMBER = 0; + RANK = 1; + DENSE_RANK = 2; + LAG = 3; + LEAD = 4; + FIRST_VALUE = 5; + LAST_VALUE = 6; +} + +message WindowExprNode { + oneof window_function { + AggregateFunction aggr_function = 1; + BuiltInWindowFunction built_in_function = 2; + // udaf = 3 + } + LogicalExprNode expr = 4; +} + message BetweenNode { LogicalExprNode expr = 1; bool negated = 2; @@ -200,6 +221,7 @@ message LogicalPlanNode { EmptyRelationNode empty_relation = 10; CreateExternalTableNode create_external_table = 11; ExplainNode explain = 12; + WindowNode window = 13; } } @@ -288,6 +310,49 @@ message AggregateNode { repeated LogicalExprNode aggr_expr = 3; } +message WindowNode { + LogicalPlanNode input = 1; + repeated LogicalExprNode window_expr = 2; + repeated LogicalExprNode partition_by_expr = 3; + repeated LogicalExprNode order_by_expr = 4; + // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430) + // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) + oneof window_frame { + WindowFrame frame = 5; + } +} + +enum WindowFrameUnits { + ROWS = 0; + RANGE = 1; + GROUPS = 2; +} + +message WindowFrame { + WindowFrameUnits window_frame_units = 1; + WindowFrameBound start_bound = 2; + // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430) + // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) + oneof end_bound { + WindowFrameBound bound = 3; + } +} + +enum WindowFrameBoundType { + CURRENT_ROW = 0; + PRECEDING = 1; + FOLLOWING = 2; +} + +message WindowFrameBound { + WindowFrameBoundType window_frame_bound_type = 1; + // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430) + // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) + oneof bound_value { + uint64 value = 2; + } +} + enum JoinType { INNER = 0; LEFT = 1; @@ -334,6 +399,7 @@ message PhysicalPlanNode { MergeExecNode merge = 14; UnresolvedShuffleExecNode unresolved = 15; RepartitionExecNode repartition = 16; + WindowAggExecNode window = 17; } } @@ -399,6 +465,13 @@ enum AggregateMode { FINAL_PARTITIONED = 2; } +message WindowAggExecNode { + PhysicalPlanNode input = 1; + repeated LogicalExprNode window_expr = 2; + repeated string window_expr_name = 3; + Schema input_schema = 4; +} + message HashAggregateExecNode { repeated LogicalExprNode group_expr = 1; repeated LogicalExprNode aggr_expr = 2; diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 6987035394c6d..2632aedcaaeb2 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -17,15 +17,15 @@ //! Serde code to convert from protocol buffers to Rust data structures. +use crate::error::BallistaError; +use crate::serde::{proto_error, protobuf}; +use crate::{convert_box_required, convert_required}; +use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::{ convert::{From, TryInto}, unimplemented, }; -use crate::error::BallistaError; -use crate::serde::{proto_error, protobuf}; -use crate::{convert_box_required, convert_required}; - use arrow::datatypes::{DataType, Field, Schema}; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, @@ -33,6 +33,7 @@ use datafusion::logical_plan::{ }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; +use datafusion::physical_plan::window_functions::BuiltInWindowFunction; use datafusion::scalar::ScalarValue; use protobuf::logical_plan_node::LogicalPlanType; use protobuf::{logical_expr_node::ExprType, scalar_type}; @@ -75,6 +76,33 @@ impl TryInto for &protobuf::LogicalPlanNode { .build() .map_err(|e| e.into()) } + LogicalPlanType::Window(window) => { + let input: LogicalPlan = convert_box_required!(window.input)?; + let window_expr = window + .window_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?; + + // let partition_by_expr = window + // .partition_by_expr + // .iter() + // .map(|expr| expr.try_into()) + // .collect::, _>>()?; + // let order_by_expr = window + // .order_by_expr + // .iter() + // .map(|expr| expr.try_into()) + // .collect::, _>>()?; + // // FIXME parse the window_frame data + // let window_frame = None; + LogicalPlanBuilder::from(&input) + .window( + window_expr, /*, partition_by_expr, order_by_expr, window_frame*/ + )? + .build() + .map_err(|e| e.into()) + } LogicalPlanType::Aggregate(aggregate) => { let input: LogicalPlan = convert_box_required!(aggregate.input)?; let group_expr = aggregate @@ -871,7 +899,10 @@ impl TryInto for &protobuf::LogicalExprNode { type Error = BallistaError; fn try_into(self) -> Result { + use datafusion::physical_plan::window_functions; use protobuf::logical_expr_node::ExprType; + use protobuf::window_expr_node; + use protobuf::WindowExprNode; let expr_type = self .expr_type @@ -889,6 +920,48 @@ impl TryInto for &protobuf::LogicalExprNode { let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?; Ok(Expr::Literal(scalar_value)) } + ExprType::WindowExpr(expr) => { + let window_function = expr + .window_function + .as_ref() + .ok_or_else(|| proto_error("Received empty window function"))?; + match window_function { + window_expr_node::WindowFunction::AggrFunction(i) => { + let aggr_function = protobuf::AggregateFunction::from_i32(*i) + .ok_or_else(|| { + proto_error(format!( + "Received an unknown aggregate window function: {}", + i + )) + })?; + + Ok(Expr::WindowFunction { + fun: window_functions::WindowFunction::AggregateFunction( + AggregateFunction::from(aggr_function), + ), + args: vec![parse_required_expr(&expr.expr)?], + }) + } + window_expr_node::WindowFunction::BuiltInFunction(i) => { + let built_in_function = + protobuf::BuiltInWindowFunction::from_i32(*i).ok_or_else( + || { + proto_error(format!( + "Received an unknown built-in window function: {}", + i + )) + }, + )?; + + Ok(Expr::WindowFunction { + fun: window_functions::WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::from(built_in_function), + ), + args: vec![parse_required_expr(&expr.expr)?], + }) + } + } + } ExprType::AggregateExpr(expr) => { let aggr_function = protobuf::AggregateFunction::from_i32(expr.aggr_function) @@ -898,13 +971,7 @@ impl TryInto for &protobuf::LogicalExprNode { expr.aggr_function )) })?; - let fun = match aggr_function { - protobuf::AggregateFunction::Min => AggregateFunction::Min, - protobuf::AggregateFunction::Max => AggregateFunction::Max, - protobuf::AggregateFunction::Sum => AggregateFunction::Sum, - protobuf::AggregateFunction::Avg => AggregateFunction::Avg, - protobuf::AggregateFunction::Count => AggregateFunction::Count, - }; + let fun = AggregateFunction::from(aggr_function); Ok(Expr::AggregateFunction { fun, @@ -1152,6 +1219,7 @@ impl TryInto for &protobuf::Field { } use datafusion::physical_plan::datetime_expressions::{date_trunc, to_timestamp}; +use datafusion::physical_plan::{aggregates, windows}; use datafusion::prelude::{ array, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper, }; @@ -1202,3 +1270,103 @@ fn parse_optional_expr( None => Ok(None), } } + +impl From for WindowFrameUnits { + fn from(units: protobuf::WindowFrameUnits) -> Self { + match units { + protobuf::WindowFrameUnits::Rows => WindowFrameUnits::Rows, + protobuf::WindowFrameUnits::Range => WindowFrameUnits::Range, + protobuf::WindowFrameUnits::Groups => WindowFrameUnits::Groups, + } + } +} + +impl TryFrom for WindowFrameBound { + type Error = BallistaError; + + fn try_from(bound: protobuf::WindowFrameBound) -> Result { + let bound_type = protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type).ok_or_else(|| { + proto_error(format!( + "Received a WindowFrameBound message with unknown WindowFrameBoundType {}", + bound.window_frame_bound_type + )) + })?.into(); + match bound_type { + protobuf::WindowFrameBoundType::CurrentRow => { + Ok(WindowFrameBound::CurrentRow) + } + protobuf::WindowFrameBoundType::Preceding => { + // FIXME implement bound value parsing + Ok(WindowFrameBound::Preceding(Some(1))) + } + protobuf::WindowFrameBoundType::Following => { + // FIXME implement bound value parsing + Ok(WindowFrameBound::Following(Some(1))) + } + } + } +} + +impl TryFrom for WindowFrame { + type Error = BallistaError; + + fn try_from(window: protobuf::WindowFrame) -> Result { + let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units) + .ok_or_else(|| { + proto_error(format!( + "Received a WindowFrame message with unknown WindowFrameUnits {}", + window.window_frame_units + )) + })? + .into(); + let start_bound = window + .start_bound + .ok_or_else(|| { + proto_error( + "Received a WindowFrame message with no start_bound".to_owned(), + ) + })? + .try_into()?; + // FIXME parse end bound + let end_bound = None; + Ok(WindowFrame { + units, + start_bound, + end_bound, + }) + } +} + +impl From for AggregateFunction { + fn from(aggr_function: protobuf::AggregateFunction) -> Self { + match aggr_function { + protobuf::AggregateFunction::Min => AggregateFunction::Min, + protobuf::AggregateFunction::Max => AggregateFunction::Max, + protobuf::AggregateFunction::Sum => AggregateFunction::Sum, + protobuf::AggregateFunction::Avg => AggregateFunction::Avg, + protobuf::AggregateFunction::Count => AggregateFunction::Count, + } + } +} + +impl From for BuiltInWindowFunction { + fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { + match built_in_function { + protobuf::BuiltInWindowFunction::RowNumber => { + BuiltInWindowFunction::RowNumber + } + protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank, + protobuf::BuiltInWindowFunction::DenseRank => { + BuiltInWindowFunction::DenseRank + } + protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag, + protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead, + protobuf::BuiltInWindowFunction::FirstValue => { + BuiltInWindowFunction::FirstValue + } + protobuf::BuiltInWindowFunction::LastValue => { + BuiltInWindowFunction::LastValue + } + } + } +} diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 01b669d264461..ddc2ad85e4d4e 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -26,16 +26,19 @@ use std::{ use crate::datasource::DfTableAdapter; use crate::serde::{protobuf, BallistaError}; - use arrow::datatypes::{DataType, Schema}; use datafusion::datasource::CsvFile; use datafusion::logical_plan::{Expr, JoinType, LogicalPlan}; use datafusion::physical_plan::aggregates::AggregateFunction; +use datafusion::physical_plan::window_functions::{ + BuiltInWindowFunction, WindowFunction, +}; use datafusion::{datasource::parquet::ParquetTable, logical_plan::exprlist_to_fields}; use protobuf::{ arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, Field, PrimitiveScalarType, ScalarListValue, ScalarType, }; +use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use super::super::proto_error; use datafusion::physical_plan::functions::BuiltinScalarFunction; @@ -772,6 +775,39 @@ impl TryInto for &LogicalPlan { ))), }) } + LogicalPlan::Window { + input, + window_expr, + // FIXME implement next + // partition_by_expr, + // FIXME implement next + // order_by_expr, + // FIXME implement next + // window_frame, + .. + } => { + let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?; + // FIXME: implement + let partition_by_expr = vec![]; + // FIXME: implement + let order_by_expr = vec![]; + // FIXME: implement + let window_frame = None; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Window(Box::new( + protobuf::WindowNode { + input: Some(Box::new(input)), + window_expr: window_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, BallistaError>>()?, + partition_by_expr, + order_by_expr, + window_frame, + }, + ))), + }) + } LogicalPlan::Aggregate { input, group_expr, @@ -997,6 +1033,30 @@ impl TryInto for &Expr { expr_type: Some(ExprType::BinaryExpr(binary_expr)), }) } + Expr::WindowFunction { + ref fun, ref args, .. + } => { + let window_function = match fun { + WindowFunction::AggregateFunction(fun) => { + protobuf::window_expr_node::WindowFunction::AggrFunction( + protobuf::AggregateFunction::from(fun).into(), + ) + } + WindowFunction::BuiltInWindowFunction(fun) => { + protobuf::window_expr_node::WindowFunction::BuiltInFunction( + protobuf::BuiltInWindowFunction::from(fun).into(), + ) + } + }; + let arg = &args[0]; + let window_expr = Box::new(protobuf::WindowExprNode { + expr: Some(Box::new(arg.try_into()?)), + window_function: Some(window_function), + }); + Ok(protobuf::LogicalExprNode { + expr_type: Some(ExprType::WindowExpr(window_expr)), + }) + } Expr::AggregateFunction { ref fun, ref args, .. } => { @@ -1178,6 +1238,58 @@ impl Into for &Schema { } } +impl From<&AggregateFunction> for protobuf::AggregateFunction { + fn from(value: &AggregateFunction) -> Self { + match value { + AggregateFunction::Min => Self::Min, + AggregateFunction::Max => Self::Max, + AggregateFunction::Sum => Self::Sum, + AggregateFunction::Avg => Self::Avg, + AggregateFunction::Count => Self::Count, + } + } +} + +impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { + fn from(value: &BuiltInWindowFunction) -> Self { + match value { + BuiltInWindowFunction::FirstValue => Self::FirstValue, + BuiltInWindowFunction::LastValue => Self::LastValue, + BuiltInWindowFunction::RowNumber => Self::RowNumber, + BuiltInWindowFunction::Rank => Self::Rank, + BuiltInWindowFunction::Lag => Self::Lag, + BuiltInWindowFunction::Lead => Self::Lead, + BuiltInWindowFunction::DenseRank => Self::DenseRank, + } + } +} + +impl From for protobuf::WindowFrameUnits { + fn from(units: WindowFrameUnits) -> Self { + match units { + WindowFrameUnits::Rows => protobuf::WindowFrameUnits::Rows, + WindowFrameUnits::Range => protobuf::WindowFrameUnits::Range, + WindowFrameUnits::Groups => protobuf::WindowFrameUnits::Groups, + } + } +} + +impl TryFrom for protobuf::WindowFrameBound { + type Error = BallistaError; + + fn try_from(bound: WindowFrameBound) -> Result { + unimplemented!("not implemented") + } +} + +impl TryFrom for protobuf::WindowFrame { + type Error = BallistaError; + + fn try_from(window: WindowFrame) -> Result { + unimplemented!("not implemented") + } +} + impl TryFrom<&arrow::datatypes::DataType> for protobuf::ScalarType { type Error = BallistaError; fn try_from(value: &arrow::datatypes::DataType) -> Result { diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 97f03948f7bd9..a8afc020fc6b1 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -28,7 +28,6 @@ use crate::serde::protobuf::LogicalExprNode; use crate::serde::scheduler::PartitionLocation; use crate::serde::{proto_error, protobuf}; use crate::{convert_box_required, convert_required}; - use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, @@ -43,6 +42,11 @@ use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::merge::MergeExec; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; +use datafusion::physical_plan::window_functions::{ + BuiltInWindowFunction, WindowFunction, +}; +use datafusion::physical_plan::windows::create_window_expr; +use datafusion::physical_plan::windows::WindowAggExec; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, csv::CsvExec, @@ -58,7 +62,7 @@ use datafusion::physical_plan::{ sort::{SortExec, SortOptions}, Partitioning, }; -use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr}; +use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use datafusion::prelude::CsvReadOptions; use log::debug; use protobuf::logical_expr_node::ExprType; @@ -189,6 +193,77 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let input: Arc = convert_box_required!(limit.input)?; Ok(Arc::new(LocalLimitExec::new(input, limit.limit as usize))) } + PhysicalPlanType::Window(window_agg) => { + let input: Arc = + convert_box_required!(window_agg.input)?; + let input_schema = window_agg + .input_schema + .as_ref() + .ok_or_else(|| { + BallistaError::General( + "input_schema in WindowAggrNode is missing.".to_owned(), + ) + })? + .clone(); + + let physical_schema: SchemaRef = + SchemaRef::new((&input_schema).try_into()?); + + let catalog_list = + Arc::new(MemoryCatalogList::new()) as Arc; + let ctx_state = ExecutionContextState { + catalog_list, + scalar_functions: Default::default(), + var_provider: Default::default(), + aggregate_functions: Default::default(), + config: ExecutionConfig::new(), + execution_props: ExecutionProps::new(), + }; + + let window_agg_expr: Vec<(Expr, String)> = window_agg + .window_expr + .iter() + .zip(window_agg.window_expr_name.iter()) + .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) + .collect::, _>>()?; + + let mut physical_window_expr = vec![]; + + let df_planner = DefaultPhysicalPlanner::default(); + + for (expr, name) in &window_agg_expr { + match expr { + Expr::WindowFunction { fun, args } => { + let arg = df_planner + .create_physical_expr( + &args[0], + &physical_schema, + &ctx_state, + ) + .map_err(|e| { + BallistaError::General(format!("{:?}", e)) + })?; + physical_window_expr.push(create_window_expr( + &fun, + &[arg], + &physical_schema, + name.to_owned(), + )?); + } + _ => { + return Err(BallistaError::General( + "Invalid expression for WindowAggrExec".to_string(), + )); + } + } + } + + Ok(Arc::new(WindowAggExec::try_new( + physical_window_expr, + input, + Arc::new((&input_schema).try_into()?), + )?)) + } PhysicalPlanType::HashAggregate(hash_agg) => { let input: Arc = convert_box_required!(hash_agg.input)?; @@ -248,6 +323,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let mut physical_aggr_expr = vec![]; + let df_planner = DefaultPhysicalPlanner::default(); for (expr, name) in &logical_agg_expr { match expr { Expr::AggregateFunction { fun, args, .. } => { diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 2f01e73e60591..b1d999b733334 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -35,6 +35,7 @@ use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::HashJoinExec; use datafusion::physical_plan::merge::MergeExec; +use datafusion::physical_plan::windows::WindowAggExec; use datafusion::physical_plan::ExecutionPlan; use log::info; @@ -150,6 +151,13 @@ impl DistributedPlanner { } else if let Some(join) = execution_plan.as_any().downcast_ref::() { Ok((join.with_new_children(children)?, stages)) + } else if let Some(window) = + execution_plan.as_any().downcast_ref::() + { + Err(BallistaError::NotImplemented(format!( + "WindowAggExec with window {:?}", + window + ))) } else { // TODO check for compatible partitioning schema, not just count if execution_plan.output_partitioning().partition_count() diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 2e69814d2634e..9e3f4f402b74e 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -24,18 +24,18 @@ use arrow::{ record_batch::RecordBatch, }; +use super::dfschema::ToDFSchema; +use super::{ + col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan, +}; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning}; use crate::{ datasource::{empty::EmptyTable, parquet::ParquetTable, CsvFile, MemTable}, prelude::CsvReadOptions, }; - -use super::dfschema::ToDFSchema; -use super::{ - col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan, -}; -use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning}; +use sqlparser::ast::WindowFrame; use std::collections::HashSet; /// Builder for logical plans @@ -289,6 +289,37 @@ impl LogicalPlanBuilder { })) } + /// Apply a window + pub fn window( + &self, + window_expr: impl IntoIterator, + // partition_by_expr: impl IntoIterator, + // order_by_expr: impl IntoIterator, + // window_frame: Option, + ) -> Result { + let window_expr = window_expr.into_iter().collect::>(); + // let partition_by_expr = partition_by_expr.into_iter().collect::>(); + // let order_by_expr = order_by_expr.into_iter().collect::>(); + let all_expr = window_expr.iter(); + validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?; + + let mut window_fields: Vec = + exprlist_to_fields(all_expr, self.plan.schema())?; + window_fields.extend_from_slice(self.plan.schema().fields()); + + Ok(Self::from(&LogicalPlan::Window { + input: Arc::new(self.plan.clone()), + // FIXME implement next + // partition_by_expr, + // FIXME implement next + // order_by_expr, + // FIXME implement next + // window_frame, + window_expr, + schema: Arc::new(DFSchema::new(window_fields)?), + })) + } + /// Apply an aggregate: grouping on the `group_expr` expressions /// and calculating `aggr_expr` aggregates for each distinct /// value of the `group_expr`; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 3365bf2603234..ab02559175302 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -30,6 +30,7 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::{DFField, DFSchema}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, + window_functions, }; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; @@ -190,6 +191,13 @@ pub enum Expr { /// Whether this is a DISTINCT aggregation or not distinct: bool, }, + /// Represents the call of a window function with arguments. + WindowFunction { + /// Name of the function + fun: window_functions::WindowFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, /// aggregate function AggregateUDF { /// The function @@ -244,6 +252,13 @@ impl Expr { .collect::>>()?; functions::return_type(fun, &data_types) } + Expr::WindowFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + window_functions::return_type(fun, &data_types) + } Expr::AggregateFunction { fun, args, .. } => { let data_types = args .iter() @@ -316,6 +331,7 @@ impl Expr { Expr::TryCast { .. } => Ok(true), Expr::ScalarFunction { .. } => Ok(true), Expr::ScalarUDF { .. } => Ok(true), + Expr::WindowFunction { .. } => Ok(true), Expr::AggregateFunction { .. } => Ok(true), Expr::AggregateUDF { .. } => Ok(true), Expr::Not(expr) => expr.nullable(input_schema), @@ -571,6 +587,9 @@ impl Expr { Expr::ScalarUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::WindowFunction { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), Expr::AggregateFunction { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), @@ -704,6 +723,10 @@ impl Expr { args: rewrite_vec(args, rewriter)?, fun, }, + Expr::WindowFunction { args, fun } => Expr::WindowFunction { + args: rewrite_vec(args, rewriter)?, + fun, + }, Expr::AggregateFunction { args, fun, @@ -1151,7 +1174,7 @@ pub fn create_udf( } /// Creates a new UDAF with a specific signature, state type and return type. -/// The signature and state type must match the `Acumulator's implementation`. +/// The signature and state type must match the `Accumulator's implementation`. #[allow(clippy::rc_buffer)] pub fn create_udaf( name: &str, @@ -1245,6 +1268,9 @@ impl fmt::Debug for Expr { Expr::ScalarUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args) } + Expr::WindowFunction { fun, ref args, .. } => { + fmt_function(f, &fun.to_string(), false, args) + } Expr::AggregateFunction { fun, distinct, @@ -1360,6 +1386,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::ScalarUDF { fun, args, .. } => { create_function_name(&fun.name, false, args, input_schema) } + Expr::WindowFunction { fun, args } => { + create_function_name(&fun.to_string(), false, args, input_schema) + } Expr::AggregateFunction { fun, distinct, @@ -1387,7 +1416,7 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } } other => Err(DataFusionError::NotImplemented(format!( - "Physical plan does not support logical expression {:?}", + "Create name does not support logical expression {:?}", other ))), } diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 8b9aac9ea73b9..ceab20a6a5e09 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -17,24 +17,21 @@ //! This module contains the `LogicalPlan` enum that describes queries //! via a logical query plan. -use std::{ - cmp::min, - fmt::{self, Display}, - sync::Arc, -}; - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - -use crate::datasource::TableProvider; -use crate::sql::parser::FileType; - use super::expr::Expr; use super::extension::UserDefinedLogicalNode; use super::{ col, display::{GraphvizVisitor, IndentVisitor}, }; +use crate::datasource::TableProvider; use crate::logical_plan::dfschema::DFSchemaRef; +use crate::sql::parser::FileType; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::{ + cmp::min, + fmt::{self, Display}, + sync::Arc, +}; /// Join type #[derive(Debug, Clone, Copy)] @@ -83,6 +80,21 @@ pub enum LogicalPlan { /// The incoming logical plan input: Arc, }, + /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) + Window { + /// The incoming logical plan + input: Arc, + /// The window function expression + window_expr: Vec, + /// Partition by expressions + // partition_by_expr: Vec, + /// Order by expressions + // order_by_expr: Vec, + /// Window Frame + // window_frame: Option, + /// The schema description of the window output + schema: DFSchemaRef, + }, /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). Aggregate { @@ -211,6 +223,7 @@ impl LogicalPlan { } => &projected_schema, LogicalPlan::Projection { schema, .. } => &schema, LogicalPlan::Filter { input, .. } => input.schema(), + LogicalPlan::Window { schema, .. } => &schema, LogicalPlan::Aggregate { schema, .. } => &schema, LogicalPlan::Sort { input, .. } => input.schema(), LogicalPlan::Join { schema, .. } => &schema, @@ -230,7 +243,8 @@ impl LogicalPlan { LogicalPlan::TableScan { projected_schema, .. } => vec![&projected_schema], - LogicalPlan::Aggregate { input, schema, .. } + LogicalPlan::Window { input, schema, .. } + | LogicalPlan::Aggregate { input, schema, .. } | LogicalPlan::Projection { input, schema, .. } => { let mut schemas = input.all_schemas(); schemas.insert(0, &schema); @@ -288,6 +302,14 @@ impl LogicalPlan { Partitioning::Hash(expr, _) => expr.clone(), _ => vec![], }, + LogicalPlan::Window { + window_expr, + // FIXME implement next + // partition_by_expr, + // FIXME implement next + // order_by_expr, + .. + } => window_expr.clone(), LogicalPlan::Aggregate { group_expr, aggr_expr, @@ -322,6 +344,7 @@ impl LogicalPlan { LogicalPlan::Projection { input, .. } => vec![input], LogicalPlan::Filter { input, .. } => vec![input], LogicalPlan::Repartition { input, .. } => vec![input], + LogicalPlan::Window { input, .. } => vec![input], LogicalPlan::Aggregate { input, .. } => vec![input], LogicalPlan::Sort { input, .. } => vec![input], LogicalPlan::Join { left, right, .. } => vec![left, right], @@ -415,6 +438,7 @@ impl LogicalPlan { LogicalPlan::Projection { input, .. } => input.accept(visitor)?, LogicalPlan::Filter { input, .. } => input.accept(visitor)?, LogicalPlan::Repartition { input, .. } => input.accept(visitor)?, + LogicalPlan::Window { input, .. } => input.accept(visitor)?, LogicalPlan::Aggregate { input, .. } => input.accept(visitor)?, LogicalPlan::Sort { input, .. } => input.accept(visitor)?, LogicalPlan::Join { left, right, .. } @@ -667,6 +691,20 @@ impl LogicalPlan { predicate: ref expr, .. } => write!(f, "Filter: {:?}", expr), + LogicalPlan::Window { + ref window_expr, + // FIXME implement next + // ref partition_by_expr, + // FIXME implement next + // ref order_by_expr, + .. + } => { + write!( + f, + "WindowAggr: windowExpr=[{:?}] partitionBy=[], orderBy=[]", + window_expr + ) + } LogicalPlan::Aggregate { ref group_expr, ref aggr_expr, diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 51bf0ce1b5054..af89aa13908c4 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -71,6 +71,7 @@ impl OptimizerRule for ConstantFolding { }), // Rest: recurse into plan, apply optimization where possible LogicalPlan::Projection { .. } + | LogicalPlan::Window { .. } | LogicalPlan::Aggregate { .. } | LogicalPlan::Repartition { .. } | LogicalPlan::CreateExternalTable { .. } diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index 168c4a17edfd0..100ae4fb09b73 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -54,6 +54,10 @@ fn get_num_rows(logical_plan: &LogicalPlan) -> Option { let num_rows_input = get_num_rows(input); num_rows_input.map(|rows| std::cmp::min(*limit, rows)) } + LogicalPlan::Window { input, .. } => { + // window functions do not change num of rows + get_num_rows(input) + } LogicalPlan::Aggregate { .. } => { // we cannot yet predict how many rows will be produced by an aggregate because // we do not know the cardinality of the grouping keys @@ -172,6 +176,7 @@ impl OptimizerRule for HashBuildProbeOrder { } // Rest: recurse into plan, apply optimization where possible LogicalPlan::Projection { .. } + | LogicalPlan::Window { .. } | LogicalPlan::Aggregate { .. } | LogicalPlan::TableScan { .. } | LogicalPlan::Limit { .. } diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 21c9caba3316d..353b475358674 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -193,6 +193,59 @@ fn optimize_plan( schema: schema.clone(), }) } + LogicalPlan::Window { + schema, + window_expr, + input, + // FIXME implement next + // partition_by_expr, + // FIXME implement next + // order_by_expr, + // FIXME implement next + // window_frame, + .. + } => { + // Gather all columns needed for expressions in this Window + let mut new_window_expr = Vec::new(); + window_expr.iter().try_for_each(|expr| { + let name = &expr.name(&schema)?; + if required_columns.contains(name) { + new_window_expr.push(expr.clone()); + new_required_columns.insert(name.clone()); + // add to the new set of required columns + utils::expr_to_column_names(expr, &mut new_required_columns) + } else { + Ok(()) + } + })?; + + let new_schema = DFSchema::new( + schema + .fields() + .iter() + .filter(|x| new_required_columns.contains(x.name())) + .cloned() + .collect(), + )?; + + Ok(LogicalPlan::Window { + window_expr: new_window_expr, + // FIXME implement next + // partition_by_expr: partition_by_expr.clone(), + // FIXME implement next + // order_by_expr: order_by_expr.clone(), + // FIXME implement next + // window_frame: window_frame.clone(), + input: Arc::new(optimize_plan( + optimizer, + &input, + &new_required_columns, + true, + execution_props, + )?), + schema: DFSchemaRef::new(new_schema), + }) + } LogicalPlan::Aggregate { schema, input, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 9288c65ac4dac..db26e5cb40653 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -78,6 +78,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { Expr::Sort { .. } => {} Expr::ScalarFunction { .. } => {} Expr::ScalarUDF { .. } => {} + Expr::WindowFunction { .. } => {} Expr::AggregateFunction { .. } => {} Expr::AggregateUDF { .. } => {} Expr::InList { .. } => {} @@ -188,6 +189,21 @@ pub fn from_plan( input: Arc::new(inputs[0].clone()), }), }, + LogicalPlan::Window { + // FIXME implement next + // partition_by_expr, + // FIXME implement next + // order_by_expr, + // FIXME implement next + // window_frame, + window_expr, + schema, + .. + } => Ok(LogicalPlan::Window { + input: Arc::new(inputs[0].clone()), + window_expr: expr[0..window_expr.len()].to_vec(), + schema: schema.clone(), + }), LogicalPlan::Aggregate { group_expr, schema, .. } => Ok(LogicalPlan::Aggregate { @@ -247,6 +263,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]), Expr::ScalarFunction { args, .. } => Ok(args.clone()), Expr::ScalarUDF { args, .. } => Ok(args.clone()), + Expr::WindowFunction { args, .. } => Ok(args.clone()), Expr::AggregateFunction { args, .. } => Ok(args.clone()), Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::Case { @@ -319,6 +336,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), + Expr::WindowFunction { fun, .. } => Ok(Expr::WindowFunction { + fun: fun.clone(), + args: expressions.to_vec(), + }), Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { fun: fun.clone(), args: expressions.to_vec(), diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 9417c7c8f05a5..393122e6cdce9 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -29,7 +29,7 @@ use super::{ functions::Signature, type_coercion::{coerce, data_types}, - Accumulator, AggregateExpr, PhysicalExpr, + Accumulator, AggregateExpr, PhysicalExpr, WindowExpr, }; use crate::error::{DataFusionError, Result}; use crate::physical_plan::distinct_expressions; @@ -37,7 +37,6 @@ use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Schema, TimeUnit}; use expressions::{avg_return_type, sum_return_type}; use std::{fmt, str::FromStr, sync::Arc}; - /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = Arc Result> + Send + Sync>; @@ -183,7 +182,7 @@ static TIMESTAMPS: &[DataType] = &[ ]; /// the signatures supported by the function `fun`. -fn signature(fun: &AggregateFunction) -> Signature { +pub fn signature(fun: &AggregateFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match fun { AggregateFunction::Count => Signature::Any(1), diff --git a/datafusion/src/physical_plan/expressions/count.rs b/datafusion/src/physical_plan/expressions/count.rs index 4a3fbe4fa7d3d..4dac22c6dc3f1 100644 --- a/datafusion/src/physical_plan/expressions/count.rs +++ b/datafusion/src/physical_plan/expressions/count.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::error::Result; -use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr, WindowExpr}; use crate::scalar::ScalarValue; use arrow::compute; use arrow::datatypes::DataType; diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index e915b2c257ddc..c053229bc000b 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -442,6 +442,23 @@ pub trait AggregateExpr: Send + Sync + Debug { } } +/// A window expression that: +/// * knows its resulting field +pub trait WindowExpr: Send + Sync + Debug { + /// Returns the window expression as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// the field of the final result of this window function. + fn field(&self) -> Result; + + /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default + /// implementation returns placeholder text. + fn name(&self) -> &str { + "WindowExpr: default name" + } +} + /// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and /// generically accumulates values. An accumulator knows how to: /// * update its state from inputs via `update` @@ -530,3 +547,5 @@ pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; pub mod union; +pub mod window_functions; +pub mod windows; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 9e7dc7172b820..64ad504bac982 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -21,7 +21,8 @@ use std::sync::Arc; use super::{ aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary, - functions, hash_join::PartitionMode, udaf, union::UnionExec, + functions, hash_join::PartitionMode, udaf, union::UnionExec, window_functions, + windows, }; use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ @@ -39,8 +40,11 @@ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sort::SortExec; use crate::physical_plan::udf; +use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{hash_utils, Partitioning}; -use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner}; +use crate::physical_plan::{ + AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner, WindowExpr, +}; use crate::prelude::JoinType; use crate::scalar::ScalarValue; use crate::variable::VarType; @@ -48,10 +52,9 @@ use crate::{ error::{DataFusionError, Result}, physical_plan::displayable, }; -use arrow::{compute::can_cast_types, datatypes::DataType}; - use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; +use arrow::{compute::can_cast_types, datatypes::DataType}; use expressions::col; use log::debug; @@ -139,6 +142,32 @@ impl DefaultPhysicalPlanner { limit, .. } => source.scan(projection, batch_size, filters, *limit), + LogicalPlan::Window { + input, window_expr, .. + } => { + // Initially need to perform the aggregate and then merge the partitions + let input_exec = self.create_initial_plan(input, ctx_state)?; + let input_schema = input_exec.schema(); + let physical_input_schema = input_exec.as_ref().schema(); + let logical_input_schema = input.as_ref().schema(); + let window_expr = window_expr + .iter() + .map(|e| { + self.create_window_expr( + e, + &logical_input_schema, + &physical_input_schema, + ctx_state, + ) + }) + .collect::>>()?; + + Ok(Arc::new(WindowAggExec::try_new( + window_expr, + input_exec.clone(), + input_schema.clone(), + )?)) + } LogicalPlan::Aggregate { input, group_expr, @@ -700,6 +729,37 @@ impl DefaultPhysicalPlanner { } } + /// Create a window expression from a logical expression + pub fn create_window_expr( + &self, + e: &Expr, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + ctx_state: &ExecutionContextState, + ) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) over () as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (e.name(logical_input_schema)?, e), + }; + + match e { + Expr::WindowFunction { fun, args } => { + let args = args + .iter() + .map(|e| { + self.create_physical_expr(e, physical_input_schema, ctx_state) + }) + .collect::>>()?; + windows::create_window_expr(fun, &args, physical_input_schema, name) + } + other => Err(DataFusionError::Internal(format!( + "Invalid window expression '{:?}'", + other + ))), + } + } + /// Create an aggregate expression from a logical expression pub fn create_aggregate_expr( &self, diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 8229060190215..caa32cfa264e1 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -135,6 +135,7 @@ impl ExecutionPlan for SortExec { "SortExec requires a single input partition".to_owned(), )); } + let input = self.input.execute(0).await?; Ok(Box::pin(SortStream::new( diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs new file mode 100644 index 0000000000000..311b5f9fb4468 --- /dev/null +++ b/datafusion/src/physical_plan/window_functions.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window functions provide the ability to perform calculations across +//! sets of rows that are related to the current query row. +//! +//! see also https://www.postgresql.org/docs/current/functions-window.html + +use crate::error::{DataFusionError, Result}; +use crate::execution::context::ExecutionContextState; +use crate::physical_plan::{ + aggregates, aggregates::AggregateFunction, functions::Signature, + type_coercion::data_types, PhysicalExpr, +}; +use arrow::datatypes::DataType; +use arrow::datatypes::{Schema, SchemaRef}; +use std::sync::Arc; +use std::{fmt, str::FromStr}; + +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WindowFunction { + /// window function that leverages an aggregate function + AggregateFunction(AggregateFunction), + /// window function that leverages a built-in window function + BuiltInWindowFunction(BuiltInWindowFunction), +} + +impl FromStr for WindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + if let Ok(aggregate) = AggregateFunction::from_str(name) { + Ok(WindowFunction::AggregateFunction(aggregate)) + } else if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name) { + Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) + } else { + Err(DataFusionError::Plan(format!( + "There is no built-in function named {}", + name + ))) + } + } +} + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // uppercase of the debug. + write!(f, "{}", format!("{:?}", self).to_uppercase()) + } +} + +impl fmt::Display for WindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunction::AggregateFunction(fun) => fun.fmt(f), + WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), + } + } +} + +/// An aggregate function that is part of a built-in window function +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BuiltInWindowFunction { + /// row number + RowNumber, + /// rank + Rank, + /// dense rank + DenseRank, + /// lag + Lag, + /// lead + Lead, + /// first value + FirstValue, + /// last value + LastValue, +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name { + "row_number" => BuiltInWindowFunction::RowNumber, + "rank" => BuiltInWindowFunction::Rank, + "dense_rank" => BuiltInWindowFunction::DenseRank, + "first_value" => BuiltInWindowFunction::FirstValue, + "last_value" => BuiltInWindowFunction::LastValue, + "lag" => BuiltInWindowFunction::Lag, + "lead" => BuiltInWindowFunction::Lead, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in function named {}", + name + ))) + } + }) + } +} + +/// Returns the datatype of the scalar function +pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(arg_types, &signature(fun))?; + + match fun { + WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types), + WindowFunction::BuiltInWindowFunction(fun) => match fun { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue => Ok(arg_types[0].clone()), + }, + } +} + +/// the signatures supported by the function `fun`. +fn signature(fun: &WindowFunction) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match fun { + WindowFunction::AggregateFunction(fun) => aggregates::signature(fun), + WindowFunction::BuiltInWindowFunction(fun) => match fun { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Signature::Any(0), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue => Signature::Any(1), + }, + } +} + +/// Create a physical (function) expression. +/// This function errors when `args`' can't be coerced to a valid argument type of the function. +pub fn create_physical_expr( + fun: &WindowFunction, + args: &[Arc], + input_schema: &Schema, + ctx_state: &ExecutionContextState, +) -> Result> { + Err(DataFusionError::NotImplemented(format!( + "Physical expr not implemented" + ))) +} diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs new file mode 100644 index 0000000000000..1cf4b7b5434dd --- /dev/null +++ b/datafusion/src/physical_plan/windows.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for window functions + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + aggregates, window_functions::WindowFunction, AggregateExpr, Distribution, + ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, WindowExpr, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use async_trait::async_trait; +use std::sync::Arc; +use std::{any::Any, pin::Pin}; + +/// Window execution plan +#[derive(Debug)] +pub struct WindowAggExec { + /// Input plan + input: Arc, + /// Window function expression + window_expr: Vec>, + /// Schema after the window is run + schema: SchemaRef, + /// Schema before the window + input_schema: SchemaRef, +} + +/// Create a physical expression for window function +pub fn create_window_expr( + fun: &WindowFunction, + args: &[Arc], + input_schema: &Schema, + name: String, +) -> Result> { + match fun { + WindowFunction::AggregateFunction(fun) => Ok(Arc::new(AggregateWindowExpr { + aggregate: aggregates::create_aggregate_expr( + fun, + false, + args, + input_schema, + name, + )?, + })), + WindowFunction::BuiltInWindowFunction(fun) => { + Err(DataFusionError::NotImplemented(format!( + "window funtion with {:?} not implemented", + fun + ))) + } + } +} + +/// A window expr that takes the form of a built in window function +#[derive(Debug)] +pub struct BuiltInWindowExpr {} + +/// A window expr that takes the form of an aggregate function +#[derive(Debug)] +pub struct AggregateWindowExpr { + aggregate: Arc, +} + +impl WindowExpr for AggregateWindowExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.aggregate.name() + } + + fn field(&self) -> Result { + self.aggregate.field() + } +} + +fn create_schema( + input_schema: &Schema, + window_expr: &[Arc], +) -> Result { + let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); + for expr in window_expr { + fields.push(expr.field()?); + } + fields.extend_from_slice(input_schema.fields()); + Ok(Schema::new(fields)) +} + +impl WindowAggExec { + /// Create a new execution plan for window aggregates + pub fn try_new( + window_expr: Vec>, + input: Arc, + input_schema: SchemaRef, + ) -> Result { + let schema = create_schema(&input.schema(), &window_expr)?; + let schema = Arc::new(schema); + Ok(WindowAggExec { + input, + window_expr, + schema, + input_schema, + }) + } + + /// Input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// Get the input schema before any aggregates are applied + pub fn input_schema(&self) -> SchemaRef { + self.input_schema.clone() + } +} + +#[async_trait] +impl ExecutionPlan for WindowAggExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn required_child_distribution(&self) -> Distribution { + Distribution::SinglePartition + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(WindowAggExec::try_new( + self.window_expr.clone(), + children[0].clone(), + children[0].schema().clone(), + )?)), + _ => Err(DataFusionError::Internal( + "WindowAggExec wrong number of children".to_owned(), + )), + } + } + + async fn execute(&self, partition: usize) -> Result { + if 0 != partition { + return Err(DataFusionError::Internal(format!( + "WindowAggExec invalid partition {}", + partition + ))); + } + + // window needs to operate on a single partition currently + if 1 != self.input.output_partitioning().partition_count() { + return Err(DataFusionError::Internal( + "WindowAggExec requires a single input partition".to_owned(), + )); + } + + // let input = self.input.execute(0).await?; + + Err(DataFusionError::NotImplemented( + "WindowAggExec::execute".to_owned(), + )) + } +} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 34c5901b450a2..da0aa22f7742c 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -35,7 +35,7 @@ use crate::{ }; use crate::{ physical_plan::udf::ScalarUDF, - physical_plan::{aggregates, functions}, + physical_plan::{aggregates, functions, window_functions}, sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}, }; @@ -57,7 +57,8 @@ use super::{ parser::DFParser, utils::{ can_columns_satisfy_exprs, expand_wildcard, expr_as_column_expr, extract_aliases, - find_aggregate_exprs, find_column_exprs, rebase_expr, resolve_aliases_to_exprs, + find_aggregate_exprs, find_column_exprs, find_window_exprs, rebase_expr, + resolve_aliases_to_exprs, }, }; @@ -413,7 +414,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )) } JoinConstraint::None => Err(DataFusionError::NotImplemented( - "NONE contraint is not supported".to_string(), + "NONE constraint is not supported".to_string(), )), } } @@ -624,15 +625,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; + // window function + let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); + + let (plan, exprs) = if window_func_exprs.is_empty() { + (plan, select_exprs_post_aggr) + } else { + self.window(&plan, window_func_exprs, &select_exprs_post_aggr)? + }; + let plan = if select.distinct { return LogicalPlanBuilder::from(&plan) - .aggregate(select_exprs_post_aggr, vec![])? + .aggregate(exprs, vec![])? .build(); } else { plan }; - self.project(&plan, select_exprs_post_aggr) + self.project(&plan, exprs) } /// Returns the `Expr`'s corresponding to a SQL query's SELECT expressions. @@ -657,10 +667,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Wrap a plan in a projection fn project(&self, input: &LogicalPlan, expr: Vec) -> Result { self.validate_schema_satisfies_exprs(&input.schema(), &expr)?; - LogicalPlanBuilder::from(input).project(expr)?.build() } + /// Wrap a plan in a window + fn window( + &self, + input: &LogicalPlan, + window_exprs: Vec, + select_exprs: &[Expr], + ) -> Result<(LogicalPlan, Vec)> { + let plan = LogicalPlanBuilder::from(input) + .window(window_exprs.clone())? + .build()?; + let select_exprs = select_exprs + .iter() + .map(|expr| expr_as_column_expr(&expr, &plan)) + .into_iter() + .collect::>>()?; + Ok((plan, select_exprs)) + } + + /// Wrap a plan in an aggregate fn aggregate( &self, input: &LogicalPlan, @@ -1059,70 +1087,69 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // first, scalar built-in if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) { - let args = function - .args - .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) - .collect::>>()?; + let args = self.function_args_to_expr(function)?; return Ok(Expr::ScalarFunction { fun, args }); }; + // then, window function + if let Some(window) = &function.over { + if window.partition_by.is_empty() + && window.order_by.is_empty() + && window.window_frame.is_none() + { + let fun = window_functions::WindowFunction::from_str(&name); + if let Ok(window_functions::WindowFunction::AggregateFunction( + aggregate_fun, + )) = fun + { + return Ok(Expr::WindowFunction { + fun: window_functions::WindowFunction::AggregateFunction( + aggregate_fun.clone(), + ), + args: self + .aggregate_fn_to_expr(&aggregate_fun, function)?, + }); + } else if let Ok( + window_functions::WindowFunction::BuiltInWindowFunction( + window_fun, + ), + ) = fun + { + return Ok(Expr::WindowFunction { + fun: window_functions::WindowFunction::BuiltInWindowFunction( + window_fun, + ), + args:self.function_args_to_expr(function)?, + }); + } + } + return Err(DataFusionError::NotImplemented(format!( + "Unsupported OVER clause ({})", + window + ))); + } + // next, aggregate built-ins if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) { - let args = if fun == aggregates::AggregateFunction::Count { - function - .args - .iter() - .map(|a| match a { - FunctionArg::Unnamed(SQLExpr::Value(Value::Number( - _, - _, - ))) => Ok(lit(1_u8)), - FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)), - _ => self.sql_fn_arg_to_logical_expr(a), - }) - .collect::>>()? - } else { - function - .args - .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) - .collect::>>()? - }; - - return match &function.over { - Some(window) => Err(DataFusionError::NotImplemented(format!( - "Unsupported OVER clause ({})", - window - ))), - _ => Ok(Expr::AggregateFunction { - fun, - distinct: function.distinct, - args, - }), - }; + let args = self.aggregate_fn_to_expr(&fun, function)?; + return Ok(Expr::AggregateFunction { + fun, + distinct: function.distinct, + args, + }); }; // finally, user-defined functions (UDF) and UDAF match self.schema_provider.get_function_meta(&name) { Some(fm) => { - let args = function - .args - .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) - .collect::>>()?; + let args = self.function_args_to_expr(function)?; Ok(Expr::ScalarUDF { fun: fm, args }) } None => match self.schema_provider.get_aggregate_meta(&name) { Some(fm) => { - let args = function - .args - .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) - .collect::>>()?; - + let args = self.function_args_to_expr(function)?; Ok(Expr::AggregateUDF { fun: fm, args }) } _ => Err(DataFusionError::Plan(format!( @@ -1142,6 +1169,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + fn function_args_to_expr( + &self, + function: &sqlparser::ast::Function, + ) -> Result> { + function + .args + .iter() + .map(|a| self.sql_fn_arg_to_logical_expr(a)) + .collect::>>() + } + + fn aggregate_fn_to_expr( + &self, + fun: &aggregates::AggregateFunction, + function: &sqlparser::ast::Function, + ) -> Result> { + if *fun == aggregates::AggregateFunction::Count { + function + .args + .iter() + .map(|a| match a { + FunctionArg::Unnamed(SQLExpr::Value(Value::Number(_, _))) => { + Ok(lit(1_u8)) + } + FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)), + _ => self.sql_fn_arg_to_logical_expr(a), + }) + .collect::>>() + } else { + self.function_args_to_expr(function) + } + } + fn sql_interval_to_literal( &self, value: &str, @@ -2641,13 +2701,17 @@ mod tests { } #[test] - fn over_not_supported() { + fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "NotImplemented(\"Unsupported OVER clause ()\")", - format!("{:?}", err) - ); + let expected = "Projection: #order_id, AGGREGATEFUNCTION(MAX)(#order_id)\n WindowAggr: windowExpr=[[AGGREGATEFUNCTION(MAX)(#order_id)]] partitionBy=[], orderBy=[]\n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn empty_over_with_alias() { + let sql = "SELECT order_id, MAX(order_id) OVER () max_order_id from orders"; + let expected = "Projection: #order_id, AGGREGATEFUNCTION(MAX)(#order_id) AS max_order_id\n WindowAggr: windowExpr=[[AGGREGATEFUNCTION(MAX)(#order_id)]] partitionBy=[], orderBy=[]\n TableScan: orders projection=None"; + quick_test(sql, expected); } #[test] diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index f41643d2ab449..70b9df0608397 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -46,6 +46,14 @@ pub(crate) fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { }) } +/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence +/// (depth first), with duplicates omitted. +pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec { + find_exprs_in_exprs(exprs, &|nested_expr| { + matches!(nested_expr, Expr::WindowFunction { .. }) + }) +} + /// Collect all deeply nested `Expr::Column`'s. They are returned in order of /// appearance (depth first), with duplicates omitted. pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec { @@ -217,6 +225,13 @@ where .collect::>>()?, distinct: *distinct, }), + Expr::WindowFunction { fun, args } => Ok(Expr::WindowFunction { + fun: fun.clone(), + args: args + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + }), Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF { fun: fun.clone(), args: args From a0b7526c413abbdd4aadab4af8ca9ad8f323f03b Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 19 May 2021 22:46:38 +0800 Subject: [PATCH 02/21] fix unused imports --- .../core/src/serde/logical_plan/to_proto.rs | 12 ++++++++---- .../core/src/serde/physical_plan/from_proto.rs | 1 - datafusion/src/logical_plan/builder.rs | 1 - datafusion/src/physical_plan/aggregates.rs | 2 +- .../src/physical_plan/expressions/count.rs | 2 +- datafusion/src/physical_plan/planner.rs | 3 +-- .../src/physical_plan/window_functions.rs | 18 +----------------- datafusion/src/physical_plan/windows.rs | 4 ++-- 8 files changed, 14 insertions(+), 29 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index ddc2ad85e4d4e..a2df1756c927e 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1277,16 +1277,20 @@ impl From for protobuf::WindowFrameUnits { impl TryFrom for protobuf::WindowFrameBound { type Error = BallistaError; - fn try_from(bound: WindowFrameBound) -> Result { - unimplemented!("not implemented") + fn try_from(_bound: WindowFrameBound) -> Result { + Err(BallistaError::NotImplemented( + "WindowFrameBound => protobuf::WindowFrameBound".to_owned(), + )) } } impl TryFrom for protobuf::WindowFrame { type Error = BallistaError; - fn try_from(window: WindowFrame) -> Result { - unimplemented!("not implemented") + fn try_from(_window: WindowFrame) -> Result { + Err(BallistaError::NotImplemented( + "WindowFrame => protobuf::WindowFrame".to_owned(), + )) } } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index a8afc020fc6b1..d034f3ca3bfee 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -297,7 +297,6 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) .collect::, _>>()?; - let df_planner = DefaultPhysicalPlanner::default(); let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; let ctx_state = ExecutionContextState { diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 9e3f4f402b74e..87716aff276c6 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -35,7 +35,6 @@ use crate::{ datasource::{empty::EmptyTable, parquet::ParquetTable, CsvFile, MemTable}, prelude::CsvReadOptions, }; -use sqlparser::ast::WindowFrame; use std::collections::HashSet; /// Builder for logical plans diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 393122e6cdce9..3607f29debba1 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -29,7 +29,7 @@ use super::{ functions::Signature, type_coercion::{coerce, data_types}, - Accumulator, AggregateExpr, PhysicalExpr, WindowExpr, + Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; use crate::physical_plan::distinct_expressions; diff --git a/datafusion/src/physical_plan/expressions/count.rs b/datafusion/src/physical_plan/expressions/count.rs index 4dac22c6dc3f1..4a3fbe4fa7d3d 100644 --- a/datafusion/src/physical_plan/expressions/count.rs +++ b/datafusion/src/physical_plan/expressions/count.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::error::Result; -use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr, WindowExpr}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; use arrow::compute; use arrow::datatypes::DataType; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 64ad504bac982..36945510da0a2 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -21,8 +21,7 @@ use std::sync::Arc; use super::{ aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary, - functions, hash_join::PartitionMode, udaf, union::UnionExec, window_functions, - windows, + functions, hash_join::PartitionMode, udaf, union::UnionExec, windows, }; use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 311b5f9fb4468..0fa34d66a6cd2 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -21,14 +21,11 @@ //! see also https://www.postgresql.org/docs/current/functions-window.html use crate::error::{DataFusionError, Result}; -use crate::execution::context::ExecutionContextState; use crate::physical_plan::{ aggregates, aggregates::AggregateFunction, functions::Signature, - type_coercion::data_types, PhysicalExpr, + type_coercion::data_types, }; use arrow::datatypes::DataType; -use arrow::datatypes::{Schema, SchemaRef}; -use std::sync::Arc; use std::{fmt, str::FromStr}; /// WindowFunction @@ -150,16 +147,3 @@ fn signature(fun: &WindowFunction) -> Signature { }, } } - -/// Create a physical (function) expression. -/// This function errors when `args`' can't be coerced to a valid argument type of the function. -pub fn create_physical_expr( - fun: &WindowFunction, - args: &[Arc], - input_schema: &Schema, - ctx_state: &ExecutionContextState, -) -> Result> { - Err(DataFusionError::NotImplemented(format!( - "Physical expr not implemented" - ))) -} diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 1cf4b7b5434dd..0c418d2785c53 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -22,10 +22,10 @@ use crate::physical_plan::{ aggregates, window_functions::WindowFunction, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, WindowExpr, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::datatypes::{Field, Schema, SchemaRef}; use async_trait::async_trait; +use std::any::Any; use std::sync::Arc; -use std::{any::Any, pin::Pin}; /// Window execution plan #[derive(Debug)] From 5c4d92dc9f570ba6919d84cb8ac70a736d73f40f Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 19 May 2021 22:48:26 +0800 Subject: [PATCH 03/21] fix clippy --- ballista/rust/core/src/serde/logical_plan/from_proto.rs | 2 +- datafusion/src/physical_plan/planner.rs | 2 +- datafusion/src/physical_plan/windows.rs | 2 +- datafusion/src/sql/planner.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 2632aedcaaeb2..e963c9c3b2256 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1290,7 +1290,7 @@ impl TryFrom for WindowFrameBound { "Received a WindowFrameBound message with unknown WindowFrameBoundType {}", bound.window_frame_bound_type )) - })?.into(); + })?; match bound_type { protobuf::WindowFrameBoundType::CurrentRow => { Ok(WindowFrameBound::CurrentRow) diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 36945510da0a2..018925d0e5356 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -164,7 +164,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(WindowAggExec::try_new( window_expr, input_exec.clone(), - input_schema.clone(), + input_schema, )?)) } LogicalPlan::Aggregate { diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 0c418d2785c53..ffc65ee3112f6 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -163,7 +163,7 @@ impl ExecutionPlan for WindowAggExec { 1 => Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - children[0].schema().clone(), + children[0].schema(), )?)), _ => Err(DataFusionError::Internal( "WindowAggExec wrong number of children".to_owned(), diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index da0aa22f7742c..24db97c141b59 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -678,7 +678,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec)> { let plan = LogicalPlanBuilder::from(input) - .window(window_exprs.clone())? + .window(window_exprs)? .build()?; let select_exprs = select_exprs .iter() From 3ee87aa3477c160f17a86628d71a353e03d736b3 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 19 May 2021 22:55:08 +0800 Subject: [PATCH 04/21] fix unit test --- datafusion/src/sql/planner.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 24db97c141b59..029aba90baea8 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -2703,14 +2703,20 @@ mod tests { #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; - let expected = "Projection: #order_id, AGGREGATEFUNCTION(MAX)(#order_id)\n WindowAggr: windowExpr=[[AGGREGATEFUNCTION(MAX)(#order_id)]] partitionBy=[], orderBy=[]\n TableScan: orders projection=None"; + let expected = "\ + Projection: #order_id, #MAX(order_id)\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\ + \n TableScan: orders projection=None"; quick_test(sql, expected); } #[test] - fn empty_over_with_alias() { - let sql = "SELECT order_id, MAX(order_id) OVER () max_order_id from orders"; - let expected = "Projection: #order_id, AGGREGATEFUNCTION(MAX)(#order_id) AS max_order_id\n WindowAggr: windowExpr=[[AGGREGATEFUNCTION(MAX)(#order_id)]] partitionBy=[], orderBy=[]\n TableScan: orders projection=None"; + fn empty_over_plus() { + let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty Multiply Float64(1.1))\ + \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[], orderBy=[]\ + \n TableScan: orders projection=None"; quick_test(sql, expected); } From f70c739fd40e30c4b476253e58b24b9297b42859 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 22:33:04 +0800 Subject: [PATCH 05/21] Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb --- datafusion/src/logical_plan/builder.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 87716aff276c6..af0d29ba5e56b 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -289,6 +289,8 @@ impl LogicalPlanBuilder { } /// Apply a window + /// + /// NOTE: this feature is under development and this API will be changing pub fn window( &self, window_expr: impl IntoIterator, From 831c069f02236a953653b8f1ca25124e393ce20b Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 22:34:04 +0800 Subject: [PATCH 06/21] Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb --- datafusion/src/logical_plan/builder.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index af0d29ba5e56b..e5e9a0e51572a 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -294,6 +294,7 @@ impl LogicalPlanBuilder { pub fn window( &self, window_expr: impl IntoIterator, + // filter: impl IntoIterator, // partition_by_expr: impl IntoIterator, // order_by_expr: impl IntoIterator, // window_frame: Option, From 0cbca53dac642233520f7d32289b1dfad77b882e Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 22:34:57 +0800 Subject: [PATCH 07/21] Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb --- datafusion/src/physical_plan/window_functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 0fa34d66a6cd2..dca88148180b4 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -101,7 +101,7 @@ impl FromStr for BuiltInWindowFunction { "lead" => BuiltInWindowFunction::Lead, _ => { return Err(DataFusionError::Plan(format!( - "There is no built-in function named {}", + "There is no built-in window function named {}", name ))) } From abf08cd137a80c1381af7de9ae2b3dab05cb4512 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 22:36:27 +0800 Subject: [PATCH 08/21] Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb --- datafusion/src/physical_plan/window_functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index dca88148180b4..1fac9709be4a1 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -109,7 +109,7 @@ impl FromStr for BuiltInWindowFunction { } } -/// Returns the datatype of the scalar function +/// Returns the datatype of the window function pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. From 8b486d53b09ff1c7a6b9cf4687796ba1c13d6160 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 23:17:22 +0800 Subject: [PATCH 09/21] adding more built-in functions --- ballista/rust/core/proto/ballista.proto | 12 ++- .../core/src/serde/logical_plan/from_proto.rs | 6 ++ .../core/src/serde/logical_plan/to_proto.rs | 4 + datafusion/src/logical_plan/builder.rs | 18 ++++- .../src/physical_plan/window_functions.rs | 78 ++++++++++++++++--- 5 files changed, 99 insertions(+), 19 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index b5f3835f1f8c0..926aefa6ab01e 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -157,10 +157,14 @@ enum BuiltInWindowFunction { ROW_NUMBER = 0; RANK = 1; DENSE_RANK = 2; - LAG = 3; - LEAD = 4; - FIRST_VALUE = 5; - LAST_VALUE = 6; + PERCENT_RANK = 3; + CUME_DIST = 4; + NTILE = 5; + LAG = 6; + LEAD = 7; + FIRST_VALUE = 8; + LAST_VALUE = 9; + NTH_VALUE = 10; } message WindowExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index e963c9c3b2256..7d0f429f1e4e4 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1356,6 +1356,9 @@ impl From for BuiltInWindowFunction { BuiltInWindowFunction::RowNumber } protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank, + protobuf::BuiltInWindowFunction::PercentRank => { + BuiltInWindowFunction::PercentRank + } protobuf::BuiltInWindowFunction::DenseRank => { BuiltInWindowFunction::DenseRank } @@ -1364,6 +1367,9 @@ impl From for BuiltInWindowFunction { protobuf::BuiltInWindowFunction::FirstValue => { BuiltInWindowFunction::FirstValue } + protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist, + protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile, + protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue, protobuf::BuiltInWindowFunction::LastValue => { BuiltInWindowFunction::LastValue } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index a2df1756c927e..4050d0bea2825 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1255,6 +1255,10 @@ impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { match value { BuiltInWindowFunction::FirstValue => Self::FirstValue, BuiltInWindowFunction::LastValue => Self::LastValue, + BuiltInWindowFunction::NthValue => Self::NthValue, + BuiltInWindowFunction::Ntile => Self::Ntile, + BuiltInWindowFunction::CumeDist => Self::CumeDist, + BuiltInWindowFunction::PercentRank => Self::PercentRank, BuiltInWindowFunction::RowNumber => Self::RowNumber, BuiltInWindowFunction::Rank => Self::Rank, BuiltInWindowFunction::Lag => Self::Lag, diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index e5e9a0e51572a..9515ac2ff3739 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -289,18 +289,30 @@ impl LogicalPlanBuilder { } /// Apply a window - /// - /// NOTE: this feature is under development and this API will be changing + /// + /// NOTE: this feature is under development and this API will be changing + /// + /// - https://github.com/apache/arrow-datafusion/issues/359 basic structure + /// - https://github.com/apache/arrow-datafusion/issues/298 empty over clause + /// - https://github.com/apache/arrow-datafusion/issues/299 with partition clause + /// - https://github.com/apache/arrow-datafusion/issues/360 with order by + /// - https://github.com/apache/arrow-datafusion/issues/361 with window frame pub fn window( &self, window_expr: impl IntoIterator, - // filter: impl IntoIterator, + // FIXME: implement next + // filter_by_expr: impl IntoIterator, + // FIXME: implement next // partition_by_expr: impl IntoIterator, + // FIXME: implement next // order_by_expr: impl IntoIterator, + // FIXME: implement next // window_frame: Option, ) -> Result { let window_expr = window_expr.into_iter().collect::>(); + // FIXME: implement next // let partition_by_expr = partition_by_expr.into_iter().collect::>(); + // FIXME: implement next // let order_by_expr = order_by_expr.into_iter().collect::>(); let all_expr = window_expr.iter(); validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?; diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 1fac9709be4a1..e6267250bdad1 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -72,33 +72,51 @@ impl fmt::Display for WindowFunction { /// An aggregate function that is part of a built-in window function #[derive(Debug, Clone, PartialEq, Eq)] pub enum BuiltInWindowFunction { - /// row number + /// number of the current row within its partition, counting from 1 RowNumber, - /// rank + /// rank of the current row with gaps; same as row_number of its first peer Rank, - /// dense rank + /// ank of the current row without gaps; this function counts peer groups DenseRank, - /// lag + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null Lag, - /// lead + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null Lead, - /// first value + /// returns value evaluated at the row that is the first row of the window frame FirstValue, - /// last value + /// returns value evaluated at the row that is the last row of the window frame LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, } impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { - Ok(match name { + Ok(match name.to_lowercase().as_str() { "row_number" => BuiltInWindowFunction::RowNumber, "rank" => BuiltInWindowFunction::Rank, "dense_rank" => BuiltInWindowFunction::DenseRank, - "first_value" => BuiltInWindowFunction::FirstValue, - "last_value" => BuiltInWindowFunction::LastValue, + "percent_rank" => BuiltInWindowFunction::PercentRank, + "cume_dist" => BuiltInWindowFunction::CumeDist, + "ntile" => BuiltInWindowFunction::Ntile, "lag" => BuiltInWindowFunction::Lag, "lead" => BuiltInWindowFunction::Lead, + "first_value" => BuiltInWindowFunction::FirstValue, + "last_value" => BuiltInWindowFunction::LastValue, + "nth_value" => BuiltInWindowFunction::NthValue, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in window function named {}", @@ -123,10 +141,15 @@ pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue => Ok(arg_types[0].clone()), + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()), }, } } @@ -139,11 +162,42 @@ fn signature(fun: &WindowFunction) -> Signature { WindowFunction::BuiltInWindowFunction(fun) => match fun { BuiltInWindowFunction::RowNumber | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Signature::Any(0), + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::Any(0), BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead | BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => Signature::Any(1), + BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]), + BuiltInWindowFunction::NthValue => Signature::Any(2), }, } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field}; + + #[test] + fn test_window_function_from_str() -> Result<()> { + assert_eq!( + WindowFunction::from_str("max")?, + WindowFunction::AggregateFunction(AggregateFunction::Max) + ); + assert_eq!( + WindowFunction::from_str("min")?, + WindowFunction::AggregateFunction(AggregateFunction::Min) + ); + assert_eq!( + WindowFunction::from_str("avg")?, + WindowFunction::AggregateFunction(AggregateFunction::Avg) + ); + assert_eq!( + WindowFunction::from_str("cum_dist")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) + ); + Ok(()) + } +} From 0d2a214131fe69e19e22144c68fbb992228db6b3 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 23:25:43 +0800 Subject: [PATCH 10/21] adding filter by todo --- ballista/rust/core/proto/ballista.proto | 1 + ballista/rust/core/src/serde/logical_plan/from_proto.rs | 5 +++-- ballista/rust/core/src/serde/logical_plan/to_proto.rs | 4 ++++ datafusion/src/logical_plan/plan.rs | 4 ++++ datafusion/src/optimizer/projection_push_down.rs | 2 ++ datafusion/src/optimizer/utils.rs | 2 ++ 6 files changed, 16 insertions(+), 2 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 926aefa6ab01e..da0c615e3b23e 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -324,6 +324,7 @@ message WindowNode { oneof window_frame { WindowFrame frame = 5; } + // TODO add filter by expr } enum WindowFrameUnits { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 7d0f429f1e4e4..020858fbfc3fe 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -94,11 +94,12 @@ impl TryInto for &protobuf::LogicalPlanNode { // .iter() // .map(|expr| expr.try_into()) // .collect::, _>>()?; - // // FIXME parse the window_frame data + // // FIXME: add filter by expr + // // FIXME: parse the window_frame data // let window_frame = None; LogicalPlanBuilder::from(&input) .window( - window_expr, /*, partition_by_expr, order_by_expr, window_frame*/ + window_expr, /* filter_by_expr, partition_by_expr, order_by_expr, window_frame*/ )? .build() .map_err(|e| e.into()) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 4050d0bea2825..47e27483ff307 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -779,6 +779,8 @@ impl TryInto for &LogicalPlan { input, window_expr, // FIXME implement next + // filter_by_expr, + // FIXME implement next // partition_by_expr, // FIXME implement next // order_by_expr, @@ -788,6 +790,8 @@ impl TryInto for &LogicalPlan { } => { let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?; // FIXME: implement + // let filter_by_expr = vec![]; + // FIXME: implement let partition_by_expr = vec![]; // FIXME: implement let order_by_expr = vec![]; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index ceab20a6a5e09..4027916c8a7cd 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -86,6 +86,8 @@ pub enum LogicalPlan { input: Arc, /// The window function expression window_expr: Vec, + /// Filter by expressions + // filter_by_expr: Vec, /// Partition by expressions // partition_by_expr: Vec, /// Order by expressions @@ -305,6 +307,8 @@ impl LogicalPlan { LogicalPlan::Window { window_expr, // FIXME implement next + // filter_by_expr, + // FIXME implement next // partition_by_expr, // FIXME implement next // order_by_expr, diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 353b475358674..e47832b07f921 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -198,6 +198,8 @@ fn optimize_plan( window_expr, input, // FIXME implement next + // filter_by_expr, + // FIXME implement next // partition_by_expr, // FIXME implement next // order_by_expr, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index db26e5cb40653..284ead252ac67 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -190,6 +190,8 @@ pub fn from_plan( }), }, LogicalPlan::Window { + // FIXME implement next + // filter_by_expr, // FIXME implement next // partition_by_expr, // FIXME implement next From a1eae864926a6acfeeebe995a12de4ad725ea869 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 23:36:15 +0800 Subject: [PATCH 11/21] enrich unit test --- .../src/physical_plan/window_functions.rs | 101 +++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index e6267250bdad1..87767a6da4b25 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -178,7 +178,36 @@ fn signature(fun: &WindowFunction) -> Signature { #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::{DataType, Field}; + + #[test] + fn test_window_function_from_str_to_str_round_trip_eq() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = WindowFunction::from_str(name)?; + assert_eq!(fun.to_string(), name.to_uppercase()); + + let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; + assert_eq!(fun, fun2); + } + Ok(()) + } #[test] fn test_window_function_from_str() -> Result<()> { @@ -198,6 +227,76 @@ mod tests { WindowFunction::from_str("cum_dist")?, WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) ); + assert_eq!( + WindowFunction::from_str("first_value")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue) + ); + assert_eq!( + WindowFunction::from_str("LAST_value")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue) + ); + assert_eq!( + WindowFunction::from_str("LAG")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag) + ); + assert_eq!( + WindowFunction::from_str("LEAD")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead) + ); + Ok(()) + } + + #[test] + fn test_count_return_type() -> Result<()> { + let fun = WindowFunction::from_str("count")?; + let observed = return_type(&fun, &[DataType::Utf8])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_first_value_return_type() -> Result<()> { + let fun = WindowFunction::from_str("first_value")?; + let observed = return_type(&fun, &[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + Ok(()) + } + + #[test] + fn test_last_value_return_type() -> Result<()> { + let fun = WindowFunction::from_str("last_value")?; + let observed = return_type(&fun, &[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + Ok(()) + } + + #[test] + fn test_lead_return_type() -> Result<()> { + let fun = WindowFunction::from_str("lead")?; + let observed = return_type(&fun, &[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + Ok(()) + } + + #[test] + fn test_lag_return_type() -> Result<()> { + let fun = WindowFunction::from_str("lag")?; + let observed = return_type(&fun, &[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + Ok(()) + } + + #[test] + fn test_cume_dist_return_type() -> Result<()> { + let fun = WindowFunction::from_str("cume_dist")?; + let observed = return_type(&fun, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + Ok(()) } } From f5e64de7192a1916df78a4c2fbab7d471c906720 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 20 May 2021 23:41:36 +0800 Subject: [PATCH 12/21] update --- datafusion/src/physical_plan/window_functions.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 87767a6da4b25..d8486989cbb50 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -40,9 +40,12 @@ pub enum WindowFunction { impl FromStr for WindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { - if let Ok(aggregate) = AggregateFunction::from_str(name) { + let name = name.to_lowercase(); + if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { Ok(WindowFunction::AggregateFunction(aggregate)) - } else if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name) { + } else if let Ok(built_in_function) = + BuiltInWindowFunction::from_str(name.as_str()) + { Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) } else { Err(DataFusionError::Plan(format!( @@ -180,7 +183,7 @@ mod tests { use super::*; #[test] - fn test_window_function_from_str_to_str_round_trip_eq() -> Result<()> { + fn test_window_function_case_insensitive() -> Result<()> { let names = vec![ "row_number", "rank", @@ -201,8 +204,6 @@ mod tests { ]; for name in names { let fun = WindowFunction::from_str(name)?; - assert_eq!(fun.to_string(), name.to_uppercase()); - let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; assert_eq!(fun, fun2); } From c36c04abf06c74d016597983bf3d3a2a5b5cbdd5 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 00:07:54 +0800 Subject: [PATCH 13/21] add more tests --- datafusion/src/physical_plan/windows.rs | 2 +- datafusion/src/sql/planner.rs | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index ffc65ee3112f6..bdd25d69fd553 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -59,7 +59,7 @@ pub fn create_window_expr( })), WindowFunction::BuiltInWindowFunction(fun) => { Err(DataFusionError::NotImplemented(format!( - "window funtion with {:?} not implemented", + "window function with {:?} not implemented", fun ))) } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 029aba90baea8..e383f0d4fb441 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -2720,6 +2720,16 @@ mod tests { quick_test(sql, expected); } + #[test] + fn empty_over_multiple() { + let sql = "SELECT order_id, MAX(qty) OVER (), CUMe_dist(qty), lag(qty) OVER () from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty Multiply Float64(1.1))\ + \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[], orderBy=[]\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn over_partition_by_not_supported() { let sql = @@ -2731,6 +2741,16 @@ mod tests { ); } + #[test] + fn over_order_by_not_supported() { + let sql = "SELECT order_id, MAX(delivered) OVER (order BY order_id) from orders"; + let err = logical_plan(sql).expect_err("query should have failed"); + assert_eq!( + "NotImplemented(\"Unsupported OVER clause (PARTITION BY order_id)\")", + format!("{:?}", err) + ); + } + #[test] fn only_union_all_supported() { let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders"; From 4e792e123a33fd0dcb5f701c679566b55589b0c0 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 08:05:17 +0800 Subject: [PATCH 14/21] fix test --- datafusion/src/physical_plan/window_functions.rs | 11 ++++++++++- datafusion/src/sql/planner.rs | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index d8486989cbb50..02ef3f15eaada 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -49,7 +49,7 @@ impl FromStr for WindowFunction { Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) } else { Err(DataFusionError::Plan(format!( - "There is no built-in function named {}", + "There is no window function named {}", name ))) } @@ -292,6 +292,15 @@ mod tests { Ok(()) } + #[test] + fn test_nth_value_return_type() -> Result<()> { + let fun = WindowFunction::from_str("nth_value")?; + let observed = return_type(&fun, &[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + Ok(()) + } + #[test] fn test_cume_dist_return_type() -> Result<()> { let fun = WindowFunction::from_str("cume_dist")?; diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index e383f0d4fb441..df7e0ce982fcd 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -2746,7 +2746,7 @@ mod tests { let sql = "SELECT order_id, MAX(delivered) OVER (order BY order_id) from orders"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "NotImplemented(\"Unsupported OVER clause (PARTITION BY order_id)\")", + "NotImplemented(\"Unsupported OVER clause (ORDER BY order_id)\")", format!("{:?}", err) ); } From 880b94f6e27df61b4d3877366f71a51b9b2f5d5d Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 08:24:00 +0800 Subject: [PATCH 15/21] fix unit test --- .../src/physical_plan/window_functions.rs | 24 ++++++++++++++++--- datafusion/src/sql/planner.rs | 7 +++--- datafusion/tests/sql.rs | 14 +++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 02ef3f15eaada..91e1f379f5f09 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -225,7 +225,7 @@ mod tests { WindowFunction::AggregateFunction(AggregateFunction::Avg) ); assert_eq!( - WindowFunction::from_str("cum_dist")?, + WindowFunction::from_str("cume_dist")?, WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) ); assert_eq!( @@ -253,6 +253,9 @@ mod tests { let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::UInt64, observed); + let observed = return_type(&fun, &[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + Ok(()) } @@ -262,6 +265,9 @@ mod tests { let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); + let observed = return_type(&fun, &[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + Ok(()) } @@ -271,6 +277,9 @@ mod tests { let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); + let observed = return_type(&fun, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + Ok(()) } @@ -280,6 +289,9 @@ mod tests { let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); + let observed = return_type(&fun, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + Ok(()) } @@ -289,22 +301,28 @@ mod tests { let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); + let observed = return_type(&fun, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + Ok(()) } #[test] fn test_nth_value_return_type() -> Result<()> { let fun = WindowFunction::from_str("nth_value")?; - let observed = return_type(&fun, &[DataType::Utf8])?; + let observed = return_type(&fun, &[DataType::Utf8, DataType::UInt64])?; assert_eq!(DataType::Utf8, observed); + let observed = return_type(&fun, &[DataType::Float64, DataType::UInt64])?; + assert_eq!(DataType::Float64, observed); + Ok(()) } #[test] fn test_cume_dist_return_type() -> Result<()> { let fun = WindowFunction::from_str("cume_dist")?; - let observed = return_type(&fun, &[DataType::Float64])?; + let observed = return_type(&fun, &[])?; assert_eq!(DataType::Float64, observed); Ok(()) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index df7e0ce982fcd..a3027e589985e 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -2722,10 +2722,11 @@ mod tests { #[test] fn empty_over_multiple() { - let sql = "SELECT order_id, MAX(qty) OVER (), CUMe_dist(qty), lag(qty) OVER () from orders"; + let sql = + "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders"; let expected = "\ - Projection: #order_id, #MAX(qty Multiply Float64(1.1))\ - \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[], orderBy=[]\ + Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[], orderBy=[]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 17e0f13609a38..5b56e76baca89 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -797,6 +797,20 @@ async fn csv_query_count() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_window_with_empty_over() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT count(c12) over () FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + // FIXME: so far the WindowAggExec is not implemented + // and the current behavior is to return empty result + // when it is done this test shall be updated + let expected: Vec> = vec![]; + assert_eq!(expected, actual); + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> { let mut ctx = ExecutionContext::new(); From bc2271d58fd4a9a9cc96126f8abcd6e8f10272ca Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 10:04:29 +0800 Subject: [PATCH 16/21] fix error --- datafusion/tests/sql.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 5b56e76baca89..d23456fed9989 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -802,12 +802,10 @@ async fn csv_query_window_with_empty_over() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT count(c12) over () FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; // FIXME: so far the WindowAggExec is not implemented - // and the current behavior is to return empty result - // when it is done this test shall be updated - let expected: Vec> = vec![]; - assert_eq!(expected, actual); + // and the current behavior is to throw not implemented exception + let plan = ctx.create_logical_plan(&sql); + assert!(plan.is_err()); Ok(()) } From 1ecae8f6cbc6c1898ccf0b38b1e596b6c2e9bb46 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 12:27:26 +0800 Subject: [PATCH 17/21] fix unit test --- datafusion/tests/sql.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index d23456fed9989..f746932e2a0ef 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -804,8 +804,8 @@ async fn csv_query_window_with_empty_over() -> Result<()> { let sql = "SELECT count(c12) over () FROM aggregate_test_100"; // FIXME: so far the WindowAggExec is not implemented // and the current behavior is to throw not implemented exception - let plan = ctx.create_logical_plan(&sql); - assert!(plan.is_err()); + let result = execute(&mut ctx, sql); + assert!(result.is_err()); Ok(()) } From 5d96e525f587fbfaf3e5e9762c9bb10315fcbc3a Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 14:16:16 +0800 Subject: [PATCH 18/21] fix unit test --- datafusion/tests/sql.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index f746932e2a0ef..927c472cfea3e 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -804,7 +804,11 @@ async fn csv_query_window_with_empty_over() -> Result<()> { let sql = "SELECT count(c12) over () FROM aggregate_test_100"; // FIXME: so far the WindowAggExec is not implemented // and the current behavior is to throw not implemented exception - let result = execute(&mut ctx, sql); + + let plan = ctx.create_logical_plan(&sql)?; + let plan = ctx.optimize(&plan)?; + let plan = ctx.create_physical_plan(&plan)?; + let result = collect(plan).await; assert!(result.is_err()); Ok(()) } From bb57c762b0a1fabc35e207e681bca2bfff7fcf01 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 14:23:34 +0800 Subject: [PATCH 19/21] use upper case --- .../src/physical_plan/window_functions.rs | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 91e1f379f5f09..65d5373d54f47 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -58,8 +58,19 @@ impl FromStr for WindowFunction { impl fmt::Display for BuiltInWindowFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // uppercase of the debug. - write!(f, "{}", format!("{:?}", self).to_uppercase()) + match self { + BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"), + BuiltInWindowFunction::Rank => write!(f, "RANK"), + BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"), + BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"), + BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"), + BuiltInWindowFunction::Ntile => write!(f, "NTILE"), + BuiltInWindowFunction::Lag => write!(f, "LAG"), + BuiltInWindowFunction::Lead => write!(f, "LEAD"), + BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"), + BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"), + BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"), + } } } @@ -108,18 +119,18 @@ pub enum BuiltInWindowFunction { impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { - Ok(match name.to_lowercase().as_str() { - "row_number" => BuiltInWindowFunction::RowNumber, - "rank" => BuiltInWindowFunction::Rank, - "dense_rank" => BuiltInWindowFunction::DenseRank, - "percent_rank" => BuiltInWindowFunction::PercentRank, - "cume_dist" => BuiltInWindowFunction::CumeDist, - "ntile" => BuiltInWindowFunction::Ntile, - "lag" => BuiltInWindowFunction::Lag, - "lead" => BuiltInWindowFunction::Lead, - "first_value" => BuiltInWindowFunction::FirstValue, - "last_value" => BuiltInWindowFunction::LastValue, - "nth_value" => BuiltInWindowFunction::NthValue, + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in window function named {}", @@ -206,6 +217,7 @@ mod tests { let fun = WindowFunction::from_str(name)?; let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); } Ok(()) } From 2af2a270262ff1bc759af39153d7cd681c32dc0a Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 14:25:12 +0800 Subject: [PATCH 20/21] fix unit test --- datafusion/tests/sql.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 927c472cfea3e..94fc2b7ec251e 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -805,11 +805,9 @@ async fn csv_query_window_with_empty_over() -> Result<()> { // FIXME: so far the WindowAggExec is not implemented // and the current behavior is to throw not implemented exception - let plan = ctx.create_logical_plan(&sql)?; - let plan = ctx.optimize(&plan)?; - let plan = ctx.create_physical_plan(&plan)?; - let result = collect(plan).await; - assert!(result.is_err()); + let result = execute(&mut ctx, sql).await; + let expected: Vec> = vec![]; + assert_eq!(result, expected); Ok(()) } From 694115190463f48e85063fef828af63b554086be Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 14:56:00 +0800 Subject: [PATCH 21/21] comment out test --- datafusion/tests/sql.rs | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 94fc2b7ec251e..e68c53b251e6c 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -797,19 +797,20 @@ async fn csv_query_count() -> Result<()> { Ok(()) } -#[tokio::test] -async fn csv_query_window_with_empty_over() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; - let sql = "SELECT count(c12) over () FROM aggregate_test_100"; - // FIXME: so far the WindowAggExec is not implemented - // and the current behavior is to throw not implemented exception - - let result = execute(&mut ctx, sql).await; - let expected: Vec> = vec![]; - assert_eq!(result, expected); - Ok(()) -} +// FIXME uncomment this when exec is done +// #[tokio::test] +// async fn csv_query_window_with_empty_over() -> Result<()> { +// let mut ctx = ExecutionContext::new(); +// register_aggregate_csv(&mut ctx)?; +// let sql = "SELECT count(c12) over () FROM aggregate_test_100"; +// // FIXME: so far the WindowAggExec is not implemented +// // and the current behavior is to throw not implemented exception + +// let result = execute(&mut ctx, sql).await; +// let expected: Vec> = vec![]; +// assert_eq!(result, expected); +// Ok(()) +// } #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> {