diff --git a/Cargo.lock b/Cargo.lock index 08198cc49b72c..5a9bed2f3cc13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2577,6 +2577,7 @@ name = "datafusion-proto" version = "51.0.0" dependencies = [ "arrow", + "async-trait", "chrono", "datafusion", "datafusion-catalog", diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index d442307e9488e..57b124a618c6b 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -100,6 +100,14 @@ impl AsyncFuncExec { input.boundedness(), )) } + + pub fn async_exprs(&self) -> &[Arc] { + &self.async_exprs + } + + pub fn input(&self) -> &Arc { + &self.input + } } impl DisplayAs for AsyncFuncExec { diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 920e277b8ccc0..b00bd0dcc6bfd 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -73,6 +73,7 @@ serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +async-trait = { workspace = true } datafusion = { workspace = true, default-features = false, features = [ "sql", "datetime_expressions", diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 789176862bf00..cb2eb515ca7d9 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -748,6 +748,7 @@ message PhysicalPlanNode { GenerateSeriesNode generate_series = 33; SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; + AsyncFuncExecNode async_func = 36; } } @@ -1393,3 +1394,9 @@ message SortMergeJoinExecNode { repeated SortExprNode sort_options = 6; datafusion_common.NullEquality null_equality = 7; } + +message AsyncFuncExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalExprNode async_exprs = 2; + repeated string async_expr_names = 3; +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 230bfa495a4b3..9b25b7b80366b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1260,6 +1260,133 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { deserializer.deserialize_struct("datafusion.AnalyzedLogicalPlanType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for AsyncFuncExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.async_exprs.is_empty() { + len += 1; + } + if !self.async_expr_names.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.AsyncFuncExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.async_exprs.is_empty() { + struct_ser.serialize_field("asyncExprs", &self.async_exprs)?; + } + if !self.async_expr_names.is_empty() { + struct_ser.serialize_field("asyncExprNames", &self.async_expr_names)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for AsyncFuncExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "async_exprs", + "asyncExprs", + "async_expr_names", + "asyncExprNames", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + AsyncExprs, + AsyncExprNames, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "asyncExprs" | "async_exprs" => Ok(GeneratedField::AsyncExprs), + "asyncExprNames" | "async_expr_names" => Ok(GeneratedField::AsyncExprNames), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AsyncFuncExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.AsyncFuncExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut async_exprs__ = None; + let mut async_expr_names__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::AsyncExprs => { + if async_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("asyncExprs")); + } + async_exprs__ = Some(map_.next_value()?); + } + GeneratedField::AsyncExprNames => { + if async_expr_names__.is_some() { + return Err(serde::de::Error::duplicate_field("asyncExprNames")); + } + async_expr_names__ = Some(map_.next_value()?); + } + } + } + Ok(AsyncFuncExecNode { + input: input__, + async_exprs: async_exprs__.unwrap_or_default(), + async_expr_names: async_expr_names__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.AsyncFuncExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AvroScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17136,6 +17263,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::MemoryScan(v) => { struct_ser.serialize_field("memoryScan", v)?; } + physical_plan_node::PhysicalPlanType::AsyncFunc(v) => { + struct_ser.serialize_field("asyncFunc", v)?; + } } } struct_ser.end() @@ -17201,6 +17331,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin", "memory_scan", "memoryScan", + "async_func", + "asyncFunc", ]; #[allow(clippy::enum_variant_names)] @@ -17239,6 +17371,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { GenerateSeries, SortMergeJoin, MemoryScan, + AsyncFunc, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17294,6 +17427,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "generateSeries" | "generate_series" => Ok(GeneratedField::GenerateSeries), "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), + "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17552,6 +17686,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("memoryScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::MemoryScan) +; + } + GeneratedField::AsyncFunc => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("asyncFunc")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AsyncFunc) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b2d0bc7751f9b..4d16f50e14009 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1076,7 +1076,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub physical_plan_type: ::core::option::Option, } @@ -1154,6 +1154,8 @@ pub mod physical_plan_node { SortMergeJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "35")] MemoryScan(super::MemoryScanExecNode), + #[prost(message, tag = "36")] + AsyncFunc(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -2104,6 +2106,15 @@ pub struct SortMergeJoinExecNode { #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] pub null_equality: i32, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AsyncFuncExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub async_exprs: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "3")] + pub async_expr_names: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 89dd0b50650b1..6f275d2a2204c 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -100,6 +100,8 @@ use datafusion_physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_plan::async_func::AsyncFuncExec; use prost::bytes::BufMut; use prost::Message; @@ -251,6 +253,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { PhysicalPlanType::SortMergeJoin(sort_join) => { self.try_into_sort_join(sort_join, ctx, extension_codec) } + PhysicalPlanType::AsyncFunc(async_func) => { + self.try_into_async_func_physical_plan(async_func, ctx, extension_codec) + } } } @@ -462,6 +467,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_async_func_exec( + exec, + extension_codec, + ); + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -1972,6 +1984,44 @@ impl protobuf::PhysicalPlanNode { Ok(Arc::new(CooperativeExec::new(input))) } + fn try_into_async_func_physical_plan( + &self, + async_func: &protobuf::AsyncFuncExecNode, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&async_func.input, ctx, extension_codec)?; + + if async_func.async_exprs.len() != async_func.async_expr_names.len() { + return internal_err!( + "AsyncFuncExecNode async_exprs length does not match async_expr_names" + ); + } + + let async_exprs = async_func + .async_exprs + .iter() + .zip(async_func.async_expr_names.iter()) + .map(|(expr, name)| { + let physical_expr = parse_physical_expr( + expr, + ctx, + input.schema().as_ref(), + extension_codec, + )?; + + Ok(Arc::new(AsyncFuncExpr::try_new( + name.clone(), + physical_expr, + input.schema().as_ref(), + )?)) + }) + .collect::>>()?; + + Ok(Arc::new(AsyncFuncExec::try_new(async_exprs, input)?)) + } + fn try_from_explain_exec( exec: &ExplainExec, _extension_codec: &dyn PhysicalExtensionCodec, @@ -3222,6 +3272,34 @@ impl protobuf::PhysicalPlanNode { Ok(None) } + + fn try_from_async_func_exec( + exec: &AsyncFuncExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + Arc::clone(exec.input()), + extension_codec, + )?; + + let mut async_exprs = vec![]; + let mut async_expr_names = vec![]; + + for async_expr in exec.async_exprs() { + async_exprs.push(serialize_physical_expr(&async_expr.func, extension_codec)?); + async_expr_names.push(async_expr.name.clone()) + } + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::AsyncFunc(Box::new( + protobuf::AsyncFuncExecNode { + input: Some(Box::new(input)), + async_exprs, + async_expr_names, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { @@ -3439,7 +3517,6 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { if let Some(field) = node { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 0bcdd610c26ff..438e65f60b001 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -107,9 +107,11 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, DataFusionError, NullEquality, Result, UnnestOptions, }; +use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, WindowUDF, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -2263,3 +2265,65 @@ async fn roundtrip_listing_table_with_schema_metadata() -> Result<()> { roundtrip_test(plan) } + +#[tokio::test] +async fn roundtrip_async_func_exec() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestAsyncUDF { + signature: Signature, + } + + impl TestAsyncUDF { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int64], Volatility::Volatile), + } + } + } + + impl ScalarUDFImpl for TestAsyncUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "test_async_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("Must call from `invoke_async_with_args`") + } + } + + #[async_trait::async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDF { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + Ok(args.args[0].clone()) + } + } + + let ctx = SessionContext::new(); + let async_udf = AsyncScalarUDF::new(Arc::new(TestAsyncUDF::new())); + ctx.register_udf(async_udf.into_scalar_udf()); + + let physical_plan = ctx + .sql("select test_async_udf(1)") + .await? + .create_physical_plan() + .await?; + + roundtrip_test_with_context(physical_plan, &ctx)?; + + Ok(()) +}