From 3e164a89ef7b72c72e6823a2a2d48fd561bdd9fc Mon Sep 17 00:00:00 2001 From: Fedomn Date: Sat, 10 Dec 2022 23:28:00 +0800 Subject: [PATCH] feat(planner): support insert nulls and enable try_cast to throw error when cast failures Signed-off-by: Fedomn --- src/execution/expression_executor.rs | 5 +- .../binder/expression/bind_cast_expression.rs | 1 - .../binder/query_node/plan_select_node.rs | 24 +++--- .../binder/statement/bind_insert.rs | 2 + .../tableref/bind_expression_list_ref.rs | 19 ++++- src/types_v2/types.rs | 78 ++++++++++++++++--- src/types_v2/values.rs | 2 +- tests/slt/create_table.slt | 38 ++------- tests/slt/insert_table.slt | 69 ++++++++++++++++ 9 files changed, 177 insertions(+), 61 deletions(-) create mode 100644 tests/slt/insert_table.slt diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs index b2aed89..7f3aa6c 100644 --- a/src/execution/expression_executor.rs +++ b/src/execution/expression_executor.rs @@ -1,5 +1,5 @@ use arrow::array::ArrayRef; -use arrow::compute::cast; +use arrow::compute::{cast_with_options, CastOptions}; use arrow::record_batch::RecordBatch; use super::ExecutorError; @@ -32,7 +32,8 @@ impl ExpressionExecutor { BoundExpression::BoundCastExpression(e) => { let child_result = Self::execute_internal(&e.child, input)?; let to_type = e.base.return_type.clone().into(); - cast(&child_result, &to_type)? + let options = CastOptions { safe: e.try_cast }; + cast_with_options(&child_result, &to_type, &options)? } }) } diff --git a/src/planner_v2/binder/expression/bind_cast_expression.rs b/src/planner_v2/binder/expression/bind_cast_expression.rs index fd815b3..0b18384 100644 --- a/src/planner_v2/binder/expression/bind_cast_expression.rs +++ b/src/planner_v2/binder/expression/bind_cast_expression.rs @@ -8,7 +8,6 @@ pub struct BoundCastExpression { pub(crate) base: BoundExpressionBase, /// The child type pub(crate) child: Box, - #[allow(dead_code)] /// Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of /// throwing an error. pub(crate) try_cast: bool, diff --git a/src/planner_v2/binder/query_node/plan_select_node.rs b/src/planner_v2/binder/query_node/plan_select_node.rs index 7b2fd02..864d65e 100644 --- a/src/planner_v2/binder/query_node/plan_select_node.rs +++ b/src/planner_v2/binder/query_node/plan_select_node.rs @@ -44,21 +44,15 @@ impl Binder { source_types.iter().zip(target_types.iter()).enumerate() { if source_type != target_type { - if LogicalType::can_implicit_cast(source_type, target_type) { - let alias = node.base.expressioins[idx].alias(); - node.base.expressioins[idx] = BoundCastExpression::add_cast_to_type( - node.base.expressioins[idx].clone(), - target_type.clone(), - alias, - false, - ); - node.base.types[idx] = target_type.clone(); - } else { - return Err(BindError::Internal(format!( - "cannot cast {:?} to {:?}", - source_type, target_type - ))); - } + // differing types, have to add a cast but may be lossy + let alias = node.base.expressioins[idx].alias(); + node.base.expressioins[idx] = BoundCastExpression::add_cast_to_type( + node.base.expressioins[idx].clone(), + target_type.clone(), + alias, + false, + ); + node.base.types[idx] = target_type.clone(); } } Ok(()) diff --git a/src/planner_v2/binder/statement/bind_insert.rs b/src/planner_v2/binder/statement/bind_insert.rs index 466b58a..adb2d09 100644 --- a/src/planner_v2/binder/statement/bind_insert.rs +++ b/src/planner_v2/binder/statement/bind_insert.rs @@ -88,6 +88,8 @@ impl Binder { let mut plan = select_node.plan; // cast inserted types to expected types when necessary self.cast_logical_operator_to_types(&inserted_types, &expected_types, &mut plan)?; + // TODO: add debug level log for plan + // println!("plan: {:#?}", plan); let root = LogicalInsert::new( LogicalOperatorBase::new(vec![plan], vec![], vec![]), diff --git a/src/planner_v2/binder/tableref/bind_expression_list_ref.rs b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs index b26cf0c..2929d90 100644 --- a/src/planner_v2/binder/tableref/bind_expression_list_ref.rs +++ b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs @@ -1,7 +1,9 @@ use derive_new::new; use sqlparser::ast::Values; -use crate::planner_v2::{BindError, Binder, BoundExpression, ExpressionBinder}; +use crate::planner_v2::{ + BindError, Binder, BoundCastExpression, BoundExpression, ExpressionBinder, +}; use crate::types_v2::LogicalType; pub static VALUES_LIST_ALIAS: &str = "valueslist"; @@ -56,6 +58,21 @@ impl Binder { } bound_expr_list.push(bound_expr_row); } + // insert values contains SqlNull, the expr should be cast to the max logical type + for exprs in bound_expr_list.iter_mut() { + for (idx, bound_expr) in exprs.iter_mut().enumerate() { + if bound_expr.return_type() == LogicalType::SqlNull { + let alias = bound_expr.alias().clone(); + *bound_expr = BoundCastExpression::add_cast_to_type( + bound_expr.clone(), + types[idx].clone(), + alias, + false, + ) + } + } + } + let table_index = self.generate_table_index(); self.bind_context.add_generic_binding( VALUES_LIST_ALIAS.to_string(), diff --git a/src/types_v2/types.rs b/src/types_v2/types.rs index 68a6528..1d6b4b5 100644 --- a/src/types_v2/types.rs +++ b/src/types_v2/types.rs @@ -5,6 +5,7 @@ use super::TypeError; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum LogicalType { Invalid, + SqlNull, Boolean, Tinyint, UTinyint, @@ -36,6 +37,26 @@ impl LogicalType { ) } + pub fn is_signed_numeric(&self) -> bool { + matches!( + self, + LogicalType::Tinyint + | LogicalType::Smallint + | LogicalType::Integer + | LogicalType::Bigint + ) + } + + pub fn is_unsigned_numeric(&self) -> bool { + matches!( + self, + LogicalType::UTinyint + | LogicalType::USmallint + | LogicalType::UInteger + | LogicalType::UBigint + ) + } + pub fn max_logical_type( left: &LogicalType, right: &LogicalType, @@ -43,17 +64,14 @@ impl LogicalType { if left == right { return Ok(left.clone()); } + match (left, right) { + // SqlNull type can be cast to anything + (LogicalType::SqlNull, _) => return Ok(right.clone()), + (_, LogicalType::SqlNull) => return Ok(left.clone()), + _ => {} + } if left.is_numeric() && right.is_numeric() { - if LogicalType::can_implicit_cast(left, right) { - return Ok(right.clone()); - } else if LogicalType::can_implicit_cast(right, left) { - return Ok(left.clone()); - } else { - return Err(TypeError::InternalError(format!( - "can not implicit cast {:?} to {:?}", - left, right - ))); - } + return LogicalType::combine_numeric_types(left, right); } Err(TypeError::InternalError(format!( "can not compare two types: {:?} and {:?}", @@ -61,12 +79,49 @@ impl LogicalType { ))) } - pub fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool { + fn combine_numeric_types( + left: &LogicalType, + right: &LogicalType, + ) -> Result { + if left == right { + return Ok(left.clone()); + } + if left.is_signed_numeric() && right.is_unsigned_numeric() { + // this method is symmetric + // arrange it so the left type is smaller + // to limit the number of options we need to check + return LogicalType::combine_numeric_types(right, left); + } + + if LogicalType::can_implicit_cast(left, right) { + return Ok(right.clone()); + } + if LogicalType::can_implicit_cast(right, left) { + return Ok(left.clone()); + } + // we can't cast implicitly either way and types are not equal + // this happens when left is signed and right is unsigned + // e.g. INTEGER and UINTEGER + // in this case we need to upcast to make sure the types fit + match (left, right) { + (LogicalType::Bigint, _) | (_, LogicalType::UBigint) => Ok(LogicalType::Double), + (LogicalType::Integer, _) | (_, LogicalType::UInteger) => Ok(LogicalType::Bigint), + (LogicalType::Smallint, _) | (_, LogicalType::USmallint) => Ok(LogicalType::Integer), + (LogicalType::Tinyint, _) | (_, LogicalType::UTinyint) => Ok(LogicalType::Smallint), + _ => Err(TypeError::InternalError(format!( + "can not combine these numeric types {:?} and {:?}", + left, right + ))), + } + } + + fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool { if from == to { return true; } match from { LogicalType::Invalid => false, + LogicalType::SqlNull => true, LogicalType::Boolean => false, LogicalType::Tinyint => matches!( to, @@ -160,6 +215,7 @@ impl From for arrow::datatypes::DataType { use arrow::datatypes::DataType; match value { LogicalType::Invalid => panic!("invalid logical type"), + LogicalType::SqlNull => DataType::Null, LogicalType::Boolean => DataType::Boolean, LogicalType::Tinyint => DataType::Int8, LogicalType::UTinyint => DataType::UInt8, diff --git a/src/types_v2/values.rs b/src/types_v2/values.rs index ab12d89..d98e6fd 100644 --- a/src/types_v2/values.rs +++ b/src/types_v2/values.rs @@ -228,7 +228,7 @@ impl ScalarValue { pub fn get_logical_type(&self) -> LogicalType { match self { - ScalarValue::Null => LogicalType::Invalid, + ScalarValue::Null => LogicalType::SqlNull, ScalarValue::Boolean(_) => LogicalType::Boolean, ScalarValue::Float32(_) => LogicalType::Float, ScalarValue::Float64(_) => LogicalType::Double, diff --git a/tests/slt/create_table.slt b/tests/slt/create_table.slt index 036f35f..67aa439 100644 --- a/tests/slt/create_table.slt +++ b/tests/slt/create_table.slt @@ -1,43 +1,21 @@ onlyif sqlrs_v2 statement ok -create table t1(a varchar, b varchar, c varchar); +create table t1(v1 varchar, v2 varchar, v3 varchar); +insert into t1 values('a', 'b', 'c'); -onlyif sqlrs_v2 -statement ok -insert into t1(c, b) values ('0','4'),('1','5'); - -onlyif sqlrs_v2 -statement ok -insert into t1 values ('2','7','9'); onlyif sqlrs_v2 -query III -select a, c, b from t1; ----- -NULL 0 4 -NULL 1 5 -2 9 7 +statement error +create table t1(v1 int); -onlyif sqlrs_v2 -statement ok -create table t2(a int, b int, c int); onlyif sqlrs_v2 statement ok -insert into t2(c, b, a) values (0, 4, 1), (1, 5, 2); +create table t2(v1 boolean, v2 tinyint, v3 smallint, v4 int, v5 bigint, v6 float, v7 double, v8 varchar); +insert into t2 values(true, 1, 2, 3, 4, 5.1, 6.2, '7'); -onlyif sqlrs_v2 -query III -select c, b, a from t2; ----- -0 4 1 -1 5 2 -# Test insert type cast onlyif sqlrs_v2 statement ok -create table t3(a TINYINT UNSIGNED); - -onlyif sqlrs_v2 -statement error -insert into t3(a) values (1481); +create table t3(v1 boolean, v2 tinyint unsigned, v3 smallint unsigned, v4 int unsigned, v5 bigint unsigned, v6 float, v7 double, v8 varchar); +insert into t3 values(true, 1, 2, 3, 4, 5.1, 6.2, '7'); diff --git a/tests/slt/insert_table.slt b/tests/slt/insert_table.slt new file mode 100644 index 0000000..569e49d --- /dev/null +++ b/tests/slt/insert_table.slt @@ -0,0 +1,69 @@ +# Test common insert case +onlyif sqlrs_v2 +statement ok +create table t1(v1 varchar, v2 varchar, v3 varchar); + +onlyif sqlrs_v2 +statement ok +insert into t1(v3, v2) values ('0','4'), ('1','5'); + +onlyif sqlrs_v2 +statement ok +insert into t1 values ('2','7','9'); + +onlyif sqlrs_v2 +query III +select v1, v3, v2 from t1; +---- +NULL 0 4 +NULL 1 5 +2 9 7 + + +# Test insert value cast type +onlyif sqlrs_v2 +statement ok +create table t2(v1 int, v2 int, v3 int); + +onlyif sqlrs_v2 +statement ok +insert into t2(v3, v2, v1) values (0, 4, 1), (1, 5, 2); + +onlyif sqlrs_v2 +query III +select v3, v2, v1 from t2; +---- +0 4 1 +1 5 2 + + +# Test insert type cast +onlyif sqlrs_v2 +statement ok +create table t3(v1 TINYINT UNSIGNED); + +onlyif sqlrs_v2 +statement error +insert into t3(v1) values (1481); + + +# Test insert null values +onlyif sqlrs_v2 +statement ok +create table t4(v1 varchar, v2 smallint unsigned, v3 bigint unsigned); + +onlyif sqlrs_v2 +statement ok +insert into t4 values (NULL, 1, 2), ('', 3, NULL); + +onlyif sqlrs_v2 +statement ok +insert into t4 values (NULL, NULL, NULL); + +onlyif sqlrs_v2 +query III +select v1, v2, v3 from t4; +---- +NULL 1 2 +(empty) 3 NULL +NULL NULL NULL