diff --git a/README.md b/README.md index 55859e7..4f5bb44 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,8 @@ select * from sqlrs_columns(); select * from read_csv('t1.csv'); select * from read_csv('t1.csv', header=>true, delim=>','); select * from 't1.csv'; +-- copy +copy t1 from 't1.csv' ( DELIMITER '|', HEADER false); ``` diff --git a/src/function/table/read_csv.rs b/src/function/table/read_csv.rs index 2b7f5ed..153e98f 100644 --- a/src/function/table/read_csv.rs +++ b/src/function/table/read_csv.rs @@ -137,7 +137,7 @@ impl ReadCSV { let mut col_names = vec![]; let mut col_types = vec![]; for field in schema.fields() { - col_names.push(field.name().to_string()); + col_names.push(field.name().to_string().to_lowercase()); col_types.push(field.data_type().try_into()?); } Ok((col_names, col_types)) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index b8cafd4..df7b427 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -20,9 +20,20 @@ impl Sqlparser { Ok(stmts) } - pub fn parse_one_query(sql: String) -> Result, ParserError> { + pub fn parse_one_stmt(sql: &str) -> Result { let dialect = PostgreSqlDialect {}; - let stmts = Parser::parse_sql(&dialect, sql.as_str())?; + let stmts = Parser::parse_sql(&dialect, sql)?; + if stmts.len() != 1 { + return Err(ParserError::ParserError( + "not a single statement".to_string(), + )); + } + Ok(stmts[0].clone()) + } + + pub fn parse_one_query(sql: &str) -> Result, ParserError> { + let dialect = PostgreSqlDialect {}; + let stmts = Parser::parse_sql(&dialect, sql)?; if stmts.len() != 1 { return Err(ParserError::ParserError( "not a single statement".to_string(), diff --git a/src/planner_v2/binder/errors.rs b/src/planner_v2/binder/errors.rs index db6903f..af1960b 100644 --- a/src/planner_v2/binder/errors.rs +++ b/src/planner_v2/binder/errors.rs @@ -1,3 +1,5 @@ +use sqlparser::parser::ParserError; + use crate::catalog_v2::CatalogError; use crate::execution::ExecutorError; use crate::function::FunctionError; @@ -38,4 +40,10 @@ pub enum BindError { #[source] ExecutorError, ), + #[error("parse error: {0}")] + ParserError( + #[from] + #[source] + ParserError, + ), } diff --git a/src/planner_v2/binder/sqlparser_util.rs b/src/planner_v2/binder/sqlparser_util.rs index 2ea6ad1..a9689dd 100644 --- a/src/planner_v2/binder/sqlparser_util.rs +++ b/src/planner_v2/binder/sqlparser_util.rs @@ -29,7 +29,7 @@ impl SqlparserResolver { pub fn column_def_to_column_definition( column_def: &ColumnDef, ) -> Result { - let name = column_def.name.value.clone(); + let name = column_def.name.value.clone().to_lowercase(); let ty = column_def.data_type.clone().try_into()?; Ok(ColumnDefinition::new(name, ty)) } diff --git a/src/planner_v2/binder/statement/bind_copy.rs b/src/planner_v2/binder/statement/bind_copy.rs new file mode 100644 index 0000000..4d8b2bf --- /dev/null +++ b/src/planner_v2/binder/statement/bind_copy.rs @@ -0,0 +1,95 @@ +use log::debug; +use sqlparser::ast::{CopyOption, CopyTarget, Ident, ObjectName, Statement}; + +use super::BoundStatement; +use crate::parser::Sqlparser; +use crate::planner_v2::{BindError, Binder, SqlparserResolver, LOGGING_TARGET}; + +impl Binder { + /// convert copy from csv into insert statement from csv_read table function + fn convert_copy_from_to_insert_sql( + table_name: &ObjectName, + columns: &[Ident], + target: &CopyTarget, + options: &[CopyOption], + ) -> Result { + let (schema_name, table_name) = SqlparserResolver::object_name_to_schema_table(table_name)?; + let col_names = columns + .iter() + .map(|c| c.to_string().to_lowercase()) + .collect::>(); + + let (insert_cols, read_cols) = if col_names.is_empty() { + // insert into main.t1 select * from read_csv('file.csv'); + ("".to_string(), "*".to_string()) + } else { + // insert into main.t1(v1) select * from read_csv('file.csv'); + (format!("({})", col_names.join(",")), col_names.join(",")) + }; + let insert_sql = format!("insert into {}.{}{}", schema_name, table_name, insert_cols,); + + let read_csv_sql = Self::build_read_csv_sql(target, options)?; + let csv_read_sql = format!("select {} from {}", read_cols, read_csv_sql); + + Ok(format!("{} {}", insert_sql, csv_read_sql)) + } + + fn build_read_csv_sql( + target: &CopyTarget, + options: &[CopyOption], + ) -> Result { + let filename = match target { + CopyTarget::File { filename } => filename, + _ => { + return Err(BindError::UnsupportedStmt(format!( + "unsupported copy target {:?}", + target + ))) + } + }; + let options_strs = options + .iter() + .filter_map(|o| match o { + CopyOption::Delimiter(v) => Some(format!("delim=>'{}'", v)), + CopyOption::Header(v) => Some(format!("header=>{}", v)), + _ => None, + }) + .collect::>(); + let options_str = if options_strs.is_empty() { + "".to_string() + } else { + format!(" ,{}", options_strs.join(", ")) + }; + Ok(format!("read_csv('{}'{})", filename, options_str)) + } + + pub fn bind_copy(&mut self, stmt: &Statement) -> Result { + match stmt { + Statement::Copy { + table_name, + columns, + to, + target, + options, + legacy_options: _, + values: _, + } => { + if *to { + return Err(BindError::UnsupportedStmt( + "unsupported copy to statement".to_string(), + )); + } + + let insert_from_sql = + Self::convert_copy_from_to_insert_sql(table_name, columns, target, options)?; + debug!( + target: LOGGING_TARGET, + "Copy converted raw sql: {:?}", insert_from_sql + ); + let stmt = Sqlparser::parse_one_stmt(&insert_from_sql)?; + self.bind(&stmt) + } + _ => Err(BindError::UnsupportedStmt(format!("{:?}", stmt))), + } + } +} diff --git a/src/planner_v2/binder/statement/bind_insert.rs b/src/planner_v2/binder/statement/bind_insert.rs index 466b58a..2246dad 100644 --- a/src/planner_v2/binder/statement/bind_insert.rs +++ b/src/planner_v2/binder/statement/bind_insert.rs @@ -5,12 +5,25 @@ use sqlparser::ast::Statement; use super::BoundStatement; use crate::catalog_v2::Catalog; use crate::planner_v2::{ - BindError, Binder, BoundTableRef, LogicalInsert, LogicalOperator, LogicalOperatorBase, - SqlparserResolver, INVALID_INDEX, + BindError, Binder, LogicalInsert, LogicalOperator, LogicalOperatorBase, SqlparserResolver, + INVALID_INDEX, }; use crate::types_v2::LogicalType; impl Binder { + fn check_insert_column_count_mismatch( + expected_columns_cnt: usize, + insert_columns_cnt: usize, + ) -> Result<(), BindError> { + if expected_columns_cnt != insert_columns_cnt { + return Err(BindError::Internal(format!( + "insert column count mismatch, expected: {}, actual: {}", + expected_columns_cnt, insert_columns_cnt + ))); + } + Ok(()) + } + pub fn bind_insert(&mut self, stmt: &Statement) -> Result { match stmt { Statement::Insert { @@ -71,21 +84,13 @@ impl Binder { let select_node = self.bind_select_node(source)?; let expected_columns_cnt = named_column_indices.len(); - // special case: check if we are inserting from a VALUES statement - if let BoundTableRef::BoundExpressionListRef(table_ref) = &select_node.from_table { - // CheckInsertColumnCountMismatch - let insert_columns_cnt = table_ref.values.first().unwrap().len(); - if expected_columns_cnt != insert_columns_cnt { - return Err(BindError::Internal(format!( - "insert column count mismatch, expected: {}, actual: {}", - expected_columns_cnt, insert_columns_cnt - ))); - } - } - let select_node = self.create_plan_for_select_node(select_node)?; let inserted_types = select_node.types; let mut plan = select_node.plan; + Self::check_insert_column_count_mismatch( + expected_columns_cnt, + inserted_types.len(), + )?; // cast inserted types to expected types when necessary self.cast_logical_operator_to_types(&inserted_types, &expected_types, &mut plan)?; diff --git a/src/planner_v2/binder/statement/mod.rs b/src/planner_v2/binder/statement/mod.rs index bcb96b6..23c8c08 100644 --- a/src/planner_v2/binder/statement/mod.rs +++ b/src/planner_v2/binder/statement/mod.rs @@ -1,3 +1,4 @@ +mod bind_copy; mod bind_create; mod bind_explain; mod bind_explain_table; @@ -29,6 +30,7 @@ impl Binder { Statement::Explain { .. } => self.bind_explain(statement), Statement::ShowTables { .. } => self.bind_show_tables(statement), Statement::ExplainTable { .. } => self.bind_explain_table(statement), + Statement::Copy { .. } => self.bind_copy(statement), _ => Err(BindError::UnsupportedStmt(format!("{:?}", statement))), } } diff --git a/tests/slt/csv/csv.slt b/tests/slt/csv/csv.slt new file mode 100644 index 0000000..b8ecfdc --- /dev/null +++ b/tests/slt/csv/csv.slt @@ -0,0 +1,20 @@ +onlyif sqlrs_v2 +statement ok +create table state(id varchar, state_code varchar, state_name varchar); + +onlyif sqlrs_v2 +statement ok +copy state from 'tests/slt/csv/state1.csv' ( DELIMITER '|' ); + +onlyif sqlrs_v2 +statement ok +copy state from 'tests/slt/csv/state2.csv' ( DELIMITER '|', HEADER false); + +onlyif sqlrs_v2 +query I +SELECT id FROM state +---- +1 +2 +3 +4 diff --git a/tests/slt/csv/state1.csv b/tests/slt/csv/state1.csv new file mode 100644 index 0000000..11abb69 --- /dev/null +++ b/tests/slt/csv/state1.csv @@ -0,0 +1,5 @@ +id|state_code|state_name +1|CA|California State +2|CO|Colorado State +3|NJ|New Jersey + diff --git a/tests/slt/csv/state2.csv b/tests/slt/csv/state2.csv new file mode 100644 index 0000000..d032779 --- /dev/null +++ b/tests/slt/csv/state2.csv @@ -0,0 +1,2 @@ +4|CA|California State + diff --git a/tests/slt/insert_table.slt b/tests/slt/insert_table.slt index 569e49d..06c3ab9 100644 --- a/tests/slt/insert_table.slt +++ b/tests/slt/insert_table.slt @@ -3,6 +3,10 @@ onlyif sqlrs_v2 statement ok create table t1(v1 varchar, v2 varchar, v3 varchar); +onlyif sqlrs_v2 +statement error +insert into t1(v3) values ('0','4'); + onlyif sqlrs_v2 statement ok insert into t1(v3, v2) values ('0','4'), ('1','5'); @@ -67,3 +71,21 @@ select v1, v2, v3 from t4; NULL 1 2 (empty) 3 NULL NULL NULL NULL + + +# Test insert from select +onlyif sqlrs_v2 +statement ok +CREATE TABLE integers(i INTEGER); + +onlyif sqlrs_v2 +statement ok +INSERT INTO integers SELECT 42; +INSERT INTO integers SELECT null; + +onlyif sqlrs_v2 +query I +SELECT * FROM integers +---- +42 +NULL diff --git a/tests/slt/table_function.slt b/tests/slt/table_function.slt index d023a43..dce2117 100644 --- a/tests/slt/table_function.slt +++ b/tests/slt/table_function.slt @@ -19,7 +19,7 @@ t1 [v1, v2, v3] [Integer, Integer, Integer] onlyif sqlrs_v2 query III -select column_1 from read_csv('../csv/t1.csv', header=>false); +select column_1 from read_csv('tests/csv/t1.csv', header=>false); ---- a 0 @@ -29,13 +29,13 @@ a onlyif sqlrs_v2 query III -select a from read_csv('../csv/t1.csv', header=>true, delim=>',') where a = 1; +select a from read_csv('tests/csv/t1.csv', header=>true, delim=>',') where a = 1; ---- 1 onlyif sqlrs_v2 query III -select t1.a from '../csv/t1.csv'; +select t1.a from 'tests/csv/t1.csv'; ---- 0 1 @@ -44,7 +44,7 @@ select t1.a from '../csv/t1.csv'; onlyif sqlrs_v2 query III -select tt.a from '../csv/t1.csv' tt; +select tt.a from 'tests/csv/t1.csv' tt; ---- 0 1 diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index 2b66725..42cc004 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -8,7 +8,7 @@ use sqlrs::main_entry::{ClientContext, DatabaseError as DatabaseErrorV2, Databas use sqlrs::util::record_batch_to_string; fn init_tables(db: Arc) { - const CSV_FILES: &str = "../csv/**/*.csv"; + const CSV_FILES: &str = "tests/csv/**/*.csv"; let csv_files = glob::glob(CSV_FILES).expect("failed to find csv files"); for csv_file in csv_files { diff --git a/tests/sqllogictest/tests/sqllogictest.rs b/tests/sqllogictest/tests/sqllogictest.rs index 7295e7d..778c0bf 100644 --- a/tests/sqllogictest/tests/sqllogictest.rs +++ b/tests/sqllogictest/tests/sqllogictest.rs @@ -1,8 +1,13 @@ +use std::path::Path; + use libtest_mimic::{Arguments, Trial}; use sqllogictest_test::{test_run, test_run_v2}; fn main() { - const SLT_PATTERN: &str = "../slt/**/*.slt"; + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("..").join(".."); + std::env::set_current_dir(path).unwrap(); + + const SLT_PATTERN: &str = "tests/slt/**/*.slt"; let args = Arguments::from_args(); let mut tests = vec![];