diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index bfb3a93e3249e..368f76774d696 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -45,7 +45,7 @@ unicode_expressions = ["unicode-segmentation"] [dependencies] ahash = "0.7" -hashbrown = "0.11" +hashbrown = { version = "0.11", features = ["raw"] } arrow = { version = "5.1", features = ["prettyprint"] } parquet = { version = "5.1", features = ["arrow"] } sqlparser = "0.9.0" diff --git a/python/Cargo.toml b/python/Cargo.toml index fe84e5234c333..83973cc25627b 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -30,16 +30,16 @@ edition = "2018" libc = "0.2" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" -pyo3 = { version = "0.14.1", features = ["extension-module"] } +pyo3 = { version = "0.14.2", features = ["extension-module"] } datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" } [lib] -name = "datafusion" +name = "internals" crate-type = ["cdylib"] [package.metadata.maturin] +name = "datafusion.internals" requires-dist = ["pyarrow>=1"] - classifier = [ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py new file mode 100644 index 0000000000000..20bc3f22bfc5c --- /dev/null +++ b/python/datafusion/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .internals import PyDataFrame as DataFrame +from .internals import PyExecutionContext as ExecutionContext +from .internals import PyExpr as Expr +from .internals import functions + +__all__ = [ + "DataFrame", + "ExecutionContext", + "Expr", + "functions" +] diff --git a/python/tests/__init__.py b/python/datafusion/tests/__init__.py similarity index 100% rename from python/tests/__init__.py rename to python/datafusion/tests/__init__.py diff --git a/python/tests/generic.py b/python/datafusion/tests/generic.py similarity index 100% rename from python/tests/generic.py rename to python/datafusion/tests/generic.py diff --git a/python/tests/test_df.py b/python/datafusion/tests/test_df.py similarity index 99% rename from python/tests/test_df.py rename to python/datafusion/tests/test_df.py index 5b6cbddbd74ba..b04eba53f6fdc 100644 --- a/python/tests/test_df.py +++ b/python/datafusion/tests/test_df.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_math_functions.py b/python/datafusion/tests/test_math_functions.py similarity index 99% rename from python/tests/test_math_functions.py rename to python/datafusion/tests/test_math_functions.py index 98656b8c4f422..4e473c3de16ac 100644 --- a/python/tests/test_math_functions.py +++ b/python/datafusion/tests/test_math_functions.py @@ -18,6 +18,7 @@ import numpy as np import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_pa_types.py b/python/datafusion/tests/test_pa_types.py similarity index 100% rename from python/tests/test_pa_types.py rename to python/datafusion/tests/test_pa_types.py diff --git a/python/tests/test_sql.py b/python/datafusion/tests/test_sql.py similarity index 99% rename from python/tests/test_sql.py rename to python/datafusion/tests/test_sql.py index 669f640529eb5..d6a16f23b6c85 100644 --- a/python/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -20,6 +20,7 @@ import pytest from datafusion import ExecutionContext + from . import generic as helpers diff --git a/python/tests/test_string_functions.py b/python/datafusion/tests/test_string_functions.py similarity index 99% rename from python/tests/test_string_functions.py rename to python/datafusion/tests/test_string_functions.py index ea064a6b2e9f6..4255d34805a04 100644 --- a/python/tests/test_string_functions.py +++ b/python/datafusion/tests/test_string_functions.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py similarity index 99% rename from python/tests/test_udaf.py rename to python/datafusion/tests/test_udaf.py index e7044d6119e38..aca1215a7cb24 100644 --- a/python/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -20,6 +20,7 @@ import pyarrow as pa import pyarrow.compute as pc import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/pyproject.toml b/python/pyproject.toml index 1482129897fae..ce33f58d29173 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -18,3 +18,12 @@ [build-system] requires = ["maturin>=0.11,<0.12"] build-backend = "maturin" + +[project] +name = "datafusion" +dependencies = [ + "pyarrow" +] + +[tool.isort] +profile = "black" diff --git a/python/rust-toolchain b/python/rust-toolchain index 6231a95e3036d..2bf5ad0447d33 100644 --- a/python/rust-toolchain +++ b/python/rust-toolchain @@ -1 +1 @@ -nightly-2021-05-10 +stable diff --git a/python/src/context.rs b/python/src/context.rs index 9acc14a5e2609..027e402c7c93e 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -24,41 +24,38 @@ use rand::Rng; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use datafusion::arrow::datatypes::Schema; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; +use datafusion::execution::context::ExecutionContext; use datafusion::prelude::CsvReadOptions; -use crate::dataframe; -use crate::errors; +use crate::dataframe::PyDataFrame; +use crate::errors::DataFusionError; use crate::functions; -use crate::to_rust; -use crate::types::PyDataType; +use crate::pyarrow::PyArrowConvert; -/// `ExecutionContext` is able to plan and execute DataFusion plans. +/// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a /// multi-threaded execution engine to perform the execution. #[pyclass(unsendable)] -pub(crate) struct ExecutionContext { - ctx: _ExecutionContext, +pub(crate) struct PyExecutionContext { + ctx: ExecutionContext, } #[pymethods] -impl ExecutionContext { +impl PyExecutionContext { #[new] fn new() -> Self { - ExecutionContext { - ctx: _ExecutionContext::new(), + PyExecutionContext { + ctx: ExecutionContext::new(), } } - /// Returns a DataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str) -> PyResult { - let df = self - .ctx - .sql(query) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - Ok(dataframe::DataFrame::new( + /// Returns a PyDataFrame whose plan corresponds to the SQL statement. + fn sql(&mut self, query: &str) -> PyResult { + let df = self.ctx.sql(query).map_err(DataFusionError::from)?; + Ok(PyDataFrame::new( self.ctx.state.clone(), df.to_logical_plan(), )) @@ -66,21 +63,20 @@ impl ExecutionContext { fn create_dataframe( &mut self, - partitions: Vec>, - py: Python, - ) -> PyResult { + partitions: Vec>, + ) -> PyResult { let partitions: Vec> = partitions - .iter() + .into_iter() .map(|batches| { batches - .iter() - .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) - .collect() + .into_iter() + .map(RecordBatch::from_pyarrow) + .collect::>() }) .collect::>()?; - let table = - errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; + let table = MemTable::try_new(partitions[0][0].schema(), partitions) + .map_err(DataFusionError::from)?; // generate a random (unique) name for this table let name = rand::thread_rng() @@ -88,15 +84,19 @@ impl ExecutionContext { .take(10) .collect::(); - errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - errors::wrap(self.ctx.table(&*name))?.to_logical_plan(), - )) + self.ctx + .register_table(&*name, Arc::new(table)) + .map_err(DataFusionError::from)?; + let table = self.ctx.table(&*name).map_err(DataFusionError::from)?; + + let df = PyDataFrame::new(self.ctx.state.clone(), table.to_logical_plan()); + Ok(df) } fn register_parquet(&mut self, name: &str, path: &str) -> PyResult<()> { - errors::wrap(self.ctx.register_parquet(name, path))?; + self.ctx + .register_parquet(name, path) + .map_err(DataFusionError::from)?; Ok(()) } @@ -121,7 +121,7 @@ impl ExecutionContext { .to_str() .ok_or(PyValueError::new_err("Unable to convert path to a string"))?; let schema = match schema { - Some(s) => Some(to_rust::to_rust_schema(s)?), + Some(s) => Some(Schema::from_pyarrow(s)?), None => None, }; let delimiter = delimiter.as_bytes(); @@ -138,7 +138,9 @@ impl ExecutionContext { .file_extension(file_extension); options.schema = schema.as_ref(); - errors::wrap(self.ctx.register_csv(name, path, options))?; + self.ctx + .register_csv(name, path, options) + .map_err(DataFusionError::from)?; Ok(()) } @@ -146,12 +148,12 @@ impl ExecutionContext { &mut self, name: &str, func: PyObject, - args_types: Vec, - return_type: PyDataType, - ) { - let function = functions::create_udf(func, args_types, return_type, name); - + args_types: Vec<&PyAny>, + return_type: &PyAny, + ) -> PyResult<()> { + let function = functions::create_udf(func, args_types, return_type, name)?; self.ctx.register_udf(function.function); + Ok(()) } fn tables(&self) -> HashSet { diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 4a50262ec3292..2dd42b64faafb 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -17,37 +17,40 @@ use std::sync::{Arc, Mutex}; -use logical_plan::LogicalPlan; -use pyo3::{prelude::*, types::PyTuple}; +use pyo3::{ + prelude::*, + types::{PyList, PyTuple}, +}; use tokio::runtime::Runtime; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; -use datafusion::logical_plan::{JoinType, LogicalPlanBuilder}; +use datafusion::execution::context::{ExecutionContext, ExecutionContextState}; +use datafusion::logical_plan::{JoinType, LogicalPlan, LogicalPlanBuilder}; use datafusion::physical_plan::collect; -use datafusion::{execution::context::ExecutionContextState, logical_plan}; -use crate::{errors, to_py}; -use crate::{errors::DataFusionError, expression}; +use crate::{ + errors, errors::DataFusionError, expression, expression::PyExpr, + pyarrow::PyArrowConvert, +}; -/// A DataFrame is a representation of a logical plan and an API to compose statements. +/// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. #[pyclass] -pub(crate) struct DataFrame { +pub(crate) struct PyDataFrame { ctx_state: Arc>, plan: LogicalPlan, } -impl DataFrame { - /// creates a new DataFrame +impl PyDataFrame { + /// creates a new PyDataFrame pub fn new(ctx_state: Arc>, plan: LogicalPlan) -> Self { Self { ctx_state, plan } } } #[pymethods] -impl DataFrame { - /// Select `expressions` from the existing DataFrame. +impl PyDataFrame { + /// Select `expressions` from the existing PyDataFrame. #[args(args = "*")] fn select(&self, args: &PyTuple) -> PyResult { let expressions = expression::from_tuple(args)?; @@ -56,30 +59,26 @@ impl DataFrame { errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?; let plan = errors::wrap(builder.build())?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) } /// Filter according to the `predicate` expression - fn filter(&self, predicate: expression::Expression) -> PyResult { + fn filter(&self, predicate: PyExpr) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); let builder = errors::wrap(builder.filter(predicate.expr))?; let plan = errors::wrap(builder.build())?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) } /// Aggregates using expressions - fn aggregate( - &self, - group_by: Vec, - aggs: Vec, - ) -> PyResult { + fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); let builder = errors::wrap(builder.aggregate( group_by.into_iter().map(|e| e.expr), @@ -87,19 +86,19 @@ impl DataFrame { ))?; let plan = errors::wrap(builder.build())?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) } /// Sort by specified sorting expressions - fn sort(&self, exprs: Vec) -> PyResult { + fn sort(&self, exprs: Vec) -> PyResult { let exprs = exprs.into_iter().map(|e| e.expr); let builder = LogicalPlanBuilder::from(self.plan.clone()); let builder = errors::wrap(builder.sort(exprs))?; let plan = errors::wrap(builder.build())?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) @@ -111,7 +110,7 @@ impl DataFrame { let builder = errors::wrap(builder.limit(count))?; let plan = errors::wrap(builder.build())?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) @@ -120,27 +119,27 @@ impl DataFrame { /// Executes the plan, returning a list of `RecordBatch`es. /// Unless some order is specified in the plan, there is no guarantee of the order of the result fn collect(&self, py: Python) -> PyResult { - let ctx = _ExecutionContext::from(self.ctx_state.clone()); - let plan = ctx - .optimize(&self.plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; + let ctx = ExecutionContext::from(self.ctx_state.clone()); + let plan = ctx.optimize(&self.plan).map_err(DataFusionError::from)?; let plan = ctx .create_physical_plan(&plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; + .map_err(DataFusionError::from)?; let rt = Runtime::new().unwrap(); let batches = py.allow_threads(|| { - rt.block_on(async { - collect(plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) + rt.block_on(async { collect(plan).await.map_err(DataFusionError::from) }) })?; - to_py::to_py(&batches) + + let mut py_batches = vec![]; + for batch in batches { + py_batches.push(batch.to_pyarrow(py)?); + } + let py_list = PyList::new(py, py_batches); + Ok(PyObject::from(py_list)) } - /// Returns the join of two DataFrames `on`. - fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult { + /// Returns the join of two PyDataFrames `on`. + fn join(&self, right: &PyDataFrame, on: Vec<&str>, how: &str) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); let join_type = match how { @@ -163,7 +162,7 @@ impl DataFrame { let plan = errors::wrap(builder.build())?; - Ok(DataFrame { + Ok(PyDataFrame { ctx_state: self.ctx_state.clone(), plan, }) diff --git a/python/src/errors.rs b/python/src/errors.rs index fbe98037a030f..cc181a98755d4 100644 --- a/python/src/errors.rs +++ b/python/src/errors.rs @@ -16,10 +16,11 @@ // under the License. use core::fmt; +//use std::result::Result; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; -use pyo3::{exceptions, PyErr}; +use pyo3::{exceptions::PyException, PyErr}; #[derive(Debug)] pub enum DataFusionError { @@ -38,9 +39,9 @@ impl fmt::Display for DataFusionError { } } -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - exceptions::PyException::new_err(err.to_string()) +impl From for DataFusionError { + fn from(err: ArrowError) -> DataFusionError { + DataFusionError::ArrowError(err) } } @@ -50,9 +51,9 @@ impl From for DataFusionError { } } -impl From for DataFusionError { - fn from(err: ArrowError) -> DataFusionError { - DataFusionError::ArrowError(err) +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) } } diff --git a/python/src/expression.rs b/python/src/expression.rs index 4320b1d14c8b7..f1a12d7ba8028 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -19,90 +19,89 @@ use pyo3::{ basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol, }; -use datafusion::logical_plan::Expr as _Expr; -use datafusion::physical_plan::udaf::AggregateUDF as _AggregateUDF; -use datafusion::physical_plan::udf::ScalarUDF as _ScalarUDF; +use datafusion::logical_plan::Expr; +use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; -/// An expression that can be used on a DataFrame +/// An PyExpr that can be used on a DataFrame #[pyclass] #[derive(Debug, Clone)] -pub(crate) struct Expression { - pub(crate) expr: _Expr, +pub(crate) struct PyExpr { + pub(crate) expr: Expr, } /// converts a tuple of expressions into a vector of Expressions -pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { +pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { value .iter() - .map(|e| e.extract::()) + .map(|e| e.extract::()) .collect::>() } #[pyproto] -impl PyNumberProtocol for Expression { - fn __add__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { +impl PyNumberProtocol for PyExpr { + fn __add__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr + rhs.expr, }) } - fn __sub__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __sub__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr - rhs.expr, }) } - fn __truediv__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __truediv__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr / rhs.expr, }) } - fn __mul__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __mul__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr * rhs.expr, }) } - fn __and__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr.and(rhs.expr), }) } - fn __or__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { + fn __or__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(PyExpr { expr: lhs.expr.or(rhs.expr), }) } - fn __invert__(&self) -> PyResult { - Ok(Expression { + fn __invert__(&self) -> PyResult { + Ok(PyExpr { expr: self.expr.clone().not(), }) } } #[pyproto] -impl PyObjectProtocol for Expression { - fn __richcmp__(&self, other: Expression, op: CompareOp) -> Expression { +impl PyObjectProtocol for PyExpr { + fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr { match op { - CompareOp::Lt => Expression { + CompareOp::Lt => PyExpr { expr: self.expr.clone().lt(other.expr), }, - CompareOp::Le => Expression { + CompareOp::Le => PyExpr { expr: self.expr.clone().lt_eq(other.expr), }, - CompareOp::Eq => Expression { + CompareOp::Eq => PyExpr { expr: self.expr.clone().eq(other.expr), }, - CompareOp::Ne => Expression { + CompareOp::Ne => PyExpr { expr: self.expr.clone().not_eq(other.expr), }, - CompareOp::Gt => Expression { + CompareOp::Gt => PyExpr { expr: self.expr.clone().gt(other.expr), }, - CompareOp::Ge => Expression { + CompareOp::Ge => PyExpr { expr: self.expr.clone().gt_eq(other.expr), }, } @@ -110,39 +109,39 @@ impl PyObjectProtocol for Expression { } #[pymethods] -impl Expression { - /// assign a name to the expression - pub fn alias(&self, name: &str) -> PyResult { - Ok(Expression { +impl PyExpr { + /// assign a name to the PyExpr + pub fn alias(&self, name: &str) -> PyResult { + Ok(PyExpr { expr: self.expr.clone().alias(name), }) } - /// Create a sort expression from an existing expression. + /// Create a sort PyExpr from an existing PyExpr. #[args(ascending = true, nulls_first = true)] - pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyResult { - Ok(Expression { + pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyResult { + Ok(PyExpr { expr: self.expr.clone().sort(ascending, nulls_first), }) } } -/// Represents a ScalarUDF +/// Represents a PyScalarUDF #[pyclass] #[derive(Debug, Clone)] -pub struct ScalarUDF { - pub(crate) function: _ScalarUDF, +pub struct PyScalarUDF { + pub(crate) function: ScalarUDF, } #[pymethods] -impl ScalarUDF { - /// creates a new expression with the call of the udf +impl PyScalarUDF { + /// creates a new PyExpr with the call of the udf #[call] #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { + fn __call__(&self, args: &PyTuple) -> PyResult { let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - Ok(Expression { + Ok(PyExpr { expr: self.function.call(args), }) } @@ -151,19 +150,19 @@ impl ScalarUDF { /// Represents a AggregateUDF #[pyclass] #[derive(Debug, Clone)] -pub struct AggregateUDF { - pub(crate) function: _AggregateUDF, +pub struct PyAggregateUDF { + pub(crate) function: AggregateUDF, } #[pymethods] -impl AggregateUDF { - /// creates a new expression with the call of the udf +impl PyAggregateUDF { + /// creates a new PyExpr with the call of the udf #[call] #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { + fn __call__(&self, args: &PyTuple) -> PyResult { let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - Ok(Expression { + Ok(PyExpr { expr: self.function.call(args), }) } diff --git a/python/src/functions.rs b/python/src/functions.rs index 23f010a6ae45c..f283e2fd838f6 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -15,46 +15,47 @@ // specific language governing permissions and limitations // under the License. -use crate::udaf; -use crate::udf; -use crate::{expression, types::PyDataType}; +use std::sync::Arc; + use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan; use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction}; -use std::sync::Arc; -/// Expression representing a column on the existing plan. +use crate::{ + expression, + expression::{PyAggregateUDF, PyExpr, PyScalarUDF}, + pyarrow::PyArrowConvert, + udaf, udf, +}; + +/// PyExpr representing a column on the existing plan. #[pyfunction] #[pyo3(text_signature = "(name)")] -fn col(name: &str) -> expression::Expression { - expression::Expression { +fn col(name: &str) -> PyExpr { + PyExpr { expr: logical_plan::col(name), } } -/// Expression representing a constant value +/// PyExpr representing a constant value #[pyfunction] #[pyo3(text_signature = "(value)")] -fn lit(value: i32) -> expression::Expression { - expression::Expression { +fn lit(value: i32) -> PyExpr { + PyExpr { expr: logical_plan::lit(value), } } #[pyfunction] -fn array(value: Vec) -> expression::Expression { - expression::Expression { +fn array(value: Vec) -> PyExpr { + PyExpr { expr: logical_plan::array(value.into_iter().map(|x| x.expr).collect::>()), } } #[pyfunction] -fn in_list( - expr: expression::Expression, - value: Vec, - negated: bool, -) -> expression::Expression { - expression::Expression { +fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { + PyExpr { expr: logical_plan::in_list( expr.expr, value.into_iter().map(|x| x.expr).collect::>(), @@ -65,8 +66,8 @@ fn in_list( /// Current date and time #[pyfunction] -fn now() -> expression::Expression { - expression::Expression { +fn now() -> PyExpr { + PyExpr { // here lit(0) is a stub for conform to arity expr: logical_plan::now(logical_plan::lit(0)), } @@ -74,8 +75,8 @@ fn now() -> expression::Expression { /// Returns a random value in the range 0.0 <= x < 1.0 #[pyfunction] -fn random() -> expression::Expression { - expression::Expression { +fn random() -> PyExpr { + PyExpr { expr: logical_plan::random(), } } @@ -83,10 +84,10 @@ fn random() -> expression::Expression { /// Concatenates the text representations of all the arguments. /// NULL arguments are ignored. #[pyfunction(args = "*")] -fn concat(args: &PyTuple) -> PyResult { +fn concat(args: &PyTuple) -> PyResult { let expressions = expression::from_tuple(args)?; let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { + Ok(PyExpr { expr: logical_plan::concat(&args), }) } @@ -95,10 +96,10 @@ fn concat(args: &PyTuple) -> PyResult { /// The first argument is used as the separator string, and should not be NULL. /// Other NULL arguments are ignored. #[pyfunction(sep, args = "*")] -fn concat_ws(sep: String, args: &PyTuple) -> PyResult { +fn concat_ws(sep: String, args: &PyTuple) -> PyResult { let expressions = expression::from_tuple(args)?; let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { + Ok(PyExpr { expr: logical_plan::concat_ws(sep, &args), }) } @@ -107,8 +108,8 @@ macro_rules! define_unary_function { ($NAME: ident) => { #[doc = "This function is not documented yet"] #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { + fn $NAME(value: PyExpr) -> PyExpr { + PyExpr { expr: logical_plan::$NAME(value.expr), } } @@ -116,8 +117,8 @@ macro_rules! define_unary_function { ($NAME: ident, $DOC: expr) => { #[doc = $DOC] #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { + fn $NAME(value: PyExpr) -> PyExpr { + PyExpr { expr: logical_plan::$NAME(value.expr), } } @@ -202,55 +203,62 @@ define_unary_function!(count); pub(crate) fn create_udf( fun: PyObject, - input_types: Vec, - return_type: PyDataType, + input_types: Vec<&PyAny>, + return_type: &PyAny, name: &str, -) -> expression::ScalarUDF { - let input_types: Vec = - input_types.iter().map(|d| d.data_type.clone()).collect(); - let return_type = Arc::new(return_type.data_type); +) -> PyResult { + let input_types: Vec = input_types + .into_iter() + .map(DataType::from_pyarrow) + .collect::>()?; + let return_type = Arc::new(DataType::from_pyarrow(return_type)?); - expression::ScalarUDF { + Ok(PyScalarUDF { function: logical_plan::create_udf( name, input_types, return_type, udf::array_udf(fun), ), - } + }) } /// Creates a new udf. #[pyfunction] fn udf( fun: PyObject, - input_types: Vec, - return_type: PyDataType, + input_types: Vec<&PyAny>, + return_type: &PyAny, py: Python, -) -> PyResult { +) -> PyResult { let name = fun.getattr(py, "__qualname__")?.extract::(py)?; - Ok(create_udf(fun, input_types, return_type, &name)) + create_udf(fun, input_types, return_type, &name) } /// Creates a new udf. #[pyfunction] fn udaf( accumulator: PyObject, - input_type: PyDataType, - return_type: PyDataType, - state_type: Vec, + input_type: &PyAny, + return_type: &PyAny, + state_type: Vec<&PyAny>, py: Python, -) -> PyResult { +) -> PyResult { let name = accumulator .getattr(py, "__qualname__")? .extract::(py)?; - let input_type = input_type.data_type; - let return_type = Arc::new(return_type.data_type); - let state_type = Arc::new(state_type.into_iter().map(|t| t.data_type).collect()); + let input_type = DataType::from_pyarrow(input_type)?; + let return_type = Arc::new(DataType::from_pyarrow(return_type)?); + let state_type = Arc::new( + state_type + .into_iter() + .map(DataType::from_pyarrow) + .collect::>()?, + ); - Ok(expression::AggregateUDF { + Ok(PyAggregateUDF { function: logical_plan::create_udaf( &name, input_type, diff --git a/python/src/lib.rs b/python/src/lib.rs index aecfe9994cd1a..19b7f8a2d1ff9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -22,19 +22,16 @@ mod dataframe; mod errors; mod expression; mod functions; -mod scalar; -mod to_py; -mod to_rust; -mod types; +mod pyarrow; mod udaf; mod udf; /// DataFusion. #[pymodule] -fn datafusion(py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; +fn internals(py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; let functions = PyModule::new(py, "functions")?; functions::init(functions)?; diff --git a/python/src/pyarrow.rs b/python/src/pyarrow.rs new file mode 100644 index 0000000000000..81180aa11cd57 --- /dev/null +++ b/python/src/pyarrow.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::convert::TryFrom; +use std::sync::Arc; + +use libc::uintptr_t; +use pyo3::exceptions::PyNotImplementedError; +use pyo3::prelude::*; +use pyo3::types::PyList; + +use datafusion::arrow::array::{make_array_from_raw, ArrayRef}; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::arrow::ffi; +use datafusion::arrow::ffi::FFI_ArrowSchema; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::scalar::ScalarValue; + +use crate::errors::DataFusionError; + +pub trait PyArrowConvert: Sized { + fn from_pyarrow(value: &PyAny) -> PyResult; + fn to_pyarrow(&self, py: Python) -> PyResult; +} + +impl PyArrowConvert for DataType { + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let dtype = DataType::try_from(&c_schema).map_err(DataFusionError::from)?; + Ok(dtype) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = FFI_ArrowSchema::try_from(self).map_err(DataFusionError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("DataType")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +impl PyArrowConvert for Field { + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let field = Field::try_from(&c_schema).map_err(DataFusionError::from)?; + Ok(field) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = FFI_ArrowSchema::try_from(self).map_err(DataFusionError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Field")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +impl PyArrowConvert for Schema { + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let schema = Schema::try_from(&c_schema).map_err(DataFusionError::from)?; + Ok(schema) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = FFI_ArrowSchema::try_from(self).map_err(DataFusionError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Schema")?; + let schema = + class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(schema.into()) + } +} + +impl PyArrowConvert for ArrayRef { + fn from_pyarrow(value: &PyAny) -> PyResult { + // prepare a pointer to receive the Array struct + let (array_pointer, schema_pointer) = + ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); + + // make the conversion through PyArrow's private API + // this changes the pointer's memory and is thus unsafe. + // In particular, `_export_to_c` can go out of bounds + value.call_method1( + "_export_to_c", + (array_pointer as uintptr_t, schema_pointer as uintptr_t), + )?; + + let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } + .map_err(DataFusionError::from)?; + Ok(array) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let (array_pointer, schema_pointer) = + self.to_raw().map_err(DataFusionError::from)?; + + let module = py.import("pyarrow")?; + let class = module.getattr("Array")?; + let array = class.call_method1( + "_import_from_c", + (array_pointer as uintptr_t, schema_pointer as uintptr_t), + )?; + Ok(array.to_object(py)) + } +} + +impl PyArrowConvert for RecordBatch { + fn from_pyarrow(value: &PyAny) -> PyResult { + // TODO(kszucs): implement the FFI conversions in arrow-rs for RecordBatches + let schema = value.getattr("schema")?; + let schema = Arc::new(Schema::from_pyarrow(schema)?); + + let arrays = value.getattr("columns")?.downcast::()?; + let arrays = arrays + .iter() + .map(ArrayRef::from_pyarrow) + .collect::>()?; + + let batch = + RecordBatch::try_new(schema, arrays).map_err(DataFusionError::from)?; + Ok(batch) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let mut py_arrays = vec![]; + let mut py_names = vec![]; + + let schema = self.schema(); + let fields = schema.fields().iter(); + let columns = self.columns().iter(); + + for (array, field) in columns.zip(fields) { + py_arrays.push(array.to_pyarrow(py)?); + py_names.push(field.name()); + } + + let module = py.import("pyarrow")?; + let class = module.getattr("RecordBatch")?; + let record = class.call_method1("from_arrays", (py_arrays, py_names))?; + + Ok(PyObject::from(record)) + } +} + +impl PyArrowConvert for ScalarValue { + fn from_pyarrow(value: &PyAny) -> PyResult { + let t = value + .getattr("__class__")? + .getattr("__name__")? + .extract::<&str>()?; + + let p = value.call_method0("as_py")?; + + Ok(match t { + "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), + "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), + "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), + "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), + "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), + "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), + "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), + "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), + "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), + "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), + "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), + "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), + "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), + other => { + return Err(DataFusionError::Common(format!( + "Type \"{}\"not yet implemented", + other + )) + .into()) + } + }) + } + + fn to_pyarrow(&self, _py: Python) -> PyResult { + Err(PyNotImplementedError::new_err("Not implemented")) + } +} diff --git a/python/src/scalar.rs b/python/src/scalar.rs deleted file mode 100644 index 0c562a9403616..0000000000000 --- a/python/src/scalar.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; - -use datafusion::scalar::ScalarValue as _Scalar; - -use crate::to_rust::to_rust_scalar; - -/// An expression that can be used on a DataFrame -#[derive(Debug, Clone)] -pub(crate) struct Scalar { - pub(crate) scalar: _Scalar, -} - -impl<'source> FromPyObject<'source> for Scalar { - fn extract(ob: &'source PyAny) -> PyResult { - Ok(Self { - scalar: to_rust_scalar(ob)?, - }) - } -} diff --git a/python/src/to_py.rs b/python/src/to_py.rs deleted file mode 100644 index 6bc0581c8c70a..0000000000000 --- a/python/src/to_py.rs +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::record_batch::RecordBatch; -use libc::uintptr_t; -use pyo3::prelude::*; -use pyo3::types::PyList; -use pyo3::PyErr; -use std::convert::From; - -use crate::errors; - -pub fn to_py_array(array: &ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(errors::DataFusionError::from)?; - - let pa = py.import("pyarrow")?; - - let array = pa.getattr("Array")?.call_method1( - "_import_from_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - Ok(array.to_object(py)) -} - -fn to_py_batch<'a>( - batch: &RecordBatch, - py: Python, - pyarrow: &'a PyModule, -) -> Result { - let mut py_arrays = vec![]; - let mut py_names = vec![]; - - let schema = batch.schema(); - for (array, field) in batch.columns().iter().zip(schema.fields().iter()) { - let array = to_py_array(array, py)?; - - py_arrays.push(array); - py_names.push(field.name()); - } - - let record = pyarrow - .getattr("RecordBatch")? - .call_method1("from_arrays", (py_arrays, py_names))?; - - Ok(PyObject::from(record)) -} - -/// Converts a &[RecordBatch] into a Vec represented in PyArrow -pub fn to_py(batches: &[RecordBatch]) -> PyResult { - Python::with_gil(|py| { - let pyarrow = PyModule::import(py, "pyarrow")?; - let mut py_batches = vec![]; - for batch in batches { - py_batches.push(to_py_batch(batch, py, pyarrow)?); - } - let list = PyList::new(py, py_batches); - Ok(PyObject::from(list)) - }) -} diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs deleted file mode 100644 index 7977fe4ff8ce1..0000000000000 --- a/python/src/to_rust.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::convert::TryFrom; -use std::sync::Arc; - -use datafusion::arrow::{ - array::{make_array_from_raw, ArrayRef}, - datatypes::Field, - datatypes::Schema, - ffi, - record_batch::RecordBatch, -}; -use datafusion::scalar::ScalarValue; -use libc::uintptr_t; -use pyo3::prelude::*; - -use crate::{errors, types::PyDataType}; - -/// converts a pyarrow Array into a Rust Array -pub fn to_rust(ob: &PyAny) -> PyResult { - // prepare a pointer to receive the Array struct - let (array_pointer, schema_pointer) = - ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - ob.call_method1( - "_export_to_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - - let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(errors::DataFusionError::from)?; - Ok(array) -} - -/// converts a pyarrow batch into a RecordBatch -pub fn to_rust_batch(batch: &PyAny) -> PyResult { - let schema = batch.getattr("schema")?; - let names = schema.getattr("names")?.extract::>()?; - - let fields = names - .iter() - .enumerate() - .map(|(i, name)| { - let field = schema.call_method1("field", (i,))?; - let nullable = field.getattr("nullable")?.extract::()?; - let py_data_type = field.getattr("type")?; - let data_type = py_data_type.extract::()?.data_type; - Ok(Field::new(name, data_type, nullable)) - }) - .collect::>()?; - - let schema = Arc::new(Schema::new(fields)); - - let arrays = (0..names.len()) - .map(|i| { - let array = batch.call_method1("column", (i,))?; - to_rust(array) - }) - .collect::>()?; - - let batch = - RecordBatch::try_new(schema, arrays).map_err(errors::DataFusionError::from)?; - Ok(batch) -} - -/// converts a pyarrow Scalar into a Rust Scalar -pub fn to_rust_scalar(ob: &PyAny) -> PyResult { - let t = ob - .getattr("__class__")? - .getattr("__name__")? - .extract::<&str>()?; - - let p = ob.call_method0("as_py")?; - - Ok(match t { - "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), - "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), - "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), - "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), - "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), - "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), - "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), - "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), - "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), - "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), - "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), - "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), - "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), - other => { - return Err(errors::DataFusionError::Common(format!( - "Type \"{}\"not yet implemented", - other - )) - .into()) - } - }) -} - -pub fn to_rust_schema(ob: &PyAny) -> PyResult { - let c_schema = ffi::FFI_ArrowSchema::empty(); - let c_schema_ptr = &c_schema as *const ffi::FFI_ArrowSchema; - ob.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; - let schema = Schema::try_from(&c_schema).map_err(errors::DataFusionError::from)?; - Ok(schema) -} diff --git a/python/src/types.rs b/python/src/types.rs deleted file mode 100644 index bd6ef0d376e63..0000000000000 --- a/python/src/types.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::DataType; -use pyo3::{FromPyObject, PyAny, PyResult}; - -use crate::errors; - -/// utility struct to convert PyObj to native DataType -#[derive(Debug, Clone)] -pub struct PyDataType { - pub data_type: DataType, -} - -impl<'source> FromPyObject<'source> for PyDataType { - fn extract(ob: &'source PyAny) -> PyResult { - let id = ob.getattr("id")?.extract::()?; - let data_type = data_type_id(&id)?; - Ok(PyDataType { data_type }) - } -} - -fn data_type_id(id: &i32) -> Result { - // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd - // this is not ideal as it does not generalize for non-basic types - // Find a way to get a unique name from the pyarrow.DataType - Ok(match id { - 1 => DataType::Boolean, - 2 => DataType::UInt8, - 3 => DataType::Int8, - 4 => DataType::UInt16, - 5 => DataType::Int16, - 6 => DataType::UInt32, - 7 => DataType::Int32, - 8 => DataType::UInt64, - 9 => DataType::Int64, - 10 => DataType::Float16, - 11 => DataType::Float32, - 12 => DataType::Float64, - 13 => DataType::Utf8, - 14 => DataType::Binary, - 34 => DataType::LargeUtf8, - 35 => DataType::LargeBinary, - other => { - return Err(errors::DataFusionError::Common(format!( - "The type {} is not valid", - other - ))) - } - }) -} diff --git a/python/src/udaf.rs b/python/src/udaf.rs index 83e8be05db603..ada75a035a28d 100644 --- a/python/src/udaf.rs +++ b/python/src/udaf.rs @@ -17,7 +17,10 @@ use std::sync::Arc; -use pyo3::{prelude::*, types::PyTuple}; +use pyo3::{ + prelude::*, + types::{PyList, PyTuple}, +}; use datafusion::arrow::array::ArrayRef; @@ -27,9 +30,7 @@ use datafusion::{ scalar::ScalarValue, }; -use crate::scalar::Scalar; -use crate::to_py::to_py_array; -use crate::to_rust::to_rust_scalar; +use crate::pyarrow::PyArrowConvert; #[derive(Debug)] struct PyAccumulator { @@ -43,18 +44,17 @@ impl PyAccumulator { } impl Accumulator for PyAccumulator { - fn state(&self) -> Result> { + fn state(&self) -> Result> { Python::with_gil(|py| { - let state = self - .accum + self.accum .as_ref(py) - .call_method0("to_scalars") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))? - .extract::>() - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - Ok(state.into_iter().map(|v| v.scalar).collect::>()) + .call_method0("to_scalars")? + .downcast::()? + .iter() + .map(ScalarValue::from_pyarrow) + .collect::>>() }) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) } fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { @@ -67,17 +67,12 @@ impl Accumulator for PyAccumulator { todo!() } - fn evaluate(&self) -> Result { + fn evaluate(&self) -> Result { Python::with_gil(|py| { - let value = self - .accum - .as_ref(py) - .call_method0("evaluate") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - to_rust_scalar(value) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) + let value = self.accum.as_ref(py).call_method0("evaluate")?; + ScalarValue::from_pyarrow(value) }) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -90,7 +85,7 @@ impl Accumulator for PyAccumulator { .iter() .map(|arg| { // remove unwrap - to_py_array(arg, py).unwrap() + arg.to_pyarrow(py).unwrap() }) .collect::>(); let py_args = PyTuple::new(py, py_args); @@ -111,7 +106,8 @@ impl Accumulator for PyAccumulator { // 2. merge let state = &states[0]; - let state = to_py_array(state, py) + let state = state + .to_pyarrow(py) .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; // 2. diff --git a/python/src/udf.rs b/python/src/udf.rs index 49a18d9932412..849104d946562 100644 --- a/python/src/udf.rs +++ b/python/src/udf.rs @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. -use pyo3::{prelude::*, types::PyTuple}; - -use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; - +use datafusion::arrow::array::ArrayRef; use datafusion::error::DataFusionError; use datafusion::physical_plan::functions::ScalarFunctionImplementation; +use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; +use pyo3::{prelude::*, types::PyTuple}; -use crate::to_py::to_py_array; -use crate::to_rust::to_rust; +use crate::pyarrow::PyArrowConvert; /// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays /// This is more efficient as it performs a zero-copy of the contents. @@ -40,7 +38,7 @@ pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { .iter() .map(|arg| { // remove unwrap - to_py_array(arg, py).unwrap() + arg.to_pyarrow(py).unwrap() }) .collect::>(); let py_args = PyTuple::new(py, py_args); @@ -52,7 +50,7 @@ pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), }?; - let array = to_rust(value).unwrap(); + let array = ArrayRef::from_pyarrow(value).unwrap(); Ok(array) }) },