From 5967f4c2005c81dc184735add55936bdfad668ab Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 30 Dec 2022 14:01:42 +0800 Subject: [PATCH 1/4] feat(planner): refactor table function to return async stream Signed-off-by: Fedomn --- Cargo.lock | 101 +++++++++++++++++++ Cargo.toml | 2 + src/execution/physical_plan_generator.rs | 4 +- src/execution/volcano_executor/table_scan.rs | 13 +-- src/function/errors.rs | 14 +++ src/function/table/seq_table_scan.rs | 38 +++---- src/function/table/sqlrs_columns.rs | 28 ++--- src/function/table/sqlrs_tables.rs | 28 ++--- src/function/table/table_function.rs | 11 +- src/planner_v2/binder/sqlparser_util.rs | 57 ++++++++++- src/types_v2/types.rs | 46 +++++++++ src/util/tree_render.rs | 1 + 12 files changed, 274 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 18c6913..cbc9ea5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,6 +189,27 @@ dependencies = [ "num", ] +[[package]] +name = "async-stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.56" @@ -503,6 +524,41 @@ dependencies = [ "syn", ] +[[package]] +name = "darling" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0dd3cd20dc6b5a876612a6e5accfe7f3dd883db6d07acfbf14c128f61550dfa" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a784d2ccaf7c98501746bf0be29b2022ba41fd62a2e622af997a03e9f972859f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7618812407e9402654622dd402b0a89dff9ba93badd6540781526117b92aab7e" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "derive-new" version = "0.5.9" @@ -514,6 +570,37 @@ dependencies = [ "syn", ] +[[package]] +name = "derive_builder" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder_macro" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +dependencies = [ + "derive_builder_core", + "syn", +] + [[package]] name = "diff" version = "0.1.13" @@ -683,6 +770,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "futures" version = "0.3.21" @@ -887,6 +980,12 @@ dependencies = [ "cxx-build", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "indexmap" version = "1.9.1" @@ -1723,7 +1822,9 @@ dependencies = [ "ahash", "anyhow", "arrow", + "async-stream", "derive-new", + "derive_builder", "dirs", "downcast-rs", "enum_dispatch", diff --git a/Cargo.toml b/Cargo.toml index 4c59cc0..dc63ddf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,8 @@ ordered-float = "3.0" derive-new = "0.5.9" log = "0.4" env_logger = "0.10" +derive_builder = "0.12.0" +async-stream = "0.3" [dev-dependencies] test-case = "2" diff --git a/src/execution/physical_plan_generator.rs b/src/execution/physical_plan_generator.rs index 650d957..23f67b3 100644 --- a/src/execution/physical_plan_generator.rs +++ b/src/execution/physical_plan_generator.rs @@ -53,8 +53,8 @@ impl PhysicalPlanGenerator { ) -> PhysicalOperatorBase { let children = base .children - .iter() - .map(|op| self.create_plan_internal(op.clone())) + .into_iter() + .map(|op| self.create_plan_internal(op)) .collect::>(); PhysicalOperatorBase::new(children, base.expressioins) } diff --git a/src/execution/volcano_executor/table_scan.rs b/src/execution/volcano_executor/table_scan.rs index 1387605..a872f8e 100644 --- a/src/execution/volcano_executor/table_scan.rs +++ b/src/execution/volcano_executor/table_scan.rs @@ -21,14 +21,15 @@ impl TableScan { let function = self.plan.function; let table_scan_func = function.function; - let mut tabel_scan_input = TableFunctionInput::new(bind_data); + let tabel_scan_input = TableFunctionInput::new(bind_data); - while let Some(batch) = - table_scan_func(context.clone_client_context(), &mut tabel_scan_input)? - { + let scan_stream = table_scan_func(context.clone_client_context(), &tabel_scan_input)?; + + #[for_await] + for batch in scan_stream { + let batch = batch?; let columns = batch.columns().to_vec(); - let try_new = RecordBatch::try_new(schema.clone(), columns)?; - yield try_new + yield RecordBatch::try_new(schema.clone(), columns)? } } } diff --git a/src/function/errors.rs b/src/function/errors.rs index 6b2abf7..3b2be31 100644 --- a/src/function/errors.rs +++ b/src/function/errors.rs @@ -1,8 +1,14 @@ +use std::io; + use arrow::error::ArrowError; use crate::catalog_v2::CatalogError; +use crate::planner_v2::BindError; use crate::types_v2::TypeError; +pub type FunctionResult = Result; + +// TODO: refactor error using https://docs.rs/snafu/latest/snafu/ #[derive(thiserror::Error, Debug)] pub enum FunctionError { #[error("catalog error: {0}")] @@ -31,4 +37,12 @@ pub enum FunctionError { ComparisonError(String), #[error("Conjunction error: {0}")] ConjunctionError(String), + #[error("io error")] + IoError(#[from] io::Error), +} + +impl From for FunctionError { + fn from(e: BindError) -> Self { + FunctionError::InternalError(e.to_string()) + } } diff --git a/src/function/table/seq_table_scan.rs b/src/function/table/seq_table_scan.rs index b04ad7f..7537a71 100644 --- a/src/function/table/seq_table_scan.rs +++ b/src/function/table/seq_table_scan.rs @@ -2,12 +2,13 @@ use std::sync::Arc; use arrow::record_batch::RecordBatch; use derive_new::new; +use futures::stream::BoxStream; use super::{TableFunction, TableFunctionBindInput, TableFunctionInput}; use crate::catalog_v2::TableCatalogEntry; -use crate::function::{FunctionData, FunctionError}; +use crate::function::{FunctionData, FunctionError, FunctionResult}; use crate::main_entry::ClientContext; -use crate::storage_v2::{LocalStorage, LocalStorageReader}; +use crate::storage_v2::LocalStorage; use crate::types_v2::LogicalType; /// The table scan function represents a sequential scan over one of base tables. @@ -16,7 +17,6 @@ pub struct SeqTableScan; #[derive(new, Debug, Clone)] pub struct SeqTableScanInputData { pub(crate) bind_table: TableCatalogEntry, - pub(crate) local_storage_reader: LocalStorageReader, } impl SeqTableScan { @@ -26,12 +26,10 @@ impl SeqTableScan { input: TableFunctionBindInput, _return_types: &mut Vec, _return_names: &mut Vec, - ) -> Result, FunctionError> { + ) -> FunctionResult> { if let Some(table) = input.bind_table { - let res = FunctionData::SeqTableScanInputData(Box::new(SeqTableScanInputData::new( - table.clone(), - LocalStorage::create_reader(&table.storage), - ))); + let res = + FunctionData::SeqTableScanInputData(Box::new(SeqTableScanInputData::new(table))); Ok(Some(res)) } else { Err(FunctionError::InternalError( @@ -42,18 +40,20 @@ impl SeqTableScan { fn scan_func( context: Arc, - input: &mut TableFunctionInput, - ) -> Result, FunctionError> { - if let Some(bind_data) = &mut input.bind_data { - if let FunctionData::SeqTableScanInputData(data) = bind_data { - Ok(data.local_storage_reader.next_batch(context)) - } else { - Err(FunctionError::InternalError( - "unexpected bind data type".to_string(), - )) - } + input: &TableFunctionInput, + ) -> FunctionResult>> { + if let Some(FunctionData::SeqTableScanInputData(data)) = &input.bind_data { + let mut reader = LocalStorage::create_reader(&data.bind_table.storage); + let stream = Box::pin(async_stream::try_stream! { + while let Some(batch) = reader.next_batch(context.clone()){ + yield batch; + } + }); + Ok(stream) } else { - Ok(None) + Err(FunctionError::InternalError( + "unexpected bind data type".to_string(), + )) } } diff --git a/src/function/table/sqlrs_columns.rs b/src/function/table/sqlrs_columns.rs index e9a58bc..bbdf0d7 100644 --- a/src/function/table/sqlrs_columns.rs +++ b/src/function/table/sqlrs_columns.rs @@ -2,12 +2,13 @@ use std::sync::Arc; use arrow::record_batch::RecordBatch; use derive_new::new; +use futures::stream::BoxStream; use itertools::Itertools; use super::{TableFunction, TableFunctionBindInput, TableFunctionInput}; use crate::catalog_v2::{Catalog, CatalogEntry, DEFAULT_SCHEMA}; use crate::execution::SchemaUtil; -use crate::function::{BuiltinFunctions, FunctionData, FunctionError}; +use crate::function::{BuiltinFunctions, FunctionData, FunctionError, FunctionResult}; use crate::main_entry::ClientContext; use crate::types_v2::{LogicalType, ScalarValue}; @@ -18,7 +19,6 @@ pub struct SqlrsColumnsData { pub(crate) entries: Vec, pub(crate) return_types: Vec, pub(crate) return_names: Vec, - pub(crate) current_cursor: usize, } impl SqlrsColumnsFunc { @@ -43,7 +43,7 @@ impl SqlrsColumnsFunc { _input: TableFunctionBindInput, return_types: &mut Vec, return_names: &mut Vec, - ) -> Result, FunctionError> { + ) -> FunctionResult> { let entries = Catalog::scan_entries(context, DEFAULT_SCHEMA.to_string(), &|entry| { matches!(entry, CatalogEntry::TableCatalogEntry(_)) })?; @@ -51,7 +51,6 @@ impl SqlrsColumnsFunc { entries, Self::generate_sqlrs_tables_types(), Self::generate_sqlrs_tables_names(), - 0, ); return_types.extend(data.return_types.clone()); return_names.extend(data.return_names.clone()); @@ -60,18 +59,9 @@ impl SqlrsColumnsFunc { fn tables_func( _context: Arc, - input: &mut TableFunctionInput, - ) -> Result, FunctionError> { - if input.bind_data.is_none() { - return Ok(None); - } - - let bind_data = input.bind_data.as_mut().unwrap(); - if let FunctionData::SqlrsColumnsData(data) = bind_data { - if data.current_cursor >= data.entries.len() { - return Ok(None); - } - + input: &TableFunctionInput, + ) -> FunctionResult>> { + if let Some(FunctionData::SqlrsColumnsData(data)) = &input.bind_data { let schema = SchemaUtil::new_schema_ref(&data.return_names, &data.return_types); let mut table_name = ScalarValue::new_builder(&LogicalType::Varchar)?; let mut column_names = ScalarValue::new_builder(&LogicalType::Varchar)?; @@ -102,9 +92,11 @@ impl SqlrsColumnsFunc { column_names.finish(), column_types.finish(), ]; - data.current_cursor += data.entries.len(); let batch = RecordBatch::try_new(schema, cols)?; - Ok(Some(batch)) + let stream = Box::pin(async_stream::try_stream! { + yield batch; + }); + Ok(stream) } else { Err(FunctionError::InternalError( "unexpected global state type".to_string(), diff --git a/src/function/table/sqlrs_tables.rs b/src/function/table/sqlrs_tables.rs index 3fc506f..011fa6a 100644 --- a/src/function/table/sqlrs_tables.rs +++ b/src/function/table/sqlrs_tables.rs @@ -2,11 +2,12 @@ use std::sync::Arc; use arrow::record_batch::RecordBatch; use derive_new::new; +use futures::stream::BoxStream; use super::{TableFunction, TableFunctionBindInput, TableFunctionInput}; use crate::catalog_v2::{Catalog, CatalogEntry, DEFAULT_SCHEMA}; use crate::execution::SchemaUtil; -use crate::function::{BuiltinFunctions, FunctionData, FunctionError}; +use crate::function::{BuiltinFunctions, FunctionData, FunctionError, FunctionResult}; use crate::main_entry::ClientContext; use crate::types_v2::{LogicalType, ScalarValue}; @@ -17,7 +18,6 @@ pub struct SqlrsTablesData { pub(crate) entries: Vec, pub(crate) return_types: Vec, pub(crate) return_names: Vec, - pub(crate) current_cursor: usize, } impl SqlrsTablesFunc { @@ -44,7 +44,7 @@ impl SqlrsTablesFunc { _input: TableFunctionBindInput, return_types: &mut Vec, return_names: &mut Vec, - ) -> Result, FunctionError> { + ) -> FunctionResult> { let entries = Catalog::scan_entries(context, DEFAULT_SCHEMA.to_string(), &|entry| { matches!(entry, CatalogEntry::TableCatalogEntry(_)) })?; @@ -52,7 +52,6 @@ impl SqlrsTablesFunc { entries, Self::generate_sqlrs_tables_types(), Self::generate_sqlrs_tables_names(), - 0, ); return_types.extend(data.return_types.clone()); return_names.extend(data.return_names.clone()); @@ -61,18 +60,9 @@ impl SqlrsTablesFunc { fn tables_func( _context: Arc, - input: &mut TableFunctionInput, - ) -> Result, FunctionError> { - if input.bind_data.is_none() { - return Ok(None); - } - - let bind_data = input.bind_data.as_mut().unwrap(); - if let FunctionData::SqlrsTablesData(data) = bind_data { - if data.current_cursor >= data.entries.len() { - return Ok(None); - } - + input: &TableFunctionInput, + ) -> FunctionResult>> { + if let Some(FunctionData::SqlrsTablesData(data)) = &input.bind_data { let schema = SchemaUtil::new_schema_ref(&data.return_names, &data.return_types); let mut schema_names = ScalarValue::new_builder(&LogicalType::Varchar)?; let mut schema_oids = ScalarValue::new_builder(&LogicalType::Integer)?; @@ -104,9 +94,11 @@ impl SqlrsTablesFunc { table_names.finish(), table_oids.finish(), ]; - data.current_cursor += data.entries.len(); let batch = RecordBatch::try_new(schema, cols)?; - Ok(Some(batch)) + let stream = Box::pin(async_stream::try_stream! { + yield batch; + }); + Ok(stream) } else { Err(FunctionError::InternalError( "unexpected global state type".to_string(), diff --git a/src/function/table/table_function.rs b/src/function/table/table_function.rs index c85e6b5..2d81f09 100644 --- a/src/function/table/table_function.rs +++ b/src/function/table/table_function.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use arrow::record_batch::RecordBatch; use derive_new::new; +use futures::stream::BoxStream; use sqlparser::ast::FunctionArg; use crate::catalog_v2::TableCatalogEntry; -use crate::function::{FunctionData, FunctionError}; +use crate::function::{FunctionData, FunctionResult}; use crate::main_entry::ClientContext; use crate::types_v2::LogicalType; @@ -27,10 +28,12 @@ pub type TableFunctionBindFunc = fn( TableFunctionBindInput, &mut Vec, &mut Vec, -) -> Result, FunctionError>; +) -> FunctionResult>; -pub type TableFunc = - fn(Arc, &mut TableFunctionInput) -> Result, FunctionError>; +pub type TableFunc = fn( + Arc, + &TableFunctionInput, +) -> FunctionResult>>; #[derive(new, Clone)] pub struct TableFunction { diff --git a/src/planner_v2/binder/sqlparser_util.rs b/src/planner_v2/binder/sqlparser_util.rs index 7142e2d..24bc88f 100644 --- a/src/planner_v2/binder/sqlparser_util.rs +++ b/src/planner_v2/binder/sqlparser_util.rs @@ -1,7 +1,7 @@ use itertools::Itertools; use sqlparser::ast::{ - BinaryOperator, ColumnDef, Expr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, - TableFactor, TableWithJoins, Value, WildcardAdditionalOptions, + BinaryOperator, ColumnDef, Expr, FunctionArgExpr, Ident, ObjectName, Query, Select, SelectItem, + SetExpr, TableFactor, TableWithJoins, Value, WildcardAdditionalOptions, }; use super::BindError; @@ -47,6 +47,59 @@ impl SqlparserResolver { }; Ok((schema_name, table_name, column_name)) } + + pub fn resolve_expr_to_string(e: &Expr) -> Result { + match e { + Expr::Value(v) => match v { + Value::SingleQuotedString(s) => Ok(s.clone()), + Value::DoubleQuotedString(s) => Ok(s.clone()), + _ => Err(BindError::Internal(format!( + "excepted string type, but got: {}", + v + ))), + }, + _ => Err(BindError::Internal(format!( + "excepted value expr, but got: {}", + e + ))), + } + } + + pub fn resolve_expr_to_bool(e: &Expr) -> Result { + match e { + Expr::Value(v) => match v { + Value::Boolean(b) => Ok(*b), + _ => Err(BindError::Internal(format!( + "excepted bool type, but got: {}", + v + ))), + }, + _ => Err(BindError::Internal(format!( + "excepted value expr, but got: {}", + e + ))), + } + } + + pub fn resolve_func_arg_expr_to_string(arg: &FunctionArgExpr) -> Result { + if let FunctionArgExpr::Expr(e) = arg { + return SqlparserResolver::resolve_expr_to_string(e); + } + Err(BindError::Internal(format!( + "expected string arg, but got {}", + arg + ))) + } + + pub fn resolve_func_arg_expr_to_bool(arg: &FunctionArgExpr) -> Result { + if let FunctionArgExpr::Expr(e) = arg { + return SqlparserResolver::resolve_expr_to_bool(e); + } + Err(BindError::Internal(format!( + "expected bool arg, but got {}", + arg + ))) + } } #[derive(Default)] diff --git a/src/types_v2/types.rs b/src/types_v2/types.rs index 1842db9..be88976 100644 --- a/src/types_v2/types.rs +++ b/src/types_v2/types.rs @@ -249,6 +249,52 @@ impl From for arrow::datatypes::DataType { } } +impl TryFrom<&arrow::datatypes::DataType> for LogicalType { + type Error = TypeError; + + fn try_from(value: &arrow::datatypes::DataType) -> Result { + use arrow::datatypes::DataType; + Ok(match value { + DataType::Null => LogicalType::SqlNull, + DataType::Boolean => LogicalType::Boolean, + DataType::Int8 => LogicalType::Tinyint, + DataType::Int16 => LogicalType::Smallint, + DataType::Int32 => LogicalType::Integer, + DataType::Int64 => LogicalType::Bigint, + DataType::UInt8 => LogicalType::UTinyint, + DataType::UInt16 => LogicalType::USmallint, + DataType::UInt32 => LogicalType::UInteger, + DataType::UInt64 => LogicalType::UBigint, + DataType::Float16 => LogicalType::Float, + DataType::Float32 => LogicalType::Float, + DataType::Float64 => LogicalType::Double, + DataType::Utf8 => LogicalType::Varchar, + DataType::LargeUtf8 => LogicalType::Varchar, + DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::Struct(_) + | DataType::Union(_, _, _) + | DataType::Dictionary(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Map(_, _) => { + return Err(TypeError::NotImplementedArrowDataType(value.to_string())) + } + }) + } +} + impl std::fmt::Display for LogicalType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.as_ref()) diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index 9e11786..f2c1f7d 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -97,6 +97,7 @@ impl TreeRender { } FunctionData::SqlrsColumnsData(_) => "sqlrs_columns".to_string(), FunctionData::SqlrsTablesData(_) => "sqlrs_tables".to_string(), + // FunctionData::ReadCSVInputData(_) => "read_csv".to_string(), }, None => "None".to_string(), }; From c3218df04d2b5797dbcdfbf2891aa8d2c1725b3e Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 30 Dec 2022 14:36:13 +0800 Subject: [PATCH 2/4] feat(planner): support read csv table function Signed-off-by: Fedomn --- src/execution/volcano_executor/table_scan.rs | 2 +- src/function/mod.rs | 2 + src/function/table/mod.rs | 2 + src/function/table/read_csv.rs | 199 +++++++++++++++++++ src/function/table/seq_table_scan.rs | 4 +- src/function/table/sqlrs_columns.rs | 4 +- src/function/table/sqlrs_tables.rs | 4 +- src/function/table/table_function.rs | 2 +- src/util/tree_render.rs | 2 +- tests/slt/table_function.slt | 17 ++ 10 files changed, 229 insertions(+), 9 deletions(-) create mode 100644 src/function/table/read_csv.rs diff --git a/src/execution/volcano_executor/table_scan.rs b/src/execution/volcano_executor/table_scan.rs index a872f8e..a316cf2 100644 --- a/src/execution/volcano_executor/table_scan.rs +++ b/src/execution/volcano_executor/table_scan.rs @@ -23,7 +23,7 @@ impl TableScan { let table_scan_func = function.function; let tabel_scan_input = TableFunctionInput::new(bind_data); - let scan_stream = table_scan_func(context.clone_client_context(), &tabel_scan_input)?; + let scan_stream = table_scan_func(context.clone_client_context(), tabel_scan_input)?; #[for_await] for batch in scan_stream { diff --git a/src/function/mod.rs b/src/function/mod.rs index ce6bd41..f1ea0e6 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -24,6 +24,7 @@ pub enum FunctionData { SeqTableScanInputData(Box), SqlrsTablesData(Box), SqlrsColumnsData(Box), + ReadCSVInputData(Box), } #[derive(new)] @@ -61,6 +62,7 @@ impl BuiltinFunctions { SubtractFunction::register_function(self)?; MultiplyFunction::register_function(self)?; DivideFunction::register_function(self)?; + ReadCSV::register_function(self)?; Ok(()) } } diff --git a/src/function/table/mod.rs b/src/function/table/mod.rs index 26ed776..a5f6047 100644 --- a/src/function/table/mod.rs +++ b/src/function/table/mod.rs @@ -1,7 +1,9 @@ +mod read_csv; mod seq_table_scan; mod sqlrs_columns; mod sqlrs_tables; mod table_function; +pub use read_csv::*; pub use seq_table_scan::*; pub use sqlrs_columns::*; pub use sqlrs_tables::*; diff --git a/src/function/table/read_csv.rs b/src/function/table/read_csv.rs new file mode 100644 index 0000000..2b7f5ed --- /dev/null +++ b/src/function/table/read_csv.rs @@ -0,0 +1,199 @@ +use std::fs::File; +use std::sync::Arc; + +use arrow::csv::{reader, Reader}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use derive_builder::Builder; +use futures::stream::BoxStream; +use sqlparser::ast::FunctionArg; + +use super::{TableFunction, TableFunctionBindInput, TableFunctionInput}; +use crate::function::{BuiltinFunctions, FunctionData, FunctionError, FunctionResult}; +use crate::main_entry::ClientContext; +use crate::planner_v2::SqlparserResolver; +use crate::types_v2::LogicalType; + +pub struct ReadCSV; + +#[derive(Builder, Debug, Clone)] +pub struct ReadCSVInputData { + pub(crate) filename: String, + pub(crate) option: ReadCSVOptions, + pub(crate) schema: SchemaRef, + #[builder(default = "None")] + pub(crate) bounds: Option<(usize, usize)>, + #[builder(default = "None")] + pub(crate) projection: Option>, +} + +#[derive(Builder, Debug, Clone)] +pub struct ReadCSVOptions { + #[builder(default = "1024")] + pub(crate) infer_schema_max_rows: usize, + #[builder(default = "1024")] + pub(crate) read_batch_size: usize, + #[builder(default = "true")] + pub(crate) has_header: bool, + #[builder(default = "b','")] + pub(crate) delimiter: u8, + #[builder(default = "None")] + pub(crate) datetime_format: Option, +} + +impl ReadCSV { + fn parse_filename(args: &[FunctionArg]) -> Result { + if let FunctionArg::Unnamed(e) = &args[0] { + return Ok(SqlparserResolver::resolve_func_arg_expr_to_string(e)?); + } + Err(FunctionError::InternalError(format!( + "unexpected filename arg: {}", + &args[0] + ))) + } + + fn parse_func_args(args: &[FunctionArg]) -> Result<(String, ReadCSVOptions), FunctionError> { + if args.is_empty() { + Err(FunctionError::InternalError( + "filename is required".to_string(), + )) + } else { + let filename = Self::parse_filename(args)?; + let mut options = ReadCSVOptionsBuilder::default(); + for each in args.iter().skip(1) { + if let FunctionArg::Named { name, arg } = each { + match name.value.as_str() { + "delim" => { + let string = SqlparserResolver::resolve_func_arg_expr_to_string(arg)?; + let bytes = string.as_bytes(); + if bytes.len() != 1 { + return Err(FunctionError::InternalError( + "delimiter must be a single byte".to_string(), + )); + } + options.delimiter(bytes[0]); + } + "header" => { + let v = SqlparserResolver::resolve_func_arg_expr_to_bool(arg)?; + options.has_header(v); + } + other => { + return Err(FunctionError::InternalError(format!( + "unexpected arg: {}", + other + ))) + } + } + } else { + return Err(FunctionError::InternalError( + "expected named arg".to_string(), + )); + } + } + Ok((filename, options.build().unwrap())) + } + } + + fn infer_arrow_schema( + filepath: String, + option: &ReadCSVOptions, + ) -> Result { + let mut file = File::open(filepath)?; + let (schema, _) = reader::infer_reader_schema( + &mut file, + option.delimiter, + Some(option.infer_schema_max_rows), + option.has_header, + )?; + Ok(Arc::new(schema)) + } + + fn create_reader(input: ReadCSVInputData) -> Result, FunctionError> { + let file = File::open(input.filename)?; + // convert bounds into csv bounds concept: (min line, max line) + let new_bounds = input.bounds.map(|(offset, limit)| { + if limit == usize::MAX { + (offset, limit) + } else { + (offset, offset + limit + 1) + } + }); + let reader = Reader::new( + file, + input.schema, + input.option.has_header, + Some(input.option.delimiter), + input.option.read_batch_size, + new_bounds, + input.projection, + input.option.datetime_format, + ); + Ok(reader) + } + + fn parse_col_names_types( + schema: &SchemaRef, + ) -> Result<(Vec, Vec), FunctionError> { + let mut col_names = vec![]; + let mut col_types = vec![]; + for field in schema.fields() { + col_names.push(field.name().to_string()); + col_types.push(field.data_type().try_into()?); + } + Ok((col_names, col_types)) + } + + fn bind_func( + _context: Arc, + input: TableFunctionBindInput, + return_types: &mut Vec, + return_names: &mut Vec, + ) -> Result, FunctionError> { + if let Some(args) = input.func_args { + let (filename, option) = Self::parse_func_args(args.as_slice())?; + let schema = Self::infer_arrow_schema(filename.clone(), &option)?; + let (col_names, col_types) = Self::parse_col_names_types(&schema)?; + return_types.extend(col_types); + return_names.extend(col_names); + let input_data = ReadCSVInputDataBuilder::default() + .filename(filename) + .schema(schema) + .option(option) + .build() + .unwrap(); + Ok(Some(FunctionData::ReadCSVInputData(Box::new(input_data)))) + } else { + Err(FunctionError::InternalError( + "unexpected bind data type".to_string(), + )) + } + } + + fn scan_func( + _context: Arc, + input: TableFunctionInput, + ) -> FunctionResult>> { + if let Some(FunctionData::ReadCSVInputData(data)) = input.bind_data { + let mut reader = Self::create_reader(*data)?; + let stream = Box::pin(async_stream::try_stream! { + while let Some(batch) = reader.next().transpose()? { + yield batch; + } + }); + Ok(stream) + } else { + Err(FunctionError::InternalError( + "unexpected bind data type".to_string(), + )) + } + } + + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { + set.add_table_functions(TableFunction::new( + "read_csv".to_string(), + Some(Self::bind_func), + Self::scan_func, + ))?; + Ok(()) + } +} diff --git a/src/function/table/seq_table_scan.rs b/src/function/table/seq_table_scan.rs index 7537a71..e0e897b 100644 --- a/src/function/table/seq_table_scan.rs +++ b/src/function/table/seq_table_scan.rs @@ -40,9 +40,9 @@ impl SeqTableScan { fn scan_func( context: Arc, - input: &TableFunctionInput, + input: TableFunctionInput, ) -> FunctionResult>> { - if let Some(FunctionData::SeqTableScanInputData(data)) = &input.bind_data { + if let Some(FunctionData::SeqTableScanInputData(data)) = input.bind_data { let mut reader = LocalStorage::create_reader(&data.bind_table.storage); let stream = Box::pin(async_stream::try_stream! { while let Some(batch) = reader.next_batch(context.clone()){ diff --git a/src/function/table/sqlrs_columns.rs b/src/function/table/sqlrs_columns.rs index bbdf0d7..b1ceddc 100644 --- a/src/function/table/sqlrs_columns.rs +++ b/src/function/table/sqlrs_columns.rs @@ -59,9 +59,9 @@ impl SqlrsColumnsFunc { fn tables_func( _context: Arc, - input: &TableFunctionInput, + input: TableFunctionInput, ) -> FunctionResult>> { - if let Some(FunctionData::SqlrsColumnsData(data)) = &input.bind_data { + if let Some(FunctionData::SqlrsColumnsData(data)) = input.bind_data { let schema = SchemaUtil::new_schema_ref(&data.return_names, &data.return_types); let mut table_name = ScalarValue::new_builder(&LogicalType::Varchar)?; let mut column_names = ScalarValue::new_builder(&LogicalType::Varchar)?; diff --git a/src/function/table/sqlrs_tables.rs b/src/function/table/sqlrs_tables.rs index 011fa6a..b862704 100644 --- a/src/function/table/sqlrs_tables.rs +++ b/src/function/table/sqlrs_tables.rs @@ -60,9 +60,9 @@ impl SqlrsTablesFunc { fn tables_func( _context: Arc, - input: &TableFunctionInput, + input: TableFunctionInput, ) -> FunctionResult>> { - if let Some(FunctionData::SqlrsTablesData(data)) = &input.bind_data { + if let Some(FunctionData::SqlrsTablesData(data)) = input.bind_data { let schema = SchemaUtil::new_schema_ref(&data.return_names, &data.return_types); let mut schema_names = ScalarValue::new_builder(&LogicalType::Varchar)?; let mut schema_oids = ScalarValue::new_builder(&LogicalType::Integer)?; diff --git a/src/function/table/table_function.rs b/src/function/table/table_function.rs index 2d81f09..8745f0e 100644 --- a/src/function/table/table_function.rs +++ b/src/function/table/table_function.rs @@ -32,7 +32,7 @@ pub type TableFunctionBindFunc = fn( pub type TableFunc = fn( Arc, - &TableFunctionInput, + TableFunctionInput, ) -> FunctionResult>>; #[derive(new, Clone)] diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index f2c1f7d..2220574 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -97,7 +97,7 @@ impl TreeRender { } FunctionData::SqlrsColumnsData(_) => "sqlrs_columns".to_string(), FunctionData::SqlrsTablesData(_) => "sqlrs_tables".to_string(), - // FunctionData::ReadCSVInputData(_) => "read_csv".to_string(), + FunctionData::ReadCSVInputData(_) => "read_csv".to_string(), }, None => "None".to_string(), }; diff --git a/tests/slt/table_function.slt b/tests/slt/table_function.slt index 8dea32f..da27a74 100644 --- a/tests/slt/table_function.slt +++ b/tests/slt/table_function.slt @@ -15,3 +15,20 @@ query III select * from sqlrs_columns() where table_name = 't1'; ---- t1 [v1, v2, v3] [Integer, Integer, Integer] + + +onlyif sqlrs_v2 +query III +select column_1 from read_csv('../csv/t1.csv', header=>false); +---- +a +0 +1 +2 +2 + +onlyif sqlrs_v2 +query III +select a from read_csv('../csv/t1.csv', header=>true, delim=>',') where a = 1; +---- +1 From c43045f03fa93fd5a9a5cf0992fa0dfeaabe2543 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 30 Dec 2022 15:43:54 +0800 Subject: [PATCH 3/4] feat(planner): support replaced read csv function when bind table Signed-off-by: Fedomn --- src/planner_v2/binder/sqlparser_util.rs | 43 +++++++++++++++++- .../binder/tableref/bind_base_table_ref.rs | 44 ++++++++++++++++++- tests/slt/table_function.slt | 18 ++++++++ 3 files changed, 101 insertions(+), 4 deletions(-) diff --git a/src/planner_v2/binder/sqlparser_util.rs b/src/planner_v2/binder/sqlparser_util.rs index 24bc88f..2ea6ad1 100644 --- a/src/planner_v2/binder/sqlparser_util.rs +++ b/src/planner_v2/binder/sqlparser_util.rs @@ -1,7 +1,10 @@ +use std::collections::HashMap; + use itertools::Itertools; use sqlparser::ast::{ - BinaryOperator, ColumnDef, Expr, FunctionArgExpr, Ident, ObjectName, Query, Select, SelectItem, - SetExpr, TableFactor, TableWithJoins, Value, WildcardAdditionalOptions, + BinaryOperator, ColumnDef, Expr, FunctionArg, FunctionArgExpr, Ident, ObjectName, Query, + Select, SelectItem, SetExpr, TableAlias, TableFactor, TableWithJoins, Value, + WildcardAdditionalOptions, }; use super::BindError; @@ -215,3 +218,39 @@ impl SqlparserQueryBuilder { } } } + +pub struct SqlparserTableFactorBuilder; + +impl SqlparserTableFactorBuilder { + pub fn build_table_func( + func_name: &str, + alias: String, + unamed_arges: Vec, + uamed_args: HashMap, + ) -> TableFactor { + let unamed_arges = unamed_arges + .into_iter() + .map(|arg| { + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString(arg), + ))) + }) + .collect::>(); + let uamed_args = uamed_args + .into_iter() + .map(|(k, v)| FunctionArg::Named { + name: Ident::new(k), + arg: FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString(v))), + }) + .collect::>(); + TableFactor::Table { + name: ObjectName(vec![Ident::new(func_name)]), + alias: Some(TableAlias { + name: Ident::new(alias), + columns: vec![], + }), + args: Some([unamed_arges, uamed_args].concat()), + with_hints: vec![], + } + } +} diff --git a/src/planner_v2/binder/tableref/bind_base_table_ref.rs b/src/planner_v2/binder/tableref/bind_base_table_ref.rs index 58b4855..805bd7f 100644 --- a/src/planner_v2/binder/tableref/bind_base_table_ref.rs +++ b/src/planner_v2/binder/tableref/bind_base_table_ref.rs @@ -1,10 +1,15 @@ +use std::collections::HashMap; +use std::path::Path; + use derive_new::new; +use sqlparser::ast::TableFactor; use super::BoundTableRef; use crate::catalog_v2::{Catalog, CatalogEntry, TableCatalogEntry}; use crate::function::{SeqTableScan, TableFunctionBindInput}; use crate::planner_v2::{ BindError, Binder, LogicalGet, LogicalOperator, LogicalOperatorBase, SqlparserResolver, + SqlparserTableFactorBuilder, }; /// Represents a TableReference to a base table in the schema @@ -32,10 +37,14 @@ impl Binder { .map(|a| a.to_string()) .unwrap_or_else(|| table.clone()); - let table_res = Catalog::get_table(self.clone_client_context(), schema, table); + let table_res = + Catalog::get_table(self.clone_client_context(), schema, table.clone()); if table_res.is_err() { // table could not be found: try to bind a replacement scan - return Err(BindError::CatalogError(table_res.err().unwrap())); + match self.bind_replacement_table_factor(table, alias) { + Some(replaced_table) => return self.bind_table_function(replaced_table), + None => return Err(BindError::CatalogError(table_res.err().unwrap())), + } } let table = table_res.unwrap(); @@ -84,4 +93,35 @@ impl Binder { ))), } } + + /// Replacement table scans are automatically attempted when a table name cannot be found in the + /// schema. This allows you to do e.g. SELECT * FROM 'filename.csv', and automatically + /// convert this into a CSV scan + fn bind_replacement_table_factor( + &mut self, + table_name: String, + alias: String, + ) -> Option { + let table_name = table_name.to_lowercase(); + let mut alias = alias.to_lowercase(); + if table_name.ends_with(".csv") { + if table_name == alias { + // which means the alias is not set, so we simply use the filename + alias = Path::new(table_name.as_str()) + .file_stem() + .unwrap() + .to_str() + .unwrap() + .to_string(); + } + Some(SqlparserTableFactorBuilder::build_table_func( + "read_csv", + alias, + vec![table_name], + HashMap::new(), + )) + } else { + None + } + } } diff --git a/tests/slt/table_function.slt b/tests/slt/table_function.slt index da27a74..d023a43 100644 --- a/tests/slt/table_function.slt +++ b/tests/slt/table_function.slt @@ -32,3 +32,21 @@ query III select a from read_csv('../csv/t1.csv', header=>true, delim=>',') where a = 1; ---- 1 + +onlyif sqlrs_v2 +query III +select t1.a from '../csv/t1.csv'; +---- +0 +1 +2 +2 + +onlyif sqlrs_v2 +query III +select tt.a from '../csv/t1.csv' tt; +---- +0 +1 +2 +2 From b10a33d9ee53b814bbab188a09ac602f27ef7fcc Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 30 Dec 2022 16:05:49 +0800 Subject: [PATCH 4/4] feat(planner): add table_functions in doc Signed-off-by: Fedomn --- Makefile | 3 +++ README.md | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/Makefile b/Makefile index 5f48f1f..7309863 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,9 @@ clean: run: cargo run --release +run_v2: + ENABLE_V2=1 cargo run --release + debug: RUST_BACKTRACE=1 cargo run diff --git a/README.md b/README.md index eee2945..8be9ae8 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,12 @@ describe t1; -- previous SQL statements select v1+1 as a from t1 where a >= 2; select v1 from t1 limit 2 offset 1; +-- table functions +select * from sqlrs_tables(); +select * from sqlrs_columns(); +select * from read_csv('t1.csv'); +select * from read_csv('t1.csv', header=>true, delim=>','); +select * from 't1.csv'; ```