diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 0bb91536da3ca..ecc3bd2990f4c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1711,7 +1711,7 @@ impl FunctionRegistry for SessionContext { } fn expr_planners(&self) -> Vec> { - self.state.read().expr_planners() + self.state.read().expr_planners().to_vec() } fn register_expr_planner( diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 10b8f13217e64..8aa812cc5258a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -552,6 +552,11 @@ impl SessionState { &self.optimizer } + /// Returns the [`ExprPlanner`]s for this session + pub fn expr_planners(&self) -> &[Arc] { + &self.expr_planners + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -1637,7 +1642,7 @@ struct SessionContextProvider<'a> { impl ContextProvider for SessionContextProvider<'_> { fn get_expr_planners(&self) -> &[Arc] { - &self.state.expr_planners + self.state.expr_planners() } fn get_type_planner(&self) -> Option> { @@ -1959,8 +1964,17 @@ pub(crate) struct PreparedPlan { #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; + use crate::common::assert_contains; + use crate::config::ConfigOptions; + use crate::datasource::empty::EmptyTable; + use crate::datasource::provider_as_source; use crate::datasource::MemTable; use crate::execution::context::SessionState; + use crate::logical_expr::planner::ExprPlanner; + use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use crate::physical_plan::ExecutionPlan; + use crate::sql::planner::ContextProvider; + use crate::sql::{ResolvedTableReference, TableReference}; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; @@ -1970,6 +1984,7 @@ mod tests { use datafusion_expr::Expr; use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; + use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -2127,4 +2142,148 @@ mod tests { Ok(()) } + + /// This test demonstrates why it's more convenient and somewhat necessary to provide + /// an `expr_planners` method for `SessionState`. + #[tokio::test] + async fn test_with_expr_planners() -> Result<()> { + // A helper method for planning count wildcard with or without expr planners. + async fn plan_count_wildcard( + with_expr_planners: bool, + ) -> Result> { + let mut context_provider = MyContextProvider::new().with_table( + "t", + provider_as_source(Arc::new(EmptyTable::new(Schema::empty().into()))), + ); + if with_expr_planners { + context_provider = context_provider.with_expr_planners(); + } + + let state = &context_provider.state; + let statement = state.sql_to_statement("select count(*) from t", "mysql")?; + let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?; + state.create_physical_plan(&plan).await + } + + // Planning count wildcard without expr planners should fail. + let got = plan_count_wildcard(false).await; + assert_contains!( + got.unwrap_err().to_string(), + "Physical plan does not support logical expression Wildcard" + ); + + // Planning count wildcard with expr planners should succeed. + let got = plan_count_wildcard(true).await?; + let displayable = DisplayableExecutionPlan::new(got.as_ref()); + assert_eq!( + displayable.indent(false).to_string(), + "ProjectionExec: expr=[0 as count(*)]\n PlaceholderRowExec\n" + ); + + Ok(()) + } + + /// A `ContextProvider` based on `SessionState`. + /// + /// Almost all planning context are retrieved from the `SessionState`. + struct MyContextProvider { + /// The session state. + state: SessionState, + /// Registered tables. + tables: HashMap>, + /// Controls whether to return expression planners when called `ContextProvider::expr_planners`. + return_expr_planners: bool, + } + + impl MyContextProvider { + /// Creates a new `SessionContextProvider`. + pub fn new() -> Self { + Self { + state: SessionStateBuilder::default() + .with_default_features() + .build(), + tables: HashMap::new(), + return_expr_planners: false, + } + } + + /// Registers a table. + /// + /// The catalog and schema are provided by default. + pub fn with_table(mut self, table: &str, source: Arc) -> Self { + self.tables.insert( + ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: table.to_string().into(), + }, + source, + ); + self + } + + /// Sets the `return_expr_planners` flag to true. + pub fn with_expr_planners(self) -> Self { + Self { + return_expr_planners: true, + ..self + } + } + } + + impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + let resolved_table_ref = ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: name.table().to_string().into(), + }; + let source = self.tables.get(&resolved_table_ref).cloned().unwrap(); + Ok(source) + } + + /// We use a `return_expr_planners` flag to demonstrate why it's necessary to + /// return the expression planners in the `SessionState`. + /// + /// Note, the default implementation returns an empty slice. + fn get_expr_planners(&self) -> &[Arc] { + if self.return_expr_planners { + self.state.expr_planners() + } else { + &[] + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } + } }