From 39b1c3aedee8ac6912a1ff69785124a8c2e1e387 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 5 May 2021 07:12:23 -0600 Subject: [PATCH] Revert "Revert "Add datafusion-python (#69)" (#257)" This reverts commit d0af907652aa8773d1de21dfd2f15bbcf6f50ce3. --- .github/workflows/python_build.yml | 89 +++++++++ .github/workflows/python_test.yaml | 58 ++++++ Cargo.toml | 4 +- dev/release/rat_exclude_files.txt | 1 + python/.cargo/config | 22 +++ python/.dockerignore | 19 ++ python/.gitignore | 20 ++ python/Cargo.toml | 57 ++++++ python/README.md | 146 ++++++++++++++ python/pyproject.toml | 20 ++ python/rust-toolchain | 1 + python/src/context.rs | 115 +++++++++++ python/src/dataframe.rs | 161 ++++++++++++++++ python/src/errors.rs | 61 ++++++ python/src/expression.rs | 162 ++++++++++++++++ python/src/functions.rs | 165 ++++++++++++++++ python/src/lib.rs | 44 +++++ python/src/scalar.rs | 36 ++++ python/src/to_py.rs | 77 ++++++++ python/src/to_rust.rs | 111 +++++++++++ python/src/types.rs | 76 ++++++++ python/src/udaf.rs | 147 +++++++++++++++ python/src/udf.rs | 62 ++++++ python/tests/__init__.py | 16 ++ python/tests/generic.py | 75 ++++++++ python/tests/test_df.py | 115 +++++++++++ python/tests/test_sql.py | 294 +++++++++++++++++++++++++++++ python/tests/test_udaf.py | 91 +++++++++ 28 files changed, 2244 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/python_build.yml create mode 100644 .github/workflows/python_test.yaml create mode 100644 python/.cargo/config create mode 100644 python/.dockerignore create mode 100644 python/.gitignore create mode 100644 python/Cargo.toml create mode 100644 python/README.md create mode 100644 python/pyproject.toml create mode 100644 python/rust-toolchain create mode 100644 python/src/context.rs create mode 100644 python/src/dataframe.rs create mode 100644 python/src/errors.rs create mode 100644 python/src/expression.rs create mode 100644 python/src/functions.rs create mode 100644 python/src/lib.rs create mode 100644 python/src/scalar.rs create mode 100644 python/src/to_py.rs create mode 100644 python/src/to_rust.rs create mode 100644 python/src/types.rs create mode 100644 python/src/udaf.rs create mode 100644 python/src/udf.rs create mode 100644 python/tests/__init__.py create mode 100644 python/tests/generic.py create mode 100644 python/tests/test_df.py create mode 100644 python/tests/test_sql.py create mode 100644 python/tests/test_udaf.py diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml new file mode 100644 index 0000000000000..c86bb81581a71 --- /dev/null +++ b/.github/workflows/python_build.yml @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Build +on: + push: + tags: + - v* + +jobs: + build-python-mac-win: + name: Mac/Win + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8] + os: [macos-latest, windows-latest] + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-2021-01-06 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install maturin + + - name: Build Python package + run: cd python && maturin build --release --no-sdist --strip --interpreter python${{matrix.python_version}} + + - name: List wheels + if: matrix.os == 'windows-latest' + run: dir python/target\wheels\ + + - name: List wheels + if: matrix.os != 'windows-latest' + run: find ./python/target/wheels/ + + - name: Archive wheels + uses: actions/upload-artifact@v2 + with: + name: dist + path: python/target/wheels/* + + build-manylinux: + name: Manylinux + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Build wheels + run: docker run --rm -v $(pwd):/io konstin2/maturin build --release --manylinux + - name: Archive wheels + uses: actions/upload-artifact@v2 + with: + name: dist + path: python/target/wheels/* + + release: + name: Publish in PyPI + needs: [build-manylinux, build-python-mac-win] + runs-on: ubuntu-latest + steps: + - uses: actions/download-artifact@v2 + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@master + with: + user: __token__ + password: ${{ secrets.pypi_password }} diff --git a/.github/workflows/python_test.yaml b/.github/workflows/python_test.yaml new file mode 100644 index 0000000000000..3b2111b59d49d --- /dev/null +++ b/.github/workflows/python_test.yaml @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Python test +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Setup Rust toolchain + run: | + rustup toolchain install nightly-2021-01-06 + rustup default nightly-2021-01-06 + rustup component add rustfmt + - name: Cache Cargo + uses: actions/cache@v2 + with: + path: /home/runner/.cargo + key: cargo-maturin-cache- + - name: Cache Rust dependencies + uses: actions/cache@v2 + with: + path: /home/runner/target + key: target-maturin-cache- + - uses: actions/setup-python@v2 + with: + python-version: '3.7' + - name: Install Python dependencies + run: python -m pip install --upgrade pip setuptools wheel + - name: Run tests + run: | + cd python/ + export CARGO_HOME="/home/runner/.cargo" + export CARGO_TARGET_DIR="/home/runner/target" + + python -m venv venv + source venv/bin/activate + + pip install maturin==0.10.4 toml==0.10.1 pyarrow==4.0.0 + maturin develop + + python -m unittest discover tests diff --git a/Cargo.toml b/Cargo.toml index fa36a0c0fed7c..9795cb68b4456 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,4 +25,6 @@ members = [ "ballista/rust/core", "ballista/rust/executor", "ballista/rust/scheduler", -] \ No newline at end of file +] + +exclude = ["python"] diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index b94c0ea1d61a6..6126699bbc1fa 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -104,3 +104,4 @@ rust-toolchain benchmarks/queries/q*.sql ballista/rust/scheduler/testdata/* ballista/ui/scheduler/yarn.lock +python/rust-toolchain diff --git a/python/.cargo/config b/python/.cargo/config new file mode 100644 index 0000000000000..0b24f30cf908a --- /dev/null +++ b/python/.cargo/config @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[target.x86_64-apple-darwin] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] diff --git a/python/.dockerignore b/python/.dockerignore new file mode 100644 index 0000000000000..08c131c2e7d60 --- /dev/null +++ b/python/.dockerignore @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +target +venv diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000000000..48fe4dbe52dde --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +/target +Cargo.lock +venv diff --git a/python/Cargo.toml b/python/Cargo.toml new file mode 100644 index 0000000000000..070720554f0ed --- /dev/null +++ b/python/Cargo.toml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion" +version = "0.2.1" +homepage = "https://github.com/apache/arrow" +repository = "https://github.com/apache/arrow" +authors = ["Apache Arrow "] +description = "Build and run queries against data" +readme = "README.md" +license = "Apache-2.0" +edition = "2018" + +[dependencies] +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } +rand = "0.7" +pyo3 = { version = "0.12.1", features = ["extension-module"] } +datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "2423ff0d" } + +[lib] +name = "datafusion" +crate-type = ["cdylib"] + +[package.metadata.maturin] +requires-dist = ["pyarrow>=1"] + +classifier = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "License :: OSI Approved", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python", + "Programming Language :: Rust", +] diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000000000..1859fca9811c0 --- /dev/null +++ b/python/README.md @@ -0,0 +1,146 @@ + + +## DataFusion in Python + +This is a Python library that binds to [Apache Arrow](https://arrow.apache.org/) in-memory query engine [DataFusion](https://github.com/apache/arrow/tree/master/rust/datafusion). + +Like pyspark, it allows you to build a plan through SQL or a DataFrame API against in-memory data, parquet or CSV files, run it in a multi-threaded environment, and obtain the result back in Python. + +It also allows you to use UDFs and UDAFs for complex operations. + +The major advantage of this library over other execution engines is that this library achieves zero-copy between Python and its execution engine: there is no cost in using UDFs, UDAFs, and collecting the results to Python apart from having to lock the GIL when running those operations. + +Its query engine, DataFusion, is written in [Rust](https://www.rust-lang.org/), which makes strong assumptions about thread safety and lack of memory leaks. + +Technically, zero-copy is achieved via the [c data interface](https://arrow.apache.org/docs/format/CDataInterface.html). + +## How to use it + +Simple usage: + +```python +import datafusion +import pyarrow + +# an alias +f = datafusion.functions + +# create a context +ctx = datafusion.ExecutionContext() + +# create a RecordBatch and a new DataFrame from it +batch = pyarrow.RecordBatch.from_arrays( + [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])], + names=["a", "b"], +) +df = ctx.create_dataframe([[batch]]) + +# create a new statement +df = df.select( + f.col("a") + f.col("b"), + f.col("a") - f.col("b"), +) + +# execute and collect the first (and only) batch +result = df.collect()[0] + +assert result.column(0) == pyarrow.array([5, 7, 9]) +assert result.column(1) == pyarrow.array([-3, -3, -3]) +``` + +### UDFs + +```python +def is_null(array: pyarrow.Array) -> pyarrow.Array: + return array.is_null() + +udf = f.udf(is_null, [pyarrow.int64()], pyarrow.bool_()) + +df = df.select(udf(f.col("a"))) +``` + +### UDAF + +```python +import pyarrow +import pyarrow.compute + + +class Accumulator: + """ + Interface of a user-defined accumulation. + """ + def __init__(self): + self._sum = pyarrow.scalar(0.0) + + def to_scalars(self) -> [pyarrow.Scalar]: + return [self._sum] + + def update(self, values: pyarrow.Array) -> None: + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py()) + + def merge(self, states: pyarrow.Array) -> None: + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py()) + + def evaluate(self) -> pyarrow.Scalar: + return self._sum + + +df = ... + +udaf = f.udaf(Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()]) + +df = df.aggregate( + [], + [udaf(f.col("a"))] +) +``` + +## How to install + +```bash +pip install datafusion +``` + +## How to develop + +This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin). + +Bootstrap: + +```bash +# fetch this repo +git clone git@github.com:apache/arrow-datafusion.git + +cd arrow-datafusion/python + +# prepare development environment (used to build wheel / install in development) +python3 -m venv venv +pip install maturin==0.10.4 toml==0.10.1 pyarrow==1.0.0 +``` + +Whenever rust code changes (your changes or via git pull): + +```bash +venv/bin/maturin develop +venv/bin/python -m unittest discover tests +``` diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 0000000000000..27480690e06cc --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[build-system] +requires = ["maturin"] +build-backend = "maturin" diff --git a/python/rust-toolchain b/python/rust-toolchain new file mode 100644 index 0000000000000..9d0cf79d367d6 --- /dev/null +++ b/python/rust-toolchain @@ -0,0 +1 @@ +nightly-2021-01-06 diff --git a/python/src/context.rs b/python/src/context.rs new file mode 100644 index 0000000000000..14ef0f7321f15 --- /dev/null +++ b/python/src/context.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashSet, sync::Arc}; + +use rand::distributions::Alphanumeric; +use rand::Rng; + +use pyo3::prelude::*; + +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::MemTable; +use datafusion::execution::context::ExecutionContext as _ExecutionContext; + +use crate::dataframe; +use crate::errors; +use crate::functions; +use crate::to_rust; +use crate::types::PyDataType; + +/// `ExecutionContext` is able to plan and execute DataFusion plans. +/// It has a powerful optimizer, a physical planner for local execution, and a +/// multi-threaded execution engine to perform the execution. +#[pyclass(unsendable)] +pub(crate) struct ExecutionContext { + ctx: _ExecutionContext, +} + +#[pymethods] +impl ExecutionContext { + #[new] + fn new() -> Self { + ExecutionContext { + ctx: _ExecutionContext::new(), + } + } + + /// Returns a DataFrame whose plan corresponds to the SQL statement. + fn sql(&mut self, query: &str) -> PyResult { + let df = self + .ctx + .sql(query) + .map_err(|e| -> errors::DataFusionError { e.into() })?; + Ok(dataframe::DataFrame::new( + self.ctx.state.clone(), + df.to_logical_plan(), + )) + } + + fn create_dataframe( + &mut self, + partitions: Vec>, + py: Python, + ) -> PyResult { + let partitions: Vec> = partitions + .iter() + .map(|batches| { + batches + .iter() + .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) + .collect() + }) + .collect::>()?; + + let table = + errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; + + // generate a random (unique) name for this table + let name = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(10) + .collect::(); + + errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; + Ok(dataframe::DataFrame::new( + self.ctx.state.clone(), + errors::wrap(self.ctx.table(&*name))?.to_logical_plan(), + )) + } + + fn register_parquet(&mut self, name: &str, path: &str) -> PyResult<()> { + errors::wrap(self.ctx.register_parquet(name, path))?; + Ok(()) + } + + fn register_udf( + &mut self, + name: &str, + func: PyObject, + args_types: Vec, + return_type: PyDataType, + ) { + let function = functions::create_udf(func, args_types, return_type, name); + + self.ctx.register_udf(function.function); + } + + fn tables(&self) -> HashSet { + self.ctx.tables().unwrap() + } +} diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs new file mode 100644 index 0000000000000..f90a7cf2f0dcf --- /dev/null +++ b/python/src/dataframe.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, Mutex}; + +use logical_plan::LogicalPlan; +use pyo3::{prelude::*, types::PyTuple}; +use tokio::runtime::Runtime; + +use datafusion::execution::context::ExecutionContext as _ExecutionContext; +use datafusion::logical_plan::{JoinType, LogicalPlanBuilder}; +use datafusion::physical_plan::collect; +use datafusion::{execution::context::ExecutionContextState, logical_plan}; + +use crate::{errors, to_py}; +use crate::{errors::DataFusionError, expression}; + +/// A DataFrame is a representation of a logical plan and an API to compose statements. +/// Use it to build a plan and `.collect()` to execute the plan and collect the result. +/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. +#[pyclass] +pub(crate) struct DataFrame { + ctx_state: Arc>, + plan: LogicalPlan, +} + +impl DataFrame { + /// creates a new DataFrame + pub fn new(ctx_state: Arc>, plan: LogicalPlan) -> Self { + Self { ctx_state, plan } + } +} + +#[pymethods] +impl DataFrame { + /// Select `expressions` from the existing DataFrame. + #[args(args = "*")] + fn select(&self, args: &PyTuple) -> PyResult { + let expressions = expression::from_tuple(args)?; + let builder = LogicalPlanBuilder::from(&self.plan); + let builder = + errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?; + let plan = errors::wrap(builder.build())?; + + Ok(DataFrame { + ctx_state: self.ctx_state.clone(), + plan, + }) + } + + /// Filter according to the `predicate` expression + fn filter(&self, predicate: expression::Expression) -> PyResult { + let builder = LogicalPlanBuilder::from(&self.plan); + let builder = errors::wrap(builder.filter(predicate.expr))?; + let plan = errors::wrap(builder.build())?; + + Ok(DataFrame { + ctx_state: self.ctx_state.clone(), + plan, + }) + } + + /// Aggregates using expressions + fn aggregate( + &self, + group_by: Vec, + aggs: Vec, + ) -> PyResult { + let builder = LogicalPlanBuilder::from(&self.plan); + let builder = errors::wrap(builder.aggregate( + group_by.into_iter().map(|e| e.expr), + aggs.into_iter().map(|e| e.expr), + ))?; + let plan = errors::wrap(builder.build())?; + + Ok(DataFrame { + ctx_state: self.ctx_state.clone(), + plan, + }) + } + + /// Limits the plan to return at most `count` rows + fn limit(&self, count: usize) -> PyResult { + let builder = LogicalPlanBuilder::from(&self.plan); + let builder = errors::wrap(builder.limit(count))?; + let plan = errors::wrap(builder.build())?; + + Ok(DataFrame { + ctx_state: self.ctx_state.clone(), + plan, + }) + } + + /// Executes the plan, returning a list of `RecordBatch`es. + /// Unless some order is specified in the plan, there is no guarantee of the order of the result + fn collect(&self, py: Python) -> PyResult { + let ctx = _ExecutionContext::from(self.ctx_state.clone()); + let plan = ctx + .optimize(&self.plan) + .map_err(|e| -> errors::DataFusionError { e.into() })?; + let plan = ctx + .create_physical_plan(&plan) + .map_err(|e| -> errors::DataFusionError { e.into() })?; + + let rt = Runtime::new().unwrap(); + let batches = py.allow_threads(|| { + rt.block_on(async { + collect(plan) + .await + .map_err(|e| -> errors::DataFusionError { e.into() }) + }) + })?; + to_py::to_py(&batches) + } + + /// Returns the join of two DataFrames `on`. + fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult { + let builder = LogicalPlanBuilder::from(&self.plan); + + let join_type = match how { + "inner" => JoinType::Inner, + "left" => JoinType::Left, + "right" => JoinType::Right, + how => { + return Err(DataFusionError::Common(format!( + "The join type {} does not exist or is not implemented", + how + )) + .into()) + } + }; + + let builder = errors::wrap(builder.join( + &right.plan, + join_type, + on.as_slice(), + on.as_slice(), + ))?; + + let plan = errors::wrap(builder.build())?; + + Ok(DataFrame { + ctx_state: self.ctx_state.clone(), + plan, + }) + } +} diff --git a/python/src/errors.rs b/python/src/errors.rs new file mode 100644 index 0000000000000..fbe98037a030f --- /dev/null +++ b/python/src/errors.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use core::fmt; + +use datafusion::arrow::error::ArrowError; +use datafusion::error::DataFusionError as InnerDataFusionError; +use pyo3::{exceptions, PyErr}; + +#[derive(Debug)] +pub enum DataFusionError { + ExecutionError(InnerDataFusionError), + ArrowError(ArrowError), + Common(String), +} + +impl fmt::Display for DataFusionError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DataFusionError::ExecutionError(e) => write!(f, "DataFusion error: {:?}", e), + DataFusionError::ArrowError(e) => write!(f, "Arrow error: {:?}", e), + DataFusionError::Common(e) => write!(f, "{}", e), + } + } +} + +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + exceptions::PyException::new_err(err.to_string()) + } +} + +impl From for DataFusionError { + fn from(err: InnerDataFusionError) -> DataFusionError { + DataFusionError::ExecutionError(err) + } +} + +impl From for DataFusionError { + fn from(err: ArrowError) -> DataFusionError { + DataFusionError::ArrowError(err) + } +} + +pub(crate) fn wrap(a: Result) -> Result { + Ok(a?) +} diff --git a/python/src/expression.rs b/python/src/expression.rs new file mode 100644 index 0000000000000..78ca6d7e598ec --- /dev/null +++ b/python/src/expression.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::{ + basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol, +}; + +use datafusion::logical_plan::Expr as _Expr; +use datafusion::physical_plan::udaf::AggregateUDF as _AggregateUDF; +use datafusion::physical_plan::udf::ScalarUDF as _ScalarUDF; + +/// An expression that can be used on a DataFrame +#[pyclass] +#[derive(Debug, Clone)] +pub(crate) struct Expression { + pub(crate) expr: _Expr, +} + +/// converts a tuple of expressions into a vector of Expressions +pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { + value + .iter() + .map(|e| e.extract::()) + .collect::>() +} + +#[pyproto] +impl PyNumberProtocol for Expression { + fn __add__(lhs: Expression, rhs: Expression) -> PyResult { + Ok(Expression { + expr: lhs.expr + rhs.expr, + }) + } + + fn __sub__(lhs: Expression, rhs: Expression) -> PyResult { + Ok(Expression { + expr: lhs.expr - rhs.expr, + }) + } + + fn __truediv__(lhs: Expression, rhs: Expression) -> PyResult { + Ok(Expression { + expr: lhs.expr / rhs.expr, + }) + } + + fn __mul__(lhs: Expression, rhs: Expression) -> PyResult { + Ok(Expression { + expr: lhs.expr * rhs.expr, + }) + } + + fn __and__(lhs: Expression, rhs: Expression) -> PyResult { + Ok(Expression { + expr: lhs.expr.and(rhs.expr), + }) + } + + fn __or__(lhs: Expression, rhs: Expression) -> PyResult { + Ok(Expression { + expr: lhs.expr.or(rhs.expr), + }) + } + + fn __invert__(&self) -> PyResult { + Ok(Expression { + expr: self.expr.clone().not(), + }) + } +} + +#[pyproto] +impl PyObjectProtocol for Expression { + fn __richcmp__(&self, other: Expression, op: CompareOp) -> Expression { + match op { + CompareOp::Lt => Expression { + expr: self.expr.clone().lt(other.expr), + }, + CompareOp::Le => Expression { + expr: self.expr.clone().lt_eq(other.expr), + }, + CompareOp::Eq => Expression { + expr: self.expr.clone().eq(other.expr), + }, + CompareOp::Ne => Expression { + expr: self.expr.clone().not_eq(other.expr), + }, + CompareOp::Gt => Expression { + expr: self.expr.clone().gt(other.expr), + }, + CompareOp::Ge => Expression { + expr: self.expr.clone().gt_eq(other.expr), + }, + } + } +} + +#[pymethods] +impl Expression { + /// assign a name to the expression + pub fn alias(&self, name: &str) -> PyResult { + Ok(Expression { + expr: self.expr.clone().alias(name), + }) + } +} + +/// Represents a ScalarUDF +#[pyclass] +#[derive(Debug, Clone)] +pub struct ScalarUDF { + pub(crate) function: _ScalarUDF, +} + +#[pymethods] +impl ScalarUDF { + /// creates a new expression with the call of the udf + #[call] + #[args(args = "*")] + fn __call__(&self, args: &PyTuple) -> PyResult { + let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); + + Ok(Expression { + expr: self.function.call(args), + }) + } +} + +/// Represents a AggregateUDF +#[pyclass] +#[derive(Debug, Clone)] +pub struct AggregateUDF { + pub(crate) function: _AggregateUDF, +} + +#[pymethods] +impl AggregateUDF { + /// creates a new expression with the call of the udf + #[call] + #[args(args = "*")] + fn __call__(&self, args: &PyTuple) -> PyResult { + let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); + + Ok(Expression { + expr: self.function.call(args), + }) + } +} diff --git a/python/src/functions.rs b/python/src/functions.rs new file mode 100644 index 0000000000000..68000cb1ecbf8 --- /dev/null +++ b/python/src/functions.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::arrow::datatypes::DataType; +use pyo3::{prelude::*, wrap_pyfunction}; + +use datafusion::logical_plan; + +use crate::udaf; +use crate::udf; +use crate::{expression, types::PyDataType}; + +/// Expression representing a column on the existing plan. +#[pyfunction] +#[text_signature = "(name)"] +fn col(name: &str) -> expression::Expression { + expression::Expression { + expr: logical_plan::col(name), + } +} + +/// Expression representing a constant value +#[pyfunction] +#[text_signature = "(value)"] +fn lit(value: i32) -> expression::Expression { + expression::Expression { + expr: logical_plan::lit(value), + } +} + +#[pyfunction] +fn sum(value: expression::Expression) -> expression::Expression { + expression::Expression { + expr: logical_plan::sum(value.expr), + } +} + +#[pyfunction] +fn avg(value: expression::Expression) -> expression::Expression { + expression::Expression { + expr: logical_plan::avg(value.expr), + } +} + +#[pyfunction] +fn min(value: expression::Expression) -> expression::Expression { + expression::Expression { + expr: logical_plan::min(value.expr), + } +} + +#[pyfunction] +fn max(value: expression::Expression) -> expression::Expression { + expression::Expression { + expr: logical_plan::max(value.expr), + } +} + +#[pyfunction] +fn count(value: expression::Expression) -> expression::Expression { + expression::Expression { + expr: logical_plan::count(value.expr), + } +} + +/* +#[pyfunction] +fn concat(value: Vec) -> expression::Expression { + expression::Expression { + expr: logical_plan::concat(value.into_iter().map(|e| e.expr)), + } +} + */ + +pub(crate) fn create_udf( + fun: PyObject, + input_types: Vec, + return_type: PyDataType, + name: &str, +) -> expression::ScalarUDF { + let input_types: Vec = + input_types.iter().map(|d| d.data_type.clone()).collect(); + let return_type = Arc::new(return_type.data_type); + + expression::ScalarUDF { + function: logical_plan::create_udf( + name, + input_types, + return_type, + udf::array_udf(fun), + ), + } +} + +/// Creates a new udf. +#[pyfunction] +fn udf( + fun: PyObject, + input_types: Vec, + return_type: PyDataType, + py: Python, +) -> PyResult { + let name = fun.getattr(py, "__qualname__")?.extract::(py)?; + + Ok(create_udf(fun, input_types, return_type, &name)) +} + +/// Creates a new udf. +#[pyfunction] +fn udaf( + accumulator: PyObject, + input_type: PyDataType, + return_type: PyDataType, + state_type: Vec, + py: Python, +) -> PyResult { + let name = accumulator + .getattr(py, "__qualname__")? + .extract::(py)?; + + let input_type = input_type.data_type; + let return_type = Arc::new(return_type.data_type); + let state_type = Arc::new(state_type.into_iter().map(|t| t.data_type).collect()); + + Ok(expression::AggregateUDF { + function: logical_plan::create_udaf( + &name, + input_type, + return_type, + udaf::array_udaf(accumulator), + state_type, + ), + }) +} + +pub fn init(module: &PyModule) -> PyResult<()> { + module.add_function(wrap_pyfunction!(col, module)?)?; + module.add_function(wrap_pyfunction!(lit, module)?)?; + // see https://github.com/apache/arrow-datafusion/issues/226 + //module.add_function(wrap_pyfunction!(concat, module)?)?; + module.add_function(wrap_pyfunction!(udf, module)?)?; + module.add_function(wrap_pyfunction!(sum, module)?)?; + module.add_function(wrap_pyfunction!(count, module)?)?; + module.add_function(wrap_pyfunction!(min, module)?)?; + module.add_function(wrap_pyfunction!(max, module)?)?; + module.add_function(wrap_pyfunction!(avg, module)?)?; + module.add_function(wrap_pyfunction!(udaf, module)?)?; + Ok(()) +} diff --git a/python/src/lib.rs b/python/src/lib.rs new file mode 100644 index 0000000000000..aecfe9994cd1a --- /dev/null +++ b/python/src/lib.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::prelude::*; + +mod context; +mod dataframe; +mod errors; +mod expression; +mod functions; +mod scalar; +mod to_py; +mod to_rust; +mod types; +mod udaf; +mod udf; + +/// DataFusion. +#[pymodule] +fn datafusion(py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + let functions = PyModule::new(py, "functions")?; + functions::init(functions)?; + m.add_submodule(functions)?; + + Ok(()) +} diff --git a/python/src/scalar.rs b/python/src/scalar.rs new file mode 100644 index 0000000000000..0c562a9403616 --- /dev/null +++ b/python/src/scalar.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::prelude::*; + +use datafusion::scalar::ScalarValue as _Scalar; + +use crate::to_rust::to_rust_scalar; + +/// An expression that can be used on a DataFrame +#[derive(Debug, Clone)] +pub(crate) struct Scalar { + pub(crate) scalar: _Scalar, +} + +impl<'source> FromPyObject<'source> for Scalar { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(Self { + scalar: to_rust_scalar(ob)?, + }) + } +} diff --git a/python/src/to_py.rs b/python/src/to_py.rs new file mode 100644 index 0000000000000..deeb9719891a3 --- /dev/null +++ b/python/src/to_py.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::prelude::*; +use pyo3::{libc::uintptr_t, PyErr}; + +use std::convert::From; + +use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::record_batch::RecordBatch; + +use crate::errors; + +pub fn to_py_array(array: &ArrayRef, py: Python) -> PyResult { + let (array_pointer, schema_pointer) = + array.to_raw().map_err(errors::DataFusionError::from)?; + + let pa = py.import("pyarrow")?; + + let array = pa.getattr("Array")?.call_method1( + "_import_from_c", + (array_pointer as uintptr_t, schema_pointer as uintptr_t), + )?; + Ok(array.to_object(py)) +} + +fn to_py_batch<'a>( + batch: &RecordBatch, + py: Python, + pyarrow: &'a PyModule, +) -> Result { + let mut py_arrays = vec![]; + let mut py_names = vec![]; + + let schema = batch.schema(); + for (array, field) in batch.columns().iter().zip(schema.fields().iter()) { + let array = to_py_array(array, py)?; + + py_arrays.push(array); + py_names.push(field.name()); + } + + let record = pyarrow + .getattr("RecordBatch")? + .call_method1("from_arrays", (py_arrays, py_names))?; + + Ok(PyObject::from(record)) +} + +/// Converts a &[RecordBatch] into a Vec represented in PyArrow +pub fn to_py(batches: &[RecordBatch]) -> PyResult { + let gil = pyo3::Python::acquire_gil(); + let py = gil.python(); + let pyarrow = PyModule::import(py, "pyarrow")?; + let builtins = PyModule::import(py, "builtins")?; + + let mut py_batches = vec![]; + for batch in batches { + py_batches.push(to_py_batch(batch, py, pyarrow)?); + } + let result = builtins.call1("list", (py_batches,))?; + Ok(PyObject::from(result)) +} diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs new file mode 100644 index 0000000000000..d8f2307a49823 --- /dev/null +++ b/python/src/to_rust.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::arrow::{ + array::{make_array_from_raw, ArrayRef}, + datatypes::Field, + datatypes::Schema, + ffi, + record_batch::RecordBatch, +}; +use datafusion::scalar::ScalarValue; +use pyo3::{libc::uintptr_t, prelude::*}; + +use crate::{errors, types::PyDataType}; + +/// converts a pyarrow Array into a Rust Array +pub fn to_rust(ob: &PyAny) -> PyResult { + // prepare a pointer to receive the Array struct + let (array_pointer, schema_pointer) = + ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); + + // make the conversion through PyArrow's private API + // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds + ob.call_method1( + "_export_to_c", + (array_pointer as uintptr_t, schema_pointer as uintptr_t), + )?; + + let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } + .map_err(errors::DataFusionError::from)?; + Ok(array) +} + +pub fn to_rust_batch(batch: &PyAny) -> PyResult { + let schema = batch.getattr("schema")?; + let names = schema.getattr("names")?.extract::>()?; + + let fields = names + .iter() + .enumerate() + .map(|(i, name)| { + let field = schema.call_method1("field", (i,))?; + let nullable = field.getattr("nullable")?.extract::()?; + let py_data_type = field.getattr("type")?; + let data_type = py_data_type.extract::()?.data_type; + Ok(Field::new(name, data_type, nullable)) + }) + .collect::>()?; + + let schema = Arc::new(Schema::new(fields)); + + let arrays = (0..names.len()) + .map(|i| { + let array = batch.call_method1("column", (i,))?; + to_rust(array) + }) + .collect::>()?; + + let batch = + RecordBatch::try_new(schema, arrays).map_err(errors::DataFusionError::from)?; + Ok(batch) +} + +/// converts a pyarrow Scalar into a Rust Scalar +pub fn to_rust_scalar(ob: &PyAny) -> PyResult { + let t = ob + .getattr("__class__")? + .getattr("__name__")? + .extract::<&str>()?; + + let p = ob.call_method0("as_py")?; + + Ok(match t { + "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), + "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), + "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), + "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), + "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), + "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), + "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), + "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), + "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), + "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), + "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), + "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), + "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), + other => { + return Err(errors::DataFusionError::Common(format!( + "Type \"{}\"not yet implemented", + other + )) + .into()) + } + }) +} diff --git a/python/src/types.rs b/python/src/types.rs new file mode 100644 index 0000000000000..ffa822e073a89 --- /dev/null +++ b/python/src/types.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::DataType; +use pyo3::{FromPyObject, PyAny, PyResult}; + +use crate::errors; + +/// utility struct to convert PyObj to native DataType +#[derive(Debug, Clone)] +pub struct PyDataType { + pub data_type: DataType, +} + +impl<'source> FromPyObject<'source> for PyDataType { + fn extract(ob: &'source PyAny) -> PyResult { + let id = ob.getattr("id")?.extract::()?; + let data_type = data_type_id(&id)?; + Ok(PyDataType { data_type }) + } +} + +fn data_type_id(id: &i32) -> Result { + // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd + // this is not ideal as it does not generalize for non-basic types + // Find a way to get a unique name from the pyarrow.DataType + Ok(match id { + 1 => DataType::Boolean, + 2 => DataType::UInt8, + 3 => DataType::Int8, + 4 => DataType::UInt16, + 5 => DataType::Int16, + 6 => DataType::UInt32, + 7 => DataType::Int32, + 8 => DataType::UInt64, + 9 => DataType::Int64, + + 10 => DataType::Float16, + 11 => DataType::Float32, + 12 => DataType::Float64, + + //13 => DataType::Decimal, + + // 14 => DataType::Date32(), + // 15 => DataType::Date64(), + // 16 => DataType::Timestamp(), + // 17 => DataType::Time32(), + // 18 => DataType::Time64(), + // 19 => DataType::Duration() + 20 => DataType::Binary, + 21 => DataType::Utf8, + 22 => DataType::LargeBinary, + 23 => DataType::LargeUtf8, + + other => { + return Err(errors::DataFusionError::Common(format!( + "The type {} is not valid", + other + ))) + } + }) +} diff --git a/python/src/udaf.rs b/python/src/udaf.rs new file mode 100644 index 0000000000000..3ce223df9a491 --- /dev/null +++ b/python/src/udaf.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use pyo3::{prelude::*, types::PyTuple}; + +use datafusion::arrow::array::ArrayRef; + +use datafusion::error::Result; +use datafusion::{ + error::DataFusionError as InnerDataFusionError, physical_plan::Accumulator, + scalar::ScalarValue, +}; + +use crate::scalar::Scalar; +use crate::to_py::to_py_array; +use crate::to_rust::to_rust_scalar; + +#[derive(Debug)] +struct PyAccumulator { + accum: PyObject, +} + +impl PyAccumulator { + fn new(accum: PyObject) -> Self { + Self { accum } + } +} + +impl Accumulator for PyAccumulator { + fn state(&self) -> Result> { + let gil = pyo3::Python::acquire_gil(); + let py = gil.python(); + + let state = self + .accum + .as_ref(py) + .call_method0("to_scalars") + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))? + .extract::>() + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + + Ok(state.into_iter().map(|v| v.scalar).collect::>()) + } + + fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { + // no need to implement as datafusion does not use it + todo!() + } + + fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { + // no need to implement as datafusion does not use it + todo!() + } + + fn evaluate(&self) -> Result { + // get GIL + let gil = pyo3::Python::acquire_gil(); + let py = gil.python(); + + let value = self + .accum + .as_ref(py) + .call_method0("evaluate") + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + + to_rust_scalar(value) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // get GIL + let gil = pyo3::Python::acquire_gil(); + let py = gil.python(); + + // 1. cast args to Pyarrow array + // 2. call function + + // 1. + let py_args = values + .iter() + .map(|arg| { + // remove unwrap + to_py_array(arg, py).unwrap() + }) + .collect::>(); + let py_args = PyTuple::new(py, py_args); + + // update accumulator + self.accum + .as_ref(py) + .call_method1("update", py_args) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // get GIL + let gil = pyo3::Python::acquire_gil(); + let py = gil.python(); + + // 1. cast states to Pyarrow array + // 2. merge + let state = &states[0]; + + let state = to_py_array(state, py) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + + // 2. + self.accum + .as_ref(py) + .call_method1("merge", (state,)) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + + Ok(()) + } +} + +pub fn array_udaf( + accumulator: PyObject, +) -> Arc Result> + Send + Sync> { + Arc::new(move || -> Result> { + let gil = pyo3::Python::acquire_gil(); + let py = gil.python(); + + let accumulator = accumulator + .call0(py) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + Ok(Box::new(PyAccumulator::new(accumulator))) + }) +} diff --git a/python/src/udf.rs b/python/src/udf.rs new file mode 100644 index 0000000000000..7fee71008ef2f --- /dev/null +++ b/python/src/udf.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::{prelude::*, types::PyTuple}; + +use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; + +use datafusion::error::DataFusionError; +use datafusion::physical_plan::functions::ScalarFunctionImplementation; + +use crate::to_py::to_py_array; +use crate::to_rust::to_rust; + +/// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays +/// This is more efficient as it performs a zero-copy of the contents. +pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { + make_scalar_function( + move |args: &[array::ArrayRef]| -> Result { + // get GIL + let gil = pyo3::Python::acquire_gil(); + let py = gil.python(); + + // 1. cast args to Pyarrow arrays + // 2. call function + // 3. cast to arrow::array::Array + + // 1. + let py_args = args + .iter() + .map(|arg| { + // remove unwrap + to_py_array(arg, py).unwrap() + }) + .collect::>(); + let py_args = PyTuple::new(py, py_args); + + // 2. + let value = func.as_ref(py).call(py_args, None); + let value = match value { + Ok(n) => Ok(n), + Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), + }?; + + let array = to_rust(value).unwrap(); + Ok(array) + }, + ) +} diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/python/tests/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/tests/generic.py b/python/tests/generic.py new file mode 100644 index 0000000000000..7362f0bb29569 --- /dev/null +++ b/python/tests/generic.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +import tempfile +import datetime +import os.path +import shutil + +import numpy +import pyarrow +import datafusion + +# used to write parquet files +import pyarrow.parquet + + +def data(): + data = numpy.concatenate( + [numpy.random.normal(0, 0.01, size=50), numpy.random.normal(50, 0.01, size=50)] + ) + return pyarrow.array(data) + + +def data_with_nans(): + data = numpy.random.normal(0, 0.01, size=50) + mask = numpy.random.randint(0, 2, size=50) + data[mask == 0] = numpy.NaN + return data + + +def data_datetime(f): + data = [ + datetime.datetime.now(), + datetime.datetime.now() - datetime.timedelta(days=1), + datetime.datetime.now() + datetime.timedelta(days=1), + ] + return pyarrow.array( + data, type=pyarrow.timestamp(f), mask=numpy.array([False, True, False]) + ) + + +def data_timedelta(f): + data = [ + datetime.timedelta(days=100), + datetime.timedelta(days=1), + datetime.timedelta(seconds=1), + ] + return pyarrow.array( + data, type=pyarrow.duration(f), mask=numpy.array([False, True, False]) + ) + + +def data_binary_other(): + return numpy.array([1, 0, 0], dtype="u4") + + +def write_parquet(path, data): + table = pyarrow.Table.from_arrays([data], names=["a"]) + pyarrow.parquet.write_table(table, path) + return path diff --git a/python/tests/test_df.py b/python/tests/test_df.py new file mode 100644 index 0000000000000..520d4e6a54723 --- /dev/null +++ b/python/tests/test_df.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +import pyarrow +import datafusion +f = datafusion.functions + + +class TestCase(unittest.TestCase): + + def _prepare(self): + ctx = datafusion.ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pyarrow.RecordBatch.from_arrays( + [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) + + def test_select(self): + df = self._prepare() + + df = df.select( + f.col("a") + f.col("b"), + f.col("a") - f.col("b"), + ) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + self.assertEqual(result.column(0), pyarrow.array([5, 7, 9])) + self.assertEqual(result.column(1), pyarrow.array([-3, -3, -3])) + + def test_filter(self): + df = self._prepare() + + df = df \ + .select( + f.col("a") + f.col("b"), + f.col("a") - f.col("b"), + ) \ + .filter(f.col("a") > f.lit(2)) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + self.assertEqual(result.column(0), pyarrow.array([9])) + self.assertEqual(result.column(1), pyarrow.array([-3])) + + def test_limit(self): + df = self._prepare() + + df = df.limit(1) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + self.assertEqual(len(result.column(0)), 1) + self.assertEqual(len(result.column(1)), 1) + + def test_udf(self): + df = self._prepare() + + # is_null is a pyarrow function over arrays + udf = f.udf(lambda x: x.is_null(), [pyarrow.int64()], pyarrow.bool_()) + + df = df.select(udf(f.col("a"))) + + self.assertEqual(df.collect()[0].column(0), pyarrow.array([False, False, False])) + + def test_join(self): + ctx = datafusion.ExecutionContext() + + batch = pyarrow.RecordBatch.from_arrays( + [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]]) + + batch = pyarrow.RecordBatch.from_arrays( + [pyarrow.array([1, 2]), pyarrow.array([8, 10])], + names=["a", "c"], + ) + df1 = ctx.create_dataframe([[batch]]) + + df = df.join(df1, on="a", how="inner") + + # execute and collect the first (and only) batch + batch = df.collect()[0] + + if batch.column(0) == pyarrow.array([1, 2]): + self.assertEqual(batch.column(0), pyarrow.array([1, 2])) + self.assertEqual(batch.column(1), pyarrow.array([8, 10])) + self.assertEqual(batch.column(2), pyarrow.array([4, 5])) + else: + self.assertEqual(batch.column(0), pyarrow.array([2, 1])) + self.assertEqual(batch.column(1), pyarrow.array([10, 8])) + self.assertEqual(batch.column(2), pyarrow.array([5, 4])) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py new file mode 100644 index 0000000000000..e9047ea6e70c3 --- /dev/null +++ b/python/tests/test_sql.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +import tempfile +import datetime +import os.path +import shutil + +import numpy +import pyarrow +import datafusion + +# used to write parquet files +import pyarrow.parquet + +from tests.generic import * + + +class TestCase(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.test_dir = tempfile.mkdtemp() + numpy.random.seed(1) + + def tearDown(self): + # Remove the directory after the test + shutil.rmtree(self.test_dir) + + def test_no_table(self): + with self.assertRaises(Exception): + datafusion.Context().sql("SELECT a FROM b").collect() + + def test_register(self): + ctx = datafusion.ExecutionContext() + + path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) + + ctx.register_parquet("t", path) + + self.assertEqual(ctx.tables(), {"t"}) + + def test_execute(self): + data = [1, 1, 2, 2, 3, 11, 12] + + ctx = datafusion.ExecutionContext() + + # single column, "a" + path = write_parquet( + os.path.join(self.test_dir, "a.parquet"), pyarrow.array(data) + ) + ctx.register_parquet("t", path) + + self.assertEqual(ctx.tables(), {"t"}) + + # count + result = ctx.sql("SELECT COUNT(a) FROM t").collect() + + expected = pyarrow.array([7], pyarrow.uint64()) + expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + self.assertEqual(expected, result) + + # where + expected = pyarrow.array([2], pyarrow.uint64()) + expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + self.assertEqual( + expected, ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() + ) + + # group by + result = ctx.sql( + "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" + ).collect() + + result_keys = result[0].to_pydict()["CAST(a AS Int32)"] + result_values = result[0].to_pydict()["COUNT(a)"] + result_keys, result_values = ( + list(t) for t in zip(*sorted(zip(result_keys, result_values))) + ) + + self.assertEqual(result_keys, [1, 2, 3, 11, 12]) + self.assertEqual(result_values, [2, 2, 1, 1, 1]) + + # order by + result = ctx.sql( + "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" + ).collect() + expected_a = pyarrow.array([50.0219, 50.0152], pyarrow.float64()) + expected_cast = pyarrow.array([50, 50], pyarrow.int32()) + expected = [ + pyarrow.RecordBatch.from_arrays( + [expected_a, expected_cast], ["a", "CAST(a AS Int32)"] + ) + ] + numpy.testing.assert_equal(expected[0].column(1), expected[0].column(1)) + + def test_cast(self): + """ + Verify that we can cast + """ + ctx = datafusion.ExecutionContext() + + path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) + ctx.register_parquet("t", path) + + valid_types = [ + "smallint", + "int", + "bigint", + "float(32)", + "float(64)", + "float", + ] + + select = ", ".join( + [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)] + ) + + # can execute, which implies that we can cast + ctx.sql(f"SELECT {select} FROM t").collect() + + def _test_udf(self, udf, args, return_type, array, expected): + ctx = datafusion.ExecutionContext() + + # write to disk + path = write_parquet(os.path.join(self.test_dir, "a.parquet"), array) + ctx.register_parquet("t", path) + + ctx.register_udf("udf", udf, args, return_type) + + batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() + + result = batches[0].column(0) + + self.assertEqual(expected, result) + + def test_udf_identity(self): + self._test_udf( + lambda x: x, + [pyarrow.float64()], + pyarrow.float64(), + pyarrow.array([-1.2, None, 1.2]), + pyarrow.array([-1.2, None, 1.2]), + ) + + def test_udf(self): + self._test_udf( + lambda x: x.is_null(), + [pyarrow.float64()], + pyarrow.bool_(), + pyarrow.array([-1.2, None, 1.2]), + pyarrow.array([False, True, False]), + ) + + +class TestIO(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + # Remove the directory after the test + shutil.rmtree(self.test_dir) + + def _test_data(self, data): + ctx = datafusion.ExecutionContext() + + # write to disk + path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data) + ctx.register_parquet("t", path) + + batches = ctx.sql("SELECT a AS tt FROM t").collect() + + result = batches[0].column(0) + + numpy.testing.assert_equal(data, result) + + def test_nans(self): + self._test_data(data_with_nans()) + + def test_utf8(self): + array = pyarrow.array( + ["a", "b", "c"], pyarrow.utf8(), numpy.array([False, True, False]) + ) + self._test_data(array) + + def test_large_utf8(self): + array = pyarrow.array( + ["a", "b", "c"], pyarrow.large_utf8(), numpy.array([False, True, False]) + ) + self._test_data(array) + + # Error from Arrow + @unittest.expectedFailure + def test_datetime_s(self): + self._test_data(data_datetime("s")) + + # C data interface missing + @unittest.expectedFailure + def test_datetime_ms(self): + self._test_data(data_datetime("ms")) + + # C data interface missing + @unittest.expectedFailure + def test_datetime_us(self): + self._test_data(data_datetime("us")) + + # Not writtable to parquet + @unittest.expectedFailure + def test_datetime_ns(self): + self._test_data(data_datetime("ns")) + + # Not writtable to parquet + @unittest.expectedFailure + def test_timedelta_s(self): + self._test_data(data_timedelta("s")) + + # Not writtable to parquet + @unittest.expectedFailure + def test_timedelta_ms(self): + self._test_data(data_timedelta("ms")) + + # Not writtable to parquet + @unittest.expectedFailure + def test_timedelta_us(self): + self._test_data(data_timedelta("us")) + + # Not writtable to parquet + @unittest.expectedFailure + def test_timedelta_ns(self): + self._test_data(data_timedelta("ns")) + + def test_date32(self): + array = pyarrow.array( + [ + datetime.date(2000, 1, 1), + datetime.date(1980, 1, 1), + datetime.date(2030, 1, 1), + ], + pyarrow.date32(), + numpy.array([False, True, False]), + ) + self._test_data(array) + + def test_binary_variable(self): + array = pyarrow.array( + [b"1", b"2", b"3"], pyarrow.binary(), numpy.array([False, True, False]) + ) + self._test_data(array) + + # C data interface missing + @unittest.expectedFailure + def test_binary_fixed(self): + array = pyarrow.array( + [b"1111", b"2222", b"3333"], + pyarrow.binary(4), + numpy.array([False, True, False]), + ) + self._test_data(array) + + def test_large_binary(self): + array = pyarrow.array( + [b"1111", b"2222", b"3333"], + pyarrow.large_binary(), + numpy.array([False, True, False]), + ) + self._test_data(array) + + def test_binary_other(self): + self._test_data(data_binary_other()) + + def test_bool(self): + array = pyarrow.array( + [False, True, True], None, numpy.array([False, True, False]) + ) + self._test_data(array) + + def test_u32(self): + array = pyarrow.array([0, 1, 2], None, numpy.array([False, True, False])) + self._test_data(array) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py new file mode 100644 index 0000000000000..ffd235e285f80 --- /dev/null +++ b/python/tests/test_udaf.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +import pyarrow +import pyarrow.compute +import datafusion + +f = datafusion.functions + + +class Accumulator: + """ + Interface of a user-defined accumulation. + """ + + def __init__(self): + self._sum = pyarrow.scalar(0.0) + + def to_scalars(self) -> [pyarrow.Scalar]: + return [self._sum] + + def update(self, values: pyarrow.Array) -> None: + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pyarrow.scalar( + self._sum.as_py() + pyarrow.compute.sum(values).as_py() + ) + + def merge(self, states: pyarrow.Array) -> None: + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pyarrow.scalar( + self._sum.as_py() + pyarrow.compute.sum(states).as_py() + ) + + def evaluate(self) -> pyarrow.Scalar: + return self._sum + + +class TestCase(unittest.TestCase): + def _prepare(self): + ctx = datafusion.ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pyarrow.RecordBatch.from_arrays( + [pyarrow.array([1, 2, 3]), pyarrow.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) + + def test_aggregate(self): + df = self._prepare() + + udaf = f.udaf( + Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] + ) + + df = df.aggregate([], [udaf(f.col("a"))]) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + self.assertEqual(result.column(0), pyarrow.array([1.0 + 2.0 + 3.0])) + + def test_group_by(self): + df = self._prepare() + + udaf = f.udaf( + Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] + ) + + df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + self.assertEqual(result.column(1), pyarrow.array([1.0 + 2.0, 3.0]))