From 603a8a02d2fdae89f529dd4d26eb891510583436 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 14 Nov 2024 17:53:38 +0530 Subject: [PATCH 01/10] Adds roundtrip physical plan test --- .../tests/cases/roundtrip_physical_plan.rs | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index aab63dd8bd66a..7ca48d43ae695 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -47,9 +47,10 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::functions_window::row_number::row_number_udwf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::Literal; -use datafusion::physical_expr::window::SlidingAggregateWindowExpr; +use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{ LexOrdering, LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr, }; @@ -73,8 +74,13 @@ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; -use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowAggExec}; -use datafusion::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr, Statistics}; +use datafusion::physical_plan::windows::{ + create_udwf_window_expr, BoundedWindowAggExec, PlainAggregateWindowExpr, + WindowAggExec, +}; +use datafusion::physical_plan::{ + ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, Statistics, +}; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_common::config::TableParquetOptions; @@ -263,6 +269,38 @@ fn roundtrip_nested_loop_join() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_built_in_window() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let built_in_window_expr = Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &schema, + "row_number() PARTITION BY [a] ORDER BY [b] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?, + &[ + col("a", &schema)? + ], + &LexOrdering::new(vec![ + PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)), + ]), + Arc::new(WindowFrame::new(None)), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + roundtrip_test(Arc::new(BoundedWindowAggExec::try_new( + vec![built_in_window_expr], + input, + vec![col("a", &schema)?], + InputOrderMode::Sorted, + )?)) +} #[test] fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); From aca808539620b9f09923ec251e1a663a235487f7 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 14 Nov 2024 18:02:29 +0530 Subject: [PATCH 02/10] Adds enum for udwf to `WindowFunction` --- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 13 +++++++++++++ datafusion/proto/src/generated/prost.rs | 4 +++- datafusion/proto/src/physical_plan/from_proto.rs | 3 +++ 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 998c172f6ef4d..92e5f59be20aa 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -868,6 +868,7 @@ message PhysicalWindowExprNode { oneof window_function { BuiltInWindowFunction built_in_function = 2; string user_defined_aggr_function = 3; + string user_defined_window_function = 10; } repeated PhysicalExprNode args = 4; repeated PhysicalExprNode partition_by = 5; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b5447ad6f473b..ea31cde91c02d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16399,6 +16399,9 @@ impl serde::Serialize for PhysicalWindowExprNode { physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(v) => { struct_ser.serialize_field("userDefinedAggrFunction", v)?; } + physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(v) => { + struct_ser.serialize_field("userDefinedWindowFunction", v)?; + } } } struct_ser.end() @@ -16425,6 +16428,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "builtInFunction", "user_defined_aggr_function", "userDefinedAggrFunction", + "user_defined_window_function", + "userDefinedWindowFunction", ]; #[allow(clippy::enum_variant_names)] @@ -16437,6 +16442,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { FunDefinition, BuiltInFunction, UserDefinedAggrFunction, + UserDefinedWindowFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16466,6 +16472,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), + "userDefinedWindowFunction" | "user_defined_window_function" => Ok(GeneratedField::UserDefinedWindowFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16544,6 +16551,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedAggrFunction); } + GeneratedField::UserDefinedWindowFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedWindowFunction")); + } + window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedWindowFunction); + } } } Ok(PhysicalWindowExprNode { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 40bc8bd9eaf52..1db91aab72b04 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1267,7 +1267,7 @@ pub struct PhysicalWindowExprNode { pub name: ::prost::alloc::string::String, #[prost(bytes = "vec", optional, tag = "9")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "2, 3")] + #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "2, 3, 10")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, >, @@ -1280,6 +1280,8 @@ pub mod physical_window_expr_node { BuiltInFunction(i32), #[prost(string, tag = "3")] UserDefinedAggrFunction(::prost::alloc::string::String), + #[prost(string, tag = "10")] + UserDefinedWindowFunction(::prost::alloc::string::String), } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 31b59c2a94573..8b05bcfe40a93 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -161,6 +161,9 @@ pub fn parse_physical_window_expr( None => registry.udaf(udaf_name)? }) } + protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { + WindowFunctionDefinition::WindowUDF(registry.udwf(udwf_name)?) + } } } else { return Err(proto_error("Missing required field in protobuf")); From 7d0176ab96a056e1866b64ebf1cbc61ce6ccc3f7 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 14 Nov 2024 18:09:27 +0530 Subject: [PATCH 03/10] initial fix for serializing udwf --- datafusion/physical-plan/src/windows/mod.rs | 8 +++++++- .../proto/src/physical_plan/to_proto.rs | 19 +++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d2eb14638c71c..43253e98b7670 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -198,7 +198,7 @@ pub fn create_udwf_window_expr( /// Implements [`BuiltInWindowFunctionExpr`] for [`WindowUDF`] #[derive(Clone, Debug)] -struct WindowUDFExpr { +pub struct WindowUDFExpr { fun: Arc, args: Vec>, /// Display name @@ -213,6 +213,12 @@ struct WindowUDFExpr { ignore_nulls: bool, } +impl WindowUDFExpr { + pub fn fun(&self) -> &Arc { + &self.fun + } +} + impl BuiltInWindowFunctionExpr for WindowUDFExpr { fn as_any(&self) -> &dyn std::any::Any { self diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 60dcd650191d6..205e6aae7c816 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -19,14 +19,14 @@ use std::sync::Arc; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::physical_expr::window::SlidingAggregateWindowExpr; +use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; -use datafusion::physical_plan::windows::PlainAggregateWindowExpr; +use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion::{ datasource::{ @@ -120,6 +120,21 @@ pub fn serialize_physical_window_expr( window_frame, codec, )? + } else if let Some(built_in_window_expr) = expr.downcast_ref::() { + if let Some(expr) = built_in_window_expr + .get_built_in_func_expr() + .as_any() + .downcast_ref::() + { + ( + physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( + expr.fun().name().to_string(), + ), + None, + ) + } else { + return not_impl_err!("WindowExpr not supported: {window_expr:?}"); + } } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; From e579ec04e63982cb6ac3b2379ebb263ba1b41d60 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 14 Nov 2024 19:00:02 +0530 Subject: [PATCH 04/10] Revives deleted test --- .../tests/cases/roundtrip_physical_plan.rs | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7ca48d43ae695..4814f68e21822 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -47,6 +47,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::functions_window::nth_value::nth_value_udwf; use datafusion::functions_window::row_number::row_number_udwf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::Literal; @@ -307,6 +308,29 @@ fn roundtrip_window() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let window_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameBound::CurrentRow, + ); + + let nth_value_window = + create_udwf_window_expr(&nth_value_udwf(), &[col("a", &schema)?, lit(2)], schema.as_ref(), "NTH_VALUE(a,2) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), false)?; + let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( + nth_value_window, + &[col("b", &schema)?], + &LexOrdering { + inner: vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + }, + Arc::new(window_frame), + )); + let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( AggregateExprBuilder::new( avg_udaf(), @@ -344,7 +368,11 @@ fn roundtrip_window() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( - vec![plain_aggr_window_expr, sliding_aggr_window_expr], + vec![ + plain_aggr_window_expr, + sliding_aggr_window_expr, + builtin_window_expr, + ], input, vec![col("b", &schema)?], )?)) From 46f4212ee7b61ff4ff21b66646e2d5167b490813 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 14 Nov 2024 22:16:19 +0530 Subject: [PATCH 05/10] Adds codec methods for physical plan --- datafusion/proto/src/physical_plan/from_proto.rs | 5 ++++- datafusion/proto/src/physical_plan/mod.rs | 10 +++++++++- datafusion/proto/src/physical_plan/to_proto.rs | 6 ++++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 8b05bcfe40a93..1f5f552343e09 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -162,7 +162,10 @@ pub fn parse_physical_window_expr( }) } protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { - WindowFunctionDefinition::WindowUDF(registry.udwf(udwf_name)?) + WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { + Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, + None => registry.udwf(udwf_name)? + }) } } } else { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 64e462d1695fd..292ce13d0eded 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -64,7 +64,7 @@ use datafusion::physical_plan::{ ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::common::{byte_to_string, str_to_byte}; use crate::physical_plan::from_proto::{ @@ -2119,6 +2119,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided for window function {name}") + } + + fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 205e6aae7c816..2a2aef01b2f78 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -68,7 +68,7 @@ pub fn serialize_physical_aggr_expr( ordering_req, distinct: aggr_expr.is_distinct(), ignore_nulls: aggr_expr.ignore_nulls(), - fun_definition: (!buf.is_empty()).then_some(buf) + fun_definition: (!buf.is_empty()).then_some(buf), }, )), }) @@ -126,11 +126,13 @@ pub fn serialize_physical_window_expr( .as_any() .downcast_ref::() { + let mut buf = Vec::new(); + codec.try_encode_udwf(expr.fun(), &mut buf)?; ( physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( expr.fun().name().to_string(), ), - None, + (!buf.is_empty()).then_some(buf), ) } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); From 880b3f2433058a5c498c270ebd3157ceb9a9ac64 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 14 Nov 2024 22:22:26 +0530 Subject: [PATCH 06/10] Rewrite error message --- datafusion/proto/src/physical_plan/to_proto.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 2a2aef01b2f78..7d9a524af8288 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -135,7 +135,9 @@ pub fn serialize_physical_window_expr( (!buf.is_empty()).then_some(buf), ) } else { - return not_impl_err!("WindowExpr not supported: {window_expr:?}"); + return not_impl_err!( + "User-defined window function not supported: {window_expr:?}" + ); } } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); From 30538fbca0e5d5463f1b637eaaf1e17b62d22175 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Thu, 14 Nov 2024 22:46:45 +0530 Subject: [PATCH 07/10] Minor: rename binding + formatting fixes --- .../tests/cases/roundtrip_physical_plan.rs | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 4814f68e21822..d81f9f5c799ac 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -271,12 +271,12 @@ fn roundtrip_nested_loop_join() -> Result<()> { } #[test] -fn roundtrip_built_in_window() -> Result<()> { +fn roundtrip_udwf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let built_in_window_expr = Arc::new(BuiltInWindowExpr::new( + let udwf_expr = Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr( &row_number_udwf(), &[], @@ -296,7 +296,7 @@ fn roundtrip_built_in_window() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(BoundedWindowAggExec::try_new( - vec![built_in_window_expr], + vec![udwf_expr], input, vec![col("a", &schema)?], InputOrderMode::Sorted, @@ -315,8 +315,14 @@ fn roundtrip_window() -> Result<()> { ); let nth_value_window = - create_udwf_window_expr(&nth_value_udwf(), &[col("a", &schema)?, lit(2)], schema.as_ref(), "NTH_VALUE(a,2) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), false)?; - let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr( + &nth_value_udwf(), + &[col("a", &schema)?, + lit(2)], schema.as_ref(), + "NTH_VALUE(a, 2) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?; + let udwf_expr = Arc::new(BuiltInWindowExpr::new( nth_value_window, &[col("b", &schema)?], &LexOrdering { @@ -368,11 +374,7 @@ fn roundtrip_window() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( - vec![ - plain_aggr_window_expr, - sliding_aggr_window_expr, - builtin_window_expr, - ], + vec![plain_aggr_window_expr, sliding_aggr_window_expr, udwf_expr], input, vec![col("b", &schema)?], )?)) From 46a00e8afdf464a1b854077820e79c2c5cc36c37 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Fri, 15 Nov 2024 00:40:55 +0530 Subject: [PATCH 08/10] Extends `PhysicalExtensionCodec` for udwf --- datafusion/proto/tests/cases/mod.rs | 60 +++++++++++++- .../tests/cases/roundtrip_physical_plan.rs | 83 ++++++++++++++++++- 2 files changed, 138 insertions(+), 5 deletions(-) diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index fbb2cd8f1e832..3d4ea5ba92167 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. +use arrow::datatypes::{DataType, Field}; use std::any::Any; - -use arrow::datatypes::DataType; +use std::fmt::Debug; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility, + Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl, + Signature, Volatility, WindowUDFImpl, }; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; mod roundtrip_logical_plan; mod roundtrip_physical_plan; @@ -125,3 +128,54 @@ pub struct MyAggregateUdfNode { #[prost(string, tag = "1")] pub result: String, } + +#[derive(Debug)] +pub(in crate::cases) struct CustomUDWF { + signature: Signature, + payload: String, +} + +impl CustomUDWF { + pub fn new(payload: String) -> Self { + Self { + signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable), + payload, + } + } +} + +impl WindowUDFImpl for CustomUDWF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "custom_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> datafusion_common::Result> { + Ok(Box::new(CustomUDWFEvaluator {})) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false)) + } +} + +#[derive(Debug)] +struct CustomUDWFEvaluator; + +impl PartitionEvaluator for CustomUDWFEvaluator {} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CustomUDWFNode { + #[prost(string, tag = "1")] + pub payload: String, +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index d81f9f5c799ac..7ae1b3ec073bc 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -32,7 +32,10 @@ use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; -use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; +use crate::cases::{ + CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, + MyRegexUdfNode, +}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -94,7 +97,7 @@ use datafusion_common::{ }; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -1016,6 +1019,33 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { } Ok(()) } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "custom_udwf" { + let proto = CustomUDWFNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode custom_udwf: {err}")) + })?; + + Ok(Arc::new(WindowUDF::from(CustomUDWF::new(proto.payload)))) + } else { + not_impl_err!( + "unrecognized user-defined window function implementation, cannot decode" + ) + } + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udwf) = binding.as_any().downcast_ref::() { + let proto = CustomUDWFNode { + payload: udwf.payload.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udwf: {err:?}")) + })?; + } + Ok(()) + } } #[test] @@ -1073,6 +1103,55 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_udwf_extension_codec() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let custom_udwf = Arc::new(WindowUDF::from(CustomUDWF::new("payload".to_string()))); + let udwf = create_udwf_window_expr( + &custom_udwf, + &[col("a", &schema)?], + schema.as_ref(), + "custom_udwf(a) PARTITION BY [b] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?; + + let window_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameBound::CurrentRow, + ); + + let udwf_expr = Arc::new(BuiltInWindowExpr::new( + udwf, + &[col("b", &schema)?], + &LexOrdering { + inner: vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + }, + Arc::new(window_frame), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + let window = Arc::new(BoundedWindowAggExec::try_new( + vec![udwf_expr], + input, + vec![col("b", &schema)?], + InputOrderMode::Sorted, + )?); + + let ctx = SessionContext::new(); + roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?; + Ok(()) +} + #[test] fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let field_text = Field::new("text", DataType::Utf8, true); From 496151af960c7a71210ca92b00d9147fe6671253 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Fri, 15 Nov 2024 00:42:56 +0530 Subject: [PATCH 09/10] Minor: formatting --- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7ae1b3ec073bc..efa462aa7a855 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -305,6 +305,7 @@ fn roundtrip_udwf() -> Result<()> { InputOrderMode::Sorted, )?)) } + #[test] fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); From 74e547b3a9a227aebfde8c81b3352cdc4791037b Mon Sep 17 00:00:00 2001 From: jcsherin Date: Fri, 15 Nov 2024 00:57:50 +0530 Subject: [PATCH 10/10] Restricts visibility to tests --- datafusion/proto/tests/cases/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index 3d4ea5ba92167..4d69ca075483b 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -175,7 +175,7 @@ struct CustomUDWFEvaluator; impl PartitionEvaluator for CustomUDWFEvaluator {} #[derive(Clone, PartialEq, ::prost::Message)] -pub struct CustomUDWFNode { +pub(in crate::cases) struct CustomUDWFNode { #[prost(string, tag = "1")] pub payload: String, }