diff --git a/Makefile b/Makefile index 245af80..2a1851e 100644 --- a/Makefile +++ b/Makefile @@ -34,3 +34,6 @@ run: debug: RUST_BACKTRACE=1 cargo run + +debug_v2: + ENABLE_V2=1 RUST_BACKTRACE=1 cargo run diff --git a/src/cli.rs b/src/cli.rs index bd8eb0a..ed6b1ed 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,3 +1,4 @@ +use std::env; use std::fs::File; use std::sync::Arc; @@ -13,7 +14,7 @@ pub async fn interactive(db: Database, client_context: Arc) -> Re let mut rl = Editor::<()>::new()?; load_history(&mut rl); - let mut enable_v2 = false; + let mut enable_v2 = env::var("ENABLE_V2").unwrap_or_else(|_| "0".to_string()) == "1"; loop { let read_sql = read_sql(&mut rl); diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs index c648b0f..b2aed89 100644 --- a/src/execution/expression_executor.rs +++ b/src/execution/expression_executor.rs @@ -1,4 +1,5 @@ use arrow::array::ArrayRef; +use arrow::compute::cast; use arrow::record_batch::RecordBatch; use super::ExecutorError; @@ -28,6 +29,11 @@ impl ExpressionExecutor { BoundExpression::BoundColumnRefExpression(_) => todo!(), BoundExpression::BoundConstantExpression(e) => e.value.to_array(), BoundExpression::BoundReferenceExpression(e) => input.column(e.index).clone(), + BoundExpression::BoundCastExpression(e) => { + let child_result = Self::execute_internal(&e.child, input)?; + let to_type = e.base.return_type.clone().into(); + cast(&child_result, &to_type)? + } }) } } diff --git a/src/planner_v2/binder/expression/bind_cast_expression.rs b/src/planner_v2/binder/expression/bind_cast_expression.rs new file mode 100644 index 0000000..fd815b3 --- /dev/null +++ b/src/planner_v2/binder/expression/bind_cast_expression.rs @@ -0,0 +1,34 @@ +use derive_new::new; + +use super::{BoundExpression, BoundExpressionBase}; +use crate::types_v2::LogicalType; + +#[derive(new, Debug, Clone)] +pub struct BoundCastExpression { + pub(crate) base: BoundExpressionBase, + /// The child type + pub(crate) child: Box, + #[allow(dead_code)] + /// Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of + /// throwing an error. + pub(crate) try_cast: bool, +} + +impl BoundCastExpression { + pub fn add_cast_to_type( + expr: BoundExpression, + target_type: LogicalType, + alias: String, + try_cast: bool, + ) -> BoundExpression { + if expr.return_type() == target_type { + return expr; + } + let base = BoundExpressionBase::new(alias, target_type); + BoundExpression::BoundCastExpression(BoundCastExpression::new( + base, + Box::new(expr), + try_cast, + )) + } +} diff --git a/src/planner_v2/binder/expression/mod.rs b/src/planner_v2/binder/expression/mod.rs index 19d61b0..0b92b6b 100644 --- a/src/planner_v2/binder/expression/mod.rs +++ b/src/planner_v2/binder/expression/mod.rs @@ -1,8 +1,10 @@ +mod bind_cast_expression; mod bind_column_ref_expression; mod bind_constant_expression; mod bind_reference_expression; mod column_binding; +pub use bind_cast_expression::*; pub use bind_column_ref_expression::*; pub use bind_constant_expression::*; pub use bind_reference_expression::*; @@ -24,6 +26,7 @@ pub enum BoundExpression { BoundColumnRefExpression(BoundColumnRefExpression), BoundConstantExpression(BoundConstantExpression), BoundReferenceExpression(BoundReferenceExpression), + BoundCastExpression(BoundCastExpression), } impl BoundExpression { @@ -32,6 +35,7 @@ impl BoundExpression { BoundExpression::BoundColumnRefExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundConstantExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundReferenceExpression(expr) => expr.base.return_type.clone(), + BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(), } } @@ -40,6 +44,7 @@ impl BoundExpression { BoundExpression::BoundColumnRefExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundConstantExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundReferenceExpression(expr) => expr.base.alias.clone(), + BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(), } } } diff --git a/src/planner_v2/binder/query_node/plan_select_node.rs b/src/planner_v2/binder/query_node/plan_select_node.rs index ab803c3..7b2fd02 100644 --- a/src/planner_v2/binder/query_node/plan_select_node.rs +++ b/src/planner_v2/binder/query_node/plan_select_node.rs @@ -1,8 +1,10 @@ use super::BoundSelectNode; use crate::planner_v2::BoundTableRef::{BoundBaseTableRef, BoundExpressionListRef}; use crate::planner_v2::{ - BindError, Binder, BoundStatement, LogicalOperator, LogicalOperatorBase, LogicalProjection, + BindError, Binder, BoundCastExpression, BoundStatement, LogicalOperator, LogicalOperatorBase, + LogicalProjection, }; +use crate::types_v2::LogicalType; impl Binder { pub fn create_plan_for_select_node( @@ -23,4 +25,46 @@ impl Binder { Ok(BoundStatement::new(root, node.types, node.names)) } + + pub fn cast_logical_operator_to_types( + &mut self, + source_types: &[LogicalType], + target_types: &[LogicalType], + op: &mut LogicalOperator, + ) -> Result<(), BindError> { + assert!(source_types.len() == target_types.len()); + if source_types == target_types { + // source and target types are equal: don't need to cast + return Ok(()); + } + if let LogicalOperator::LogicalProjection(node) = op { + // "node" is a projection; we can just do the casts in there + assert!(node.base.expressioins.len() == source_types.len()); + for (idx, (source_type, target_type)) in + source_types.iter().zip(target_types.iter()).enumerate() + { + if source_type != target_type { + if LogicalType::can_implicit_cast(source_type, target_type) { + let alias = node.base.expressioins[idx].alias(); + node.base.expressioins[idx] = BoundCastExpression::add_cast_to_type( + node.base.expressioins[idx].clone(), + target_type.clone(), + alias, + false, + ); + node.base.types[idx] = target_type.clone(); + } else { + return Err(BindError::Internal(format!( + "cannot cast {:?} to {:?}", + source_type, target_type + ))); + } + } + } + Ok(()) + } else { + // found a non-projection operator, push a new projection containing the casts + todo!(); + } + } } diff --git a/src/planner_v2/binder/statement/bind_insert.rs b/src/planner_v2/binder/statement/bind_insert.rs index 854a972..466b58a 100644 --- a/src/planner_v2/binder/statement/bind_insert.rs +++ b/src/planner_v2/binder/statement/bind_insert.rs @@ -70,6 +70,8 @@ impl Binder { let select_node = self.bind_select_node(source)?; let expected_columns_cnt = named_column_indices.len(); + + // special case: check if we are inserting from a VALUES statement if let BoundTableRef::BoundExpressionListRef(table_ref) = &select_node.from_table { // CheckInsertColumnCountMismatch let insert_columns_cnt = table_ref.values.first().unwrap().len(); @@ -79,12 +81,14 @@ impl Binder { expected_columns_cnt, insert_columns_cnt ))); } - }; - - // TODO: cast types + } let select_node = self.create_plan_for_select_node(select_node)?; - let plan = select_node.plan; + let inserted_types = select_node.types; + let mut plan = select_node.plan; + // cast inserted types to expected types when necessary + self.cast_logical_operator_to_types(&inserted_types, &expected_types, &mut plan)?; + let root = LogicalInsert::new( LogicalOperatorBase::new(vec![plan], vec![], vec![]), column_index_list, diff --git a/src/planner_v2/binder/tableref/bind_expression_list_ref.rs b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs index ee89bd4..e365d0b 100644 --- a/src/planner_v2/binder/tableref/bind_expression_list_ref.rs +++ b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs @@ -24,10 +24,21 @@ impl Binder { &mut self, values: &Values, ) -> Result { + // ensure all values lists are the same length + let mut values_cnt = 0; + for val_expr_list in values.0.iter() { + if values_cnt == 0 { + values_cnt = val_expr_list.len(); + } else if values_cnt != val_expr_list.len() { + return Err(BindError::Internal( + "VALUES lists must all be the same length".to_string(), + )); + } + } + let mut bound_expr_list = vec![]; - let mut names = vec![]; - let mut types = vec![]; - let mut finish_name = false; + let mut names = vec!["".to_string(); values_cnt]; + let mut types = vec![LogicalType::Invalid; values_cnt]; let mut expr_binder = ExpressionBinder::new(self); @@ -35,14 +46,15 @@ impl Binder { let mut bound_expr_row = vec![]; for (idx, expr) in val_expr_list.iter().enumerate() { let bound_expr = expr_binder.bind_expression(expr, &mut vec![], &mut vec![])?; - if !finish_name { - names.push(format!("col{}", idx)); - types.push(bound_expr.return_type()); + names[idx] = format!("col{}", idx); + if types[idx] == LogicalType::Invalid { + types[idx] = bound_expr.return_type().clone(); } + // use values max type as the column type + types[idx] = LogicalType::max_logical_type(&types[idx], &bound_expr.return_type())?; bound_expr_row.push(bound_expr); } bound_expr_list.push(bound_expr_row); - finish_name = true; } let table_index = self.generate_table_index(); self.bind_context.add_generic_binding( diff --git a/src/planner_v2/expression_iterator.rs b/src/planner_v2/expression_iterator.rs index 182fbfd..b175dd0 100644 --- a/src/planner_v2/expression_iterator.rs +++ b/src/planner_v2/expression_iterator.rs @@ -3,7 +3,7 @@ use super::BoundExpression; pub struct ExpressionIterator; impl ExpressionIterator { - pub fn enumerate_children(expr: &mut BoundExpression, _callback: F) + pub fn enumerate_children(expr: &mut BoundExpression, callback: F) where F: Fn(&mut BoundExpression), { @@ -13,6 +13,7 @@ impl ExpressionIterator { | BoundExpression::BoundReferenceExpression(_) => { // these node types have no children } + BoundExpression::BoundCastExpression(e) => callback(&mut e.child), } } } diff --git a/src/planner_v2/logical_operator_visitor.rs b/src/planner_v2/logical_operator_visitor.rs index 04753fe..813b694 100644 --- a/src/planner_v2/logical_operator_visitor.rs +++ b/src/planner_v2/logical_operator_visitor.rs @@ -1,6 +1,6 @@ use super::{ - BoundColumnRefExpression, BoundConstantExpression, BoundExpression, BoundReferenceExpression, - ExpressionIterator, LogicalOperator, + BoundCastExpression, BoundColumnRefExpression, BoundConstantExpression, BoundExpression, + BoundReferenceExpression, ExpressionIterator, LogicalOperator, }; /// Visitor pattern on logical operators, also includes rewrite expression ability. @@ -34,6 +34,7 @@ pub trait LogicalOperatorVisitor { BoundExpression::BoundColumnRefExpression(e) => self.visit_replace_column_ref(e), BoundExpression::BoundConstantExpression(e) => self.visit_replace_constant(e), BoundExpression::BoundReferenceExpression(e) => self.visit_replace_reference(e), + BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e), }; if let Some(new_expr) = result { *expr = new_expr; @@ -55,4 +56,7 @@ pub trait LogicalOperatorVisitor { fn visit_replace_reference(&self, _: &BoundReferenceExpression) -> Option { None } + fn visit_replace_cast(&self, _: &BoundCastExpression) -> Option { + None + } } diff --git a/src/planner_v2/operator/mod.rs b/src/planner_v2/operator/mod.rs index 01c5463..f82991f 100644 --- a/src/planner_v2/operator/mod.rs +++ b/src/planner_v2/operator/mod.rs @@ -53,6 +53,16 @@ impl LogicalOperator { } } + pub fn types(&self) -> &[LogicalType] { + match self { + LogicalOperator::LogicalCreateTable(op) => &op.base.types, + LogicalOperator::LogicalExpressionGet(op) => &op.base.types, + LogicalOperator::LogicalInsert(op) => &op.base.types, + LogicalOperator::LogicalGet(op) => &op.base.types, + LogicalOperator::LogicalProjection(op) => &op.base.types, + } + } + pub fn get_column_bindings(&self) -> Vec { let default = vec![ColumnBinding::new(0, 0)]; match self { diff --git a/src/types_v2/errors.rs b/src/types_v2/errors.rs index babe8ee..f3842b4 100644 --- a/src/types_v2/errors.rs +++ b/src/types_v2/errors.rs @@ -6,4 +6,6 @@ pub enum TypeError { NotImplementedArrowDataType(String), #[error("not implemented sqlparser datatype: {0}")] NotImplementedSqlparserDataType(String), + #[error("internal error: {0}")] + InternalError(String), } diff --git a/src/types_v2/types.rs b/src/types_v2/types.rs index 24f6178..68a6528 100644 --- a/src/types_v2/types.rs +++ b/src/types_v2/types.rs @@ -19,6 +19,110 @@ pub enum LogicalType { Varchar, } +impl LogicalType { + pub fn is_numeric(&self) -> bool { + matches!( + self, + LogicalType::Tinyint + | LogicalType::UTinyint + | LogicalType::Smallint + | LogicalType::USmallint + | LogicalType::Integer + | LogicalType::UInteger + | LogicalType::Bigint + | LogicalType::UBigint + | LogicalType::Float + | LogicalType::Double + ) + } + + pub fn max_logical_type( + left: &LogicalType, + right: &LogicalType, + ) -> Result { + if left == right { + return Ok(left.clone()); + } + if left.is_numeric() && right.is_numeric() { + if LogicalType::can_implicit_cast(left, right) { + return Ok(right.clone()); + } else if LogicalType::can_implicit_cast(right, left) { + return Ok(left.clone()); + } else { + return Err(TypeError::InternalError(format!( + "can not implicit cast {:?} to {:?}", + left, right + ))); + } + } + Err(TypeError::InternalError(format!( + "can not compare two types: {:?} and {:?}", + left, right + ))) + } + + pub fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool { + if from == to { + return true; + } + match from { + LogicalType::Invalid => false, + LogicalType::Boolean => false, + LogicalType::Tinyint => matches!( + to, + LogicalType::Smallint + | LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::UTinyint => matches!( + to, + LogicalType::USmallint + | LogicalType::UInteger + | LogicalType::UBigint + | LogicalType::Smallint + | LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::Smallint => matches!( + to, + LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::USmallint => matches!( + to, + LogicalType::UInteger + | LogicalType::UBigint + | LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::Integer => matches!( + to, + LogicalType::Bigint | LogicalType::Float | LogicalType::Double + ), + LogicalType::UInteger => matches!( + to, + LogicalType::UBigint + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::Bigint => matches!(to, LogicalType::Float | LogicalType::Double), + LogicalType::UBigint => matches!(to, LogicalType::Float | LogicalType::Double), + LogicalType::Float => matches!(to, LogicalType::Double), + LogicalType::Double => false, + LogicalType::Varchar => false, + } + } +} + /// sqlparser datatype to logical type impl TryFrom for LogicalType { type Error = TypeError; diff --git a/tests/slt/create_table.slt b/tests/slt/create_table.slt index 6a70df3..036f35f 100644 --- a/tests/slt/create_table.slt +++ b/tests/slt/create_table.slt @@ -17,3 +17,27 @@ select a, c, b from t1; NULL 0 4 NULL 1 5 2 9 7 + +onlyif sqlrs_v2 +statement ok +create table t2(a int, b int, c int); + +onlyif sqlrs_v2 +statement ok +insert into t2(c, b, a) values (0, 4, 1), (1, 5, 2); + +onlyif sqlrs_v2 +query III +select c, b, a from t2; +---- +0 4 1 +1 5 2 + +# Test insert type cast +onlyif sqlrs_v2 +statement ok +create table t3(a TINYINT UNSIGNED); + +onlyif sqlrs_v2 +statement error +insert into t3(a) values (1481);