Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ select * from sqlrs_columns();
select * from read_csv('t1.csv');
select * from read_csv('t1.csv', header=>true, delim=>',');
select * from 't1.csv';
-- copy
copy t1 from 't1.csv' ( DELIMITER '|', HEADER false);
```


Expand Down
2 changes: 1 addition & 1 deletion src/function/table/read_csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl ReadCSV {
let mut col_names = vec![];
let mut col_types = vec![];
for field in schema.fields() {
col_names.push(field.name().to_string());
col_names.push(field.name().to_string().to_lowercase());
col_types.push(field.data_type().try_into()?);
}
Ok((col_names, col_types))
Expand Down
15 changes: 13 additions & 2 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,20 @@ impl Sqlparser {
Ok(stmts)
}

pub fn parse_one_query(sql: String) -> Result<Box<Query>, ParserError> {
pub fn parse_one_stmt(sql: &str) -> Result<Statement, ParserError> {
let dialect = PostgreSqlDialect {};
let stmts = Parser::parse_sql(&dialect, sql.as_str())?;
let stmts = Parser::parse_sql(&dialect, sql)?;
if stmts.len() != 1 {
return Err(ParserError::ParserError(
"not a single statement".to_string(),
));
}
Ok(stmts[0].clone())
}

pub fn parse_one_query(sql: &str) -> Result<Box<Query>, ParserError> {
let dialect = PostgreSqlDialect {};
let stmts = Parser::parse_sql(&dialect, sql)?;
if stmts.len() != 1 {
return Err(ParserError::ParserError(
"not a single statement".to_string(),
Expand Down
8 changes: 8 additions & 0 deletions src/planner_v2/binder/errors.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use sqlparser::parser::ParserError;

use crate::catalog_v2::CatalogError;
use crate::execution::ExecutorError;
use crate::function::FunctionError;
Expand Down Expand Up @@ -38,4 +40,10 @@ pub enum BindError {
#[source]
ExecutorError,
),
#[error("parse error: {0}")]
ParserError(
#[from]
#[source]
ParserError,
),
}
2 changes: 1 addition & 1 deletion src/planner_v2/binder/sqlparser_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl SqlparserResolver {
pub fn column_def_to_column_definition(
column_def: &ColumnDef,
) -> Result<ColumnDefinition, BindError> {
let name = column_def.name.value.clone();
let name = column_def.name.value.clone().to_lowercase();
let ty = column_def.data_type.clone().try_into()?;
Ok(ColumnDefinition::new(name, ty))
}
Expand Down
95 changes: 95 additions & 0 deletions src/planner_v2/binder/statement/bind_copy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use log::debug;
use sqlparser::ast::{CopyOption, CopyTarget, Ident, ObjectName, Statement};

use super::BoundStatement;
use crate::parser::Sqlparser;
use crate::planner_v2::{BindError, Binder, SqlparserResolver, LOGGING_TARGET};

impl Binder {
/// convert copy from csv into insert statement from csv_read table function
fn convert_copy_from_to_insert_sql(
table_name: &ObjectName,
columns: &[Ident],
target: &CopyTarget,
options: &[CopyOption],
) -> Result<String, BindError> {
let (schema_name, table_name) = SqlparserResolver::object_name_to_schema_table(table_name)?;
let col_names = columns
.iter()
.map(|c| c.to_string().to_lowercase())
.collect::<Vec<_>>();

let (insert_cols, read_cols) = if col_names.is_empty() {
// insert into main.t1 select * from read_csv('file.csv');
("".to_string(), "*".to_string())
} else {
// insert into main.t1(v1) select * from read_csv('file.csv');
(format!("({})", col_names.join(",")), col_names.join(","))
};
let insert_sql = format!("insert into {}.{}{}", schema_name, table_name, insert_cols,);

let read_csv_sql = Self::build_read_csv_sql(target, options)?;
let csv_read_sql = format!("select {} from {}", read_cols, read_csv_sql);

Ok(format!("{} {}", insert_sql, csv_read_sql))
}

fn build_read_csv_sql(
target: &CopyTarget,
options: &[CopyOption],
) -> Result<String, BindError> {
let filename = match target {
CopyTarget::File { filename } => filename,
_ => {
return Err(BindError::UnsupportedStmt(format!(
"unsupported copy target {:?}",
target
)))
}
};
let options_strs = options
.iter()
.filter_map(|o| match o {
CopyOption::Delimiter(v) => Some(format!("delim=>'{}'", v)),
CopyOption::Header(v) => Some(format!("header=>{}", v)),
_ => None,
})
.collect::<Vec<_>>();
let options_str = if options_strs.is_empty() {
"".to_string()
} else {
format!(" ,{}", options_strs.join(", "))
};
Ok(format!("read_csv('{}'{})", filename, options_str))
}

pub fn bind_copy(&mut self, stmt: &Statement) -> Result<BoundStatement, BindError> {
match stmt {
Statement::Copy {
table_name,
columns,
to,
target,
options,
legacy_options: _,
values: _,
} => {
if *to {
return Err(BindError::UnsupportedStmt(
"unsupported copy to statement".to_string(),
));
}

let insert_from_sql =
Self::convert_copy_from_to_insert_sql(table_name, columns, target, options)?;
debug!(
target: LOGGING_TARGET,
"Copy converted raw sql: {:?}", insert_from_sql
);
let stmt = Sqlparser::parse_one_stmt(&insert_from_sql)?;
self.bind(&stmt)
}
_ => Err(BindError::UnsupportedStmt(format!("{:?}", stmt))),
}
}
}
33 changes: 19 additions & 14 deletions src/planner_v2/binder/statement/bind_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,25 @@ use sqlparser::ast::Statement;
use super::BoundStatement;
use crate::catalog_v2::Catalog;
use crate::planner_v2::{
BindError, Binder, BoundTableRef, LogicalInsert, LogicalOperator, LogicalOperatorBase,
SqlparserResolver, INVALID_INDEX,
BindError, Binder, LogicalInsert, LogicalOperator, LogicalOperatorBase, SqlparserResolver,
INVALID_INDEX,
};
use crate::types_v2::LogicalType;

impl Binder {
fn check_insert_column_count_mismatch(
expected_columns_cnt: usize,
insert_columns_cnt: usize,
) -> Result<(), BindError> {
if expected_columns_cnt != insert_columns_cnt {
return Err(BindError::Internal(format!(
"insert column count mismatch, expected: {}, actual: {}",
expected_columns_cnt, insert_columns_cnt
)));
}
Ok(())
}

pub fn bind_insert(&mut self, stmt: &Statement) -> Result<BoundStatement, BindError> {
match stmt {
Statement::Insert {
Expand Down Expand Up @@ -71,21 +84,13 @@ impl Binder {
let select_node = self.bind_select_node(source)?;
let expected_columns_cnt = named_column_indices.len();

// special case: check if we are inserting from a VALUES statement
if let BoundTableRef::BoundExpressionListRef(table_ref) = &select_node.from_table {
// CheckInsertColumnCountMismatch
let insert_columns_cnt = table_ref.values.first().unwrap().len();
if expected_columns_cnt != insert_columns_cnt {
return Err(BindError::Internal(format!(
"insert column count mismatch, expected: {}, actual: {}",
expected_columns_cnt, insert_columns_cnt
)));
}
}

let select_node = self.create_plan_for_select_node(select_node)?;
let inserted_types = select_node.types;
let mut plan = select_node.plan;
Self::check_insert_column_count_mismatch(
expected_columns_cnt,
inserted_types.len(),
)?;
// cast inserted types to expected types when necessary
self.cast_logical_operator_to_types(&inserted_types, &expected_types, &mut plan)?;

Expand Down
2 changes: 2 additions & 0 deletions src/planner_v2/binder/statement/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod bind_copy;
mod bind_create;
mod bind_explain;
mod bind_explain_table;
Expand Down Expand Up @@ -29,6 +30,7 @@ impl Binder {
Statement::Explain { .. } => self.bind_explain(statement),
Statement::ShowTables { .. } => self.bind_show_tables(statement),
Statement::ExplainTable { .. } => self.bind_explain_table(statement),
Statement::Copy { .. } => self.bind_copy(statement),
_ => Err(BindError::UnsupportedStmt(format!("{:?}", statement))),
}
}
Expand Down
20 changes: 20 additions & 0 deletions tests/slt/csv/csv.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
onlyif sqlrs_v2
statement ok
create table state(id varchar, state_code varchar, state_name varchar);

onlyif sqlrs_v2
statement ok
copy state from 'tests/slt/csv/state1.csv' ( DELIMITER '|' );

onlyif sqlrs_v2
statement ok
copy state from 'tests/slt/csv/state2.csv' ( DELIMITER '|', HEADER false);

onlyif sqlrs_v2
query I
SELECT id FROM state
----
1
2
3
4
5 changes: 5 additions & 0 deletions tests/slt/csv/state1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
id|state_code|state_name
1|CA|California State
2|CO|Colorado State
3|NJ|New Jersey

2 changes: 2 additions & 0 deletions tests/slt/csv/state2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
4|CA|California State

22 changes: 22 additions & 0 deletions tests/slt/insert_table.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ onlyif sqlrs_v2
statement ok
create table t1(v1 varchar, v2 varchar, v3 varchar);

onlyif sqlrs_v2
statement error
insert into t1(v3) values ('0','4');

onlyif sqlrs_v2
statement ok
insert into t1(v3, v2) values ('0','4'), ('1','5');
Expand Down Expand Up @@ -67,3 +71,21 @@ select v1, v2, v3 from t4;
NULL 1 2
(empty) 3 NULL
NULL NULL NULL


# Test insert from select
onlyif sqlrs_v2
statement ok
CREATE TABLE integers(i INTEGER);

onlyif sqlrs_v2
statement ok
INSERT INTO integers SELECT 42;
INSERT INTO integers SELECT null;

onlyif sqlrs_v2
query I
SELECT * FROM integers
----
42
NULL
8 changes: 4 additions & 4 deletions tests/slt/table_function.slt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ t1 [v1, v2, v3] [Integer, Integer, Integer]

onlyif sqlrs_v2
query III
select column_1 from read_csv('../csv/t1.csv', header=>false);
select column_1 from read_csv('tests/csv/t1.csv', header=>false);
----
a
0
Expand All @@ -29,13 +29,13 @@ a

onlyif sqlrs_v2
query III
select a from read_csv('../csv/t1.csv', header=>true, delim=>',') where a = 1;
select a from read_csv('tests/csv/t1.csv', header=>true, delim=>',') where a = 1;
----
1

onlyif sqlrs_v2
query III
select t1.a from '../csv/t1.csv';
select t1.a from 'tests/csv/t1.csv';
----
0
1
Expand All @@ -44,7 +44,7 @@ select t1.a from '../csv/t1.csv';

onlyif sqlrs_v2
query III
select tt.a from '../csv/t1.csv' tt;
select tt.a from 'tests/csv/t1.csv' tt;
----
0
1
Expand Down
2 changes: 1 addition & 1 deletion tests/sqllogictest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use sqlrs::main_entry::{ClientContext, DatabaseError as DatabaseErrorV2, Databas
use sqlrs::util::record_batch_to_string;

fn init_tables(db: Arc<Database>) {
const CSV_FILES: &str = "../csv/**/*.csv";
const CSV_FILES: &str = "tests/csv/**/*.csv";

let csv_files = glob::glob(CSV_FILES).expect("failed to find csv files");
for csv_file in csv_files {
Expand Down
7 changes: 6 additions & 1 deletion tests/sqllogictest/tests/sqllogictest.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::path::Path;

use libtest_mimic::{Arguments, Trial};
use sqllogictest_test::{test_run, test_run_v2};

fn main() {
const SLT_PATTERN: &str = "../slt/**/*.slt";
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("..").join("..");
std::env::set_current_dir(path).unwrap();

const SLT_PATTERN: &str = "tests/slt/**/*.slt";

let args = Arguments::from_args();
let mut tests = vec![];
Expand Down