From 8aa1eeb4da3a96ad2e19f9d9af733c95de2afa2c Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 30 Jan 2023 21:27:10 -0500 Subject: [PATCH 01/16] type bindings --- Cargo.lock | 45 ++-- Cargo.toml | 3 +- src/common/dffield.rs | 36 ++++ src/common/mod.rs | 18 ++ src/context.rs | 2 +- src/lib.rs | 2 + src/sql.rs | 19 ++ src/sql/exceptions.rs | 25 +++ src/sql/types.rs | 344 +++++++++++++++++++++++++++++++ src/sql/types/arrow_type.rs | 0 src/sql/types/data_type.rs | 242 ++++++++++++++++++++++ src/sql/types/datafusion_type.rs | 0 src/sql/types/sql_type.rs | 75 +++++++ 13 files changed, 787 insertions(+), 24 deletions(-) create mode 100644 src/common/dffield.rs create mode 100644 src/common/mod.rs create mode 100644 src/sql.rs create mode 100644 src/sql/exceptions.rs create mode 100644 src/sql/types.rs create mode 100644 src/sql/types/arrow_type.rs create mode 100644 src/sql/types/data_type.rs create mode 100644 src/sql/types/datafusion_type.rs create mode 100644 src/sql/types/sql_type.rs diff --git a/Cargo.lock b/Cargo.lock index 4ab865681..981db1239 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16,9 +16,9 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf6ccdb167abbf410dcb915cabd428929d7f6a04980b54a11f26a39f1c7f7107" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", "const-random", @@ -342,9 +342,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.62" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "689894c2db1ea643a50834b999abf1c110887402542955ff5451dab8f861f9ed" +checksum = "eff18d764974428cf3a9328e23fc5c986f5fbed46e6cd4cdf42544df5d297ec1" dependencies = [ "proc-macro2", "quote", @@ -620,9 +620,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d1075c37807dcf850c379432f0df05ba52cc30f279c5cfc43cc221ce7f8579" +checksum = "b61a7545f753a88bcbe0a70de1fcc0221e10bfc752f576754fa91e663db1622e" dependencies = [ "cc", "cxxbridge-flags", @@ -632,9 +632,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5044281f61b27bc598f2f6647d480aed48d2bf52d6eb0b627d84c0361b17aa70" +checksum = "f464457d494b5ed6905c63b0c4704842aba319084a0a3561cdc1359536b53200" dependencies = [ "cc", "codespan-reporting", @@ -647,15 +647,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61b50bc93ba22c27b0d31128d2d130a0a6b3d267ae27ef7e4fae2167dfe8781c" +checksum = "43c7119ce3a3701ed81aca8410b9acf6fc399d2629d057b87e2efa4e63a3aaea" [[package]] name = "cxxbridge-macro" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e61fda7e62115119469c7b3591fd913ecca96fb766cfd3f2e2502ab7bc87a5" +checksum = "65e07508b90551e610910fa648a1878991d367064997a596135b86df30daf07e" dependencies = [ "proc-macro2", "quote", @@ -806,6 +806,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-optimizer", + "datafusion-sql", "datafusion-substrait", "futures", "mimalloc", @@ -1754,9 +1755,9 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pest" -version = "2.5.3" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4257b4a04d91f7e9e6290be5d3da4804dd5784fafde3a497d73eb2b4a158c30a" +checksum = "4ab62d2fa33726dbe6321cc97ef96d8cde531e3eeaf858a058de53a8a6d40d8f" dependencies = [ "thiserror", "ucd-trie", @@ -2336,9 +2337,9 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.9.16" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b5b431e8907b50339b51223b97d102db8d987ced36f6e4d03621db9316c834" +checksum = "8fb06d4b6cdaef0e0c51fa881acb721bed3c924cfaa71d9c94a3b771dfdf6567" dependencies = [ "indexmap", "itoa 1.0.5", @@ -2650,9 +2651,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.5.10" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1333c76748e868a4d9d1017b5ab53171dfd095f70c712fdb4653a406547f598f" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" dependencies = [ "serde", ] @@ -2794,9 +2795,9 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "unicode-bidi" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0046be40136ef78dc325e0edefccf84ccddacd0afcc1ca54103fa3c61bbdab1d" +checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" [[package]] name = "unicode-ident" @@ -3007,9 +3008,9 @@ dependencies = [ [[package]] name = "which" -version = "4.3.0" +version = "4.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c831fbbee9e129a8cf93e7747a82da9d95ba8e16621cae60ec2cdc849bacb7b" +checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" dependencies = [ "either", "libc", diff --git a/Cargo.toml b/Cargo.toml index fddeabbd9..ecd2260d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,11 +33,12 @@ default = ["mimalloc"] [dependencies] tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.8" -pyo3 = { version = "~0.17.3", features = ["extension-module", "abi3", "abi3-py37"] } +pyo3 = { version = "^0.17.3", features = ["extension-module", "abi3", "abi3-py37"] } datafusion = { git = "https://github.com/apache/arrow-datafusion", rev = "6dce728a3c7130ca3590a16f413c7c6ccb7209b7", features = ["pyarrow", "avro"] } datafusion-expr = { git = "https://github.com/apache/arrow-datafusion", rev = "6dce728a3c7130ca3590a16f413c7c6ccb7209b7" } datafusion-optimizer = { git = "https://github.com/apache/arrow-datafusion", rev = "6dce728a3c7130ca3590a16f413c7c6ccb7209b7" } datafusion-common = { git = "https://github.com/apache/arrow-datafusion", rev = "6dce728a3c7130ca3590a16f413c7c6ccb7209b7", features = ["pyarrow"] } +datafusion-sql = { git = "https://github.com/apache/arrow-datafusion", rev = "6dce728a3c7130ca3590a16f413c7c6ccb7209b7" } datafusion-substrait = { git = "https://github.com/apache/arrow-datafusion", rev = "6dce728a3c7130ca3590a16f413c7c6ccb7209b7" } uuid = { version = "1.2", features = ["v4"] } mimalloc = { version = "*", optional = true, default-features = false } diff --git a/src/common/dffield.rs b/src/common/dffield.rs new file mode 100644 index 000000000..78dbdbbb7 --- /dev/null +++ b/src/common/dffield.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::Field; +use pyo3::prelude::*; + +use crate::sql::types::DataTypeMap; + +/// PyDFField wraps an arrow-datafusion `DFField` struct type +/// and also supplies convenience methods for interacting +/// with the `DFField` instance in the context of Python +#[pyclass(name = "DFField", module = "datafusion", subclass)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct PyDFField { + /// Optional qualifier (usually a table or relation name) + qualifier: Option, + name: String, + data_type: DataTypeMap, + /// Arrow field definition + field: Field, + index: usize, +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 000000000..4c66ffd44 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod dffield; diff --git a/src/context.rs b/src/context.rs index 71f99f55a..6021ba0c4 100644 --- a/src/context.rs +++ b/src/context.rs @@ -162,7 +162,7 @@ impl PySessionContext { // table name cannot start with numeric digit let name = "c".to_owned() + Uuid::new_v4() - .to_simple() + .simple() .encode_lower(&mut Uuid::encode_buffer()); self.ctx diff --git a/src/lib.rs b/src/lib.rs index 21b47f449..4c2df0cea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,8 @@ mod udaf; #[allow(clippy::borrow_deref_ref)] mod udf; pub mod utils; +pub mod sql; +pub mod common; #[cfg(feature = "mimalloc")] #[global_allocator] diff --git a/src/sql.rs b/src/sql.rs new file mode 100644 index 000000000..e468ad8cc --- /dev/null +++ b/src/sql.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod exceptions; +pub mod types; diff --git a/src/sql/exceptions.rs b/src/sql/exceptions.rs new file mode 100644 index 000000000..871402279 --- /dev/null +++ b/src/sql/exceptions.rs @@ -0,0 +1,25 @@ +use std::fmt::Debug; + +use pyo3::{create_exception, PyErr}; + +// Identifies exceptions that occur while attempting to generate a `LogicalPlan` from a SQL string +create_exception!(rust, ParsingException, pyo3::exceptions::PyException); + +// Identifies exceptions that occur during attempts to optimization an existing `LogicalPlan` +create_exception!(rust, OptimizationException, pyo3::exceptions::PyException); + +pub fn py_type_err(e: impl Debug) -> PyErr { + PyErr::new::(format!("{:?}", e)) +} + +pub fn py_runtime_err(e: impl Debug) -> PyErr { + PyErr::new::(format!("{:?}", e)) +} + +pub fn py_parsing_exp(e: impl Debug) -> PyErr { + PyErr::new::(format!("{:?}", e)) +} + +pub fn py_optimization_exp(e: impl Debug) -> PyErr { + PyErr::new::(format!("{:?}", e)) +} diff --git a/src/sql/types.rs b/src/sql/types.rs new file mode 100644 index 000000000..3a5d1e884 --- /dev/null +++ b/src/sql/types.rs @@ -0,0 +1,344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod data_type; +pub mod sql_type; + +use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; +use datafusion_sql::sqlparser::{ast::DataType as SQLType, parser::Parser, tokenizer::Tokenizer}; +use pyo3::{prelude::*, types::PyDict}; + +use self::sql_type::SqlType; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "RexType", module = "datafusion")] +pub enum RexType { + Literal, + Call, + Reference, + SubqueryAlias, + Other, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "DataTypeMap", module = "datafusion", subclass)] +pub struct DataTypeMap { + sql_type: SqlType, + arrow_type: DataType, +} + + +// /// Functions not exposed to Python +// impl DaskTypeMap { +// pub fn from(sql_type: SqlTypeName, data_type: PyDataType) -> Self { +// DaskTypeMap { +// sql_type, +// data_type, +// } +// } +// } + +// #[pymethods] +// impl DaskTypeMap { +// #[new] +// #[pyo3(signature = (sql_type, **py_kwargs))] +// fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> PyResult { +// let d_type: DataType = match sql_type { +// SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => { +// let (unit, tz) = match py_kwargs { +// Some(dict) => { +// let tz: Option = match dict.get_item("tz") { +// Some(e) => { +// let res: PyResult = e.extract(); +// Some(res.unwrap()) +// } +// None => None, +// }; +// let unit: TimeUnit = match dict.get_item("unit") { +// Some(e) => { +// let res: PyResult<&str> = e.extract(); +// match res.unwrap() { +// "Second" => TimeUnit::Second, +// "Millisecond" => TimeUnit::Millisecond, +// "Microsecond" => TimeUnit::Microsecond, +// "Nanosecond" => TimeUnit::Nanosecond, +// _ => TimeUnit::Nanosecond, +// } +// } +// // Default to Nanosecond which is common if not present +// None => TimeUnit::Nanosecond, +// }; +// (unit, tz) +// } +// // Default to Nanosecond and None for tz which is common if not present +// None => (TimeUnit::Nanosecond, None), +// }; +// DataType::Timestamp(unit, tz) +// } +// SqlTypeName::TIMESTAMP => { +// let (unit, tz) = match py_kwargs { +// Some(dict) => { +// let tz: Option = match dict.get_item("tz") { +// Some(e) => { +// let res: PyResult = e.extract(); +// Some(res.unwrap()) +// } +// None => None, +// }; +// let unit: TimeUnit = match dict.get_item("unit") { +// Some(e) => { +// let res: PyResult<&str> = e.extract(); +// match res.unwrap() { +// "Second" => TimeUnit::Second, +// "Millisecond" => TimeUnit::Millisecond, +// "Microsecond" => TimeUnit::Microsecond, +// "Nanosecond" => TimeUnit::Nanosecond, +// _ => TimeUnit::Nanosecond, +// } +// } +// // Default to Nanosecond which is common if not present +// None => TimeUnit::Nanosecond, +// }; +// (unit, tz) +// } +// // Default to Nanosecond and None for tz which is common if not present +// None => (TimeUnit::Nanosecond, None), +// }; +// DataType::Timestamp(unit, tz) +// } +// _ => sql_type.to_arrow()?, +// }; + +// Ok(DaskTypeMap { +// sql_type, +// data_type: d_type.into(), +// }) +// } + +// #[pyo3(name = "getSqlType")] +// pub fn sql_type(&self) -> SqlTypeName { +// self.sql_type.clone() +// } + +// #[pyo3(name = "getDataType")] +// pub fn data_type(&self) -> PyDataType { +// self.data_type.clone() +// } +// } + +// #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +// #[pyclass(name = "PyDataType", module = "datafusion", subclass)] +// pub struct PyDataType { +// data_type: DataType, +// } + +// impl From for DataType { +// fn from(data_type: PyDataType) -> DataType { +// data_type.data_type +// } +// } + +// impl From for PyDataType { +// fn from(data_type: DataType) -> PyDataType { +// PyDataType { data_type } +// } +// } + + +// impl SqlTypeName { +// pub fn to_arrow(&self) -> Result { +// match self { +// SqlTypeName::NULL => Ok(DataType::Null), +// SqlTypeName::BOOLEAN => Ok(DataType::Boolean), +// SqlTypeName::TINYINT => Ok(DataType::Int8), +// SqlTypeName::SMALLINT => Ok(DataType::Int16), +// SqlTypeName::INTEGER => Ok(DataType::Int32), +// SqlTypeName::BIGINT => Ok(DataType::Int64), +// SqlTypeName::REAL => Ok(DataType::Float16), +// SqlTypeName::FLOAT => Ok(DataType::Float32), +// SqlTypeName::DOUBLE => Ok(DataType::Float64), +// SqlTypeName::DATE => Ok(DataType::Date64), +// SqlTypeName::VARCHAR => Ok(DataType::Utf8), +// _ => Err(DaskPlannerError::Internal(format!( +// "Cannot determine Arrow type for Dask SQL type '{:?}'", +// self +// ))), +// } +// } + +// pub fn from_arrow(arrow_type: &DataType) -> Result { +// match arrow_type { +// DataType::Null => Ok(SqlTypeName::NULL), +// DataType::Boolean => Ok(SqlTypeName::BOOLEAN), +// DataType::Int8 => Ok(SqlTypeName::TINYINT), +// DataType::Int16 => Ok(SqlTypeName::SMALLINT), +// DataType::Int32 => Ok(SqlTypeName::INTEGER), +// DataType::Int64 => Ok(SqlTypeName::BIGINT), +// DataType::UInt8 => Ok(SqlTypeName::TINYINT), +// DataType::UInt16 => Ok(SqlTypeName::SMALLINT), +// DataType::UInt32 => Ok(SqlTypeName::INTEGER), +// DataType::UInt64 => Ok(SqlTypeName::BIGINT), +// DataType::Float16 => Ok(SqlTypeName::REAL), +// DataType::Float32 => Ok(SqlTypeName::FLOAT), +// DataType::Float64 => Ok(SqlTypeName::DOUBLE), +// DataType::Time32(_) | DataType::Time64(_) => Ok(SqlTypeName::TIME), +// DataType::Timestamp(_unit, tz) => match tz { +// Some(_) => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE), +// None => Ok(SqlTypeName::TIMESTAMP), +// }, +// DataType::Date32 => Ok(SqlTypeName::DATE), +// DataType::Date64 => Ok(SqlTypeName::DATE), +// DataType::Interval(unit) => match unit { +// IntervalUnit::DayTime => Ok(SqlTypeName::INTERVAL_DAY), +// IntervalUnit::YearMonth => Ok(SqlTypeName::INTERVAL_YEAR_MONTH), +// IntervalUnit::MonthDayNano => Ok(SqlTypeName::INTERVAL_MONTH), +// }, +// DataType::Binary => Ok(SqlTypeName::BINARY), +// DataType::FixedSizeBinary(_size) => Ok(SqlTypeName::VARBINARY), +// DataType::Utf8 => Ok(SqlTypeName::CHAR), +// DataType::LargeUtf8 => Ok(SqlTypeName::VARCHAR), +// DataType::Struct(_fields) => Ok(SqlTypeName::STRUCTURED), +// DataType::Decimal128(_precision, _scale) => Ok(SqlTypeName::DECIMAL), +// DataType::Decimal256(_precision, _scale) => Ok(SqlTypeName::DECIMAL), +// DataType::Map(_field, _bool) => Ok(SqlTypeName::MAP), +// _ => Err(DaskPlannerError::Internal(format!( +// "Cannot determine Dask SQL type for Arrow type '{:?}'", +// arrow_type +// ))), +// } +// } +// } + +// #[pymethods] +// impl SqlTypeName { +// #[pyo3(name = "fromString")] +// #[staticmethod] +// pub fn py_from_string(input_type: &str) -> PyResult { +// SqlTypeName::from_string(input_type).map_err(|e| e.into()) +// } +// } + +// impl SqlTypeName { +// pub fn from_string(input_type: &str) -> Result { +// match input_type.to_uppercase().as_ref() { +// "ANY" => Ok(SqlTypeName::ANY), +// "ARRAY" => Ok(SqlTypeName::ARRAY), +// "NULL" => Ok(SqlTypeName::NULL), +// "BOOLEAN" => Ok(SqlTypeName::BOOLEAN), +// "COLUMN_LIST" => Ok(SqlTypeName::COLUMN_LIST), +// "DISTINCT" => Ok(SqlTypeName::DISTINCT), +// "CURSOR" => Ok(SqlTypeName::CURSOR), +// "TINYINT" => Ok(SqlTypeName::TINYINT), +// "SMALLINT" => Ok(SqlTypeName::SMALLINT), +// "INT" => Ok(SqlTypeName::INTEGER), +// "INTEGER" => Ok(SqlTypeName::INTEGER), +// "BIGINT" => Ok(SqlTypeName::BIGINT), +// "REAL" => Ok(SqlTypeName::REAL), +// "FLOAT" => Ok(SqlTypeName::FLOAT), +// "GEOMETRY" => Ok(SqlTypeName::GEOMETRY), +// "DOUBLE" => Ok(SqlTypeName::DOUBLE), +// "TIME" => Ok(SqlTypeName::TIME), +// "TIME_WITH_LOCAL_TIME_ZONE" => Ok(SqlTypeName::TIME_WITH_LOCAL_TIME_ZONE), +// "TIMESTAMP" => Ok(SqlTypeName::TIMESTAMP), +// "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE), +// "DATE" => Ok(SqlTypeName::DATE), +// "INTERVAL" => Ok(SqlTypeName::INTERVAL), +// "INTERVAL_DAY" => Ok(SqlTypeName::INTERVAL_DAY), +// "INTERVAL_DAY_HOUR" => Ok(SqlTypeName::INTERVAL_DAY_HOUR), +// "INTERVAL_DAY_MINUTE" => Ok(SqlTypeName::INTERVAL_DAY_MINUTE), +// "INTERVAL_DAY_SECOND" => Ok(SqlTypeName::INTERVAL_DAY_SECOND), +// "INTERVAL_HOUR" => Ok(SqlTypeName::INTERVAL_HOUR), +// "INTERVAL_HOUR_MINUTE" => Ok(SqlTypeName::INTERVAL_HOUR_MINUTE), +// "INTERVAL_HOUR_SECOND" => Ok(SqlTypeName::INTERVAL_HOUR_SECOND), +// "INTERVAL_MINUTE" => Ok(SqlTypeName::INTERVAL_MINUTE), +// "INTERVAL_MINUTE_SECOND" => Ok(SqlTypeName::INTERVAL_MINUTE_SECOND), +// "INTERVAL_MONTH" => Ok(SqlTypeName::INTERVAL_MONTH), +// "INTERVAL_SECOND" => Ok(SqlTypeName::INTERVAL_SECOND), +// "INTERVAL_YEAR" => Ok(SqlTypeName::INTERVAL_YEAR), +// "INTERVAL_YEAR_MONTH" => Ok(SqlTypeName::INTERVAL_YEAR_MONTH), +// "MAP" => Ok(SqlTypeName::MAP), +// "MULTISET" => Ok(SqlTypeName::MULTISET), +// "OTHER" => Ok(SqlTypeName::OTHER), +// "ROW" => Ok(SqlTypeName::ROW), +// "SARG" => Ok(SqlTypeName::SARG), +// "BINARY" => Ok(SqlTypeName::BINARY), +// "VARBINARY" => Ok(SqlTypeName::VARBINARY), +// "CHAR" => Ok(SqlTypeName::CHAR), +// "VARCHAR" | "STRING" => Ok(SqlTypeName::VARCHAR), +// "STRUCTURED" => Ok(SqlTypeName::STRUCTURED), +// "SYMBOL" => Ok(SqlTypeName::SYMBOL), +// "DECIMAL" => Ok(SqlTypeName::DECIMAL), +// "DYNAMIC_STAT" => Ok(SqlTypeName::DYNAMIC_STAR), +// "UNKNOWN" => Ok(SqlTypeName::UNKNOWN), +// _ => { +// // complex data type name so use the sqlparser +// let dialect = DaskDialect {}; +// let mut tokenizer = Tokenizer::new(&dialect, input_type); +// let tokens = tokenizer.tokenize().map_err(DaskPlannerError::from)?; +// let mut parser = Parser::new(tokens, &dialect); +// match parser.parse_data_type().map_err(DaskPlannerError::from)? { +// SQLType::Decimal(_) => Ok(SqlTypeName::DECIMAL), +// SQLType::Binary(_) => Ok(SqlTypeName::BINARY), +// SQLType::Varbinary(_) => Ok(SqlTypeName::VARBINARY), +// SQLType::Varchar(_) | SQLType::Nvarchar(_) => Ok(SqlTypeName::VARCHAR), +// SQLType::Char(_) => Ok(SqlTypeName::CHAR), +// _ => Err(DaskPlannerError::Internal(format!( +// "Cannot determine Dask SQL type for '{}'", +// input_type +// ))), +// } +// } +// } +// } +// } + +// #[cfg(test)] +// mod test { +// use crate::sql::types::SqlTypeName; + +// #[test] +// fn invalid_type_name() { +// assert_eq!( +// "Internal Error: Cannot determine Dask SQL type for 'bob'", +// SqlTypeName::from_string("bob") +// .expect_err("invalid type name") +// .to_string() +// ); +// } + +// #[test] +// fn string() { +// assert_expected("VARCHAR", "string"); +// } + +// #[test] +// fn varchar_n() { +// assert_expected("VARCHAR", "VARCHAR(10)"); +// } + +// #[test] +// fn decimal_p_s() { +// assert_expected("DECIMAL", "DECIMAL(10, 2)"); +// } + +// fn assert_expected(expected: &str, input: &str) { +// assert_eq!( +// expected, +// &format!("{:?}", SqlTypeName::from_string(input).unwrap()) +// ); +// } +// } diff --git a/src/sql/types/arrow_type.rs b/src/sql/types/arrow_type.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/sql/types/data_type.rs b/src/sql/types/data_type.rs new file mode 100644 index 000000000..da4969677 --- /dev/null +++ b/src/sql/types/data_type.rs @@ -0,0 +1,242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt; + +use datafusion_common::{DFField, DFSchema}; +use pyo3::prelude::*; + +use super::DataTypeMap; + +const PRECISION_NOT_SPECIFIED: i32 = i32::MIN; +const SCALE_NOT_SPECIFIED: i32 = -1; + +/// RelDataTypeField represents the definition of a field in a structured RelDataType. +#[pyclass(name = "RelDataTypeField", module = "dask_planner", subclass)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct RelDataTypeField { + qualifier: Option, + name: String, + data_type: DataTypeMap, + index: usize, +} + +// // Functions that should not be presented to Python are placed here +// impl RelDataTypeField { +// pub fn from(field: &DFField, schema: &DFSchema) -> Result { +// let qualifier: Option<&str> = match field.qualifier() { +// Some(qualifier) => Some(qualifier), +// None => None, +// }; +// Ok(RelDataTypeField { +// qualifier: qualifier.map(|qualifier| qualifier.to_string()), +// name: field.name().clone(), +// data_type: DaskTypeMap { +// sql_type: SqlTypeName::from_arrow(field.data_type())?, +// data_type: field.data_type().clone().into(), +// }, +// index: schema.index_of_column_by_name(qualifier, field.name())?, +// }) +// } +// } + +// #[pymethods] +// impl RelDataTypeField { +// #[new] +// pub fn new(name: &str, type_map: DaskTypeMap, index: usize) -> Self { +// Self { +// qualifier: None, +// name: name.to_owned(), +// data_type: type_map, +// index, +// } +// } + +// #[pyo3(name = "getQualifier")] +// pub fn qualifier(&self) -> Option { +// self.qualifier.clone() +// } + +// #[pyo3(name = "getName")] +// pub fn name(&self) -> &str { +// &self.name +// } + +// #[pyo3(name = "getQualifiedName")] +// pub fn qualified_name(&self) -> String { +// match &self.qualifier() { +// Some(qualifier) => format!("{}.{}", &qualifier, self.name()), +// None => self.name().to_string(), +// } +// } + +// #[pyo3(name = "getIndex")] +// pub fn index(&self) -> usize { +// self.index +// } + +// #[pyo3(name = "getType")] +// pub fn data_type(&self) -> DaskTypeMap { +// self.data_type.clone() +// } + +// /// Since this logic is being ported from Java getKey is synonymous with getName. +// /// Alas it is used in certain places so it is implemented here to allow other +// /// places in the code base to not have to change. +// #[pyo3(name = "getKey")] +// pub fn get_key(&self) -> &str { +// self.name() +// } + +// /// Since this logic is being ported from Java getValue is synonymous with getType. +// /// Alas it is used in certain places so it is implemented here to allow other +// /// places in the code base to not have to change. +// #[pyo3(name = "getValue")] +// pub fn get_value(&self) -> DaskTypeMap { +// self.data_type() +// } + +// #[pyo3(name = "setValue")] +// pub fn set_value(&mut self, data_type: DaskTypeMap) { +// self.data_type = data_type +// } + +// // TODO: Uncomment after implementing in RelDataType +// // #[pyo3(name = "isDynamicStar")] +// // pub fn is_dynamic_star(&self) -> bool { +// // self.data_type.getSqlTypeName() == SqlTypeName.DYNAMIC_STAR +// // } +// } + +// impl fmt::Display for RelDataTypeField { +// fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { +// fmt.write_str("Field: ")?; +// fmt.write_str(&self.name)?; +// fmt.write_str(" - Index: ")?; +// fmt.write_str(&self.index.to_string())?; +// // TODO: Uncomment this after implementing the Display trait in RelDataType +// // fmt.write_str(" - DataType: ")?; +// // fmt.write_str(self.data_type.to_string())?; +// Ok(()) +// } +// } + +// /// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. +// #[pyclass(name = "RelDataType", module = "dask_planner", subclass)] +// #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +// pub struct RelDataType { +// nullable: bool, +// field_list: Vec, +// } + +// /// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. +// #[pymethods] +// impl RelDataType { +// #[new] +// pub fn new(nullable: bool, fields: Vec) -> Self { +// Self { +// nullable, +// field_list: fields, +// } +// } + +// /// Looks up a field by name. +// /// +// /// # Arguments +// /// +// /// * `field_name` - A String containing the name of the field to find +// /// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise +// #[pyo3(name = "getField")] +// pub fn field(&self, field_name: &str, case_sensitive: bool) -> PyResult { +// let field_map: HashMap = self.field_map(); +// if case_sensitive && !field_map.is_empty() { +// Ok(field_map.get(field_name).unwrap().clone()) +// } else { +// for field in &self.field_list { +// if (case_sensitive && field.name().eq(field_name)) +// || (!case_sensitive && field.name().eq_ignore_ascii_case(field_name)) +// { +// return Ok(field.clone()); +// } +// } + +// // TODO: Throw a proper error here +// Err(py_runtime_err(format!( +// "Unable to find RelDataTypeField with name {:?} in the RelDataType field_list", +// field_name, +// ))) +// } +// } + +// /// Returns a map from field names to fields. +// /// +// /// # Notes +// /// +// /// * If several fields have the same name, the map contains the first. +// #[pyo3(name = "getFieldMap")] +// pub fn field_map(&self) -> HashMap { +// let mut fields: HashMap = HashMap::new(); +// for field in &self.field_list { +// fields.insert(String::from(field.name()), field.clone()); +// } +// fields +// } + +// /// Gets the fields in a struct type. The field count is equal to the size of the returned list. +// #[pyo3(name = "getFieldList")] +// pub fn field_list(&self) -> Vec { +// self.field_list.clone() +// } + +// /// Returns the names of all of the columns in a given DaskTable +// #[pyo3(name = "getFieldNames")] +// pub fn field_names(&self) -> Vec { +// let mut field_names: Vec = Vec::new(); +// for field in &self.field_list { +// field_names.push(field.qualified_name()); +// } +// field_names +// } + +// /// Returns the number of fields in a struct type. +// #[pyo3(name = "getFieldCount")] +// pub fn field_count(&self) -> usize { +// self.field_list.len() +// } + +// #[pyo3(name = "isStruct")] +// pub fn is_struct(&self) -> bool { +// !self.field_list.is_empty() +// } + +// /// Queries whether this type allows null values. +// #[pyo3(name = "isNullable")] +// pub fn is_nullable(&self) -> bool { +// self.nullable +// } + +// #[pyo3(name = "getPrecision")] +// pub fn precision(&self) -> i32 { +// PRECISION_NOT_SPECIFIED +// } + +// #[pyo3(name = "getScale")] +// pub fn scale(&self) -> i32 { +// SCALE_NOT_SPECIFIED +// } +// } diff --git a/src/sql/types/datafusion_type.rs b/src/sql/types/datafusion_type.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/sql/types/sql_type.rs b/src/sql/types/sql_type.rs new file mode 100644 index 000000000..333c6420e --- /dev/null +++ b/src/sql/types/sql_type.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Enumeration of the type names which can be used to construct a SQL type. Since +/// several SQL types do not exist as Rust types and also because the Enum +/// `SqlTypeName` is already used in the Python Dask-SQL code base this enum is used +/// in place of just using the built-in Rust types. +#[allow(non_camel_case_types)] +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "SqlType", module = "datafusion")] +pub enum SqlType { + ANY, + ARRAY, + BIGINT, + BINARY, + BOOLEAN, + CHAR, + COLUMN_LIST, + CURSOR, + DATE, + DECIMAL, + DISTINCT, + DOUBLE, + DYNAMIC_STAR, + FLOAT, + GEOMETRY, + INTEGER, + INTERVAL, + INTERVAL_DAY, + INTERVAL_DAY_HOUR, + INTERVAL_DAY_MINUTE, + INTERVAL_DAY_SECOND, + INTERVAL_HOUR, + INTERVAL_HOUR_MINUTE, + INTERVAL_HOUR_SECOND, + INTERVAL_MINUTE, + INTERVAL_MINUTE_SECOND, + INTERVAL_MONTH, + INTERVAL_SECOND, + INTERVAL_YEAR, + INTERVAL_YEAR_MONTH, + MAP, + MULTISET, + NULL, + OTHER, + REAL, + ROW, + SARG, + SMALLINT, + STRUCTURED, + SYMBOL, + TIME, + TIME_WITH_LOCAL_TIME_ZONE, + TIMESTAMP, + TIMESTAMP_WITH_LOCAL_TIME_ZONE, + TINYINT, + UNKNOWN, + VARBINARY, + VARCHAR, +} From de52692c093e61fdf73ab5f0ef4cda5122368b12 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sat, 4 Feb 2023 16:00:02 -0500 Subject: [PATCH 02/16] checkpoint commit --- src/common/data_type.rs | 265 +++++++++++++++++++ src/common/{dffield.rs => df_field.rs} | 4 +- src/common/mod.rs | 3 +- src/sql.rs | 1 - src/sql/types.rs | 344 ------------------------- src/sql/types/arrow_type.rs | 0 src/sql/types/data_type.rs | 242 ----------------- src/sql/types/datafusion_type.rs | 0 src/sql/types/sql_type.rs | 75 ------ 9 files changed, 269 insertions(+), 665 deletions(-) create mode 100644 src/common/data_type.rs rename src/common/{dffield.rs => df_field.rs} (93%) delete mode 100644 src/sql/types.rs delete mode 100644 src/sql/types/arrow_type.rs delete mode 100644 src/sql/types/data_type.rs delete mode 100644 src/sql/types/datafusion_type.rs delete mode 100644 src/sql/types/sql_type.rs diff --git a/src/common/data_type.rs b/src/common/data_type.rs new file mode 100644 index 000000000..174ae57fe --- /dev/null +++ b/src/common/data_type.rs @@ -0,0 +1,265 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + +use datafusion::arrow::datatypes::DataType; +use pyo3::prelude::*; + + +/// These bindings are tying together several disparate systems. +/// You have SQL types for the SQL strings and RDBMS systems itself. +/// Rust types for the DataFusion code +/// Arrow types which represents the underlying arrow format +/// Python types which represent the type in Python +/// It is important to keep all of those types in a single +/// and managable location. Therefore this structure exists +/// to map those types and provide a simple place for developers +/// to map types from one system to another. +#[derive(Debug, Clone)] +#[pyclass(name = "DataTypeMap", module = "datafusion", subclass)] +pub struct DataTypeMap { + arrow_type: PyDataType, + python_type: PythonType, + sql_type: SqlType, +} + +#[pymethods] +impl DataTypeMap { + + #[staticmethod] + #[pyo3(name = "arrow")] + pub fn map_from_arrow_type(arrow_type: &PyDataType) -> PyResult { + match arrow_type.data_type { + DataType::Null => todo!(), + DataType::Boolean => todo!(), + DataType::Int8 => todo!(), + DataType::Int16 => todo!(), + DataType::Int32 => todo!(), + DataType::Int64 => todo!(), + DataType::UInt8 => todo!(), + DataType::UInt16 => todo!(), + DataType::UInt32 => todo!(), + DataType::UInt64 => todo!(), + DataType::Float16 => todo!(), + DataType::Float32 => todo!(), + DataType::Float64 => todo!(), + DataType::Timestamp(_, _) => todo!(), + DataType::Date32 => todo!(), + DataType::Date64 => todo!(), + DataType::Time32(_) => todo!(), + DataType::Time64(_) => todo!(), + DataType::Duration(_) => todo!(), + DataType::Interval(_) => todo!(), + DataType::Binary => todo!(), + DataType::FixedSizeBinary(_) => todo!(), + DataType::LargeBinary => todo!(), + DataType::Utf8 => todo!(), + DataType::LargeUtf8 => todo!(), + DataType::List(_) => todo!(), + DataType::FixedSizeList(_, _) => todo!(), + DataType::LargeList(_) => todo!(), + DataType::Struct(_) => todo!(), + DataType::Union(_, _, _) => todo!(), + DataType::Dictionary(_, _) => todo!(), + DataType::Decimal128(_, _) => todo!(), + DataType::Decimal256(_, _) => todo!(), + DataType::Map(_, _) => todo!(), + } + } + + #[staticmethod] + #[pyo3(name = "sql")] + pub fn map_from_sql_type(sql_type: &SqlType) -> PyResult { + + let data_type: DataTypeMap = match sql_type { + SqlType::ANY => unimplemented!(), + SqlType::ARRAY => todo!(), // unsure which type to use for DataType in this situation? + SqlType::BIGINT => DataTypeMap { + arrow_type: PyDataType { data_type: DataType::Int64 }, + python_type: PythonType::Float64, // According to https://learn.microsoft.com/en-us/sql/machine-learning/python/python-libraries-and-data-types?view=sql-server-ver16 should be float + sql_type: SqlType::BIGINT + }, + SqlType::BINARY => DataTypeMap { + arrow_type: PyDataType { data_type: DataType::Binary }, + python_type: PythonType::Bytes, + sql_type: SqlType::BINARY + }, + SqlType::BOOLEAN => DataTypeMap { + arrow_type: PyDataType { data_type: DataType::Boolean }, + python_type: PythonType::Bool, + sql_type: SqlType::BOOLEAN + }, + SqlType::CHAR => DataTypeMap { + arrow_type: PyDataType { data_type: DataType::UInt8 }, + python_type: PythonType::Int32, + sql_type: SqlType::CHAR + }, + SqlType::COLUMN_LIST => unimplemented!(), + SqlType::CURSOR => unimplemented!(), + SqlType::DATE => DataTypeMap { + arrow_type: PyDataType { data_type: DataType::Date64 }, + python_type: PythonType::Datetime, + sql_type: SqlType::DATE + }, + SqlType::DECIMAL => todo!(), + SqlType::DISTINCT => unimplemented!(), + SqlType::DOUBLE => todo!(), + SqlType::DYNAMIC_STAR => unimplemented!(), + SqlType::FLOAT => todo!(), + SqlType::GEOMETRY => unimplemented!(), + SqlType::INTEGER => DataTypeMap { + arrow_type: PyDataType { data_type: DataType::Int8 }, + python_type: PythonType::Int32, + sql_type: SqlType::INTEGER + }, + SqlType::INTERVAL => todo!(), + SqlType::INTERVAL_DAY => todo!(), + SqlType::INTERVAL_DAY_HOUR => todo!(), + SqlType::INTERVAL_DAY_MINUTE => todo!(), + SqlType::INTERVAL_DAY_SECOND => todo!(), + SqlType::INTERVAL_HOUR => todo!(), + SqlType::INTERVAL_HOUR_MINUTE => todo!(), + SqlType::INTERVAL_HOUR_SECOND => todo!(), + SqlType::INTERVAL_MINUTE => todo!(), + SqlType::INTERVAL_MINUTE_SECOND => todo!(), + SqlType::INTERVAL_MONTH => todo!(), + SqlType::INTERVAL_SECOND => todo!(), + SqlType::INTERVAL_YEAR => todo!(), + SqlType::INTERVAL_YEAR_MONTH => todo!(), + SqlType::MAP => todo!(), + SqlType::MULTISET => unimplemented!(), + SqlType::NULL => todo!(), + SqlType::OTHER => unimplemented!(), + SqlType::REAL => todo!(), + SqlType::ROW => todo!(), + SqlType::SARG => unimplemented!(), + SqlType::SMALLINT => todo!(), + SqlType::STRUCTURED => unimplemented!(), + SqlType::SYMBOL => unimplemented!(), + SqlType::TIME => todo!(), + SqlType::TIME_WITH_LOCAL_TIME_ZONE => todo!(), + SqlType::TIMESTAMP => todo!(), + SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => todo!(), + SqlType::TINYINT => todo!(), + SqlType::UNKNOWN => todo!(), + SqlType::VARBINARY => todo!(), + SqlType::VARCHAR => DataTypeMap { + arrow_type: PyDataType { data_type: DataType::Utf8 }, + python_type: PythonType::Str, + sql_type: SqlType::VARCHAR + }, + }; + + Ok(data_type) + } +} + + +/// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass` +/// Since `DataType` exists in another package we cannot make that happen here so we wrap +/// `DataType` as `PyDataType` This exists solely to satisfy those constraints. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "DataType", module = "datafusion")] +pub struct PyDataType { + data_type: DataType, +} + +impl From for DataType { + fn from(data_type: PyDataType) -> DataType { + data_type.data_type + } +} + +impl From for PyDataType { + fn from(data_type: DataType) -> PyDataType { + PyDataType { data_type } + } +} + +/// Represents the possible Python types that can be mapped to the SQL types +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "PythonType", module = "datafusion")] +pub enum PythonType { + Array, + Bool, + Bytes, + Datetime, + Float64, + Int32, + List, + None, + Object, + Str, +} + +/// Represents the types that are possible for DataFusion to parse +/// from a SQL query. Aka "SqlType" and are valid values for +/// ANSI SQL +#[allow(non_camel_case_types)] +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[pyclass(name = "SqlType", module = "datafusion")] +pub enum SqlType { + ANY, + ARRAY, + BIGINT, + BINARY, + BOOLEAN, + CHAR, + COLUMN_LIST, + CURSOR, + DATE, + DECIMAL, + DISTINCT, + DOUBLE, + DYNAMIC_STAR, + FLOAT, + GEOMETRY, + INTEGER, + INTERVAL, + INTERVAL_DAY, + INTERVAL_DAY_HOUR, + INTERVAL_DAY_MINUTE, + INTERVAL_DAY_SECOND, + INTERVAL_HOUR, + INTERVAL_HOUR_MINUTE, + INTERVAL_HOUR_SECOND, + INTERVAL_MINUTE, + INTERVAL_MINUTE_SECOND, + INTERVAL_MONTH, + INTERVAL_SECOND, + INTERVAL_YEAR, + INTERVAL_YEAR_MONTH, + MAP, + MULTISET, + NULL, + OTHER, + REAL, + ROW, + SARG, + SMALLINT, + STRUCTURED, + SYMBOL, + TIME, + TIME_WITH_LOCAL_TIME_ZONE, + TIMESTAMP, + TIMESTAMP_WITH_LOCAL_TIME_ZONE, + TINYINT, + UNKNOWN, + VARBINARY, + VARCHAR, +} diff --git a/src/common/dffield.rs b/src/common/df_field.rs similarity index 93% rename from src/common/dffield.rs rename to src/common/df_field.rs index 78dbdbbb7..1c1aecc6c 100644 --- a/src/common/dffield.rs +++ b/src/common/df_field.rs @@ -18,13 +18,13 @@ use datafusion::arrow::datatypes::Field; use pyo3::prelude::*; -use crate::sql::types::DataTypeMap; +use crate::common::data_type::DataTypeMap; /// PyDFField wraps an arrow-datafusion `DFField` struct type /// and also supplies convenience methods for interacting /// with the `DFField` instance in the context of Python #[pyclass(name = "DFField", module = "datafusion", subclass)] -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone)] pub struct PyDFField { /// Optional qualifier (usually a table or relation name) qualifier: Option, diff --git a/src/common/mod.rs b/src/common/mod.rs index 4c66ffd44..ce8612bea 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -15,4 +15,5 @@ // specific language governing permissions and limitations // under the License. -pub mod dffield; +pub mod df_field; +pub mod data_type; diff --git a/src/sql.rs b/src/sql.rs index 2ba2d1011..9f1fe81be 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -17,4 +17,3 @@ pub mod exceptions; pub mod logical; -pub mod types; diff --git a/src/sql/types.rs b/src/sql/types.rs deleted file mode 100644 index 3a5d1e884..000000000 --- a/src/sql/types.rs +++ /dev/null @@ -1,344 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -pub mod data_type; -pub mod sql_type; - -use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_sql::sqlparser::{ast::DataType as SQLType, parser::Parser, tokenizer::Tokenizer}; -use pyo3::{prelude::*, types::PyDict}; - -use self::sql_type::SqlType; - -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(name = "RexType", module = "datafusion")] -pub enum RexType { - Literal, - Call, - Reference, - SubqueryAlias, - Other, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(name = "DataTypeMap", module = "datafusion", subclass)] -pub struct DataTypeMap { - sql_type: SqlType, - arrow_type: DataType, -} - - -// /// Functions not exposed to Python -// impl DaskTypeMap { -// pub fn from(sql_type: SqlTypeName, data_type: PyDataType) -> Self { -// DaskTypeMap { -// sql_type, -// data_type, -// } -// } -// } - -// #[pymethods] -// impl DaskTypeMap { -// #[new] -// #[pyo3(signature = (sql_type, **py_kwargs))] -// fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> PyResult { -// let d_type: DataType = match sql_type { -// SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => { -// let (unit, tz) = match py_kwargs { -// Some(dict) => { -// let tz: Option = match dict.get_item("tz") { -// Some(e) => { -// let res: PyResult = e.extract(); -// Some(res.unwrap()) -// } -// None => None, -// }; -// let unit: TimeUnit = match dict.get_item("unit") { -// Some(e) => { -// let res: PyResult<&str> = e.extract(); -// match res.unwrap() { -// "Second" => TimeUnit::Second, -// "Millisecond" => TimeUnit::Millisecond, -// "Microsecond" => TimeUnit::Microsecond, -// "Nanosecond" => TimeUnit::Nanosecond, -// _ => TimeUnit::Nanosecond, -// } -// } -// // Default to Nanosecond which is common if not present -// None => TimeUnit::Nanosecond, -// }; -// (unit, tz) -// } -// // Default to Nanosecond and None for tz which is common if not present -// None => (TimeUnit::Nanosecond, None), -// }; -// DataType::Timestamp(unit, tz) -// } -// SqlTypeName::TIMESTAMP => { -// let (unit, tz) = match py_kwargs { -// Some(dict) => { -// let tz: Option = match dict.get_item("tz") { -// Some(e) => { -// let res: PyResult = e.extract(); -// Some(res.unwrap()) -// } -// None => None, -// }; -// let unit: TimeUnit = match dict.get_item("unit") { -// Some(e) => { -// let res: PyResult<&str> = e.extract(); -// match res.unwrap() { -// "Second" => TimeUnit::Second, -// "Millisecond" => TimeUnit::Millisecond, -// "Microsecond" => TimeUnit::Microsecond, -// "Nanosecond" => TimeUnit::Nanosecond, -// _ => TimeUnit::Nanosecond, -// } -// } -// // Default to Nanosecond which is common if not present -// None => TimeUnit::Nanosecond, -// }; -// (unit, tz) -// } -// // Default to Nanosecond and None for tz which is common if not present -// None => (TimeUnit::Nanosecond, None), -// }; -// DataType::Timestamp(unit, tz) -// } -// _ => sql_type.to_arrow()?, -// }; - -// Ok(DaskTypeMap { -// sql_type, -// data_type: d_type.into(), -// }) -// } - -// #[pyo3(name = "getSqlType")] -// pub fn sql_type(&self) -> SqlTypeName { -// self.sql_type.clone() -// } - -// #[pyo3(name = "getDataType")] -// pub fn data_type(&self) -> PyDataType { -// self.data_type.clone() -// } -// } - -// #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -// #[pyclass(name = "PyDataType", module = "datafusion", subclass)] -// pub struct PyDataType { -// data_type: DataType, -// } - -// impl From for DataType { -// fn from(data_type: PyDataType) -> DataType { -// data_type.data_type -// } -// } - -// impl From for PyDataType { -// fn from(data_type: DataType) -> PyDataType { -// PyDataType { data_type } -// } -// } - - -// impl SqlTypeName { -// pub fn to_arrow(&self) -> Result { -// match self { -// SqlTypeName::NULL => Ok(DataType::Null), -// SqlTypeName::BOOLEAN => Ok(DataType::Boolean), -// SqlTypeName::TINYINT => Ok(DataType::Int8), -// SqlTypeName::SMALLINT => Ok(DataType::Int16), -// SqlTypeName::INTEGER => Ok(DataType::Int32), -// SqlTypeName::BIGINT => Ok(DataType::Int64), -// SqlTypeName::REAL => Ok(DataType::Float16), -// SqlTypeName::FLOAT => Ok(DataType::Float32), -// SqlTypeName::DOUBLE => Ok(DataType::Float64), -// SqlTypeName::DATE => Ok(DataType::Date64), -// SqlTypeName::VARCHAR => Ok(DataType::Utf8), -// _ => Err(DaskPlannerError::Internal(format!( -// "Cannot determine Arrow type for Dask SQL type '{:?}'", -// self -// ))), -// } -// } - -// pub fn from_arrow(arrow_type: &DataType) -> Result { -// match arrow_type { -// DataType::Null => Ok(SqlTypeName::NULL), -// DataType::Boolean => Ok(SqlTypeName::BOOLEAN), -// DataType::Int8 => Ok(SqlTypeName::TINYINT), -// DataType::Int16 => Ok(SqlTypeName::SMALLINT), -// DataType::Int32 => Ok(SqlTypeName::INTEGER), -// DataType::Int64 => Ok(SqlTypeName::BIGINT), -// DataType::UInt8 => Ok(SqlTypeName::TINYINT), -// DataType::UInt16 => Ok(SqlTypeName::SMALLINT), -// DataType::UInt32 => Ok(SqlTypeName::INTEGER), -// DataType::UInt64 => Ok(SqlTypeName::BIGINT), -// DataType::Float16 => Ok(SqlTypeName::REAL), -// DataType::Float32 => Ok(SqlTypeName::FLOAT), -// DataType::Float64 => Ok(SqlTypeName::DOUBLE), -// DataType::Time32(_) | DataType::Time64(_) => Ok(SqlTypeName::TIME), -// DataType::Timestamp(_unit, tz) => match tz { -// Some(_) => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE), -// None => Ok(SqlTypeName::TIMESTAMP), -// }, -// DataType::Date32 => Ok(SqlTypeName::DATE), -// DataType::Date64 => Ok(SqlTypeName::DATE), -// DataType::Interval(unit) => match unit { -// IntervalUnit::DayTime => Ok(SqlTypeName::INTERVAL_DAY), -// IntervalUnit::YearMonth => Ok(SqlTypeName::INTERVAL_YEAR_MONTH), -// IntervalUnit::MonthDayNano => Ok(SqlTypeName::INTERVAL_MONTH), -// }, -// DataType::Binary => Ok(SqlTypeName::BINARY), -// DataType::FixedSizeBinary(_size) => Ok(SqlTypeName::VARBINARY), -// DataType::Utf8 => Ok(SqlTypeName::CHAR), -// DataType::LargeUtf8 => Ok(SqlTypeName::VARCHAR), -// DataType::Struct(_fields) => Ok(SqlTypeName::STRUCTURED), -// DataType::Decimal128(_precision, _scale) => Ok(SqlTypeName::DECIMAL), -// DataType::Decimal256(_precision, _scale) => Ok(SqlTypeName::DECIMAL), -// DataType::Map(_field, _bool) => Ok(SqlTypeName::MAP), -// _ => Err(DaskPlannerError::Internal(format!( -// "Cannot determine Dask SQL type for Arrow type '{:?}'", -// arrow_type -// ))), -// } -// } -// } - -// #[pymethods] -// impl SqlTypeName { -// #[pyo3(name = "fromString")] -// #[staticmethod] -// pub fn py_from_string(input_type: &str) -> PyResult { -// SqlTypeName::from_string(input_type).map_err(|e| e.into()) -// } -// } - -// impl SqlTypeName { -// pub fn from_string(input_type: &str) -> Result { -// match input_type.to_uppercase().as_ref() { -// "ANY" => Ok(SqlTypeName::ANY), -// "ARRAY" => Ok(SqlTypeName::ARRAY), -// "NULL" => Ok(SqlTypeName::NULL), -// "BOOLEAN" => Ok(SqlTypeName::BOOLEAN), -// "COLUMN_LIST" => Ok(SqlTypeName::COLUMN_LIST), -// "DISTINCT" => Ok(SqlTypeName::DISTINCT), -// "CURSOR" => Ok(SqlTypeName::CURSOR), -// "TINYINT" => Ok(SqlTypeName::TINYINT), -// "SMALLINT" => Ok(SqlTypeName::SMALLINT), -// "INT" => Ok(SqlTypeName::INTEGER), -// "INTEGER" => Ok(SqlTypeName::INTEGER), -// "BIGINT" => Ok(SqlTypeName::BIGINT), -// "REAL" => Ok(SqlTypeName::REAL), -// "FLOAT" => Ok(SqlTypeName::FLOAT), -// "GEOMETRY" => Ok(SqlTypeName::GEOMETRY), -// "DOUBLE" => Ok(SqlTypeName::DOUBLE), -// "TIME" => Ok(SqlTypeName::TIME), -// "TIME_WITH_LOCAL_TIME_ZONE" => Ok(SqlTypeName::TIME_WITH_LOCAL_TIME_ZONE), -// "TIMESTAMP" => Ok(SqlTypeName::TIMESTAMP), -// "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE), -// "DATE" => Ok(SqlTypeName::DATE), -// "INTERVAL" => Ok(SqlTypeName::INTERVAL), -// "INTERVAL_DAY" => Ok(SqlTypeName::INTERVAL_DAY), -// "INTERVAL_DAY_HOUR" => Ok(SqlTypeName::INTERVAL_DAY_HOUR), -// "INTERVAL_DAY_MINUTE" => Ok(SqlTypeName::INTERVAL_DAY_MINUTE), -// "INTERVAL_DAY_SECOND" => Ok(SqlTypeName::INTERVAL_DAY_SECOND), -// "INTERVAL_HOUR" => Ok(SqlTypeName::INTERVAL_HOUR), -// "INTERVAL_HOUR_MINUTE" => Ok(SqlTypeName::INTERVAL_HOUR_MINUTE), -// "INTERVAL_HOUR_SECOND" => Ok(SqlTypeName::INTERVAL_HOUR_SECOND), -// "INTERVAL_MINUTE" => Ok(SqlTypeName::INTERVAL_MINUTE), -// "INTERVAL_MINUTE_SECOND" => Ok(SqlTypeName::INTERVAL_MINUTE_SECOND), -// "INTERVAL_MONTH" => Ok(SqlTypeName::INTERVAL_MONTH), -// "INTERVAL_SECOND" => Ok(SqlTypeName::INTERVAL_SECOND), -// "INTERVAL_YEAR" => Ok(SqlTypeName::INTERVAL_YEAR), -// "INTERVAL_YEAR_MONTH" => Ok(SqlTypeName::INTERVAL_YEAR_MONTH), -// "MAP" => Ok(SqlTypeName::MAP), -// "MULTISET" => Ok(SqlTypeName::MULTISET), -// "OTHER" => Ok(SqlTypeName::OTHER), -// "ROW" => Ok(SqlTypeName::ROW), -// "SARG" => Ok(SqlTypeName::SARG), -// "BINARY" => Ok(SqlTypeName::BINARY), -// "VARBINARY" => Ok(SqlTypeName::VARBINARY), -// "CHAR" => Ok(SqlTypeName::CHAR), -// "VARCHAR" | "STRING" => Ok(SqlTypeName::VARCHAR), -// "STRUCTURED" => Ok(SqlTypeName::STRUCTURED), -// "SYMBOL" => Ok(SqlTypeName::SYMBOL), -// "DECIMAL" => Ok(SqlTypeName::DECIMAL), -// "DYNAMIC_STAT" => Ok(SqlTypeName::DYNAMIC_STAR), -// "UNKNOWN" => Ok(SqlTypeName::UNKNOWN), -// _ => { -// // complex data type name so use the sqlparser -// let dialect = DaskDialect {}; -// let mut tokenizer = Tokenizer::new(&dialect, input_type); -// let tokens = tokenizer.tokenize().map_err(DaskPlannerError::from)?; -// let mut parser = Parser::new(tokens, &dialect); -// match parser.parse_data_type().map_err(DaskPlannerError::from)? { -// SQLType::Decimal(_) => Ok(SqlTypeName::DECIMAL), -// SQLType::Binary(_) => Ok(SqlTypeName::BINARY), -// SQLType::Varbinary(_) => Ok(SqlTypeName::VARBINARY), -// SQLType::Varchar(_) | SQLType::Nvarchar(_) => Ok(SqlTypeName::VARCHAR), -// SQLType::Char(_) => Ok(SqlTypeName::CHAR), -// _ => Err(DaskPlannerError::Internal(format!( -// "Cannot determine Dask SQL type for '{}'", -// input_type -// ))), -// } -// } -// } -// } -// } - -// #[cfg(test)] -// mod test { -// use crate::sql::types::SqlTypeName; - -// #[test] -// fn invalid_type_name() { -// assert_eq!( -// "Internal Error: Cannot determine Dask SQL type for 'bob'", -// SqlTypeName::from_string("bob") -// .expect_err("invalid type name") -// .to_string() -// ); -// } - -// #[test] -// fn string() { -// assert_expected("VARCHAR", "string"); -// } - -// #[test] -// fn varchar_n() { -// assert_expected("VARCHAR", "VARCHAR(10)"); -// } - -// #[test] -// fn decimal_p_s() { -// assert_expected("DECIMAL", "DECIMAL(10, 2)"); -// } - -// fn assert_expected(expected: &str, input: &str) { -// assert_eq!( -// expected, -// &format!("{:?}", SqlTypeName::from_string(input).unwrap()) -// ); -// } -// } diff --git a/src/sql/types/arrow_type.rs b/src/sql/types/arrow_type.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/sql/types/data_type.rs b/src/sql/types/data_type.rs deleted file mode 100644 index da4969677..000000000 --- a/src/sql/types/data_type.rs +++ /dev/null @@ -1,242 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashMap; -use std::fmt; - -use datafusion_common::{DFField, DFSchema}; -use pyo3::prelude::*; - -use super::DataTypeMap; - -const PRECISION_NOT_SPECIFIED: i32 = i32::MIN; -const SCALE_NOT_SPECIFIED: i32 = -1; - -/// RelDataTypeField represents the definition of a field in a structured RelDataType. -#[pyclass(name = "RelDataTypeField", module = "dask_planner", subclass)] -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct RelDataTypeField { - qualifier: Option, - name: String, - data_type: DataTypeMap, - index: usize, -} - -// // Functions that should not be presented to Python are placed here -// impl RelDataTypeField { -// pub fn from(field: &DFField, schema: &DFSchema) -> Result { -// let qualifier: Option<&str> = match field.qualifier() { -// Some(qualifier) => Some(qualifier), -// None => None, -// }; -// Ok(RelDataTypeField { -// qualifier: qualifier.map(|qualifier| qualifier.to_string()), -// name: field.name().clone(), -// data_type: DaskTypeMap { -// sql_type: SqlTypeName::from_arrow(field.data_type())?, -// data_type: field.data_type().clone().into(), -// }, -// index: schema.index_of_column_by_name(qualifier, field.name())?, -// }) -// } -// } - -// #[pymethods] -// impl RelDataTypeField { -// #[new] -// pub fn new(name: &str, type_map: DaskTypeMap, index: usize) -> Self { -// Self { -// qualifier: None, -// name: name.to_owned(), -// data_type: type_map, -// index, -// } -// } - -// #[pyo3(name = "getQualifier")] -// pub fn qualifier(&self) -> Option { -// self.qualifier.clone() -// } - -// #[pyo3(name = "getName")] -// pub fn name(&self) -> &str { -// &self.name -// } - -// #[pyo3(name = "getQualifiedName")] -// pub fn qualified_name(&self) -> String { -// match &self.qualifier() { -// Some(qualifier) => format!("{}.{}", &qualifier, self.name()), -// None => self.name().to_string(), -// } -// } - -// #[pyo3(name = "getIndex")] -// pub fn index(&self) -> usize { -// self.index -// } - -// #[pyo3(name = "getType")] -// pub fn data_type(&self) -> DaskTypeMap { -// self.data_type.clone() -// } - -// /// Since this logic is being ported from Java getKey is synonymous with getName. -// /// Alas it is used in certain places so it is implemented here to allow other -// /// places in the code base to not have to change. -// #[pyo3(name = "getKey")] -// pub fn get_key(&self) -> &str { -// self.name() -// } - -// /// Since this logic is being ported from Java getValue is synonymous with getType. -// /// Alas it is used in certain places so it is implemented here to allow other -// /// places in the code base to not have to change. -// #[pyo3(name = "getValue")] -// pub fn get_value(&self) -> DaskTypeMap { -// self.data_type() -// } - -// #[pyo3(name = "setValue")] -// pub fn set_value(&mut self, data_type: DaskTypeMap) { -// self.data_type = data_type -// } - -// // TODO: Uncomment after implementing in RelDataType -// // #[pyo3(name = "isDynamicStar")] -// // pub fn is_dynamic_star(&self) -> bool { -// // self.data_type.getSqlTypeName() == SqlTypeName.DYNAMIC_STAR -// // } -// } - -// impl fmt::Display for RelDataTypeField { -// fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { -// fmt.write_str("Field: ")?; -// fmt.write_str(&self.name)?; -// fmt.write_str(" - Index: ")?; -// fmt.write_str(&self.index.to_string())?; -// // TODO: Uncomment this after implementing the Display trait in RelDataType -// // fmt.write_str(" - DataType: ")?; -// // fmt.write_str(self.data_type.to_string())?; -// Ok(()) -// } -// } - -// /// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. -// #[pyclass(name = "RelDataType", module = "dask_planner", subclass)] -// #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -// pub struct RelDataType { -// nullable: bool, -// field_list: Vec, -// } - -// /// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. -// #[pymethods] -// impl RelDataType { -// #[new] -// pub fn new(nullable: bool, fields: Vec) -> Self { -// Self { -// nullable, -// field_list: fields, -// } -// } - -// /// Looks up a field by name. -// /// -// /// # Arguments -// /// -// /// * `field_name` - A String containing the name of the field to find -// /// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise -// #[pyo3(name = "getField")] -// pub fn field(&self, field_name: &str, case_sensitive: bool) -> PyResult { -// let field_map: HashMap = self.field_map(); -// if case_sensitive && !field_map.is_empty() { -// Ok(field_map.get(field_name).unwrap().clone()) -// } else { -// for field in &self.field_list { -// if (case_sensitive && field.name().eq(field_name)) -// || (!case_sensitive && field.name().eq_ignore_ascii_case(field_name)) -// { -// return Ok(field.clone()); -// } -// } - -// // TODO: Throw a proper error here -// Err(py_runtime_err(format!( -// "Unable to find RelDataTypeField with name {:?} in the RelDataType field_list", -// field_name, -// ))) -// } -// } - -// /// Returns a map from field names to fields. -// /// -// /// # Notes -// /// -// /// * If several fields have the same name, the map contains the first. -// #[pyo3(name = "getFieldMap")] -// pub fn field_map(&self) -> HashMap { -// let mut fields: HashMap = HashMap::new(); -// for field in &self.field_list { -// fields.insert(String::from(field.name()), field.clone()); -// } -// fields -// } - -// /// Gets the fields in a struct type. The field count is equal to the size of the returned list. -// #[pyo3(name = "getFieldList")] -// pub fn field_list(&self) -> Vec { -// self.field_list.clone() -// } - -// /// Returns the names of all of the columns in a given DaskTable -// #[pyo3(name = "getFieldNames")] -// pub fn field_names(&self) -> Vec { -// let mut field_names: Vec = Vec::new(); -// for field in &self.field_list { -// field_names.push(field.qualified_name()); -// } -// field_names -// } - -// /// Returns the number of fields in a struct type. -// #[pyo3(name = "getFieldCount")] -// pub fn field_count(&self) -> usize { -// self.field_list.len() -// } - -// #[pyo3(name = "isStruct")] -// pub fn is_struct(&self) -> bool { -// !self.field_list.is_empty() -// } - -// /// Queries whether this type allows null values. -// #[pyo3(name = "isNullable")] -// pub fn is_nullable(&self) -> bool { -// self.nullable -// } - -// #[pyo3(name = "getPrecision")] -// pub fn precision(&self) -> i32 { -// PRECISION_NOT_SPECIFIED -// } - -// #[pyo3(name = "getScale")] -// pub fn scale(&self) -> i32 { -// SCALE_NOT_SPECIFIED -// } -// } diff --git a/src/sql/types/datafusion_type.rs b/src/sql/types/datafusion_type.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/sql/types/sql_type.rs b/src/sql/types/sql_type.rs deleted file mode 100644 index 333c6420e..000000000 --- a/src/sql/types/sql_type.rs +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -/// Enumeration of the type names which can be used to construct a SQL type. Since -/// several SQL types do not exist as Rust types and also because the Enum -/// `SqlTypeName` is already used in the Python Dask-SQL code base this enum is used -/// in place of just using the built-in Rust types. -#[allow(non_camel_case_types)] -#[allow(clippy::upper_case_acronyms)] -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(name = "SqlType", module = "datafusion")] -pub enum SqlType { - ANY, - ARRAY, - BIGINT, - BINARY, - BOOLEAN, - CHAR, - COLUMN_LIST, - CURSOR, - DATE, - DECIMAL, - DISTINCT, - DOUBLE, - DYNAMIC_STAR, - FLOAT, - GEOMETRY, - INTEGER, - INTERVAL, - INTERVAL_DAY, - INTERVAL_DAY_HOUR, - INTERVAL_DAY_MINUTE, - INTERVAL_DAY_SECOND, - INTERVAL_HOUR, - INTERVAL_HOUR_MINUTE, - INTERVAL_HOUR_SECOND, - INTERVAL_MINUTE, - INTERVAL_MINUTE_SECOND, - INTERVAL_MONTH, - INTERVAL_SECOND, - INTERVAL_YEAR, - INTERVAL_YEAR_MONTH, - MAP, - MULTISET, - NULL, - OTHER, - REAL, - ROW, - SARG, - SMALLINT, - STRUCTURED, - SYMBOL, - TIME, - TIME_WITH_LOCAL_TIME_ZONE, - TIMESTAMP, - TIMESTAMP_WITH_LOCAL_TIME_ZONE, - TINYINT, - UNKNOWN, - VARBINARY, - VARCHAR, -} From 1465871a6e7e9ec4474e2619d04be7a321b82ffd Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sun, 5 Feb 2023 20:57:03 -0500 Subject: [PATCH 03/16] First mappings of types --- src/common/data_type.rs | 130 ++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 71 deletions(-) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 174ae57fe..06697351a 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -37,38 +37,57 @@ pub struct DataTypeMap { sql_type: SqlType, } +impl DataTypeMap { + fn new(arrow_type: DataType, python_type: PythonType, sql_type: SqlType) -> Self { + DataTypeMap { + arrow_type: PyDataType { data_type: arrow_type}, + python_type, + sql_type + } + } +} + #[pymethods] impl DataTypeMap { + + #[new] + pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self { + DataTypeMap { + arrow_type: arrow_type, + python_type, + sql_type + } + } #[staticmethod] #[pyo3(name = "arrow")] pub fn map_from_arrow_type(arrow_type: &PyDataType) -> PyResult { - match arrow_type.data_type { - DataType::Null => todo!(), - DataType::Boolean => todo!(), - DataType::Int8 => todo!(), - DataType::Int16 => todo!(), - DataType::Int32 => todo!(), - DataType::Int64 => todo!(), - DataType::UInt8 => todo!(), - DataType::UInt16 => todo!(), - DataType::UInt32 => todo!(), - DataType::UInt64 => todo!(), - DataType::Float16 => todo!(), - DataType::Float32 => todo!(), - DataType::Float64 => todo!(), + Ok(match arrow_type.data_type { + DataType::Null => DataTypeMap::new(DataType::Null, PythonType::None, SqlType::NULL), + DataType::Boolean => DataTypeMap::new(DataType::Boolean, PythonType::Bool, SqlType::BOOLEAN), + DataType::Int8 => DataTypeMap::new(DataType::Int8, PythonType::Int, SqlType::TINYINT), + DataType::Int16 => DataTypeMap::new(DataType::Int16, PythonType::Int, SqlType::SMALLINT), + DataType::Int32 => DataTypeMap::new(DataType::Int32, PythonType::Int, SqlType::INTEGER), + DataType::Int64 => DataTypeMap::new(DataType::Int64, PythonType::Int, SqlType::BIGINT), + DataType::UInt8 => DataTypeMap::new(DataType::UInt8, PythonType::Int, SqlType::TINYINT), + DataType::UInt16 => DataTypeMap::new(DataType::UInt16, PythonType::Int, SqlType::SMALLINT), + DataType::UInt32 => DataTypeMap::new(DataType::UInt32, PythonType::Int, SqlType::INTEGER), + DataType::UInt64 => DataTypeMap::new(DataType::UInt64, PythonType::Int, SqlType::BIGINT), + DataType::Float16 => DataTypeMap::new(DataType::Float16, PythonType::Float, SqlType::FLOAT), + DataType::Float32 => DataTypeMap::new(DataType::Float32, PythonType::Float, SqlType::FLOAT), + DataType::Float64 => DataTypeMap::new(DataType::Float64, PythonType::Float, SqlType::FLOAT), DataType::Timestamp(_, _) => todo!(), - DataType::Date32 => todo!(), - DataType::Date64 => todo!(), + DataType::Date32 => DataTypeMap::new(DataType::Date32, PythonType::Datetime, SqlType::DATE), + DataType::Date64 => DataTypeMap::new(DataType::Date64, PythonType::Datetime, SqlType::DATE), DataType::Time32(_) => todo!(), DataType::Time64(_) => todo!(), DataType::Duration(_) => todo!(), DataType::Interval(_) => todo!(), - DataType::Binary => todo!(), + DataType::Binary => DataTypeMap::new(DataType::Binary, PythonType::Bytes, SqlType::BINARY), DataType::FixedSizeBinary(_) => todo!(), - DataType::LargeBinary => todo!(), - DataType::Utf8 => todo!(), - DataType::LargeUtf8 => todo!(), + DataType::LargeBinary => DataTypeMap::new(DataType::LargeBinary, PythonType::Bytes, SqlType::BINARY), + DataType::Utf8 => DataTypeMap::new(DataType::Utf8, PythonType::Str, SqlType::VARCHAR), + DataType::LargeUtf8 => DataTypeMap::new(DataType::LargeUtf8, PythonType::Str, SqlType::VARCHAR), DataType::List(_) => todo!(), DataType::FixedSizeList(_, _) => todo!(), DataType::LargeList(_) => todo!(), @@ -78,54 +97,29 @@ impl DataTypeMap { DataType::Decimal128(_, _) => todo!(), DataType::Decimal256(_, _) => todo!(), DataType::Map(_, _) => todo!(), - } + }) } #[staticmethod] #[pyo3(name = "sql")] pub fn map_from_sql_type(sql_type: &SqlType) -> PyResult { - - let data_type: DataTypeMap = match sql_type { + Ok(match sql_type { SqlType::ANY => unimplemented!(), SqlType::ARRAY => todo!(), // unsure which type to use for DataType in this situation? - SqlType::BIGINT => DataTypeMap { - arrow_type: PyDataType { data_type: DataType::Int64 }, - python_type: PythonType::Float64, // According to https://learn.microsoft.com/en-us/sql/machine-learning/python/python-libraries-and-data-types?view=sql-server-ver16 should be float - sql_type: SqlType::BIGINT - }, - SqlType::BINARY => DataTypeMap { - arrow_type: PyDataType { data_type: DataType::Binary }, - python_type: PythonType::Bytes, - sql_type: SqlType::BINARY - }, - SqlType::BOOLEAN => DataTypeMap { - arrow_type: PyDataType { data_type: DataType::Boolean }, - python_type: PythonType::Bool, - sql_type: SqlType::BOOLEAN - }, - SqlType::CHAR => DataTypeMap { - arrow_type: PyDataType { data_type: DataType::UInt8 }, - python_type: PythonType::Int32, - sql_type: SqlType::CHAR - }, + SqlType::BIGINT => DataTypeMap::new(DataType::Int64, PythonType::Float, SqlType::BIGINT), + SqlType::BINARY => DataTypeMap::new(DataType::Binary, PythonType::Bytes, SqlType::BINARY), + SqlType::BOOLEAN => DataTypeMap::new(DataType::Boolean, PythonType::Bool, SqlType::BOOLEAN), + SqlType::CHAR => DataTypeMap::new(DataType::UInt8, PythonType::Int, SqlType::CHAR), SqlType::COLUMN_LIST => unimplemented!(), SqlType::CURSOR => unimplemented!(), - SqlType::DATE => DataTypeMap { - arrow_type: PyDataType { data_type: DataType::Date64 }, - python_type: PythonType::Datetime, - sql_type: SqlType::DATE - }, - SqlType::DECIMAL => todo!(), + SqlType::DATE => DataTypeMap::new(DataType::Date64, PythonType::Datetime, SqlType::DATE), + SqlType::DECIMAL => DataTypeMap::new(DataType::Decimal128(1, 1), PythonType::Float, SqlType::DECIMAL), SqlType::DISTINCT => unimplemented!(), - SqlType::DOUBLE => todo!(), + SqlType::DOUBLE => DataTypeMap::new(DataType::Decimal256(1, 1), PythonType::Float, SqlType::DOUBLE), SqlType::DYNAMIC_STAR => unimplemented!(), - SqlType::FLOAT => todo!(), + SqlType::FLOAT => DataTypeMap::new(DataType::Decimal128(1, 1), PythonType::Float, SqlType::FLOAT), SqlType::GEOMETRY => unimplemented!(), - SqlType::INTEGER => DataTypeMap { - arrow_type: PyDataType { data_type: DataType::Int8 }, - python_type: PythonType::Int32, - sql_type: SqlType::INTEGER - }, + SqlType::INTEGER => DataTypeMap::new(DataType::Int8, PythonType::Int, SqlType::INTEGER), SqlType::INTERVAL => todo!(), SqlType::INTERVAL_DAY => todo!(), SqlType::INTERVAL_DAY_HOUR => todo!(), @@ -142,29 +136,23 @@ impl DataTypeMap { SqlType::INTERVAL_YEAR_MONTH => todo!(), SqlType::MAP => todo!(), SqlType::MULTISET => unimplemented!(), - SqlType::NULL => todo!(), + SqlType::NULL => DataTypeMap::new(DataType::Null, PythonType::None, SqlType::NULL), SqlType::OTHER => unimplemented!(), SqlType::REAL => todo!(), SqlType::ROW => todo!(), SqlType::SARG => unimplemented!(), - SqlType::SMALLINT => todo!(), + SqlType::SMALLINT => DataTypeMap::new(DataType::Int16, PythonType::Int, SqlType::SMALLINT), SqlType::STRUCTURED => unimplemented!(), SqlType::SYMBOL => unimplemented!(), SqlType::TIME => todo!(), SqlType::TIME_WITH_LOCAL_TIME_ZONE => todo!(), SqlType::TIMESTAMP => todo!(), SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => todo!(), - SqlType::TINYINT => todo!(), - SqlType::UNKNOWN => todo!(), - SqlType::VARBINARY => todo!(), - SqlType::VARCHAR => DataTypeMap { - arrow_type: PyDataType { data_type: DataType::Utf8 }, - python_type: PythonType::Str, - sql_type: SqlType::VARCHAR - }, - }; - - Ok(data_type) + SqlType::TINYINT => DataTypeMap::new(DataType::Int8, PythonType::Int, SqlType::TINYINT), + SqlType::UNKNOWN => unimplemented!(), + SqlType::VARBINARY => DataTypeMap::new(DataType::LargeBinary, PythonType::Bytes, SqlType::VARBINARY), + SqlType::VARCHAR => DataTypeMap::new(DataType::Utf8, PythonType::Str, SqlType::VARCHAR), + }) } } @@ -198,8 +186,8 @@ pub enum PythonType { Bool, Bytes, Datetime, - Float64, - Int32, + Float, + Int, List, None, Object, From 113f33b8f76d4fcc70cbba2a07d677d12203e749 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Sun, 5 Feb 2023 22:16:27 -0500 Subject: [PATCH 04/16] map_from_arrow_type --- src/common/data_type.rs | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 06697351a..6da1f6e3b 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -45,24 +45,9 @@ impl DataTypeMap { sql_type } } -} -#[pymethods] -impl DataTypeMap { - - #[new] - pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self { - DataTypeMap { - arrow_type: arrow_type, - python_type, - sql_type - } - } - - #[staticmethod] - #[pyo3(name = "arrow")] - pub fn map_from_arrow_type(arrow_type: &PyDataType) -> PyResult { - Ok(match arrow_type.data_type { + pub fn map_from_arrow_type(arrow_type: &DataType) -> DataTypeMap { + match arrow_type { DataType::Null => DataTypeMap::new(DataType::Null, PythonType::None, SqlType::NULL), DataType::Boolean => DataTypeMap::new(DataType::Boolean, PythonType::Bool, SqlType::BOOLEAN), DataType::Int8 => DataTypeMap::new(DataType::Int8, PythonType::Int, SqlType::TINYINT), @@ -97,7 +82,26 @@ impl DataTypeMap { DataType::Decimal128(_, _) => todo!(), DataType::Decimal256(_, _) => todo!(), DataType::Map(_, _) => todo!(), - }) + } + } +} + +#[pymethods] +impl DataTypeMap { + + #[new] + pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self { + DataTypeMap { + arrow_type: arrow_type, + python_type, + sql_type + } + } + + #[staticmethod] + #[pyo3(name = "arrow")] + pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult { + Ok(DataTypeMap::map_from_arrow_type(&arrow_type.data_type)) } #[staticmethod] From a93f8602c5565601acfd55c8cb6f005f21d3e226 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 6 Feb 2023 09:46:18 -0500 Subject: [PATCH 05/16] update LogicalPlan crate location --- src/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context.rs b/src/context.rs index cf19b8e1f..7d9f1c570 100644 --- a/src/context.rs +++ b/src/context.rs @@ -29,7 +29,7 @@ use crate::catalog::{PyCatalog, PyTable}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::DataFusionError; -use crate::logical::PyLogicalPlan; +use crate::sql::logical::PyLogicalPlan; use crate::store::StorageContexts; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; From 895c72544bd874f194129e138492efa94b19ab06 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 6 Feb 2023 11:20:22 -0500 Subject: [PATCH 06/16] add missing apache license to src/sql/exceptions.rs --- src/sql/exceptions.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/sql/exceptions.rs b/src/sql/exceptions.rs index 871402279..912c54db9 100644 --- a/src/sql/exceptions.rs +++ b/src/sql/exceptions.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use std::fmt::Debug; use pyo3::{create_exception, PyErr}; From 59bf409dd65f60d340605e6c4bba8ad5907c0d51 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 6 Feb 2023 12:27:06 -0500 Subject: [PATCH 07/16] cargo fmt --- src/common/data_type.rs | 112 ++++++++++++++++++++++++++++------------ src/common/mod.rs | 2 +- src/dataframe.rs | 2 +- src/lib.rs | 4 +- src/substrait.rs | 2 +- 5 files changed, 85 insertions(+), 37 deletions(-) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 6da1f6e3b..62d52099e 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. - use datafusion::arrow::datatypes::DataType; use pyo3::prelude::*; - /// These bindings are tying together several disparate systems. /// You have SQL types for the SQL strings and RDBMS systems itself. /// Rust types for the DataFusion code @@ -40,39 +38,67 @@ pub struct DataTypeMap { impl DataTypeMap { fn new(arrow_type: DataType, python_type: PythonType, sql_type: SqlType) -> Self { DataTypeMap { - arrow_type: PyDataType { data_type: arrow_type}, + arrow_type: PyDataType { + data_type: arrow_type, + }, python_type, - sql_type + sql_type, } } pub fn map_from_arrow_type(arrow_type: &DataType) -> DataTypeMap { match arrow_type { DataType::Null => DataTypeMap::new(DataType::Null, PythonType::None, SqlType::NULL), - DataType::Boolean => DataTypeMap::new(DataType::Boolean, PythonType::Bool, SqlType::BOOLEAN), + DataType::Boolean => { + DataTypeMap::new(DataType::Boolean, PythonType::Bool, SqlType::BOOLEAN) + } DataType::Int8 => DataTypeMap::new(DataType::Int8, PythonType::Int, SqlType::TINYINT), - DataType::Int16 => DataTypeMap::new(DataType::Int16, PythonType::Int, SqlType::SMALLINT), + DataType::Int16 => { + DataTypeMap::new(DataType::Int16, PythonType::Int, SqlType::SMALLINT) + } DataType::Int32 => DataTypeMap::new(DataType::Int32, PythonType::Int, SqlType::INTEGER), DataType::Int64 => DataTypeMap::new(DataType::Int64, PythonType::Int, SqlType::BIGINT), DataType::UInt8 => DataTypeMap::new(DataType::UInt8, PythonType::Int, SqlType::TINYINT), - DataType::UInt16 => DataTypeMap::new(DataType::UInt16, PythonType::Int, SqlType::SMALLINT), - DataType::UInt32 => DataTypeMap::new(DataType::UInt32, PythonType::Int, SqlType::INTEGER), - DataType::UInt64 => DataTypeMap::new(DataType::UInt64, PythonType::Int, SqlType::BIGINT), - DataType::Float16 => DataTypeMap::new(DataType::Float16, PythonType::Float, SqlType::FLOAT), - DataType::Float32 => DataTypeMap::new(DataType::Float32, PythonType::Float, SqlType::FLOAT), - DataType::Float64 => DataTypeMap::new(DataType::Float64, PythonType::Float, SqlType::FLOAT), + DataType::UInt16 => { + DataTypeMap::new(DataType::UInt16, PythonType::Int, SqlType::SMALLINT) + } + DataType::UInt32 => { + DataTypeMap::new(DataType::UInt32, PythonType::Int, SqlType::INTEGER) + } + DataType::UInt64 => { + DataTypeMap::new(DataType::UInt64, PythonType::Int, SqlType::BIGINT) + } + DataType::Float16 => { + DataTypeMap::new(DataType::Float16, PythonType::Float, SqlType::FLOAT) + } + DataType::Float32 => { + DataTypeMap::new(DataType::Float32, PythonType::Float, SqlType::FLOAT) + } + DataType::Float64 => { + DataTypeMap::new(DataType::Float64, PythonType::Float, SqlType::FLOAT) + } DataType::Timestamp(_, _) => todo!(), - DataType::Date32 => DataTypeMap::new(DataType::Date32, PythonType::Datetime, SqlType::DATE), - DataType::Date64 => DataTypeMap::new(DataType::Date64, PythonType::Datetime, SqlType::DATE), + DataType::Date32 => { + DataTypeMap::new(DataType::Date32, PythonType::Datetime, SqlType::DATE) + } + DataType::Date64 => { + DataTypeMap::new(DataType::Date64, PythonType::Datetime, SqlType::DATE) + } DataType::Time32(_) => todo!(), DataType::Time64(_) => todo!(), DataType::Duration(_) => todo!(), DataType::Interval(_) => todo!(), - DataType::Binary => DataTypeMap::new(DataType::Binary, PythonType::Bytes, SqlType::BINARY), + DataType::Binary => { + DataTypeMap::new(DataType::Binary, PythonType::Bytes, SqlType::BINARY) + } DataType::FixedSizeBinary(_) => todo!(), - DataType::LargeBinary => DataTypeMap::new(DataType::LargeBinary, PythonType::Bytes, SqlType::BINARY), + DataType::LargeBinary => { + DataTypeMap::new(DataType::LargeBinary, PythonType::Bytes, SqlType::BINARY) + } DataType::Utf8 => DataTypeMap::new(DataType::Utf8, PythonType::Str, SqlType::VARCHAR), - DataType::LargeUtf8 => DataTypeMap::new(DataType::LargeUtf8, PythonType::Str, SqlType::VARCHAR), + DataType::LargeUtf8 => { + DataTypeMap::new(DataType::LargeUtf8, PythonType::Str, SqlType::VARCHAR) + } DataType::List(_) => todo!(), DataType::FixedSizeList(_, _) => todo!(), DataType::LargeList(_) => todo!(), @@ -88,16 +114,15 @@ impl DataTypeMap { #[pymethods] impl DataTypeMap { - #[new] pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self { DataTypeMap { arrow_type: arrow_type, python_type, - sql_type + sql_type, } } - + #[staticmethod] #[pyo3(name = "arrow")] pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult { @@ -110,18 +135,38 @@ impl DataTypeMap { Ok(match sql_type { SqlType::ANY => unimplemented!(), SqlType::ARRAY => todo!(), // unsure which type to use for DataType in this situation? - SqlType::BIGINT => DataTypeMap::new(DataType::Int64, PythonType::Float, SqlType::BIGINT), - SqlType::BINARY => DataTypeMap::new(DataType::Binary, PythonType::Bytes, SqlType::BINARY), - SqlType::BOOLEAN => DataTypeMap::new(DataType::Boolean, PythonType::Bool, SqlType::BOOLEAN), + SqlType::BIGINT => { + DataTypeMap::new(DataType::Int64, PythonType::Float, SqlType::BIGINT) + } + SqlType::BINARY => { + DataTypeMap::new(DataType::Binary, PythonType::Bytes, SqlType::BINARY) + } + SqlType::BOOLEAN => { + DataTypeMap::new(DataType::Boolean, PythonType::Bool, SqlType::BOOLEAN) + } SqlType::CHAR => DataTypeMap::new(DataType::UInt8, PythonType::Int, SqlType::CHAR), SqlType::COLUMN_LIST => unimplemented!(), SqlType::CURSOR => unimplemented!(), - SqlType::DATE => DataTypeMap::new(DataType::Date64, PythonType::Datetime, SqlType::DATE), - SqlType::DECIMAL => DataTypeMap::new(DataType::Decimal128(1, 1), PythonType::Float, SqlType::DECIMAL), + SqlType::DATE => { + DataTypeMap::new(DataType::Date64, PythonType::Datetime, SqlType::DATE) + } + SqlType::DECIMAL => DataTypeMap::new( + DataType::Decimal128(1, 1), + PythonType::Float, + SqlType::DECIMAL, + ), SqlType::DISTINCT => unimplemented!(), - SqlType::DOUBLE => DataTypeMap::new(DataType::Decimal256(1, 1), PythonType::Float, SqlType::DOUBLE), + SqlType::DOUBLE => DataTypeMap::new( + DataType::Decimal256(1, 1), + PythonType::Float, + SqlType::DOUBLE, + ), SqlType::DYNAMIC_STAR => unimplemented!(), - SqlType::FLOAT => DataTypeMap::new(DataType::Decimal128(1, 1), PythonType::Float, SqlType::FLOAT), + SqlType::FLOAT => DataTypeMap::new( + DataType::Decimal128(1, 1), + PythonType::Float, + SqlType::FLOAT, + ), SqlType::GEOMETRY => unimplemented!(), SqlType::INTEGER => DataTypeMap::new(DataType::Int8, PythonType::Int, SqlType::INTEGER), SqlType::INTERVAL => todo!(), @@ -145,7 +190,9 @@ impl DataTypeMap { SqlType::REAL => todo!(), SqlType::ROW => todo!(), SqlType::SARG => unimplemented!(), - SqlType::SMALLINT => DataTypeMap::new(DataType::Int16, PythonType::Int, SqlType::SMALLINT), + SqlType::SMALLINT => { + DataTypeMap::new(DataType::Int16, PythonType::Int, SqlType::SMALLINT) + } SqlType::STRUCTURED => unimplemented!(), SqlType::SYMBOL => unimplemented!(), SqlType::TIME => todo!(), @@ -154,15 +201,16 @@ impl DataTypeMap { SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => todo!(), SqlType::TINYINT => DataTypeMap::new(DataType::Int8, PythonType::Int, SqlType::TINYINT), SqlType::UNKNOWN => unimplemented!(), - SqlType::VARBINARY => DataTypeMap::new(DataType::LargeBinary, PythonType::Bytes, SqlType::VARBINARY), + SqlType::VARBINARY => { + DataTypeMap::new(DataType::LargeBinary, PythonType::Bytes, SqlType::VARBINARY) + } SqlType::VARCHAR => DataTypeMap::new(DataType::Utf8, PythonType::Str, SqlType::VARCHAR), }) } } - /// PyO3 requires that objects passed between Rust and Python implement the trait `PyClass` -/// Since `DataType` exists in another package we cannot make that happen here so we wrap +/// Since `DataType` exists in another package we cannot make that happen here so we wrap /// `DataType` as `PyDataType` This exists solely to satisfy those constraints. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass(name = "DataType", module = "datafusion")] @@ -199,7 +247,7 @@ pub enum PythonType { } /// Represents the types that are possible for DataFusion to parse -/// from a SQL query. Aka "SqlType" and are valid values for +/// from a SQL query. Aka "SqlType" and are valid values for /// ANSI SQL #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] diff --git a/src/common/mod.rs b/src/common/mod.rs index ce8612bea..2d37f6836 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -15,5 +15,5 @@ // specific language governing permissions and limitations // under the License. -pub mod df_field; pub mod data_type; +pub mod df_field; diff --git a/src/dataframe.rs b/src/dataframe.rs index da57ef95e..9c11b26f8 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::sql::logical::PyLogicalPlan; use crate::physical_plan::PyExecutionPlan; +use crate::sql::logical::PyLogicalPlan; use crate::utils::wait_for_future; use crate::{errors::DataFusionError, expression::PyExpr}; use datafusion::arrow::datatypes::Schema; diff --git a/src/lib.rs b/src/lib.rs index 875675613..be699d529 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ use pyo3::prelude::*; #[allow(clippy::borrow_deref_ref)] pub mod catalog; +pub mod common; #[allow(clippy::borrow_deref_ref)] mod config; #[allow(clippy::borrow_deref_ref)] @@ -36,6 +37,7 @@ mod expression; mod functions; pub mod physical_plan; mod pyarrow_filter_expression; +pub mod sql; pub mod store; pub mod substrait; #[allow(clippy::borrow_deref_ref)] @@ -43,8 +45,6 @@ mod udaf; #[allow(clippy::borrow_deref_ref)] mod udf; pub mod utils; -pub mod sql; -pub mod common; #[cfg(feature = "mimalloc")] #[global_allocator] diff --git a/src/substrait.rs b/src/substrait.rs index ae1f3023c..f50734932 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -18,7 +18,7 @@ use pyo3::prelude::*; use crate::context::PySessionContext; -use crate::errors::{DataFusionError, py_datafusion_err}; +use crate::errors::{py_datafusion_err, DataFusionError}; use crate::sql::logical::PyLogicalPlan; use crate::utils::wait_for_future; From ebf38c4ff8dc98cc85688794eb4c0c66f9c764d0 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 6 Feb 2023 19:49:34 -0500 Subject: [PATCH 08/16] clippy warnings --- src/common/data_type.rs | 5 ++++- src/common/df_field.rs | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 62d52099e..ff4f2543e 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -30,8 +30,11 @@ use pyo3::prelude::*; #[derive(Debug, Clone)] #[pyclass(name = "DataTypeMap", module = "datafusion", subclass)] pub struct DataTypeMap { + #[allow(dead_code)] arrow_type: PyDataType, + #[allow(dead_code)] python_type: PythonType, + #[allow(dead_code)] sql_type: SqlType, } @@ -117,7 +120,7 @@ impl DataTypeMap { #[new] pub fn py_new(arrow_type: PyDataType, python_type: PythonType, sql_type: SqlType) -> Self { DataTypeMap { - arrow_type: arrow_type, + arrow_type, python_type, sql_type, } diff --git a/src/common/df_field.rs b/src/common/df_field.rs index 1c1aecc6c..098df9bda 100644 --- a/src/common/df_field.rs +++ b/src/common/df_field.rs @@ -27,10 +27,15 @@ use crate::common::data_type::DataTypeMap; #[derive(Debug, Clone)] pub struct PyDFField { /// Optional qualifier (usually a table or relation name) + #[allow(dead_code)] qualifier: Option, + #[allow(dead_code)] name: String, + #[allow(dead_code)] data_type: DataTypeMap, /// Arrow field definition + #[allow(dead_code)] field: Field, + #[allow(dead_code)] index: usize, } From 112df24f93dc44518fda7d8dc1fdbba1ef7a842c Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 6 Feb 2023 19:58:16 -0500 Subject: [PATCH 09/16] format!() changes --- src/sql/exceptions.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sql/exceptions.rs b/src/sql/exceptions.rs index 912c54db9..d769adef1 100644 --- a/src/sql/exceptions.rs +++ b/src/sql/exceptions.rs @@ -25,18 +25,18 @@ create_exception!(rust, ParsingException, pyo3::exceptions::PyException); // Identifies exceptions that occur during attempts to optimization an existing `LogicalPlan` create_exception!(rust, OptimizationException, pyo3::exceptions::PyException); -pub fn py_type_err(e: impl Debug) -> PyErr { - PyErr::new::(format!("{:?}", e)) +pub fn py_type_err(e: impl Debug + std::fmt::Display) -> PyErr { + PyErr::new::(format!("{e}")) } -pub fn py_runtime_err(e: impl Debug) -> PyErr { - PyErr::new::(format!("{:?}", e)) +pub fn py_runtime_err(e: impl Debug + std::fmt::Display) -> PyErr { + PyErr::new::(format!("{e}")) } -pub fn py_parsing_exp(e: impl Debug) -> PyErr { - PyErr::new::(format!("{:?}", e)) +pub fn py_parsing_exp(e: impl Debug + std::fmt::Display) -> PyErr { + PyErr::new::(format!("{e}")) } -pub fn py_optimization_exp(e: impl Debug) -> PyErr { - PyErr::new::(format!("{:?}", e)) +pub fn py_optimization_exp(e: impl Debug + std::fmt::Display) -> PyErr { + PyErr::new::(format!("{e}")) } From b06c5e33faca1cfa9a3e36ab5cb25252bd7929b5 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Mon, 6 Feb 2023 20:12:00 -0500 Subject: [PATCH 10/16] Add Table_Scan instance --- Cargo.lock | 16 +++---- src/sql/logical.rs | 2 + src/sql/logical/table_scan.rs | 87 +++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 8 deletions(-) create mode 100644 src/sql/logical/table_scan.rs diff --git a/Cargo.lock b/Cargo.lock index d48c303a2..9bc8f5964 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,9 +62,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cb2f989d18dd141ab8ae82f64d1a8cdd37e0840f73a406896cf5e99502fab61" +checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" [[package]] name = "apache-avro" @@ -1773,9 +1773,9 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5014253a1331579ce62aa67443b4a658c5e7dd03d4bc6d302b94474888143" +checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", "indexmap", @@ -1813,9 +1813,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ef7d57beacfaf2d8aee5937dab7b7f28de3cb8b1828479bb5de2a7106f2bae2" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] @@ -2311,9 +2311,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.91" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883" +checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a" dependencies = [ "itoa 1.0.5", "ryu", diff --git a/src/sql/logical.rs b/src/sql/logical.rs index dcd7baa58..3c284be66 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -20,6 +20,8 @@ use std::sync::Arc; use datafusion_expr::LogicalPlan; use pyo3::prelude::*; +pub mod table_scan; + #[pyclass(name = "LogicalPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyLogicalPlan { diff --git a/src/sql/logical/table_scan.rs b/src/sql/logical/table_scan.rs new file mode 100644 index 000000000..6a9bd93ca --- /dev/null +++ b/src/sql/logical/table_scan.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::DFSchema; +use datafusion_expr::{logical_plan::TableScan, LogicalPlan}; +use pyo3::prelude::*; + +use crate::{ + expression::{py_expr_list, PyExpr}, + sql::exceptions::py_type_err, +}; + +#[pyclass(name = "TableScan", module = "dask_planner", subclass)] +#[derive(Clone)] +pub struct PyTableScan { + pub(crate) table_scan: TableScan, + input: Arc, +} + +#[pymethods] +impl PyTableScan { + #[pyo3(name = "getTableScanProjects")] + fn scan_projects(&mut self) -> PyResult> { + match &self.table_scan.projection { + Some(indices) => { + let schema = self.table_scan.source.schema(); + Ok(indices + .iter() + .map(|i| schema.field(*i).name().to_string()) + .collect()) + } + None => Ok(vec![]), + } + } + + /// If the 'TableScan' contains columns that should be projected during the + /// read return True, otherwise return False + #[pyo3(name = "containsProjections")] + fn contains_projections(&self) -> bool { + self.table_scan.projection.is_some() + } + + #[pyo3(name = "getFilters")] + fn scan_filters(&self) -> PyResult> { + py_expr_list(&self.input, &self.table_scan.filters) + } +} + +impl TryFrom for PyTableScan { + type Error = PyErr; + + fn try_from(logical_plan: LogicalPlan) -> Result { + match logical_plan { + LogicalPlan::TableScan(table_scan) => { + // Create an input logical plan that's identical to the table scan with schema from the table source + let mut input = table_scan.clone(); + input.projected_schema = DFSchema::try_from_qualified_schema( + &table_scan.table_name, + &table_scan.source.schema(), + ) + .map_or(input.projected_schema, Arc::new); + + Ok(PyTableScan { + table_scan, + input: Arc::new(LogicalPlan::TableScan(input)), + }) + } + _ => Err(py_type_err("unexpected plan")), + } + } +} From 1c2edd075b7fba06a606e3d8ed16b312e4e0ee94 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 7 Feb 2023 09:14:47 -0500 Subject: [PATCH 11/16] checkpoint commit --- src/sql/logical/table_scan.rs | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/sql/logical/table_scan.rs b/src/sql/logical/table_scan.rs index 6a9bd93ca..ca491790e 100644 --- a/src/sql/logical/table_scan.rs +++ b/src/sql/logical/table_scan.rs @@ -27,7 +27,7 @@ use crate::{ }; #[pyclass(name = "TableScan", module = "dask_planner", subclass)] -#[derive(Clone)] +#[derive(Clone, FromPyObject)] pub struct PyTableScan { pub(crate) table_scan: TableScan, input: Arc, @@ -62,26 +62,26 @@ impl PyTableScan { } } -impl TryFrom for PyTableScan { - type Error = PyErr; +// impl TryFrom for PyTableScan { +// type Error = PyErr; - fn try_from(logical_plan: LogicalPlan) -> Result { - match logical_plan { - LogicalPlan::TableScan(table_scan) => { - // Create an input logical plan that's identical to the table scan with schema from the table source - let mut input = table_scan.clone(); - input.projected_schema = DFSchema::try_from_qualified_schema( - &table_scan.table_name, - &table_scan.source.schema(), - ) - .map_or(input.projected_schema, Arc::new); +// fn try_from(logical_plan: LogicalPlan) -> Result { +// match logical_plan { +// LogicalPlan::TableScan(table_scan) => { +// // Create an input logical plan that's identical to the table scan with schema from the table source +// let mut input = table_scan.clone(); +// input.projected_schema = DFSchema::try_from_qualified_schema( +// &table_scan.table_name, +// &table_scan.source.schema(), +// ) +// .map_or(input.projected_schema, Arc::new); - Ok(PyTableScan { - table_scan, - input: Arc::new(LogicalPlan::TableScan(input)), - }) - } - _ => Err(py_type_err("unexpected plan")), - } - } -} +// Ok(PyTableScan { +// table_scan, +// input: Arc::new(LogicalPlan::TableScan(input)), +// }) +// } +// _ => Err(py_type_err("unexpected plan")), +// } +// } +// } From 1f32717a8fd6126fb046feb5a4c11db1c0cab1d5 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 14 Feb 2023 13:02:04 -0500 Subject: [PATCH 12/16] table scan bindings --- .github/workflows/test.yaml | 2 +- Cargo.lock | 283 +++++++++++++-------------- Cargo.toml | 14 +- datafusion/tests/test_aggregation.py | 99 +++++++++- datafusion/tests/test_sql.py | 4 +- src/catalog.rs | 2 +- src/context.rs | 89 ++++----- src/dataframe.rs | 14 +- src/dataset_exec.rs | 7 +- src/expression.rs | 4 +- src/functions.rs | 63 +++++- src/sql/logical/table_scan.rs | 128 +++++++----- src/store.rs | 25 +-- src/substrait.rs | 2 +- src/udaf.rs | 2 +- src/udf.rs | 2 +- 16 files changed, 435 insertions(+), 305 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0f4681ac5..2142aceb7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -81,7 +81,7 @@ jobs: if: ${{ matrix.python-version == '3.10' && matrix.toolchain == 'stable' }} with: command: clippy - args: --all-targets --all-features -- -D clippy::all + args: --all-targets --all-features -- -D clippy::all -A clippy::redundant_closure - name: Create Virtualenv (3.10) if: ${{ matrix.python-version == '3.10' }} diff --git a/Cargo.lock b/Cargo.lock index 9bc8f5964..6bbda948c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -107,9 +107,9 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b556d39f9d19e363833a0fe65d591cd0e2ecc0977589a78179b592bea8dc945" +checksum = "87d948f553cf556656eb89265700258e1032d26fec9b7920cd20319336e06afd" dependencies = [ "ahash", "arrow-arith", @@ -132,9 +132,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85c61b9235694b48f60d89e0e8d6cb478f39c65dd14b0fe1c3f04379b7d50068" +checksum = "cf30d4ebc3df9dfd8bd26883aa30687d4ddcfd7b2443e62bd7c8fedf153b8e45" dependencies = [ "arrow-array", "arrow-buffer", @@ -147,9 +147,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e6e839764618a911cc460a58ebee5ad3d42bc12d9a5e96a29b7cc296303aa1" +checksum = "9fe66ec388d882a61fff3eb613b5266af133aa08a3318e5e493daf0f5c1696cb" dependencies = [ "ahash", "arrow-buffer", @@ -163,9 +163,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a21d232b1bc1190a3fdd2f9c1e39b7cd41235e95a0d44dd4f522bc5f495748" +checksum = "4ef967dadbccd4586ec8d7aab27d7033ecb5dfae8a605c839613039eac227bda" dependencies = [ "half", "num", @@ -173,9 +173,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83dcdb1436cac574f1c1b30fda91c53c467534337bef4064bbd4ea2d6fbc6e04" +checksum = "491a7979ea9e76dc218f532896e2d245fde5235e2e6420ce80d27cf6395dda84" dependencies = [ "arrow-array", "arrow-buffer", @@ -189,9 +189,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a01677ae9458f5af9e35e1aa6ba97502f539e621db0c6672566403f97edd0448" +checksum = "4b1d4fc91078dbe843c2c50d90f8119c96e8dfac2f78d30f7a8cb9397399c61d" dependencies = [ "arrow-array", "arrow-buffer", @@ -208,9 +208,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14e3e69c9fd98357eeeab4aa0f626ecf7ecf663e68e8fc04eac87c424a414477" +checksum = "ee0c0e3c5d3b80be8f267f4b2af714c08cad630569be01a8379cfe27b4866495" dependencies = [ "arrow-buffer", "arrow-schema", @@ -220,9 +220,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64cac2706acbd796965b6eaf0da30204fe44aacf70273f8cb3c9b7d7f3d4c190" +checksum = "0a3ca7eb8d23c83fe40805cbafec70a6a31df72de47355545ff34c850f715403" dependencies = [ "arrow-array", "arrow-buffer", @@ -234,9 +234,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7790e8b7df2d8ef5ac802377ac256cf2fb80cbf7d44b82d6464e20ace6232a5a" +checksum = "bf65aff76d2e340d827d5cab14759e7dd90891a288347e2202e4ee28453d9bed" dependencies = [ "arrow-array", "arrow-buffer", @@ -246,15 +246,16 @@ dependencies = [ "chrono", "half", "indexmap", + "lexical-core", "num", "serde_json", ] [[package]] name = "arrow-ord" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7ee6e1b761dfffaaf7b5bbe68c113a576a3a802146c5c0b9fcec781e30d80a3" +checksum = "074a5a55c37ae4750af4811c8861c0378d8ab2ff6c262622ad24efae6e0b73b3" dependencies = [ "arrow-array", "arrow-buffer", @@ -266,9 +267,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e65bfedf782fc92721e796fdd26ae7343c98ba9a9243d62def9e4e1c4c1cf0b" +checksum = "e064ac4e64960ebfbe35f218f5e7d9dc9803b59c2e56f611da28ce6d008f839e" dependencies = [ "ahash", "arrow-array", @@ -281,15 +282,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73ca49d010b27e2d73f70c1d1f90c1b378550ed0f4ad379c4dea0c997d97d723" +checksum = "ead3f373b9173af52f2fdefcb5a7dd89f453fbc40056f574a8aeb23382a4ef81" [[package]] name = "arrow-select" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976cbaeb1a85c09eea81f3f9c149c758630ff422ed0238624c5c3f4704b6a53c" +checksum = "646b4f15b5a77c970059e748aeb1539705c68cd397ecf0f0264c4ef3737d35f3" dependencies = [ "arrow-array", "arrow-buffer", @@ -300,9 +301,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d4882762f8f48a9218946c016553d38b04b4fe8202038dad4141b3b887b7da8" +checksum = "c8b8bf150caaeca03f39f1a91069701387d93f7cfd256d27f423ac8496d99a51" dependencies = [ "arrow-array", "arrow-buffer", @@ -357,12 +358,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "base64" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" - [[package]] name = "base64" version = "0.21.0" @@ -428,18 +423,6 @@ dependencies = [ "alloc-stdlib", ] -[[package]] -name = "bstr" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" -dependencies = [ - "lazy_static", - "memchr", - "regex-automata", - "serde", -] - [[package]] name = "bumpalo" version = "3.12.0" @@ -501,9 +484,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" dependencies = [ "iana-time-zone", + "js-sys", "num-integer", "num-traits", "serde", + "time", + "wasm-bindgen", "winapi", ] @@ -598,13 +584,12 @@ dependencies = [ [[package]] name = "csv" -version = "1.1.6" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +checksum = "af91f40b7355f82b0a891f50e70399475945bb0b0da4f1700ce60761c9d3e359" dependencies = [ - "bstr", "csv-core", - "itoa 0.4.8", + "itoa", "ryu", "serde", ] @@ -620,9 +605,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.89" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc831ee6a32dd495436e317595e639a587aa9907bef96fe6e6abc290ab6204e9" +checksum = "90d59d9acd2a682b4e40605a242f6670eaa58c5957471cbf85e8aa6a0b97a5e8" dependencies = [ "cc", "cxxbridge-flags", @@ -632,9 +617,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.89" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94331d54f1b1a8895cd81049f7eaaaef9d05a7dcb4d1fd08bf3ff0806246789d" +checksum = "ebfa40bda659dd5c864e65f4c9a2b0aff19bea56b017b9b77c73d3766a453a38" dependencies = [ "cc", "codespan-reporting", @@ -647,15 +632,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.89" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48dcd35ba14ca9b40d6e4b4b39961f23d835dbb8eed74565ded361d93e1feb8a" +checksum = "457ce6757c5c70dc6ecdbda6925b958aae7f959bda7d8fb9bde889e34a09dc03" [[package]] name = "cxxbridge-macro" -version = "1.0.89" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bbeb29798b407ccd82a3324ade1a7286e0d29851475990b612670f6f5124d2" +checksum = "ebf883b7aacd7b2aeb2a7b338648ee19f57c140d4ee8e52c68979c6b2f7f2263" dependencies = [ "proc-macro2", "quote", @@ -677,9 +662,8 @@ dependencies = [ [[package]] name = "datafusion" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6d90cae91414aaeda37ae8022a23ef1124ca8efc08ac7d7770274249f7cf148" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "ahash", "apache-avro", @@ -726,9 +710,8 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b21c4b8e8b7815e86d79d25da16854fee6d4d1b386572e802a248b7d43188e86" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "apache-avro", "arrow", @@ -742,9 +725,8 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db8c07b051fbaf01657a3eb910a76b042ecfed0350a40412f70cf6b949bd5328" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "ahash", "arrow", @@ -755,9 +737,8 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ce4d34a808cd2e4c4864cdc759dd1bd22dcac2b8af38aa570e30fd54577c4d" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "arrow", "async-trait", @@ -772,9 +753,8 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a38afa11a09505c24bd7e595039d7914ec39329ba490209413ef2d37895c8220" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "ahash", "arrow", @@ -825,9 +805,8 @@ dependencies = [ [[package]] name = "datafusion-row" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9172411b25ff4aa97f8e99884898595a581636d93cc96c12f96dbe3bf51cd7e5" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "arrow", "datafusion-common", @@ -837,9 +816,8 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fbe5e61563ced2f6992a60afea568ff3de69e32940bbf07db06fc5c9d8cd866" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "arrow-schema", "datafusion-common", @@ -850,13 +828,14 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "17.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e5af8bc23708f6d9d1721947c8486c96153ce671269522d7d917bb428d2fa73" +version = "18.0.0" +source = "git+https://github.com/apache/arrow-datafusion.git?rev=3da790214ea479626eb4114c53440dc17b737d54#3da790214ea479626eb4114c53440dc17b737d54" dependencies = [ "async-recursion", + "chrono", "datafusion", "itertools", + "object_store", "prost 0.11.6", "prost-build 0.9.0", "prost-types 0.11.6", @@ -904,9 +883,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] @@ -919,12 +898,12 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flatbuffers" -version = "22.9.29" +version = "23.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce016b9901aef3579617931fbb2df8fc9a9f7cb95a16eb8acc8148209bb9e70" +checksum = "77f5399c2c9c50ae9418e522842ad362f61ee48b346ac106807bd355a8a7c619" dependencies = [ "bitflags", - "thiserror", + "rustc_version", ] [[package]] @@ -1059,7 +1038,7 @@ checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", ] [[package]] @@ -1153,7 +1132,7 @@ checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" dependencies = [ "bytes", "fnv", - "itoa 1.0.5", + "itoa", ] [[package]] @@ -1194,7 +1173,7 @@ dependencies = [ "http-body", "httparse", "httpdate", - "itoa 1.0.5", + "itoa", "pin-project-lite", "socket2", "tokio", @@ -1296,12 +1275,6 @@ dependencies = [ "either", ] -[[package]] -name = "itoa" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" - [[package]] name = "itoa" version = "1.0.5" @@ -1514,9 +1487,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "memoffset" -version = "0.6.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" dependencies = [ "autocfg", ] @@ -1547,14 +1520,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", - "wasi", - "windows-sys 0.42.0", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.45.0", ] [[package]] @@ -1652,12 +1625,12 @@ dependencies = [ [[package]] name = "object_store" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4201837dc4c27a8670f0363b1255cd3845a4f0c521211cced1ed14c1d0cc6d2" +checksum = "1f344e51ec9584d2f51199c0c29c6f73dddd04ade986497875bf8fa2f178caf0" dependencies = [ "async-trait", - "base64 0.20.0", + "base64", "bytes", "chrono", "futures", @@ -1718,9 +1691,9 @@ dependencies = [ [[package]] name = "parquet" -version = "31.0.0" +version = "32.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b4ee1ffc0778395c9783a5c74f2cad2fb1a128ade95a965212d31b7b13e3d45" +checksum = "23b3d4917209e17e1da5fb07d276da237a42465f0def2b8d5fa5ce0e85855b4c" dependencies = [ "ahash", "arrow-array", @@ -1730,7 +1703,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "arrow-select", - "base64 0.21.0", + "base64", "brotli", "bytes", "chrono", @@ -1763,9 +1736,9 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pest" -version = "2.5.4" +version = "2.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab62d2fa33726dbe6321cc97ef96d8cde531e3eeaf858a058de53a8a6d40d8f" +checksum = "028accff104c4e513bad663bbcd2ad7cfd5304144404c31ed0a77ac103d00660" dependencies = [ "thiserror", "ucd-trie", @@ -1928,9 +1901,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "268be0c73583c183f2b14052337465768c07726936a260f480f0857cb95ba543" +checksum = "06a3d8e8a46ab2738109347433cb7b96dffda2e4a218b03ef27090238886b147" dependencies = [ "cfg-if", "indoc", @@ -1945,9 +1918,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28fcd1e73f06ec85bf3280c48c67e731d8290ad3d730f8be9dc07946923005c8" +checksum = "75439f995d07ddfad42b192dfcf3bc66a7ecfd8b4a1f5f6f046aa5c2c5d7677d" dependencies = [ "once_cell", "target-lexicon", @@ -1955,9 +1928,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f6cb136e222e49115b3c51c32792886defbfb0adead26a688142b346a0b9ffc" +checksum = "839526a5c07a17ff44823679b68add4a58004de00512a95b6c1c98a6dcac0ee5" dependencies = [ "libc", "pyo3-build-config", @@ -1965,9 +1938,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94144a1266e236b1c932682136dc35a9dee8d3589728f68130c7c3861ef96b28" +checksum = "bd44cf207476c6a9760c4653559be4f206efafb924d3e4cbf2721475fc0d6cc5" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1977,9 +1950,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8df9be978a2d2f0cdebabb03206ed73b11314701a5bfe71b0d753b81997777f" +checksum = "dc1f43d8e30460f36350d18631ccf85ded64c059829208fe680904c65bcd0a4c" dependencies = [ "proc-macro2", "quote", @@ -2061,12 +2034,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" - [[package]] name = "regex-syntax" version = "0.6.28" @@ -2097,7 +2064,7 @@ version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" dependencies = [ - "base64 0.21.0", + "base64", "bytes", "encoding_rs", "futures-core", @@ -2153,6 +2120,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver 1.0.16", +] + [[package]] name = "rustfmt-wrapper" version = "0.2.0" @@ -2184,7 +2160,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.0", + "base64", ] [[package]] @@ -2263,6 +2239,12 @@ dependencies = [ "semver-parser", ] +[[package]] +name = "semver" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" + [[package]] name = "semver-parser" version = "0.10.2" @@ -2311,11 +2293,11 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a" +checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" dependencies = [ - "itoa 1.0.5", + "itoa", "ryu", "serde", ] @@ -2338,7 +2320,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" dependencies = [ "form_urlencoded", - "itoa 1.0.5", + "itoa", "ryu", "serde", ] @@ -2350,7 +2332,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fb06d4b6cdaef0e0c51fa881acb721bed3c924cfaa71d9c94a3b771dfdf6567" dependencies = [ "indexmap", - "itoa 1.0.5", + "itoa", "ryu", "serde", "unsafe-libyaml", @@ -2509,9 +2491,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.5" +version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9410d0f6853b1d94f0e519fb95df60f29d2c1eff2d921ffdf01a4c8a3b54f12d" +checksum = "8ae9980cab1db3fceee2f6c6f643d5d8de2997c58ee8d25fb0cc8a9e9e7348e5" [[package]] name = "tempfile" @@ -2567,6 +2549,17 @@ dependencies = [ "ordered-float", ] +[[package]] +name = "time" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" +dependencies = [ + "libc", + "wasi 0.10.0+wasi-snapshot-preview1", + "winapi", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -2645,9 +2638,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.4" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" dependencies = [ "bytes", "futures-core", @@ -2675,7 +2668,7 @@ dependencies = [ "home", "lazy_static", "regex", - "semver", + "semver 0.11.0", "walkdir", ] @@ -2900,6 +2893,12 @@ dependencies = [ "try-lock", ] +[[package]] +name = "wasi" +version = "0.10.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3187,9 +3186,9 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "6.0.3+zstd.1.5.2" +version = "6.0.4+zstd.1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68e4a3f57d13d0ab7e478665c60f35e2a613dcd527851c2c7287ce5c787e134a" +checksum = "7afb4b54b8910cf5447638cb54bf4e8a65cbedd783af98b98c62ffe91f185543" dependencies = [ "libc", "zstd-sys", @@ -3197,9 +3196,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.6+zstd.1.5.2" +version = "2.0.7+zstd.1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68a3f9792c0c3dc6c165840a75f47ae1f4da402c2d006881129579f6597e801b" +checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" dependencies = [ "cc", "libc", diff --git a/Cargo.toml b/Cargo.toml index 7116892e3..960e7af8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,13 +33,13 @@ default = ["mimalloc"] [dependencies] tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.8" -pyo3 = { version = "~0.17.3", features = ["extension-module", "abi3", "abi3-py37"] } -datafusion = { version = "17.0.0", features = ["pyarrow", "avro"] } -datafusion-expr = "17.0.0" -datafusion-optimizer = "17.0.0" -datafusion-common = { version = "17.0.0", features = ["pyarrow"] } -datafusion-sql = "17.0.0" -datafusion-substrait = "17.0.0" +pyo3 = { version = "0.18.0", features = ["extension-module", "abi3", "abi3-py37"] } +datafusion = { git="https://github.com/apache/arrow-datafusion.git", rev="3da790214ea479626eb4114c53440dc17b737d54", features = ["pyarrow", "avro"] } +datafusion-expr = { git="https://github.com/apache/arrow-datafusion.git", rev="3da790214ea479626eb4114c53440dc17b737d54" } +datafusion-optimizer = { git="https://github.com/apache/arrow-datafusion.git", rev="3da790214ea479626eb4114c53440dc17b737d54" } +datafusion-common = { git="https://github.com/apache/arrow-datafusion.git", rev="3da790214ea479626eb4114c53440dc17b737d54", features = ["pyarrow"] } +datafusion-sql = { git="https://github.com/apache/arrow-datafusion.git", rev="3da790214ea479626eb4114c53440dc17b737d54" } +datafusion-substrait = { git="https://github.com/apache/arrow-datafusion.git", rev="3da790214ea479626eb4114c53440dc17b737d54" } uuid = { version = "1.2", features = ["v4"] } mimalloc = { version = "*", optional = true, default-features = false } async-trait = "0.1" diff --git a/datafusion/tests/test_aggregation.py b/datafusion/tests/test_aggregation.py index b274e18cf..2c8c064b1 100644 --- a/datafusion/tests/test_aggregation.py +++ b/datafusion/tests/test_aggregation.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. +import numpy as np import pyarrow as pa import pytest -from datafusion import SessionContext, column +from datafusion import SessionContext, column, lit from datafusion import functions as f @@ -28,8 +29,12 @@ def df(): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 4, 6])], - names=["a", "b"], + [ + pa.array([1, 2, 3]), + pa.array([4, 4, 6]), + pa.array([9, 8, 5]), + ], + names=["a", "b", "c"], ) return ctx.create_dataframe([[batch]]) @@ -37,12 +42,86 @@ def df(): def test_built_in_aggregation(df): col_a = column("a") col_b = column("b") - df = df.aggregate( + col_c = column("c") + + agg_df = df.aggregate( [], - [f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)], + [ + f.approx_distinct(col_b), + f.approx_median(col_b), + f.approx_percentile_cont(col_b, lit(0.5)), + f.approx_percentile_cont_with_weight(col_b, lit(0.6), lit(0.5)), + f.array_agg(col_b), + f.avg(col_a), + f.corr(col_a, col_b), + f.count(col_a), + f.covar(col_a, col_b), + f.covar_pop(col_a, col_c), + f.covar_samp(col_b, col_c), + # f.grouping(col_a), # No physical plan implemented yet + f.max(col_a), + f.mean(col_b), + f.median(col_b), + f.min(col_a), + f.sum(col_b), + f.stddev(col_a), + f.stddev_pop(col_b), + f.stddev_samp(col_c), + f.var(col_a), + f.var_pop(col_b), + f.var_samp(col_c), + ], + ) + result = agg_df.collect()[0] + values_a, values_b, values_c = df.collect()[0] + + assert result.column(0) == pa.array([2], type=pa.uint64()) + assert result.column(1) == pa.array([4]) + assert result.column(2) == pa.array([4]) + assert result.column(3) == pa.array([6]) + assert result.column(4) == pa.array([[4, 4, 6]]) + np.testing.assert_array_almost_equal( + result.column(5), np.average(values_a) + ) + np.testing.assert_array_almost_equal( + result.column(6), np.corrcoef(values_a, values_b)[0][1] + ) + assert result.column(7) == pa.array([len(values_a)]) + # Sample (co)variance -> ddof=1 + # Population (co)variance -> ddof=0 + np.testing.assert_array_almost_equal( + result.column(8), np.cov(values_a, values_b, ddof=1)[0][1] + ) + np.testing.assert_array_almost_equal( + result.column(9), np.cov(values_a, values_c, ddof=0)[0][1] + ) + np.testing.assert_array_almost_equal( + result.column(10), np.cov(values_b, values_c, ddof=1)[0][1] + ) + np.testing.assert_array_almost_equal(result.column(11), np.max(values_a)) + np.testing.assert_array_almost_equal(result.column(12), np.mean(values_b)) + np.testing.assert_array_almost_equal( + result.column(13), np.median(values_b) + ) + np.testing.assert_array_almost_equal(result.column(14), np.min(values_a)) + np.testing.assert_array_almost_equal( + result.column(15), np.sum(values_b.to_pylist()) + ) + np.testing.assert_array_almost_equal( + result.column(16), np.std(values_a, ddof=1) + ) + np.testing.assert_array_almost_equal( + result.column(17), np.std(values_b, ddof=0) + ) + np.testing.assert_array_almost_equal( + result.column(18), np.std(values_c, ddof=1) + ) + np.testing.assert_array_almost_equal( + result.column(19), np.var(values_a, ddof=1) + ) + np.testing.assert_array_almost_equal( + result.column(20), np.var(values_b, ddof=0) + ) + np.testing.assert_array_almost_equal( + result.column(21), np.var(values_c, ddof=1) ) - result = df.collect()[0] - assert result.column(0) == pa.array([3]) - assert result.column(1) == pa.array([1]) - assert result.column(2) == pa.array([3], type=pa.int64()) - assert result.column(3) == pa.array([2], type=pa.uint64()) diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 19c27665d..638a222dc 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -145,7 +145,9 @@ def test_execute(ctx, tmp_path): assert ctx.tables() == {"t"} # count - result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() + result = ctx.sql( + "SELECT COUNT(a) AS cnt FROM t WHERE a IS NOT NULL" + ).collect() expected = pa.array([7], pa.int64()) expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] diff --git a/src/catalog.rs b/src/catalog.rs index 4dd431fcb..76521e9af 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -72,7 +72,7 @@ impl PyCatalog { self.catalog.schema_names() } - #[args(name = "\"public\"")] + #[pyo3(signature = (name="public"))] fn database(&self, name: &str) -> PyResult { match self.catalog.schema(name) { Some(database) => Ok(PyDatabase::new(database)), diff --git a/src/context.rs b/src/context.rs index 7d9f1c570..c50d0392a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -57,18 +57,16 @@ pub(crate) struct PySessionContext { #[pymethods] impl PySessionContext { #[allow(clippy::too_many_arguments)] - #[args( - default_catalog = "\"datafusion\"", - default_schema = "\"public\"", - create_default_catalog_and_schema = "true", - information_schema = "false", - repartition_joins = "true", - repartition_aggregations = "true", - repartition_windows = "true", - parquet_pruning = "true", - target_partitions = "None", - config_options = "None" - )] + #[pyo3(signature = (default_catalog="datafusion", + default_schema="public", + create_default_catalog_and_schema=true, + information_schema=false, + repartition_joins=true, + repartition_aggregations=true, + repartition_windows=true, + parquet_pruning=true, + target_partitions=None, + config_options=None))] #[new] fn new( default_catalog: &str, @@ -209,11 +207,9 @@ impl PySessionContext { } #[allow(clippy::too_many_arguments)] - #[args( - table_partition_cols = "vec![]", - parquet_pruning = "true", - file_extension = "\".parquet\"" - )] + #[pyo3(signature = (name, path, table_partition_cols=vec![], + parquet_pruning=true, + file_extension=".parquet"))] fn register_parquet( &mut self, name: &str, @@ -233,13 +229,13 @@ impl PySessionContext { } #[allow(clippy::too_many_arguments)] - #[args( - schema = "None", - has_header = "true", - delimiter = "\",\"", - schema_infer_max_records = "1000", - file_extension = "\".csv\"" - )] + #[pyo3(signature = (name, + path, + schema=None, + has_header=true, + delimiter=",", + schema_infer_max_records=1000, + file_extension=".csv"))] fn register_csv( &mut self, name: &str, @@ -295,13 +291,13 @@ impl PySessionContext { Ok(()) } - #[args(name = "\"datafusion\"")] + #[pyo3(signature = (name="datafusion"))] fn catalog(&self, name: &str) -> PyResult { match self.ctx.catalog(name) { Some(catalog) => Ok(PyCatalog::new(catalog)), None => Err(PyKeyError::new_err(format!( "Catalog with name {} doesn't exist.", - &name + &name, ))), } } @@ -329,12 +325,7 @@ impl PySessionContext { } #[allow(clippy::too_many_arguments)] - #[args( - schema = "None", - schema_infer_max_records = "1000", - file_extension = "\".json\"", - table_partition_cols = "vec![]" - )] + #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![]))] fn read_json( &mut self, path: PathBuf, @@ -363,14 +354,14 @@ impl PySessionContext { } #[allow(clippy::too_many_arguments)] - #[args( - schema = "None", - has_header = "true", - delimiter = "\",\"", - schema_infer_max_records = "1000", - file_extension = "\".csv\"", - table_partition_cols = "vec![]" - )] + #[pyo3(signature = ( + path, + schema=None, + has_header=true, + delimiter=",", + schema_infer_max_records=1000, + file_extension=".csv", + table_partition_cols=vec![]))] fn read_csv( &self, path: PathBuf, @@ -413,12 +404,12 @@ impl PySessionContext { } #[allow(clippy::too_many_arguments)] - #[args( - parquet_pruning = "true", - file_extension = "\".parquet\"", - table_partition_cols = "vec![]", - skip_metadata = "true" - )] + #[pyo3(signature = ( + path, + table_partition_cols=vec![], + parquet_pruning=true, + file_extension=".parquet", + skip_metadata=true))] fn read_parquet( &self, path: &str, @@ -440,11 +431,7 @@ impl PySessionContext { } #[allow(clippy::too_many_arguments)] - #[args( - schema = "None", - file_extension = "\".avro\"", - table_partition_cols = "vec![]" - )] + #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))] fn read_avro( &self, path: &str, diff --git a/src/dataframe.rs b/src/dataframe.rs index 9c11b26f8..0f7375736 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -81,13 +81,13 @@ impl PyDataFrame { PyArrowType(self.df.schema().into()) } - #[args(args = "*")] + #[pyo3(signature = (*args))] fn select_columns(&self, args: Vec<&str>) -> PyResult { let df = self.df.as_ref().clone().select_columns(&args)?; Ok(Self::new(df)) } - #[args(args = "*")] + #[pyo3(signature = (*args))] fn select(&self, args: Vec) -> PyResult { let expr = args.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().select(expr)?; @@ -122,7 +122,7 @@ impl PyDataFrame { Ok(Self::new(df)) } - #[args(exprs = "*")] + #[pyo3(signature = (*exprs))] fn sort(&self, exprs: Vec) -> PyResult { let exprs = exprs.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().sort(exprs)?; @@ -162,7 +162,7 @@ impl PyDataFrame { } /// Print the result, 20 lines by default - #[args(num = "20")] + #[pyo3(signature = (num=20))] fn show(&self, py: Python, num: usize) -> PyResult<()> { let df = self.df.as_ref().clone().limit(0, Some(num))?; let batches = wait_for_future(py, df.collect())?; @@ -207,7 +207,7 @@ impl PyDataFrame { } /// Print the query plan - #[args(verbose = false, analyze = false)] + #[pyo3(signature = (verbose=false, analyze=false))] fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> { let df = self.df.as_ref().clone().explain(verbose, analyze)?; let batches = wait_for_future(py, df.collect())?; @@ -241,7 +241,7 @@ impl PyDataFrame { } /// Repartition a `DataFrame` based on a logical partitioning scheme. - #[args(args = "*", num)] + #[pyo3(signature = (*args, num))] fn repartition_by_hash(&self, args: Vec, num: usize) -> PyResult { let expr = args.into_iter().map(|py_expr| py_expr.into()).collect(); let new_df = self @@ -254,7 +254,7 @@ impl PyDataFrame { /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The /// two `DataFrame`s must have exactly the same schema - #[args(distinct = false)] + #[pyo3(signature = (py_df, distinct=false))] fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyResult { let new_df = if distinct { self.df diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs index be6dc1eed..859678856 100644 --- a/src/dataset_exec.rs +++ b/src/dataset_exec.rs @@ -23,7 +23,7 @@ use pyo3::types::{PyDict, PyIterator, PyList}; use std::any::Any; use std::sync::Arc; -use futures::stream; +use futures::{stream, TryStreamExt}; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::error::ArrowError; @@ -228,8 +228,9 @@ impl ExecutionPlan for DatasetExec { }; let record_batch_stream = stream::iter(record_batches); - let record_batch_stream: SendableRecordBatchStream = - Box::pin(RecordBatchStreamAdapter::new(schema, record_batch_stream)); + let record_batch_stream: SendableRecordBatchStream = Box::pin( + RecordBatchStreamAdapter::new(schema, record_batch_stream.map_err(|e| e.into())), + ); Ok(record_batch_stream) }) } diff --git a/src/expression.rs b/src/expression.rs index 1eb7813ed..ccc6dd873 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -24,7 +24,7 @@ use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField}; use datafusion::scalar::ScalarValue; -/// An PyExpr that can be used on a DataFrame +/// A PyExpr that can be used on a DataFrame #[pyclass(name = "Expression", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub(crate) struct PyExpr { @@ -117,7 +117,7 @@ impl PyExpr { } /// Create a sort PyExpr from an existing PyExpr. - #[args(ascending = true, nulls_first = true)] + #[pyo3(signature = (ascending=true, nulls_first=true))] pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyExpr { self.expr.clone().sort(ascending, nulls_first).into() } diff --git a/src/functions.rs b/src/functions.rs index ac1077ea5..5cabb1225 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -38,7 +38,9 @@ fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { /// Computes a binary hash of the given data. type is the algorithm to use. /// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. -#[pyfunction(value, method)] +// #[pyfunction(value, method)] +#[pyfunction] +#[pyo3(signature = (value, method))] fn digest(value: PyExpr, method: PyExpr) -> PyExpr { PyExpr { expr: datafusion_expr::digest(value.expr, method.expr), @@ -47,7 +49,8 @@ fn digest(value: PyExpr, method: PyExpr) -> PyExpr { /// Concatenates the text representations of all the arguments. /// NULL arguments are ignored. -#[pyfunction(args = "*")] +#[pyfunction] +#[pyo3(signature = (*args))] fn concat(args: Vec) -> PyResult { let args = args.into_iter().map(|e| e.expr).collect::>(); Ok(datafusion_expr::concat(&args).into()) @@ -56,7 +59,8 @@ fn concat(args: Vec) -> PyResult { /// Concatenates all but the first argument, with separators. /// The first argument is used as the separator string, and should not be NULL. /// Other NULL arguments are ignored. -#[pyfunction(sep, args = "*")] +#[pyfunction] +#[pyo3(signature = (sep, *args))] fn concat_ws(sep: String, args: Vec) -> PyResult { let args = args.into_iter().map(|e| e.expr).collect::>(); Ok(datafusion_expr::concat_ws(lit(sep), args).into()) @@ -146,7 +150,8 @@ macro_rules! scalar_function { ($NAME: ident, $FUNC: ident, $DOC: expr) => { #[doc = $DOC] - #[pyfunction(args = "*")] + #[pyfunction] + #[pyo3(signature = (*args))] fn $NAME(args: Vec) -> PyExpr { let expr = datafusion_expr::Expr::ScalarFunction { fun: BuiltinScalarFunction::$FUNC, @@ -163,7 +168,8 @@ macro_rules! aggregate_function { }; ($NAME: ident, $FUNC: ident, $DOC: expr) => { #[doc = $DOC] - #[pyfunction(args = "*", distinct = "false")] + #[pyfunction] + #[pyo3(signature = (*args, distinct=false))] fn $NAME(args: Vec, distinct: bool) -> PyExpr { let expr = datafusion_expr::Expr::AggregateFunction(AggregateFunction { fun: datafusion_expr::aggregate_function::AggregateFunction::$FUNC, @@ -287,25 +293,49 @@ scalar_function!(upper, Upper, "Converts the string to all upper case."); scalar_function!(make_array, MakeArray); scalar_function!(array, MakeArray); scalar_function!(nullif, NullIf); -//scalar_function!(uuid, Uuid); -//scalar_function!(struct, Struct); +scalar_function!(uuid, Uuid); +scalar_function!(r#struct, Struct); // Use raw identifier since struct is a keyword scalar_function!(from_unixtime, FromUnixtime); scalar_function!(arrow_typeof, ArrowTypeof); scalar_function!(random, Random); +aggregate_function!(approx_distinct, ApproxDistinct); +aggregate_function!(approx_median, ApproxMedian); +aggregate_function!(approx_percentile_cont, ApproxPercentileCont); +aggregate_function!( + approx_percentile_cont_with_weight, + ApproxPercentileContWithWeight +); +aggregate_function!(array_agg, ArrayAgg); aggregate_function!(avg, Avg); +aggregate_function!(corr, Correlation); aggregate_function!(count, Count); +aggregate_function!(covar, Covariance); +aggregate_function!(covar_pop, CovariancePop); +aggregate_function!(covar_samp, Covariance); +aggregate_function!(grouping, Grouping); aggregate_function!(max, Max); +aggregate_function!(mean, Avg); +aggregate_function!(median, Median); aggregate_function!(min, Min); aggregate_function!(sum, Sum); -aggregate_function!(approx_distinct, ApproxDistinct); +aggregate_function!(stddev, Stddev); +aggregate_function!(stddev_pop, StddevPop); +aggregate_function!(stddev_samp, Stddev); +aggregate_function!(var, Variance); +aggregate_function!(var_pop, VariancePop); +aggregate_function!(var_samp, Variance); pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; m.add_wrapped(wrap_pyfunction!(approx_distinct))?; m.add_wrapped(wrap_pyfunction!(alias))?; + m.add_wrapped(wrap_pyfunction!(approx_median))?; + m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?; + m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?; m.add_wrapped(wrap_pyfunction!(array))?; + m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; @@ -322,9 +352,13 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(col))?; m.add_wrapped(wrap_pyfunction!(concat_ws))?; m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(corr))?; m.add_wrapped(wrap_pyfunction!(cos))?; m.add_wrapped(wrap_pyfunction!(count))?; m.add_wrapped(wrap_pyfunction!(count_star))?; + m.add_wrapped(wrap_pyfunction!(covar))?; + m.add_wrapped(wrap_pyfunction!(covar_pop))?; + m.add_wrapped(wrap_pyfunction!(covar_samp))?; m.add_wrapped(wrap_pyfunction!(current_date))?; m.add_wrapped(wrap_pyfunction!(current_time))?; m.add_wrapped(wrap_pyfunction!(date_bin))?; @@ -336,6 +370,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(exp))?; m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(from_unixtime))?; + m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(left))?; @@ -350,6 +385,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(max))?; m.add_wrapped(wrap_pyfunction!(make_array))?; m.add_wrapped(wrap_pyfunction!(md5))?; + m.add_wrapped(wrap_pyfunction!(mean))?; + m.add_wrapped(wrap_pyfunction!(median))?; m.add_wrapped(wrap_pyfunction!(min))?; m.add_wrapped(wrap_pyfunction!(now))?; m.add_wrapped(wrap_pyfunction!(nullif))?; @@ -376,8 +413,11 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(split_part))?; m.add_wrapped(wrap_pyfunction!(sqrt))?; m.add_wrapped(wrap_pyfunction!(starts_with))?; + m.add_wrapped(wrap_pyfunction!(stddev))?; + m.add_wrapped(wrap_pyfunction!(stddev_pop))?; + m.add_wrapped(wrap_pyfunction!(stddev_samp))?; m.add_wrapped(wrap_pyfunction!(strpos))?; - //m.add_wrapped(wrap_pyfunction!(struct))?; + m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since struct is a keyword m.add_wrapped(wrap_pyfunction!(substr))?; m.add_wrapped(wrap_pyfunction!(sum))?; m.add_wrapped(wrap_pyfunction!(tan))?; @@ -390,7 +430,10 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; - //m.add_wrapped(wrap_pyfunction!(uuid))?; + m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision + m.add_wrapped(wrap_pyfunction!(var))?; + m.add_wrapped(wrap_pyfunction!(var_pop))?; + m.add_wrapped(wrap_pyfunction!(var_samp))?; m.add_wrapped(wrap_pyfunction!(window))?; Ok(()) } diff --git a/src/sql/logical/table_scan.rs b/src/sql/logical/table_scan.rs index ca491790e..5435eb9ba 100644 --- a/src/sql/logical/table_scan.rs +++ b/src/sql/logical/table_scan.rs @@ -15,73 +15,109 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use datafusion_common::DFSchema; -use datafusion_expr::{logical_plan::TableScan, LogicalPlan}; +use std::fmt::{self, Display, Formatter}; +use datafusion_expr::logical_plan::TableScan; use pyo3::prelude::*; -use crate::{ - expression::{py_expr_list, PyExpr}, - sql::exceptions::py_type_err, -}; +use crate::expression::PyExpr; + + #[pyclass(name = "TableScan", module = "dask_planner", subclass)] -#[derive(Clone, FromPyObject)] +#[derive(Clone)] pub struct PyTableScan { - pub(crate) table_scan: TableScan, - input: Arc, + table_scan: TableScan, +} + +impl From for TableScan { + fn from(tbl_scan: PyTableScan) -> TableScan { + tbl_scan.table_scan + } +} + +impl From for PyTableScan { + fn from(table_scan: TableScan) -> PyTableScan { + PyTableScan { table_scan } + } +} + +impl Display for PyTableScan { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "TableScan\nTable Name: {} + \nProjections: {:?} + \nProjected Schema: {:?} + \nFilters: {:?}", + &self.table_scan.table_name, + &self.py_projections(), + self.table_scan.projected_schema, + self.py_filters(), + ) + } } #[pymethods] impl PyTableScan { - #[pyo3(name = "getTableScanProjects")] - fn scan_projects(&mut self) -> PyResult> { + + /// Retrieves the name of the table represented by this `TableScan` instance + #[pyo3(name = "table_name")] + fn py_table_name(&self) -> PyResult<&str> { + Ok(&self.table_scan.table_name) + } + + /// TODO: Bindings for `TableSource` need to exist first. Left as a + /// placeholder to display intention to add when able to. + // #[pyo3(name = "source")] + // fn py_source(&self) -> PyResult> { + // Ok(self.table_scan.source) + // } + + /// The column indexes that should be. Note if this is empty then + /// all columns should be read by the `TableProvider`. This function + /// provides a Tuple of the (index, column_name) to make things simplier + /// for the calling code since often times the name is preferred to + /// the index which is a lower level abstraction. + #[pyo3(name = "projection")] + fn py_projections(&self) -> PyResult> { match &self.table_scan.projection { Some(indices) => { let schema = self.table_scan.source.schema(); Ok(indices .iter() - .map(|i| schema.field(*i).name().to_string()) + .map(|i| (*i, schema.field(*i).name().to_string())) .collect()) } None => Ok(vec![]), } } - /// If the 'TableScan' contains columns that should be projected during the - /// read return True, otherwise return False - #[pyo3(name = "containsProjections")] - fn contains_projections(&self) -> bool { - self.table_scan.projection.is_some() + /// TODO: Bindings for `DFSchema` need to exist first. Left as a + /// placeholder to display intention to add when able to. + // /// Resulting schema from the `TableScan` operation + // #[pyo3(name = "projectedSchema")] + // fn py_projected_schema(&self) -> PyResult { + // Ok(self.table_scan.projected_schema) + // } + + /// Certain `TableProvider` physical readers offer the capability to filter rows that + /// are read at read time. These `filters` are contained here. + #[pyo3(name = "filters")] + fn py_filters(&self) -> PyResult> { + Ok( + self.table_scan.filters + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect() + ) } - #[pyo3(name = "getFilters")] - fn scan_filters(&self) -> PyResult> { - py_expr_list(&self.input, &self.table_scan.filters) + /// Optional number of rows that should be read at read time by the `TableProvider` + #[pyo3(name = "fetch")] + fn py_fetch(&self) -> PyResult> { + Ok(self.table_scan.fetch) } -} -// impl TryFrom for PyTableScan { -// type Error = PyErr; - -// fn try_from(logical_plan: LogicalPlan) -> Result { -// match logical_plan { -// LogicalPlan::TableScan(table_scan) => { -// // Create an input logical plan that's identical to the table scan with schema from the table source -// let mut input = table_scan.clone(); -// input.projected_schema = DFSchema::try_from_qualified_schema( -// &table_scan.table_name, -// &table_scan.source.schema(), -// ) -// .map_or(input.projected_schema, Arc::new); - -// Ok(PyTableScan { -// table_scan, -// input: Arc::new(LogicalPlan::TableScan(input)), -// }) -// } -// _ => Err(py_type_err("unexpected plan")), -// } -// } -// } + fn __repr__(&self) -> PyResult { + Ok(format!("TableScan({})", self)) + } + +} diff --git a/src/store.rs b/src/store.rs index 2e8c9eb25..7d9bb7518 100644 --- a/src/store.rs +++ b/src/store.rs @@ -45,7 +45,7 @@ pub struct PyLocalFileSystemContext { #[pymethods] impl PyLocalFileSystemContext { - #[args(prefix = "None")] + #[pyo3(signature = (prefix=None))] #[new] fn new(prefix: Option) -> Self { if let Some(prefix) = prefix { @@ -78,17 +78,7 @@ pub struct PyMicrosoftAzureContext { #[pymethods] impl PyMicrosoftAzureContext { #[allow(clippy::too_many_arguments)] - #[args( - account = "None", - access_key = "None", - bearer_token = "None", - client_id = "None", - client_secret = "None", - tenant_id = "None", - sas_query_pairs = "None", - use_emulator = "None", - allow_http = "None" - )] + #[pyo3(signature = (container_name, account=None, access_key=None, bearer_token=None, client_id=None, client_secret=None, tenant_id=None, sas_query_pairs=None, use_emulator=None, allow_http=None))] #[new] fn new( container_name: String, @@ -165,7 +155,7 @@ pub struct PyGoogleCloudContext { #[pymethods] impl PyGoogleCloudContext { #[allow(clippy::too_many_arguments)] - #[args(service_account_path = "None")] + #[pyo3(signature = (bucket_name, service_account_path=None))] #[new] fn new(bucket_name: String, service_account_path: Option) -> Self { let mut builder = GoogleCloudStorageBuilder::new().with_bucket_name(&bucket_name); @@ -195,14 +185,7 @@ pub struct PyAmazonS3Context { #[pymethods] impl PyAmazonS3Context { #[allow(clippy::too_many_arguments)] - #[args( - region = "None", - access_key_id = "None", - secret_access_key = "None", - endpoint = "None", - imdsv1_fallback = "false", - allow_http = "false" - )] + #[pyo3(signature = (bucket_name, region=None, access_key_id=None, secret_access_key=None, endpoint=None, allow_http=false, imdsv1_fallback=false))] #[new] fn new( bucket_name: String, diff --git a/src/substrait.rs b/src/substrait.rs index f50734932..b52742d55 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -22,7 +22,7 @@ use crate::errors::{py_datafusion_err, DataFusionError}; use crate::sql::logical::PyLogicalPlan; use crate::utils::wait_for_future; -use datafusion_substrait::{consumer, producer, serializer, substrait::proto::Plan}; +use datafusion_substrait::{logical_plan::{consumer, producer}, serializer, substrait::proto::Plan}; #[pyclass(name = "plan", module = "datafusion.substrait", subclass, unsendable)] #[derive(Debug, Clone)] diff --git a/src/udaf.rs b/src/udaf.rs index a623de6b0..7a8bd493b 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -136,7 +136,7 @@ impl PyAggregateUDF { } /// creates a new PyExpr with the call of the udf - #[args(args = "*")] + #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { let args = args.iter().map(|e| e.expr.clone()).collect(); Ok(self.function.call(args).into()) diff --git a/src/udf.rs b/src/udf.rs index 10a8782b2..1849711f3 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -87,7 +87,7 @@ impl PyScalarUDF { } /// creates a new PyExpr with the call of the udf - #[args(args = "*")] + #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { let args = args.iter().map(|e| e.expr.clone()).collect(); Ok(self.function.call(args).into()) From 2c536a368c878d4b424de47d2d4bceedfab6293f Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 14 Feb 2023 13:27:36 -0500 Subject: [PATCH 13/16] Add new module to match datafusion-expr --- src/common/data_type.rs | 2 +- src/dataframe.rs | 2 +- src/{expression.rs => expr.rs} | 12 +++++++++++- src/{sql/logical => expr}/table_scan.rs | 4 ++-- src/functions.rs | 2 +- src/lib.rs | 8 ++++++-- src/sql/logical.rs | 2 -- src/substrait.rs | 4 ---- src/udaf.rs | 2 +- src/udf.rs | 2 +- 10 files changed, 24 insertions(+), 16 deletions(-) rename src/{expression.rs => expr.rs} (91%) rename src/{sql/logical => expr}/table_scan.rs (97%) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index ed0665cb7..8ada1c756 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -216,7 +216,7 @@ impl DataTypeMap { #[staticmethod] #[pyo3(name = "arrow")] pub fn py_map_from_arrow_type(arrow_type: &PyDataType) -> PyResult { - Ok(DataTypeMap::map_from_arrow_type(&arrow_type.data_type)) + DataTypeMap::map_from_arrow_type(&arrow_type.data_type) } #[staticmethod] diff --git a/src/dataframe.rs b/src/dataframe.rs index 0f7375736..4b9fbca6c 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -18,7 +18,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::sql::logical::PyLogicalPlan; use crate::utils::wait_for_future; -use crate::{errors::DataFusionError, expression::PyExpr}; +use crate::{errors::DataFusionError, expr::PyExpr}; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowConvert, PyArrowException, PyArrowType}; use datafusion::arrow::util::pretty; diff --git a/src/expression.rs b/src/expr.rs similarity index 91% rename from src/expression.rs rename to src/expr.rs index ccc6dd873..132e0e085 100644 --- a/src/expression.rs +++ b/src/expr.rs @@ -24,8 +24,10 @@ use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField}; use datafusion::scalar::ScalarValue; +pub mod table_scan; + /// A PyExpr that can be used on a DataFrame -#[pyclass(name = "Expression", module = "datafusion", subclass)] +#[pyclass(name = "Expr", module = "datafusion.expr", subclass)] #[derive(Debug, Clone)] pub(crate) struct PyExpr { pub(crate) expr: Expr, @@ -133,3 +135,11 @@ impl PyExpr { expr.into() } } + + +/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/ +pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/src/sql/logical/table_scan.rs b/src/expr/table_scan.rs similarity index 97% rename from src/sql/logical/table_scan.rs rename to src/expr/table_scan.rs index 5435eb9ba..043061d68 100644 --- a/src/sql/logical/table_scan.rs +++ b/src/expr/table_scan.rs @@ -19,11 +19,11 @@ use std::fmt::{self, Display, Formatter}; use datafusion_expr::logical_plan::TableScan; use pyo3::prelude::*; -use crate::expression::PyExpr; +use crate::expr::PyExpr; -#[pyclass(name = "TableScan", module = "dask_planner", subclass)] +#[pyclass(name = "TableScan", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyTableScan { table_scan: TableScan, diff --git a/src/functions.rs b/src/functions.rs index 5cabb1225..e9b43340c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -24,7 +24,7 @@ use datafusion_expr::window_function::find_df_window_func; use datafusion_expr::{aggregate_function, lit, BuiltinScalarFunction, Expr, WindowFrame}; use crate::errors::DataFusionError; -use crate::expression::PyExpr; +use crate::expr::PyExpr; #[pyfunction] fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { diff --git a/src/lib.rs b/src/lib.rs index 5391de57c..b16ef753f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ mod dataset; mod dataset_exec; pub mod errors; #[allow(clippy::borrow_deref_ref)] -mod expression; +mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; pub mod physical_plan; @@ -64,13 +64,17 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Register `expr` as a submodule. Matching `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/ + let expr = PyModule::new(py, "expr")?; + expr::init_module(expr)?; + m.add_submodule(expr)?; + // Register the functions as a submodule let funcs = PyModule::new(py, "functions")?; functions::init_module(funcs)?; diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 3c284be66..dcd7baa58 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -20,8 +20,6 @@ use std::sync::Arc; use datafusion_expr::LogicalPlan; use pyo3::prelude::*; -pub mod table_scan; - #[pyclass(name = "LogicalPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyLogicalPlan { diff --git a/src/substrait.rs b/src/substrait.rs index ec5ead185..2bde01123 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -22,13 +22,9 @@ use crate::errors::{py_datafusion_err, DataFusionError}; use crate::sql::logical::PyLogicalPlan; use crate::utils::wait_for_future; -<<<<<<< HEAD -use datafusion_substrait::{logical_plan::{consumer, producer}, serializer, substrait::proto::Plan}; -======= use datafusion_substrait::logical_plan::{consumer, producer}; use datafusion_substrait::serializer; use datafusion_substrait::substrait::proto::Plan; ->>>>>>> upstream/main #[pyclass(name = "plan", module = "datafusion.substrait", subclass, unsendable)] #[derive(Debug, Clone)] diff --git a/src/udaf.rs b/src/udaf.rs index 7a8bd493b..d5866f840 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -26,7 +26,7 @@ use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; use datafusion_expr::{create_udaf, Accumulator, AccumulatorFunctionImplementation, AggregateUDF}; -use crate::expression::PyExpr; +use crate::expr::PyExpr; use crate::utils::parse_volatility; #[derive(Debug)] diff --git a/src/udf.rs b/src/udf.rs index 1849711f3..f3e6cfb58 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -28,7 +28,7 @@ use datafusion::physical_plan::udf::ScalarUDF; use datafusion_expr::create_udf; use datafusion_expr::function::ScalarFunctionImplementation; -use crate::expression::PyExpr; +use crate::expr::PyExpr; use crate::utils::parse_volatility; /// Create a DataFusion's UDF implementation from a python function From 8922047c38ca982f48e1392fa864ce5b67021725 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 14 Feb 2023 13:33:40 -0500 Subject: [PATCH 14/16] cargo fmt fixes --- src/expr.rs | 1 - src/expr/table_scan.rs | 24 +++++++++++------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/expr.rs b/src/expr.rs index 132e0e085..dceedc1fc 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -136,7 +136,6 @@ impl PyExpr { } } - /// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; diff --git a/src/expr/table_scan.rs b/src/expr/table_scan.rs index 043061d68..bc7d68af5 100644 --- a/src/expr/table_scan.rs +++ b/src/expr/table_scan.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::fmt::{self, Display, Formatter}; use datafusion_expr::logical_plan::TableScan; use pyo3::prelude::*; +use std::fmt::{self, Display, Formatter}; use crate::expr::PyExpr; - - #[pyclass(name = "TableScan", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyTableScan { @@ -43,7 +41,9 @@ impl From for PyTableScan { impl Display for PyTableScan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "TableScan\nTable Name: {} + write!( + f, + "TableScan\nTable Name: {} \nProjections: {:?} \nProjected Schema: {:?} \nFilters: {:?}", @@ -57,7 +57,6 @@ impl Display for PyTableScan { #[pymethods] impl PyTableScan { - /// Retrieves the name of the table represented by this `TableScan` instance #[pyo3(name = "table_name")] fn py_table_name(&self) -> PyResult<&str> { @@ -71,7 +70,7 @@ impl PyTableScan { // Ok(self.table_scan.source) // } - /// The column indexes that should be. Note if this is empty then + /// The column indexes that should be. Note if this is empty then /// all columns should be read by the `TableProvider`. This function /// provides a Tuple of the (index, column_name) to make things simplier /// for the calling code since often times the name is preferred to @@ -102,12 +101,12 @@ impl PyTableScan { /// are read at read time. These `filters` are contained here. #[pyo3(name = "filters")] fn py_filters(&self) -> PyResult> { - Ok( - self.table_scan.filters - .iter() - .map(|expr| PyExpr::from(expr.clone())) - .collect() - ) + Ok(self + .table_scan + .filters + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect()) } /// Optional number of rows that should be read at read time by the `TableProvider` @@ -119,5 +118,4 @@ impl PyTableScan { fn __repr__(&self) -> PyResult { Ok(format!("TableScan({})", self)) } - } From 433c1f19fe4b049f7a64e14c0dac7c0d42019423 Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 14 Feb 2023 14:25:10 -0500 Subject: [PATCH 15/16] Update pytest for refactoring --- datafusion/__init__.py | 12 ++++++++---- datafusion/expr.py | 23 +++++++++++++++++++++++ datafusion/tests/test_imports.py | 9 ++++++--- docs/source/api.rst | 2 +- docs/source/api/expression.rst | 4 ++-- src/functions.rs | 6 +++--- 6 files changed, 43 insertions(+), 13 deletions(-) create mode 100644 datafusion/expr.py diff --git a/datafusion/__init__.py b/datafusion/__init__.py index ddab950be..54dbc8b7c 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -32,10 +32,14 @@ SessionContext, SessionConfig, RuntimeConfig, - Expression, ScalarUDF, ) +from .expr import ( + Expr, + TableScan, +) + __version__ = importlib_metadata.version(__name__) __all__ = [ @@ -44,7 +48,7 @@ "SessionContext", "SessionConfig", "RuntimeConfig", - "Expression", + "Expr", "AggregateUDF", "ScalarUDF", "column", @@ -71,7 +75,7 @@ def evaluate(self) -> pa.Scalar: def column(value): - return Expression.column(value) + return Expr.column(value) col = column @@ -80,7 +84,7 @@ def column(value): def literal(value): if not isinstance(value, pa.Scalar): value = pa.scalar(value) - return Expression.literal(value) + return Expr.literal(value) lit = literal diff --git a/datafusion/expr.py b/datafusion/expr.py new file mode 100644 index 000000000..e914b85d7 --- /dev/null +++ b/datafusion/expr.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from ._internal import expr + + +def __getattr__(name): + return getattr(expr, name) diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py index f571bc4bb..cac0e6d53 100644 --- a/datafusion/tests/test_imports.py +++ b/datafusion/tests/test_imports.py @@ -22,11 +22,15 @@ AggregateUDF, DataFrame, SessionContext, - Expression, ScalarUDF, functions, ) +from datafusion.expr import ( + Expr, + TableScan, +) + def test_import_datafusion(): assert datafusion.__name__ == "datafusion" @@ -39,7 +43,6 @@ def test_datafusion_python_version(): def test_class_module_is_datafusion(): for klass in [ SessionContext, - Expression, DataFrame, ScalarUDF, AggregateUDF, @@ -62,7 +65,7 @@ def test_classes_are_inheritable(): class MyExecContext(SessionContext): pass - class MyExpression(Expression): + class MyExpression(Expr): pass class MyDataFrame(DataFrame): diff --git a/docs/source/api.rst b/docs/source/api.rst index a5d65433d..a3e7e24df 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -27,6 +27,6 @@ API Reference api/config api/dataframe api/execution_context - api/expression + api/expr api/functions api/object_store diff --git a/docs/source/api/expression.rst b/docs/source/api/expression.rst index 45923fb54..30137d135 100644 --- a/docs/source/api/expression.rst +++ b/docs/source/api/expression.rst @@ -18,10 +18,10 @@ .. _api.expression: .. currentmodule:: datafusion -Expression +Expr ========== .. autosummary:: :toctree: ../generated/ - Expression + Expr diff --git a/src/functions.rs b/src/functions.rs index e9b43340c..8acffeb55 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -66,7 +66,7 @@ fn concat_ws(sep: String, args: Vec) -> PyResult { Ok(datafusion_expr::concat_ws(lit(sep), args).into()) } -/// Creates a new Sort expression +/// Creates a new Sort Expr #[pyfunction] fn order_by(expr: PyExpr, asc: Option, nulls_first: Option) -> PyResult { Ok(PyExpr { @@ -78,7 +78,7 @@ fn order_by(expr: PyExpr, asc: Option, nulls_first: Option) -> PyRes }) } -/// Creates a new Alias expression +/// Creates a new Alias Expr #[pyfunction] fn alias(expr: PyExpr, name: &str) -> PyResult { Ok(PyExpr { @@ -86,7 +86,7 @@ fn alias(expr: PyExpr, name: &str) -> PyResult { }) } -/// Create a column reference expression +/// Create a column reference Expr #[pyfunction] fn col(name: &str) -> PyResult { Ok(PyExpr { From 89ab22639a642874bbafe828eeead4e0911a47bf Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 14 Feb 2023 14:49:00 -0500 Subject: [PATCH 16/16] Python linter changes --- datafusion/__init__.py | 1 + datafusion/tests/test_context.py | 3 ++- datafusion/tests/test_dataframe.py | 4 ++-- datafusion/tests/test_imports.py | 3 +++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/datafusion/__init__.py b/datafusion/__init__.py index 54dbc8b7c..0b4e89643 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -53,6 +53,7 @@ "ScalarUDF", "column", "literal", + "TableScan", ] diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index 48d41c114..6faffaf5b 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -172,7 +172,8 @@ def test_dataset_filter_nested_data(ctx): df = ctx.table("t") - # This filter will not be pushed down to DatasetExec since it isn't supported + # This filter will not be pushed down to DatasetExec since it + # isn't supported df = df.select( column("nested_data")["a"] + column("nested_data")["b"], column("nested_data")["a"] - column("nested_data")["b"], diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 4d70845aa..30327ee0c 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -314,8 +314,8 @@ def test_execution_plan(aggregate_df): indent = plan.display_indent() - # indent plan will be different for everyone due to absolute path to filename, so - # we just check for some expected content + # indent plan will be different for everyone due to absolute path + # to filename, so we just check for some expected content assert "ProjectionExec:" in indent assert "AggregateExec:" in indent assert "CoalesceBatchesExec:" in indent diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py index cac0e6d53..1e8c796bb 100644 --- a/datafusion/tests/test_imports.py +++ b/datafusion/tests/test_imports.py @@ -49,6 +49,9 @@ def test_class_module_is_datafusion(): ]: assert klass.__module__ == "datafusion" + for klass in [Expr, TableScan]: + assert klass.__module__ == "datafusion.expr" + def test_import_from_functions_submodule(): from datafusion.functions import abs, sin # noqa