From dee5177ccffa114de3229d4159e4826d94e440ed Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Thu, 7 Nov 2024 23:14:25 +0800 Subject: [PATCH 1/3] introduce `plan_data_type` for ExprPlanner --- datafusion/expr/src/planner.rs | 10 ++++- datafusion/sql/src/planner.rs | 8 ++++ datafusion/sql/tests/common/mod.rs | 41 ++++++++++++++++-- datafusion/sql/tests/sql_integration.rs | 55 ++++++++++++++++++++++++- 4 files changed, 108 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 7dd7360e478f2..e2e05b93924d5 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -25,6 +25,7 @@ use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, }; +use sqlparser::ast; use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; @@ -205,6 +206,13 @@ pub trait ExprPlanner: Debug + Send + Sync { fn plan_any(&self, expr: RawBinaryExpr) -> Result> { Ok(PlannerResult::Original(expr)) } + + /// Plan SQL type to DataFusion data type + /// + /// Returns None if not possible + fn plan_data_type(&self, _sql_type: &ast::DataType) -> Result> { + Ok(None) + } } /// An operator with two arguments to plan @@ -216,7 +224,7 @@ pub trait ExprPlanner: Debug + Send + Sync { /// custom expressions. #[derive(Debug, Clone)] pub struct RawBinaryExpr { - pub op: sqlparser::ast::BinaryOperator, + pub op: ast::BinaryOperator, pub left: Expr, pub right: Expr, } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 4d44d5ff25849..b5687ba066f48 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -401,6 +401,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { + // First check if any of the registered expr_planners can handle this type + for expr_planner in self.context_provider.get_expr_planners() { + if let Some(data_type) = expr_planner.plan_data_type(sql_type)? { + return Ok(data_type); + } + } + + // If no expr_planner can handle this type, use the default conversion match sql_type { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => { // Arrays may be multi-dimensional. diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index b0fa170318493..e18bd213d725d 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -18,15 +18,16 @@ use std::any::Any; #[cfg(test)] use std::collections::HashMap; -use std::fmt::Display; +use std::fmt::{Debug, Display}; use std::{sync::Arc, vec}; use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; -use datafusion_common::{plan_err, GetExt, Result, TableReference}; -use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; +use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference}; +use datafusion_expr::planner::{ExprPlanner, PlannerResult}; +use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; struct MockCsvType {} @@ -280,3 +281,35 @@ impl TableSource for EmptyTable { Arc::clone(&self.table_schema) } } + +#[derive(Debug)] +pub struct CustomTypePlanner {} + +impl ExprPlanner for CustomTypePlanner { + fn plan_data_type( + &self, + sql_type: &sqlparser::ast::DataType, + ) -> Result> { + match sql_type { + sqlparser::ast::DataType::Datetime(precision) => { + let precision = match precision { + Some(0) => TimeUnit::Second, + Some(3) => TimeUnit::Millisecond, + Some(6) => TimeUnit::Microsecond, + None | Some(9) => TimeUnit::Nanosecond, + _ => unreachable!(), + }; + Ok(Some(DataType::Timestamp(precision, None))) + } + _ => Ok(None), + } + } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Planned(make_array(exprs))) + } +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 698c408e538f5..d4d415e8f0812 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -41,13 +41,14 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use crate::common::MockSessionState; +use crate::common::{CustomTypePlanner, MockSessionState}; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf, min_max::min_udaf, }; use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; +use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::rank::rank_udwf; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -4495,3 +4496,55 @@ fn test_no_functions_registered() { "Internal error: No functions registered with this context." ); } + +#[test] +fn test_custom_type_plan() -> Result<()> { + let sql = "SELECT DATETIME '2001-01-01 18:00:00'"; + + // test the default behavior + let options = ParserOptions::default(); + let dialect = &GenericDialect {}; + let state = MockSessionState::default(); + let context = MockContextProvider { state }; + let planner = SqlToRel::new_with_options(&context, options); + let result = DFParser::parse_sql_with_dialect(sql, dialect); + let mut ast = result.unwrap(); + let err = planner.statement_to_plan(ast.pop_front().unwrap()); + assert_contains!( + err.unwrap_err().to_string(), + "This feature is not implemented: Unsupported SQL type Datetime(None)" + ); + + fn plan_sql(sql: &str) -> LogicalPlan { + let options = ParserOptions::default(); + let dialect = &GenericDialect {}; + let state = MockSessionState::default() + .with_scalar_function(make_array_udf()) + .with_expr_planner(Arc::new(CustomTypePlanner {})); + let context = MockContextProvider { state }; + let planner = SqlToRel::new_with_options(&context, options); + let result = DFParser::parse_sql_with_dialect(sql, dialect); + let mut ast = result.unwrap(); + planner.statement_to_plan(ast.pop_front().unwrap()).unwrap() + } + + let plan = plan_sql(sql); + let expected = + "Projection: CAST(Utf8(\"2001-01-01 18:00:00\") AS Timestamp(Nanosecond, None))\ + \n EmptyRelation"; + assert_eq!(plan.to_string(), expected); + + let plan = plan_sql("SELECT CAST(TIMESTAMP '2001-01-01 18:00:00' AS DATETIME)"); + let expected = "Projection: CAST(CAST(Utf8(\"2001-01-01 18:00:00\") AS Timestamp(Nanosecond, None)) AS Timestamp(Nanosecond, None))\ + \n EmptyRelation"; + assert_eq!(plan.to_string(), expected); + + let plan = plan_sql( + "SELECT ARRAY[DATETIME '2001-01-01 18:00:00', DATETIME '2001-01-02 18:00:00']", + ); + let expected = "Projection: make_array(CAST(Utf8(\"2001-01-01 18:00:00\") AS Timestamp(Nanosecond, None)), CAST(Utf8(\"2001-01-02 18:00:00\") AS Timestamp(Nanosecond, None)))\ + \n EmptyRelation"; + assert_eq!(plan.to_string(), expected); + + Ok(()) +} From df814409d64631545b2a36b1b247d734d27acf85 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sat, 9 Nov 2024 17:06:37 +0800 Subject: [PATCH 2/3] implement TypePlanner trait instead of extending ExprPlanner --- datafusion/core/src/execution/context/mod.rs | 52 +++++++++++++++++-- .../core/src/execution/session_state.rs | 30 ++++++++++- datafusion/expr/src/planner.rs | 22 +++++--- datafusion/sql/src/planner.rs | 4 +- datafusion/sql/tests/common/mod.rs | 28 +++++++--- datafusion/sql/tests/sql_integration.rs | 5 +- 6 files changed, 120 insertions(+), 21 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 7868a7f9e59c7..604a72293ea11 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1788,15 +1788,15 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { #[cfg(test)] mod tests { - use std::env; - use std::path::PathBuf; - use super::{super::options::CsvReadOptions, *}; use crate::assert_batches_eq; use crate::execution::memory_pool::MemoryConsumer; use crate::execution::runtime_env::RuntimeEnvBuilder; use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; + use arrow_schema::{DataType, TimeUnit}; + use std::env; + use std::path::PathBuf; use datafusion_common_runtime::SpawnedTask; @@ -1804,6 +1804,8 @@ mod tests { use crate::execution::session_state::SessionStateBuilder; use crate::physical_planner::PhysicalPlanner; use async_trait::async_trait; + use datafusion_expr::planner::TypePlanner; + use sqlparser::ast; use tempfile::TempDir; #[tokio::test] @@ -2200,6 +2202,29 @@ mod tests { Ok(()) } + #[tokio::test] + async fn custom_type_planner() -> Result<()> { + let state = SessionStateBuilder::new() + .with_default_features() + .with_type_planner(Arc::new(MyTypePlanner {})) + .build(); + let ctx = SessionContext::new_with_state(state); + let result = ctx + .sql("SELECT DATETIME '2021-01-01 00:00:00'") + .await? + .collect() + .await?; + let expected = [ + "+-----------------------------+", + "| Utf8(\"2021-01-01 00:00:00\") |", + "+-----------------------------+", + "| 2021-01-01T00:00:00 |", + "+-----------------------------+", + ]; + assert_batches_eq!(expected, &result); + Ok(()) + } + struct MyPhysicalPlanner {} #[async_trait] @@ -2260,4 +2285,25 @@ mod tests { Ok(ctx) } + + #[derive(Debug)] + struct MyTypePlanner {} + + impl TypePlanner for MyTypePlanner { + fn plan_type(&self, sql_type: &ast::DataType) -> Result> { + match sql_type { + ast::DataType::Datetime(precision) => { + let precision = match precision { + Some(0) => TimeUnit::Second, + Some(3) => TimeUnit::Millisecond, + Some(6) => TimeUnit::Microsecond, + None | Some(9) => TimeUnit::Nanosecond, + _ => unreachable!(), + }; + Ok(Some(DataType::Timestamp(precision, None))) + } + _ => Ok(None), + } + } + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ad0c1c2d41a67..3127f634ad508 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -48,7 +48,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::planner::{ExprPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::var_provider::{is_system_variables, VarType}; @@ -128,6 +128,8 @@ pub struct SessionState { analyzer: Analyzer, /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, + /// Provides support for customising the type planning, e.g. to add support for planning custom SQL types + type_planner: Option>, /// Responsible for optimizing a logical plan optimizer: Optimizer, /// Responsible for optimizing a physical execution plan @@ -192,6 +194,7 @@ impl Debug for SessionState { .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) .field("expr_planners", &self.expr_planners) + .field("type_planner", &self.type_planner) .field("query_planners", &self.query_planner) .field("analyzer", &self.analyzer) .field("optimizer", &self.optimizer) @@ -943,6 +946,7 @@ pub struct SessionStateBuilder { session_id: Option, analyzer: Option, expr_planners: Option>>, + type_planner: Option>, optimizer: Option, physical_optimizers: Option, query_planner: Option>, @@ -972,6 +976,7 @@ impl SessionStateBuilder { session_id: None, analyzer: None, expr_planners: None, + type_planner: None, optimizer: None, physical_optimizers: None, query_planner: None, @@ -1019,6 +1024,7 @@ impl SessionStateBuilder { session_id: None, analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), + type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), query_planner: Some(existing.query_planner), @@ -1113,6 +1119,12 @@ impl SessionStateBuilder { self } + /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. + pub fn with_type_planner(mut self, type_planner: Arc) -> Self { + self.type_planner = Some(type_planner); + self + } + /// Set the [`PhysicalOptimizerRule`]s used to optimize plans. pub fn with_physical_optimizer_rules( mut self, @@ -1306,6 +1318,7 @@ impl SessionStateBuilder { session_id, analyzer, expr_planners, + type_planner, optimizer, physical_optimizers, query_planner, @@ -1334,6 +1347,7 @@ impl SessionStateBuilder { session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), + type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), @@ -1444,6 +1458,11 @@ impl SessionStateBuilder { &mut self.expr_planners } + /// Returns the current type_planner value + pub fn type_planner(&mut self) -> &mut Option> { + &mut self.type_planner + } + /// Returns the current optimizer value pub fn optimizer(&mut self) -> &mut Option { &mut self.optimizer @@ -1566,6 +1585,7 @@ impl Debug for SessionStateBuilder { .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) .field("expr_planners", &self.expr_planners) + .field("type_planner", &self.type_planner) .field("query_planners", &self.query_planner) .field("analyzer_rules", &self.analyzer_rules) .field("analyzer", &self.analyzer) @@ -1607,6 +1627,14 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { &self.state.expr_planners } + fn get_type_planner(&self) -> Option> { + if let Some(type_planner) = &self.state.type_planner { + Some(Arc::clone(type_planner)) + } else { + None + } + } + fn get_table_source( &self, name: TableReference, diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index e2e05b93924d5..42047e8e6caa2 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -67,6 +67,11 @@ pub trait ContextProvider { &[] } + /// Getter for the data type planner + fn get_type_planner(&self) -> Option> { + None + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -206,13 +211,6 @@ pub trait ExprPlanner: Debug + Send + Sync { fn plan_any(&self, expr: RawBinaryExpr) -> Result> { Ok(PlannerResult::Original(expr)) } - - /// Plan SQL type to DataFusion data type - /// - /// Returns None if not possible - fn plan_data_type(&self, _sql_type: &ast::DataType) -> Result> { - Ok(None) - } } /// An operator with two arguments to plan @@ -257,3 +255,13 @@ pub enum PlannerResult { /// The raw expression could not be planned, and is returned unmodified Original(T), } + +/// This trait allows users to customize the behavior of the data type planning +pub trait TypePlanner: Debug + Send + Sync { + /// Plan SQL type to DataFusion data type + /// + /// Returns None if not possible + fn plan_type(&self, _sql_type: &ast::DataType) -> Result> { + Ok(None) + } +} diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b5687ba066f48..f0177a03013ed 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -402,8 +402,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { // First check if any of the registered expr_planners can handle this type - for expr_planner in self.context_provider.get_expr_planners() { - if let Some(data_type) = expr_planner.plan_data_type(sql_type)? { + if let Some(type_planner) = self.context_provider.get_type_planner() { + if let Some(data_type) = type_planner.plan_type(sql_type)? { return Ok(data_type); } } diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index e18bd213d725d..63c296dfbc2f9 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -25,7 +25,7 @@ use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference}; -use datafusion_expr::planner::{ExprPlanner, PlannerResult}; +use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner}; use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; @@ -55,6 +55,7 @@ pub(crate) struct MockSessionState { scalar_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, + type_planner: Option>, window_functions: HashMap>, pub config_options: ConfigOptions, } @@ -65,6 +66,11 @@ impl MockSessionState { self } + pub fn with_type_planner(mut self, type_planner: Arc) -> Self { + self.type_planner = Some(type_planner); + self + } + pub fn with_scalar_function(mut self, scalar_function: Arc) -> Self { self.scalar_functions .insert(scalar_function.name().to_string(), scalar_function); @@ -260,6 +266,14 @@ impl ContextProvider for MockContextProvider { fn get_expr_planners(&self) -> &[Arc] { &self.state.expr_planners } + + fn get_type_planner(&self) -> Option> { + if let Some(type_planner) = &self.state.type_planner { + Some(Arc::clone(type_planner)) + } else { + None + } + } } struct EmptyTable { @@ -285,11 +299,8 @@ impl TableSource for EmptyTable { #[derive(Debug)] pub struct CustomTypePlanner {} -impl ExprPlanner for CustomTypePlanner { - fn plan_data_type( - &self, - sql_type: &sqlparser::ast::DataType, - ) -> Result> { +impl TypePlanner for CustomTypePlanner { + fn plan_type(&self, sql_type: &sqlparser::ast::DataType) -> Result> { match sql_type { sqlparser::ast::DataType::Datetime(precision) => { let precision = match precision { @@ -304,7 +315,12 @@ impl ExprPlanner for CustomTypePlanner { _ => Ok(None), } } +} + +#[derive(Debug)] +pub struct CustomExprPlanner {} +impl ExprPlanner for CustomExprPlanner { fn plan_array_literal( &self, exprs: Vec, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index d4d415e8f0812..f538fb4a4c440 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -41,7 +41,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use crate::common::{CustomTypePlanner, MockSessionState}; +use crate::common::{CustomExprPlanner, CustomTypePlanner, MockSessionState}; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf, @@ -4520,7 +4520,8 @@ fn test_custom_type_plan() -> Result<()> { let dialect = &GenericDialect {}; let state = MockSessionState::default() .with_scalar_function(make_array_udf()) - .with_expr_planner(Arc::new(CustomTypePlanner {})); + .with_expr_planner(Arc::new(CustomExprPlanner {})) + .with_type_planner(Arc::new(CustomTypePlanner {})); let context = MockContextProvider { state }; let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); From 6ef993842dfbdbe2715c5ecc49d1a01d8e6f2a1b Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Sat, 9 Nov 2024 17:14:50 +0800 Subject: [PATCH 3/3] enhance the document --- datafusion/core/src/execution/session_state.rs | 2 +- datafusion/sql/src/planner.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 3127f634ad508..6654d16db34c6 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -128,7 +128,7 @@ pub struct SessionState { analyzer: Analyzer, /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, - /// Provides support for customising the type planning, e.g. to add support for planning custom SQL types + /// Provides support for customising the SQL type planning type_planner: Option>, /// Responsible for optimizing a logical plan optimizer: Optimizer, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index f0177a03013ed..ccb2ccf7126f1 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -401,14 +401,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { - // First check if any of the registered expr_planners can handle this type + // First check if any of the registered type_planner can handle this type if let Some(type_planner) = self.context_provider.get_type_planner() { if let Some(data_type) = type_planner.plan_type(sql_type)? { return Ok(data_type); } } - // If no expr_planner can handle this type, use the default conversion + // If no type_planner can handle this type, use the default conversion match sql_type { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => { // Arrays may be multi-dimensional.