From 1369b361856a36ccae8107cfb4f1591b8bfc893b Mon Sep 17 00:00:00 2001 From: David Stancu Date: Fri, 5 Dec 2025 10:24:11 -0500 Subject: [PATCH 1/5] support serialization of asyncfuncexec --- datafusion/physical-plan/src/async_func.rs | 8 ++ datafusion/proto/proto/datafusion.proto | 7 + datafusion/proto/src/generated/pbjson.rs | 141 +++++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 13 +- datafusion/proto/src/physical_plan/mod.rs | 71 +++++++++++ 5 files changed, 239 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index d442307e9488e..c5e67cead16c7 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -100,6 +100,14 @@ impl AsyncFuncExec { input.boundedness(), )) } + + pub fn async_exprs(&self) -> Vec> { + self.async_exprs.clone() + } + + pub fn input(&self) -> Arc { + Arc::clone(&self.input) + } } impl DisplayAs for AsyncFuncExec { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 789176862bf00..cb2eb515ca7d9 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -748,6 +748,7 @@ message PhysicalPlanNode { GenerateSeriesNode generate_series = 33; SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; + AsyncFuncExecNode async_func = 36; } } @@ -1393,3 +1394,9 @@ message SortMergeJoinExecNode { repeated SortExprNode sort_options = 6; datafusion_common.NullEquality null_equality = 7; } + +message AsyncFuncExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalExprNode async_exprs = 2; + repeated string async_expr_names = 3; +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 230bfa495a4b3..9b25b7b80366b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1260,6 +1260,133 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { deserializer.deserialize_struct("datafusion.AnalyzedLogicalPlanType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for AsyncFuncExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.async_exprs.is_empty() { + len += 1; + } + if !self.async_expr_names.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.AsyncFuncExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.async_exprs.is_empty() { + struct_ser.serialize_field("asyncExprs", &self.async_exprs)?; + } + if !self.async_expr_names.is_empty() { + struct_ser.serialize_field("asyncExprNames", &self.async_expr_names)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for AsyncFuncExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "async_exprs", + "asyncExprs", + "async_expr_names", + "asyncExprNames", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + AsyncExprs, + AsyncExprNames, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "asyncExprs" | "async_exprs" => Ok(GeneratedField::AsyncExprs), + "asyncExprNames" | "async_expr_names" => Ok(GeneratedField::AsyncExprNames), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AsyncFuncExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.AsyncFuncExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut async_exprs__ = None; + let mut async_expr_names__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::AsyncExprs => { + if async_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("asyncExprs")); + } + async_exprs__ = Some(map_.next_value()?); + } + GeneratedField::AsyncExprNames => { + if async_expr_names__.is_some() { + return Err(serde::de::Error::duplicate_field("asyncExprNames")); + } + async_expr_names__ = Some(map_.next_value()?); + } + } + } + Ok(AsyncFuncExecNode { + input: input__, + async_exprs: async_exprs__.unwrap_or_default(), + async_expr_names: async_expr_names__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.AsyncFuncExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AvroScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17136,6 +17263,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::MemoryScan(v) => { struct_ser.serialize_field("memoryScan", v)?; } + physical_plan_node::PhysicalPlanType::AsyncFunc(v) => { + struct_ser.serialize_field("asyncFunc", v)?; + } } } struct_ser.end() @@ -17201,6 +17331,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin", "memory_scan", "memoryScan", + "async_func", + "asyncFunc", ]; #[allow(clippy::enum_variant_names)] @@ -17239,6 +17371,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { GenerateSeries, SortMergeJoin, MemoryScan, + AsyncFunc, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17294,6 +17427,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "generateSeries" | "generate_series" => Ok(GeneratedField::GenerateSeries), "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), + "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17552,6 +17686,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("memoryScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::MemoryScan) +; + } + GeneratedField::AsyncFunc => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("asyncFunc")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AsyncFunc) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b2d0bc7751f9b..4d16f50e14009 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1076,7 +1076,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub physical_plan_type: ::core::option::Option, } @@ -1154,6 +1154,8 @@ pub mod physical_plan_node { SortMergeJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "35")] MemoryScan(super::MemoryScanExecNode), + #[prost(message, tag = "36")] + AsyncFunc(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -2104,6 +2106,15 @@ pub struct SortMergeJoinExecNode { #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] pub null_equality: i32, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AsyncFuncExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub async_exprs: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "3")] + pub async_expr_names: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 89dd0b50650b1..6725e85d20590 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -102,6 +102,8 @@ use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, Wind use prost::bytes::BufMut; use prost::Message; +use datafusion::physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion::physical_plan::async_func::AsyncFuncExec; pub mod from_proto; pub mod to_proto; @@ -251,6 +253,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { PhysicalPlanType::SortMergeJoin(sort_join) => { self.try_into_sort_join(sort_join, ctx, extension_codec) } + PhysicalPlanType::AsyncFunc(async_func) => { + self.try_into_async_func_physical_plan( + async_func, + ctx, + runtime, + extension_codec, + ) + } } } @@ -1972,6 +1982,39 @@ impl protobuf::PhysicalPlanNode { Ok(Arc::new(CooperativeExec::new(input))) } + fn try_into_async_func_physical_plan( + &self, + async_func: &protobuf::AsyncFuncExecNode, + ctx: &SessionContext, + runtime: &RuntimeEnv, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&async_func.input, ctx, runtime, extension_codec)?; + + let async_exprs = async_func + .async_exprs + .iter() + .zip(async_func.async_expr_names.iter()) + .map(|(expr, name)| { + let physical_expr = parse_physical_expr( + expr, + ctx, + input.schema().as_ref(), + extension_codec, + )?; + + Ok(Arc::new(AsyncFuncExpr::try_new( + name.clone(), + physical_expr, + input.schema().as_ref(), + )?)) + }) + .collect::>>()?; + + Ok(Arc::new(AsyncFuncExec::try_new(async_exprs, input)?)) + } + fn try_from_explain_exec( exec: &ExplainExec, _extension_codec: &dyn PhysicalExtensionCodec, @@ -3222,6 +3265,34 @@ impl protobuf::PhysicalPlanNode { Ok(None) } + + fn try_from_async_func_exec( + exec: &AsyncFuncExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input(), + extension_codec, + )?; + + let mut async_exprs = vec![]; + let mut async_expr_names = vec![]; + + for async_expr in exec.async_exprs() { + async_exprs.push(serialize_physical_expr(&async_expr.func, extension_codec)?); + async_expr_names.push(async_expr.name.clone()) + } + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::AsyncFunc(Box::new( + protobuf::AsyncFuncExecNode { + input: Some(Box::new(input)), + async_exprs, + async_expr_names, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { From 21011e3d6886d58da3420f4659f0a286011d6e1c Mon Sep 17 00:00:00 2001 From: David Stancu Date: Fri, 5 Dec 2025 11:21:49 -0500 Subject: [PATCH 2/5] hook up AsExecutionPlan, roundtrip test --- Cargo.lock | 1 + datafusion/proto/Cargo.toml | 1 + datafusion/proto/src/physical_plan/mod.rs | 19 +++- .../tests/cases/roundtrip_physical_plan.rs | 103 +++++++++++++++++- 4 files changed, 116 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6d54d234e023d..54337a83bd729 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2577,6 +2577,7 @@ name = "datafusion-proto" version = "51.0.0" dependencies = [ "arrow", + "async-trait", "chrono", "datafusion", "datafusion-catalog", diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 920e277b8ccc0..0816355a81f9e 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -49,6 +49,7 @@ avro = ["datafusion-datasource-avro", "datafusion-common/avro"] [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 6725e85d20590..e90a3e6aadda7 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -100,10 +100,11 @@ use datafusion_physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr}; -use prost::bytes::BufMut; -use prost::Message; use datafusion::physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion::physical_plan::async_func::AsyncFuncExec; +use datafusion::prelude::SessionContext; +use prost::bytes::BufMut; +use prost::Message; pub mod from_proto; pub mod to_proto; @@ -253,14 +254,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { PhysicalPlanType::SortMergeJoin(sort_join) => { self.try_into_sort_join(sort_join, ctx, extension_codec) } - PhysicalPlanType::AsyncFunc(async_func) => { - self.try_into_async_func_physical_plan( + PhysicalPlanType::AsyncFunc(async_func) => self + .try_into_async_func_physical_plan( async_func, ctx, runtime, extension_codec, - ) - } + ), } } @@ -472,6 +472,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_async_func_exec( + exec, + extension_codec, + ); + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 0bcdd610c26ff..cd2412a9a1c33 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -107,9 +107,11 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, DataFusionError, NullEquality, Result, UnnestOptions, }; +use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, WindowUDF, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -2263,3 +2265,100 @@ async fn roundtrip_listing_table_with_schema_metadata() -> Result<()> { roundtrip_test(plan) } + +#[tokio::test] +async fn roundtrip_async_func_exec() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestAsyncUDF { + signature: Signature, + } + + impl TestAsyncUDF { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int64], Volatility::Volatile), + } + } + } + + impl ScalarUDFImpl for TestAsyncUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "test_async_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("Must call from `invoke_async_with_args`") + } + } + + #[async_trait::async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDF { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + Ok(args.args[0].clone()) + } + } + + #[derive(Debug)] + struct TestAsyncUDFPhysicalCodec {} + impl PhysicalExtensionCodec for TestAsyncUDFPhysicalCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!( + "TestAsyncUDFPhysicalCodec should not be called to decode an extension" + ) + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + Ok(()) + } + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + if name == "test_async_udf" { + Ok(Arc::new(TestAsyncUDF::new().into())) + } else { + not_impl_err!("TestAsyncUDFPhysicalCodec unrecognized UDF {name}") + } + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } + } + + let ctx = SessionContext::new(); + let async_udf = AsyncScalarUDF::new(Arc::new(TestAsyncUDF::new())); + ctx.register_udf(async_udf.into_scalar_udf()); + + let logical_plan = ctx + .state() + .create_logical_plan("select test_async_udf(1)") + .await?; + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + let codec = TestAsyncUDFPhysicalCodec {}; + roundtrip_test_and_return(physical_plan, &ctx, &codec)?; + + Ok(()) +} From 6336d478e74bc6a21ec998013da0f956474bba19 Mon Sep 17 00:00:00 2001 From: David Stancu Date: Fri, 5 Dec 2025 11:32:50 -0500 Subject: [PATCH 3/5] df 51 changes --- datafusion/proto/src/physical_plan/mod.rs | 21 +++++++------------ .../tests/cases/roundtrip_physical_plan.rs | 2 +- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e90a3e6aadda7..a1c0887a39d12 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -100,9 +100,8 @@ use datafusion_physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr}; -use datafusion::physical_expr::async_scalar_function::AsyncFuncExpr; -use datafusion::physical_plan::async_func::AsyncFuncExec; -use datafusion::prelude::SessionContext; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_plan::async_func::AsyncFuncExec; use prost::bytes::BufMut; use prost::Message; @@ -254,13 +253,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { PhysicalPlanType::SortMergeJoin(sort_join) => { self.try_into_sort_join(sort_join, ctx, extension_codec) } - PhysicalPlanType::AsyncFunc(async_func) => self - .try_into_async_func_physical_plan( - async_func, - ctx, - runtime, - extension_codec, - ), + PhysicalPlanType::AsyncFunc(async_func) => { + self.try_into_async_func_physical_plan(async_func, ctx, extension_codec) + } } } @@ -1992,12 +1987,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_async_func_physical_plan( &self, async_func: &protobuf::AsyncFuncExecNode, - ctx: &SessionContext, - runtime: &RuntimeEnv, + ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let input: Arc = - into_physical_plan(&async_func.input, ctx, runtime, extension_codec)?; + into_physical_plan(&async_func.input, ctx, extension_codec)?; let async_exprs = async_func .async_exprs @@ -3517,7 +3511,6 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { if let Some(field) = node { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index cd2412a9a1c33..62f0e746734aa 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -2320,7 +2320,7 @@ async fn roundtrip_async_func_exec() -> Result<()> { &self, _buf: &[u8], _inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { not_impl_err!( "TestAsyncUDFPhysicalCodec should not be called to decode an extension" From c44ab92eadf695c813ef6e8624b6d20120ab720f Mon Sep 17 00:00:00 2001 From: David Stancu Date: Mon, 8 Dec 2025 10:11:21 -0500 Subject: [PATCH 4/5] feedback --- datafusion/physical-plan/src/async_func.rs | 8 ++++---- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/src/physical_plan/mod.rs | 8 +++++++- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 9 +++++---- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index c5e67cead16c7..57b124a618c6b 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -101,12 +101,12 @@ impl AsyncFuncExec { )) } - pub fn async_exprs(&self) -> Vec> { - self.async_exprs.clone() + pub fn async_exprs(&self) -> &[Arc] { + &self.async_exprs } - pub fn input(&self) -> Arc { - Arc::clone(&self.input) + pub fn input(&self) -> &Arc { + &self.input } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 0816355a81f9e..b00bd0dcc6bfd 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -49,7 +49,6 @@ avro = ["datafusion-datasource-avro", "datafusion-common/avro"] [dependencies] arrow = { workspace = true } -async-trait = { workspace = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } @@ -74,6 +73,7 @@ serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +async-trait = { workspace = true } datafusion = { workspace = true, default-features = false, features = [ "sql", "datetime_expressions", diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a1c0887a39d12..6f275d2a2204c 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1993,6 +1993,12 @@ impl protobuf::PhysicalPlanNode { let input: Arc = into_physical_plan(&async_func.input, ctx, extension_codec)?; + if async_func.async_exprs.len() != async_func.async_expr_names.len() { + return internal_err!( + "AsyncFuncExecNode async_exprs length does not match async_expr_names" + ); + } + let async_exprs = async_func .async_exprs .iter() @@ -3272,7 +3278,7 @@ impl protobuf::PhysicalPlanNode { extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input(), + Arc::clone(exec.input()), extension_codec, )?; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 62f0e746734aa..241b5754eb0ac 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -2352,11 +2352,12 @@ async fn roundtrip_async_func_exec() -> Result<()> { let async_udf = AsyncScalarUDF::new(Arc::new(TestAsyncUDF::new())); ctx.register_udf(async_udf.into_scalar_udf()); - let logical_plan = ctx - .state() - .create_logical_plan("select test_async_udf(1)") + let physical_plan = ctx + .sql("select test_async_udf(1)") + .await? + .create_physical_plan() .await?; - let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + let codec = TestAsyncUDFPhysicalCodec {}; roundtrip_test_and_return(physical_plan, &ctx, &codec)?; From f025c2d4c433d502bc6f15ef4479840925aa6e31 Mon Sep 17 00:00:00 2001 From: David Stancu Date: Tue, 9 Dec 2025 08:29:19 -0500 Subject: [PATCH 5/5] simplify test --- .../tests/cases/roundtrip_physical_plan.rs | 38 +------------------ 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 241b5754eb0ac..438e65f60b001 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -2313,41 +2313,6 @@ async fn roundtrip_async_func_exec() -> Result<()> { } } - #[derive(Debug)] - struct TestAsyncUDFPhysicalCodec {} - impl PhysicalExtensionCodec for TestAsyncUDFPhysicalCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[Arc], - _ctx: &TaskContext, - ) -> Result> { - not_impl_err!( - "TestAsyncUDFPhysicalCodec should not be called to decode an extension" - ) - } - - fn try_encode( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> Result<()> { - Ok(()) - } - - fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { - if name == "test_async_udf" { - Ok(Arc::new(TestAsyncUDF::new().into())) - } else { - not_impl_err!("TestAsyncUDFPhysicalCodec unrecognized UDF {name}") - } - } - - fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { - Ok(()) - } - } - let ctx = SessionContext::new(); let async_udf = AsyncScalarUDF::new(Arc::new(TestAsyncUDF::new())); ctx.register_udf(async_udf.into_scalar_udf()); @@ -2358,8 +2323,7 @@ async fn roundtrip_async_func_exec() -> Result<()> { .create_physical_plan() .await?; - let codec = TestAsyncUDFPhysicalCodec {}; - roundtrip_test_and_return(physical_plan, &ctx, &codec)?; + roundtrip_test_with_context(physical_plan, &ctx)?; Ok(()) }