From 543639fed8650410232dd980bf1ff11dcff4335b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 May 2021 08:09:57 -0600 Subject: [PATCH] Revert "Add datafusion-python (#69)" This reverts commit 46bde0bd148aacf1677a575cb9ddbc154b6c4fb3. --- .github/workflows/python_build.yml | 89 --------- .github/workflows/python_test.yaml | 58 ------ Cargo.toml | 4 +- dev/release/rat_exclude_files.txt | 1 - python/.cargo/config | 22 --- python/.dockerignore | 19 -- python/.gitignore | 20 -- python/Cargo.toml | 57 ------ python/README.md | 146 -------------- python/pyproject.toml | 20 -- python/rust-toolchain | 1 - python/src/context.rs | 115 ----------- python/src/dataframe.rs | 161 ---------------- python/src/errors.rs | 61 ------ python/src/expression.rs | 162 ---------------- python/src/functions.rs | 165 ---------------- python/src/lib.rs | 44 ----- python/src/scalar.rs | 36 ---- python/src/to_py.rs | 77 -------- python/src/to_rust.rs | 111 ----------- python/src/types.rs | 76 -------- python/src/udaf.rs | 147 --------------- python/src/udf.rs | 62 ------ python/tests/__init__.py | 16 -- python/tests/generic.py | 75 -------- python/tests/test_df.py | 115 ----------- python/tests/test_sql.py | 294 ----------------------------- python/tests/test_udaf.py | 91 --------- 28 files changed, 1 insertion(+), 2244 deletions(-) delete mode 100644 .github/workflows/python_build.yml delete mode 100644 .github/workflows/python_test.yaml delete mode 100644 python/.cargo/config delete mode 100644 python/.dockerignore delete mode 100644 python/.gitignore delete mode 100644 python/Cargo.toml delete mode 100644 python/README.md delete mode 100644 python/pyproject.toml delete mode 100644 python/rust-toolchain delete mode 100644 python/src/context.rs delete mode 100644 python/src/dataframe.rs delete mode 100644 python/src/errors.rs delete mode 100644 python/src/expression.rs delete mode 100644 python/src/functions.rs delete mode 100644 python/src/lib.rs delete mode 100644 python/src/scalar.rs delete mode 100644 python/src/to_py.rs delete mode 100644 python/src/to_rust.rs delete mode 100644 python/src/types.rs delete mode 100644 python/src/udaf.rs delete mode 100644 python/src/udf.rs delete mode 100644 python/tests/__init__.py delete mode 100644 python/tests/generic.py delete mode 100644 python/tests/test_df.py delete mode 100644 python/tests/test_sql.py delete mode 100644 python/tests/test_udaf.py diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml deleted file mode 100644 index c86bb81581a71..0000000000000 --- a/.github/workflows/python_build.yml +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Build -on: - push: - tags: - - v* - -jobs: - build-python-mac-win: - name: Mac/Win - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - python-version: [3.6, 3.7, 3.8] - os: [macos-latest, windows-latest] - steps: - - uses: actions/checkout@v2 - - - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly-2021-01-06 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install maturin - - - name: Build Python package - run: cd python && maturin build --release --no-sdist --strip --interpreter python${{matrix.python_version}} - - - name: List wheels - if: matrix.os == 'windows-latest' - run: dir python/target\wheels\ - - - name: List wheels - if: matrix.os != 'windows-latest' - run: find ./python/target/wheels/ - - - name: Archive wheels - uses: actions/upload-artifact@v2 - with: - name: dist - path: python/target/wheels/* - - build-manylinux: - name: Manylinux - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Build wheels - run: docker run --rm -v $(pwd):/io konstin2/maturin build --release --manylinux - - name: Archive wheels - uses: actions/upload-artifact@v2 - with: - name: dist - path: python/target/wheels/* - - release: - name: Publish in PyPI - needs: [build-manylinux, build-python-mac-win] - runs-on: ubuntu-latest - steps: - - uses: actions/download-artifact@v2 - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.pypi_password }} diff --git a/.github/workflows/python_test.yaml b/.github/workflows/python_test.yaml deleted file mode 100644 index 3b2111b59d49d..0000000000000 --- a/.github/workflows/python_test.yaml +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Python test -on: [push, pull_request] - -jobs: - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Setup Rust toolchain - run: | - rustup toolchain install nightly-2021-01-06 - rustup default nightly-2021-01-06 - rustup component add rustfmt - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /home/runner/.cargo - key: cargo-maturin-cache- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - path: /home/runner/target - key: target-maturin-cache- - - uses: actions/setup-python@v2 - with: - python-version: '3.7' - - name: Install Python dependencies - run: python -m pip install --upgrade pip setuptools wheel - - name: Run tests - run: | - cd python/ - export CARGO_HOME="/home/runner/.cargo" - export CARGO_TARGET_DIR="/home/runner/target" - - python -m venv venv - source venv/bin/activate - - pip install maturin==0.10.4 toml==0.10.1 pyarrow==4.0.0 - maturin develop - - python -m unittest discover tests diff --git a/Cargo.toml b/Cargo.toml index 9795cb68b4456..fa36a0c0fed7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,4 @@ members = [ "ballista/rust/core", "ballista/rust/executor", "ballista/rust/scheduler", -] - -exclude = ["python"] +] \ No newline at end of file diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 6126699bbc1fa..b94c0ea1d61a6 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -104,4 +104,3 @@ rust-toolchain benchmarks/queries/q*.sql ballista/rust/scheduler/testdata/* ballista/ui/scheduler/yarn.lock -python/rust-toolchain diff --git a/python/.cargo/config b/python/.cargo/config deleted file mode 100644 index 0b24f30cf908a..0000000000000 --- a/python/.cargo/config +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[target.x86_64-apple-darwin] -rustflags = [ - "-C", "link-arg=-undefined", - "-C", "link-arg=dynamic_lookup", -] diff --git a/python/.dockerignore b/python/.dockerignore deleted file mode 100644 index 08c131c2e7d60..0000000000000 --- a/python/.dockerignore +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -target -venv diff --git a/python/.gitignore b/python/.gitignore deleted file mode 100644 index 48fe4dbe52dde..0000000000000 --- a/python/.gitignore +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -/target -Cargo.lock -venv diff --git a/python/Cargo.toml b/python/Cargo.toml deleted file mode 100644 index 070720554f0ed..0000000000000 --- a/python/Cargo.toml +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "datafusion" -version = "0.2.1" -homepage = "https://github.com/apache/arrow" -repository = "https://github.com/apache/arrow" -authors = ["Apache Arrow "] -description = "Build and run queries against data" -readme = "README.md" -license = "Apache-2.0" -edition = "2018" - -[dependencies] -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } -rand = "0.7" -pyo3 = { version = "0.12.1", features = ["extension-module"] } -datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "2423ff0d" } - -[lib] -name = "datafusion" -crate-type = ["cdylib"] - -[package.metadata.maturin] -requires-dist = ["pyarrow>=1"] - -classifier = [ - "Development Status :: 2 - Pre-Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "License :: OSI Approved", - "Operating System :: MacOS", - "Operating System :: Microsoft :: Windows", - "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python", - "Programming Language :: Rust", -] diff --git a/python/README.md b/python/README.md deleted file mode 100644 index 1859fca9811c0..0000000000000 --- a/python/README.md +++ /dev/null @@ -1,146 +0,0 @@ - - -## DataFusion in Python - -This is a Python library that binds to [Apache Arrow](https://arrow.apache.org/) in-memory query engine [DataFusion](https://github.com/apache/arrow/tree/master/rust/datafusion). - -Like pyspark, it allows you to build a plan through SQL or a DataFrame API against in-memory data, parquet or CSV files, run it in a multi-threaded environment, and obtain the result back in Python. - -It also allows you to use UDFs and UDAFs for complex operations. - -The major advantage of this library over other execution engines is that this library achieves zero-copy between Python and its execution engine: there is no cost in using UDFs, UDAFs, and collecting the results to Python apart from having to lock the GIL when running those operations. - -Its query engine, DataFusion, is written in [Rust](https://www.rust-lang.org/), which makes strong assumptions about thread safety and lack of memory leaks. - -Technically, zero-copy is achieved via the [c data interface](https://arrow.apache.org/docs/format/CDataInterface.html). - -## How to use it - -Simple usage: - -```python -import datafusion -import pyarrow - -# an alias -f = datafusion.functions - -# create a context -ctx = datafusion.ExecutionContext() - -# create a RecordBatch and a new DataFrame from it -batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])], - names=["a", "b"], -) -df = ctx.create_dataframe([[batch]]) - -# create a new statement -df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), -) - -# execute and collect the first (and only) batch -result = df.collect()[0] - -assert result.column(0) == pyarrow.array([5, 7, 9]) -assert result.column(1) == pyarrow.array([-3, -3, -3]) -``` - -### UDFs - -```python -def is_null(array: pyarrow.Array) -> pyarrow.Array: - return array.is_null() - -udf = f.udf(is_null, [pyarrow.int64()], pyarrow.bool_()) - -df = df.select(udf(f.col("a"))) -``` - -### UDAF - -```python -import pyarrow -import pyarrow.compute - - -class Accumulator: - """ - Interface of a user-defined accumulation. - """ - def __init__(self): - self._sum = pyarrow.scalar(0.0) - - def to_scalars(self) -> [pyarrow.Scalar]: - return [self._sum] - - def update(self, values: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py()) - - def merge(self, states: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py()) - - def evaluate(self) -> pyarrow.Scalar: - return self._sum - - -df = ... - -udaf = f.udaf(Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()]) - -df = df.aggregate( - [], - [udaf(f.col("a"))] -) -``` - -## How to install - -```bash -pip install datafusion -``` - -## How to develop - -This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin). - -Bootstrap: - -```bash -# fetch this repo -git clone git@github.com:apache/arrow-datafusion.git - -cd arrow-datafusion/python - -# prepare development environment (used to build wheel / install in development) -python3 -m venv venv -pip install maturin==0.10.4 toml==0.10.1 pyarrow==1.0.0 -``` - -Whenever rust code changes (your changes or via git pull): - -```bash -venv/bin/maturin develop -venv/bin/python -m unittest discover tests -``` diff --git a/python/pyproject.toml b/python/pyproject.toml deleted file mode 100644 index 27480690e06cc..0000000000000 --- a/python/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[build-system] -requires = ["maturin"] -build-backend = "maturin" diff --git a/python/rust-toolchain b/python/rust-toolchain deleted file mode 100644 index 9d0cf79d367d6..0000000000000 --- a/python/rust-toolchain +++ /dev/null @@ -1 +0,0 @@ -nightly-2021-01-06 diff --git a/python/src/context.rs b/python/src/context.rs deleted file mode 100644 index 14ef0f7321f15..0000000000000 --- a/python/src/context.rs +++ /dev/null @@ -1,115 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{collections::HashSet, sync::Arc}; - -use rand::distributions::Alphanumeric; -use rand::Rng; - -use pyo3::prelude::*; - -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; - -use crate::dataframe; -use crate::errors; -use crate::functions; -use crate::to_rust; -use crate::types::PyDataType; - -/// `ExecutionContext` is able to plan and execute DataFusion plans. -/// It has a powerful optimizer, a physical planner for local execution, and a -/// multi-threaded execution engine to perform the execution. -#[pyclass(unsendable)] -pub(crate) struct ExecutionContext { - ctx: _ExecutionContext, -} - -#[pymethods] -impl ExecutionContext { - #[new] - fn new() -> Self { - ExecutionContext { - ctx: _ExecutionContext::new(), - } - } - - /// Returns a DataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str) -> PyResult { - let df = self - .ctx - .sql(query) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - df.to_logical_plan(), - )) - } - - fn create_dataframe( - &mut self, - partitions: Vec>, - py: Python, - ) -> PyResult { - let partitions: Vec> = partitions - .iter() - .map(|batches| { - batches - .iter() - .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) - .collect() - }) - .collect::>()?; - - let table = - errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; - - // generate a random (unique) name for this table - let name = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(10) - .collect::(); - - errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - errors::wrap(self.ctx.table(&*name))?.to_logical_plan(), - )) - } - - fn register_parquet(&mut self, name: &str, path: &str) -> PyResult<()> { - errors::wrap(self.ctx.register_parquet(name, path))?; - Ok(()) - } - - fn register_udf( - &mut self, - name: &str, - func: PyObject, - args_types: Vec, - return_type: PyDataType, - ) { - let function = functions::create_udf(func, args_types, return_type, name); - - self.ctx.register_udf(function.function); - } - - fn tables(&self) -> HashSet { - self.ctx.tables().unwrap() - } -} diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs deleted file mode 100644 index f90a7cf2f0dcf..0000000000000 --- a/python/src/dataframe.rs +++ /dev/null @@ -1,161 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::{Arc, Mutex}; - -use logical_plan::LogicalPlan; -use pyo3::{prelude::*, types::PyTuple}; -use tokio::runtime::Runtime; - -use datafusion::execution::context::ExecutionContext as _ExecutionContext; -use datafusion::logical_plan::{JoinType, LogicalPlanBuilder}; -use datafusion::physical_plan::collect; -use datafusion::{execution::context::ExecutionContextState, logical_plan}; - -use crate::{errors, to_py}; -use crate::{errors::DataFusionError, expression}; - -/// A DataFrame is a representation of a logical plan and an API to compose statements. -/// Use it to build a plan and `.collect()` to execute the plan and collect the result. -/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. -#[pyclass] -pub(crate) struct DataFrame { - ctx_state: Arc>, - plan: LogicalPlan, -} - -impl DataFrame { - /// creates a new DataFrame - pub fn new(ctx_state: Arc>, plan: LogicalPlan) -> Self { - Self { ctx_state, plan } - } -} - -#[pymethods] -impl DataFrame { - /// Select `expressions` from the existing DataFrame. - #[args(args = "*")] - fn select(&self, args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; - let builder = LogicalPlanBuilder::from(&self.plan); - let builder = - errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) - } - - /// Filter according to the `predicate` expression - fn filter(&self, predicate: expression::Expression) -> PyResult { - let builder = LogicalPlanBuilder::from(&self.plan); - let builder = errors::wrap(builder.filter(predicate.expr))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) - } - - /// Aggregates using expressions - fn aggregate( - &self, - group_by: Vec, - aggs: Vec, - ) -> PyResult { - let builder = LogicalPlanBuilder::from(&self.plan); - let builder = errors::wrap(builder.aggregate( - group_by.into_iter().map(|e| e.expr), - aggs.into_iter().map(|e| e.expr), - ))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) - } - - /// Limits the plan to return at most `count` rows - fn limit(&self, count: usize) -> PyResult { - let builder = LogicalPlanBuilder::from(&self.plan); - let builder = errors::wrap(builder.limit(count))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) - } - - /// Executes the plan, returning a list of `RecordBatch`es. - /// Unless some order is specified in the plan, there is no guarantee of the order of the result - fn collect(&self, py: Python) -> PyResult { - let ctx = _ExecutionContext::from(self.ctx_state.clone()); - let plan = ctx - .optimize(&self.plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - let plan = ctx - .create_physical_plan(&plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - - let rt = Runtime::new().unwrap(); - let batches = py.allow_threads(|| { - rt.block_on(async { - collect(plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - to_py::to_py(&batches) - } - - /// Returns the join of two DataFrames `on`. - fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult { - let builder = LogicalPlanBuilder::from(&self.plan); - - let join_type = match how { - "inner" => JoinType::Inner, - "left" => JoinType::Left, - "right" => JoinType::Right, - how => { - return Err(DataFusionError::Common(format!( - "The join type {} does not exist or is not implemented", - how - )) - .into()) - } - }; - - let builder = errors::wrap(builder.join( - &right.plan, - join_type, - on.as_slice(), - on.as_slice(), - ))?; - - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) - } -} diff --git a/python/src/errors.rs b/python/src/errors.rs deleted file mode 100644 index fbe98037a030f..0000000000000 --- a/python/src/errors.rs +++ /dev/null @@ -1,61 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use core::fmt; - -use datafusion::arrow::error::ArrowError; -use datafusion::error::DataFusionError as InnerDataFusionError; -use pyo3::{exceptions, PyErr}; - -#[derive(Debug)] -pub enum DataFusionError { - ExecutionError(InnerDataFusionError), - ArrowError(ArrowError), - Common(String), -} - -impl fmt::Display for DataFusionError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - DataFusionError::ExecutionError(e) => write!(f, "DataFusion error: {:?}", e), - DataFusionError::ArrowError(e) => write!(f, "Arrow error: {:?}", e), - DataFusionError::Common(e) => write!(f, "{}", e), - } - } -} - -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - exceptions::PyException::new_err(err.to_string()) - } -} - -impl From for DataFusionError { - fn from(err: InnerDataFusionError) -> DataFusionError { - DataFusionError::ExecutionError(err) - } -} - -impl From for DataFusionError { - fn from(err: ArrowError) -> DataFusionError { - DataFusionError::ArrowError(err) - } -} - -pub(crate) fn wrap(a: Result) -> Result { - Ok(a?) -} diff --git a/python/src/expression.rs b/python/src/expression.rs deleted file mode 100644 index 78ca6d7e598ec..0000000000000 --- a/python/src/expression.rs +++ /dev/null @@ -1,162 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::{ - basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol, -}; - -use datafusion::logical_plan::Expr as _Expr; -use datafusion::physical_plan::udaf::AggregateUDF as _AggregateUDF; -use datafusion::physical_plan::udf::ScalarUDF as _ScalarUDF; - -/// An expression that can be used on a DataFrame -#[pyclass] -#[derive(Debug, Clone)] -pub(crate) struct Expression { - pub(crate) expr: _Expr, -} - -/// converts a tuple of expressions into a vector of Expressions -pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { - value - .iter() - .map(|e| e.extract::()) - .collect::>() -} - -#[pyproto] -impl PyNumberProtocol for Expression { - fn __add__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr + rhs.expr, - }) - } - - fn __sub__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr - rhs.expr, - }) - } - - fn __truediv__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr / rhs.expr, - }) - } - - fn __mul__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr * rhs.expr, - }) - } - - fn __and__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr.and(rhs.expr), - }) - } - - fn __or__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr.or(rhs.expr), - }) - } - - fn __invert__(&self) -> PyResult { - Ok(Expression { - expr: self.expr.clone().not(), - }) - } -} - -#[pyproto] -impl PyObjectProtocol for Expression { - fn __richcmp__(&self, other: Expression, op: CompareOp) -> Expression { - match op { - CompareOp::Lt => Expression { - expr: self.expr.clone().lt(other.expr), - }, - CompareOp::Le => Expression { - expr: self.expr.clone().lt_eq(other.expr), - }, - CompareOp::Eq => Expression { - expr: self.expr.clone().eq(other.expr), - }, - CompareOp::Ne => Expression { - expr: self.expr.clone().not_eq(other.expr), - }, - CompareOp::Gt => Expression { - expr: self.expr.clone().gt(other.expr), - }, - CompareOp::Ge => Expression { - expr: self.expr.clone().gt_eq(other.expr), - }, - } - } -} - -#[pymethods] -impl Expression { - /// assign a name to the expression - pub fn alias(&self, name: &str) -> PyResult { - Ok(Expression { - expr: self.expr.clone().alias(name), - }) - } -} - -/// Represents a ScalarUDF -#[pyclass] -#[derive(Debug, Clone)] -pub struct ScalarUDF { - pub(crate) function: _ScalarUDF, -} - -#[pymethods] -impl ScalarUDF { - /// creates a new expression with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { - let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - - Ok(Expression { - expr: self.function.call(args), - }) - } -} - -/// Represents a AggregateUDF -#[pyclass] -#[derive(Debug, Clone)] -pub struct AggregateUDF { - pub(crate) function: _AggregateUDF, -} - -#[pymethods] -impl AggregateUDF { - /// creates a new expression with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { - let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - - Ok(Expression { - expr: self.function.call(args), - }) - } -} diff --git a/python/src/functions.rs b/python/src/functions.rs deleted file mode 100644 index 68000cb1ecbf8..0000000000000 --- a/python/src/functions.rs +++ /dev/null @@ -1,165 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use datafusion::arrow::datatypes::DataType; -use pyo3::{prelude::*, wrap_pyfunction}; - -use datafusion::logical_plan; - -use crate::udaf; -use crate::udf; -use crate::{expression, types::PyDataType}; - -/// Expression representing a column on the existing plan. -#[pyfunction] -#[text_signature = "(name)"] -fn col(name: &str) -> expression::Expression { - expression::Expression { - expr: logical_plan::col(name), - } -} - -/// Expression representing a constant value -#[pyfunction] -#[text_signature = "(value)"] -fn lit(value: i32) -> expression::Expression { - expression::Expression { - expr: logical_plan::lit(value), - } -} - -#[pyfunction] -fn sum(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::sum(value.expr), - } -} - -#[pyfunction] -fn avg(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::avg(value.expr), - } -} - -#[pyfunction] -fn min(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::min(value.expr), - } -} - -#[pyfunction] -fn max(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::max(value.expr), - } -} - -#[pyfunction] -fn count(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::count(value.expr), - } -} - -/* -#[pyfunction] -fn concat(value: Vec) -> expression::Expression { - expression::Expression { - expr: logical_plan::concat(value.into_iter().map(|e| e.expr)), - } -} - */ - -pub(crate) fn create_udf( - fun: PyObject, - input_types: Vec, - return_type: PyDataType, - name: &str, -) -> expression::ScalarUDF { - let input_types: Vec = - input_types.iter().map(|d| d.data_type.clone()).collect(); - let return_type = Arc::new(return_type.data_type); - - expression::ScalarUDF { - function: logical_plan::create_udf( - name, - input_types, - return_type, - udf::array_udf(fun), - ), - } -} - -/// Creates a new udf. -#[pyfunction] -fn udf( - fun: PyObject, - input_types: Vec, - return_type: PyDataType, - py: Python, -) -> PyResult { - let name = fun.getattr(py, "__qualname__")?.extract::(py)?; - - Ok(create_udf(fun, input_types, return_type, &name)) -} - -/// Creates a new udf. -#[pyfunction] -fn udaf( - accumulator: PyObject, - input_type: PyDataType, - return_type: PyDataType, - state_type: Vec, - py: Python, -) -> PyResult { - let name = accumulator - .getattr(py, "__qualname__")? - .extract::(py)?; - - let input_type = input_type.data_type; - let return_type = Arc::new(return_type.data_type); - let state_type = Arc::new(state_type.into_iter().map(|t| t.data_type).collect()); - - Ok(expression::AggregateUDF { - function: logical_plan::create_udaf( - &name, - input_type, - return_type, - udaf::array_udaf(accumulator), - state_type, - ), - }) -} - -pub fn init(module: &PyModule) -> PyResult<()> { - module.add_function(wrap_pyfunction!(col, module)?)?; - module.add_function(wrap_pyfunction!(lit, module)?)?; - // see https://github.com/apache/arrow-datafusion/issues/226 - //module.add_function(wrap_pyfunction!(concat, module)?)?; - module.add_function(wrap_pyfunction!(udf, module)?)?; - module.add_function(wrap_pyfunction!(sum, module)?)?; - module.add_function(wrap_pyfunction!(count, module)?)?; - module.add_function(wrap_pyfunction!(min, module)?)?; - module.add_function(wrap_pyfunction!(max, module)?)?; - module.add_function(wrap_pyfunction!(avg, module)?)?; - module.add_function(wrap_pyfunction!(udaf, module)?)?; - Ok(()) -} diff --git a/python/src/lib.rs b/python/src/lib.rs deleted file mode 100644 index aecfe9994cd1a..0000000000000 --- a/python/src/lib.rs +++ /dev/null @@ -1,44 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; - -mod context; -mod dataframe; -mod errors; -mod expression; -mod functions; -mod scalar; -mod to_py; -mod to_rust; -mod types; -mod udaf; -mod udf; - -/// DataFusion. -#[pymodule] -fn datafusion(py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - - let functions = PyModule::new(py, "functions")?; - functions::init(functions)?; - m.add_submodule(functions)?; - - Ok(()) -} diff --git a/python/src/scalar.rs b/python/src/scalar.rs deleted file mode 100644 index 0c562a9403616..0000000000000 --- a/python/src/scalar.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; - -use datafusion::scalar::ScalarValue as _Scalar; - -use crate::to_rust::to_rust_scalar; - -/// An expression that can be used on a DataFrame -#[derive(Debug, Clone)] -pub(crate) struct Scalar { - pub(crate) scalar: _Scalar, -} - -impl<'source> FromPyObject<'source> for Scalar { - fn extract(ob: &'source PyAny) -> PyResult { - Ok(Self { - scalar: to_rust_scalar(ob)?, - }) - } -} diff --git a/python/src/to_py.rs b/python/src/to_py.rs deleted file mode 100644 index deeb9719891a3..0000000000000 --- a/python/src/to_py.rs +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; -use pyo3::{libc::uintptr_t, PyErr}; - -use std::convert::From; - -use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::record_batch::RecordBatch; - -use crate::errors; - -pub fn to_py_array(array: &ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(errors::DataFusionError::from)?; - - let pa = py.import("pyarrow")?; - - let array = pa.getattr("Array")?.call_method1( - "_import_from_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - Ok(array.to_object(py)) -} - -fn to_py_batch<'a>( - batch: &RecordBatch, - py: Python, - pyarrow: &'a PyModule, -) -> Result { - let mut py_arrays = vec![]; - let mut py_names = vec![]; - - let schema = batch.schema(); - for (array, field) in batch.columns().iter().zip(schema.fields().iter()) { - let array = to_py_array(array, py)?; - - py_arrays.push(array); - py_names.push(field.name()); - } - - let record = pyarrow - .getattr("RecordBatch")? - .call_method1("from_arrays", (py_arrays, py_names))?; - - Ok(PyObject::from(record)) -} - -/// Converts a &[RecordBatch] into a Vec represented in PyArrow -pub fn to_py(batches: &[RecordBatch]) -> PyResult { - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - let pyarrow = PyModule::import(py, "pyarrow")?; - let builtins = PyModule::import(py, "builtins")?; - - let mut py_batches = vec![]; - for batch in batches { - py_batches.push(to_py_batch(batch, py, pyarrow)?); - } - let result = builtins.call1("list", (py_batches,))?; - Ok(PyObject::from(result)) -} diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs deleted file mode 100644 index d8f2307a49823..0000000000000 --- a/python/src/to_rust.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use datafusion::arrow::{ - array::{make_array_from_raw, ArrayRef}, - datatypes::Field, - datatypes::Schema, - ffi, - record_batch::RecordBatch, -}; -use datafusion::scalar::ScalarValue; -use pyo3::{libc::uintptr_t, prelude::*}; - -use crate::{errors, types::PyDataType}; - -/// converts a pyarrow Array into a Rust Array -pub fn to_rust(ob: &PyAny) -> PyResult { - // prepare a pointer to receive the Array struct - let (array_pointer, schema_pointer) = - ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - ob.call_method1( - "_export_to_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - - let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(errors::DataFusionError::from)?; - Ok(array) -} - -pub fn to_rust_batch(batch: &PyAny) -> PyResult { - let schema = batch.getattr("schema")?; - let names = schema.getattr("names")?.extract::>()?; - - let fields = names - .iter() - .enumerate() - .map(|(i, name)| { - let field = schema.call_method1("field", (i,))?; - let nullable = field.getattr("nullable")?.extract::()?; - let py_data_type = field.getattr("type")?; - let data_type = py_data_type.extract::()?.data_type; - Ok(Field::new(name, data_type, nullable)) - }) - .collect::>()?; - - let schema = Arc::new(Schema::new(fields)); - - let arrays = (0..names.len()) - .map(|i| { - let array = batch.call_method1("column", (i,))?; - to_rust(array) - }) - .collect::>()?; - - let batch = - RecordBatch::try_new(schema, arrays).map_err(errors::DataFusionError::from)?; - Ok(batch) -} - -/// converts a pyarrow Scalar into a Rust Scalar -pub fn to_rust_scalar(ob: &PyAny) -> PyResult { - let t = ob - .getattr("__class__")? - .getattr("__name__")? - .extract::<&str>()?; - - let p = ob.call_method0("as_py")?; - - Ok(match t { - "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), - "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), - "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), - "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), - "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), - "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), - "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), - "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), - "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), - "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), - "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), - "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), - "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), - other => { - return Err(errors::DataFusionError::Common(format!( - "Type \"{}\"not yet implemented", - other - )) - .into()) - } - }) -} diff --git a/python/src/types.rs b/python/src/types.rs deleted file mode 100644 index ffa822e073a89..0000000000000 --- a/python/src/types.rs +++ /dev/null @@ -1,76 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::DataType; -use pyo3::{FromPyObject, PyAny, PyResult}; - -use crate::errors; - -/// utility struct to convert PyObj to native DataType -#[derive(Debug, Clone)] -pub struct PyDataType { - pub data_type: DataType, -} - -impl<'source> FromPyObject<'source> for PyDataType { - fn extract(ob: &'source PyAny) -> PyResult { - let id = ob.getattr("id")?.extract::()?; - let data_type = data_type_id(&id)?; - Ok(PyDataType { data_type }) - } -} - -fn data_type_id(id: &i32) -> Result { - // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd - // this is not ideal as it does not generalize for non-basic types - // Find a way to get a unique name from the pyarrow.DataType - Ok(match id { - 1 => DataType::Boolean, - 2 => DataType::UInt8, - 3 => DataType::Int8, - 4 => DataType::UInt16, - 5 => DataType::Int16, - 6 => DataType::UInt32, - 7 => DataType::Int32, - 8 => DataType::UInt64, - 9 => DataType::Int64, - - 10 => DataType::Float16, - 11 => DataType::Float32, - 12 => DataType::Float64, - - //13 => DataType::Decimal, - - // 14 => DataType::Date32(), - // 15 => DataType::Date64(), - // 16 => DataType::Timestamp(), - // 17 => DataType::Time32(), - // 18 => DataType::Time64(), - // 19 => DataType::Duration() - 20 => DataType::Binary, - 21 => DataType::Utf8, - 22 => DataType::LargeBinary, - 23 => DataType::LargeUtf8, - - other => { - return Err(errors::DataFusionError::Common(format!( - "The type {} is not valid", - other - ))) - } - }) -} diff --git a/python/src/udaf.rs b/python/src/udaf.rs deleted file mode 100644 index 3ce223df9a491..0000000000000 --- a/python/src/udaf.rs +++ /dev/null @@ -1,147 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use pyo3::{prelude::*, types::PyTuple}; - -use datafusion::arrow::array::ArrayRef; - -use datafusion::error::Result; -use datafusion::{ - error::DataFusionError as InnerDataFusionError, physical_plan::Accumulator, - scalar::ScalarValue, -}; - -use crate::scalar::Scalar; -use crate::to_py::to_py_array; -use crate::to_rust::to_rust_scalar; - -#[derive(Debug)] -struct PyAccumulator { - accum: PyObject, -} - -impl PyAccumulator { - fn new(accum: PyObject) -> Self { - Self { accum } - } -} - -impl Accumulator for PyAccumulator { - fn state(&self) -> Result> { - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - - let state = self - .accum - .as_ref(py) - .call_method0("to_scalars") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))? - .extract::>() - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - Ok(state.into_iter().map(|v| v.scalar).collect::>()) - } - - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - - fn evaluate(&self) -> Result { - // get GIL - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - - let value = self - .accum - .as_ref(py) - .call_method0("evaluate") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - to_rust_scalar(value) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // get GIL - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - - // 1. cast args to Pyarrow array - // 2. call function - - // 1. - let py_args = values - .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) - .collect::>(); - let py_args = PyTuple::new(py, py_args); - - // update accumulator - self.accum - .as_ref(py) - .call_method1("update", py_args) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // get GIL - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - - // 1. cast states to Pyarrow array - // 2. merge - let state = &states[0]; - - let state = to_py_array(state, py) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - // 2. - self.accum - .as_ref(py) - .call_method1("merge", (state,)) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - Ok(()) - } -} - -pub fn array_udaf( - accumulator: PyObject, -) -> Arc Result> + Send + Sync> { - Arc::new(move || -> Result> { - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - - let accumulator = accumulator - .call0(py) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - Ok(Box::new(PyAccumulator::new(accumulator))) - }) -} diff --git a/python/src/udf.rs b/python/src/udf.rs deleted file mode 100644 index 7fee71008ef2f..0000000000000 --- a/python/src/udf.rs +++ /dev/null @@ -1,62 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::{prelude::*, types::PyTuple}; - -use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; - -use datafusion::error::DataFusionError; -use datafusion::physical_plan::functions::ScalarFunctionImplementation; - -use crate::to_py::to_py_array; -use crate::to_rust::to_rust; - -/// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays -/// This is more efficient as it performs a zero-copy of the contents. -pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { - make_scalar_function( - move |args: &[array::ArrayRef]| -> Result { - // get GIL - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - - // 1. cast args to Pyarrow arrays - // 2. call function - // 3. cast to arrow::array::Array - - // 1. - let py_args = args - .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) - .collect::>(); - let py_args = PyTuple::new(py, py_args); - - // 2. - let value = func.as_ref(py).call(py_args, None); - let value = match value { - Ok(n) => Ok(n), - Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), - }?; - - let array = to_rust(value).unwrap(); - Ok(array) - }, - ) -} diff --git a/python/tests/__init__.py b/python/tests/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/python/tests/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/python/tests/generic.py b/python/tests/generic.py deleted file mode 100644 index 7362f0bb29569..0000000000000 --- a/python/tests/generic.py +++ /dev/null @@ -1,75 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest -import tempfile -import datetime -import os.path -import shutil - -import numpy -import pyarrow -import datafusion - -# used to write parquet files -import pyarrow.parquet - - -def data(): - data = numpy.concatenate( - [numpy.random.normal(0, 0.01, size=50), numpy.random.normal(50, 0.01, size=50)] - ) - return pyarrow.array(data) - - -def data_with_nans(): - data = numpy.random.normal(0, 0.01, size=50) - mask = numpy.random.randint(0, 2, size=50) - data[mask == 0] = numpy.NaN - return data - - -def data_datetime(f): - data = [ - datetime.datetime.now(), - datetime.datetime.now() - datetime.timedelta(days=1), - datetime.datetime.now() + datetime.timedelta(days=1), - ] - return pyarrow.array( - data, type=pyarrow.timestamp(f), mask=numpy.array([False, True, False]) - ) - - -def data_timedelta(f): - data = [ - datetime.timedelta(days=100), - datetime.timedelta(days=1), - datetime.timedelta(seconds=1), - ] - return pyarrow.array( - data, type=pyarrow.duration(f), mask=numpy.array([False, True, False]) - ) - - -def data_binary_other(): - return numpy.array([1, 0, 0], dtype="u4") - - -def write_parquet(path, data): - table = pyarrow.Table.from_arrays([data], names=["a"]) - pyarrow.parquet.write_table(table, path) - return path diff --git a/python/tests/test_df.py b/python/tests/test_df.py deleted file mode 100644 index 520d4e6a54723..0000000000000 --- a/python/tests/test_df.py +++ /dev/null @@ -1,115 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest - -import pyarrow -import datafusion -f = datafusion.functions - - -class TestCase(unittest.TestCase): - - def _prepare(self): - ctx = datafusion.ExecutionContext() - - # create a RecordBatch and a new DataFrame from it - batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) - - def test_select(self): - df = self._prepare() - - df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - self.assertEqual(result.column(0), pyarrow.array([5, 7, 9])) - self.assertEqual(result.column(1), pyarrow.array([-3, -3, -3])) - - def test_filter(self): - df = self._prepare() - - df = df \ - .select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ) \ - .filter(f.col("a") > f.lit(2)) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - self.assertEqual(result.column(0), pyarrow.array([9])) - self.assertEqual(result.column(1), pyarrow.array([-3])) - - def test_limit(self): - df = self._prepare() - - df = df.limit(1) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - self.assertEqual(len(result.column(0)), 1) - self.assertEqual(len(result.column(1)), 1) - - def test_udf(self): - df = self._prepare() - - # is_null is a pyarrow function over arrays - udf = f.udf(lambda x: x.is_null(), [pyarrow.int64()], pyarrow.bool_()) - - df = df.select(udf(f.col("a"))) - - self.assertEqual(df.collect()[0].column(0), pyarrow.array([False, False, False])) - - def test_join(self): - ctx = datafusion.ExecutionContext() - - batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])], - names=["a", "b"], - ) - df = ctx.create_dataframe([[batch]]) - - batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2]), pyarrow.array([8, 10])], - names=["a", "c"], - ) - df1 = ctx.create_dataframe([[batch]]) - - df = df.join(df1, on="a", how="inner") - - # execute and collect the first (and only) batch - batch = df.collect()[0] - - if batch.column(0) == pyarrow.array([1, 2]): - self.assertEqual(batch.column(0), pyarrow.array([1, 2])) - self.assertEqual(batch.column(1), pyarrow.array([8, 10])) - self.assertEqual(batch.column(2), pyarrow.array([4, 5])) - else: - self.assertEqual(batch.column(0), pyarrow.array([2, 1])) - self.assertEqual(batch.column(1), pyarrow.array([10, 8])) - self.assertEqual(batch.column(2), pyarrow.array([5, 4])) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py deleted file mode 100644 index e9047ea6e70c3..0000000000000 --- a/python/tests/test_sql.py +++ /dev/null @@ -1,294 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest -import tempfile -import datetime -import os.path -import shutil - -import numpy -import pyarrow -import datafusion - -# used to write parquet files -import pyarrow.parquet - -from tests.generic import * - - -class TestCase(unittest.TestCase): - def setUp(self): - # Create a temporary directory - self.test_dir = tempfile.mkdtemp() - numpy.random.seed(1) - - def tearDown(self): - # Remove the directory after the test - shutil.rmtree(self.test_dir) - - def test_no_table(self): - with self.assertRaises(Exception): - datafusion.Context().sql("SELECT a FROM b").collect() - - def test_register(self): - ctx = datafusion.ExecutionContext() - - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) - - ctx.register_parquet("t", path) - - self.assertEqual(ctx.tables(), {"t"}) - - def test_execute(self): - data = [1, 1, 2, 2, 3, 11, 12] - - ctx = datafusion.ExecutionContext() - - # single column, "a" - path = write_parquet( - os.path.join(self.test_dir, "a.parquet"), pyarrow.array(data) - ) - ctx.register_parquet("t", path) - - self.assertEqual(ctx.tables(), {"t"}) - - # count - result = ctx.sql("SELECT COUNT(a) FROM t").collect() - - expected = pyarrow.array([7], pyarrow.uint64()) - expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - self.assertEqual(expected, result) - - # where - expected = pyarrow.array([2], pyarrow.uint64()) - expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - self.assertEqual( - expected, ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() - ) - - # group by - result = ctx.sql( - "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" - ).collect() - - result_keys = result[0].to_pydict()["CAST(a AS Int32)"] - result_values = result[0].to_pydict()["COUNT(a)"] - result_keys, result_values = ( - list(t) for t in zip(*sorted(zip(result_keys, result_values))) - ) - - self.assertEqual(result_keys, [1, 2, 3, 11, 12]) - self.assertEqual(result_values, [2, 2, 1, 1, 1]) - - # order by - result = ctx.sql( - "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" - ).collect() - expected_a = pyarrow.array([50.0219, 50.0152], pyarrow.float64()) - expected_cast = pyarrow.array([50, 50], pyarrow.int32()) - expected = [ - pyarrow.RecordBatch.from_arrays( - [expected_a, expected_cast], ["a", "CAST(a AS Int32)"] - ) - ] - numpy.testing.assert_equal(expected[0].column(1), expected[0].column(1)) - - def test_cast(self): - """ - Verify that we can cast - """ - ctx = datafusion.ExecutionContext() - - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) - ctx.register_parquet("t", path) - - valid_types = [ - "smallint", - "int", - "bigint", - "float(32)", - "float(64)", - "float", - ] - - select = ", ".join( - [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)] - ) - - # can execute, which implies that we can cast - ctx.sql(f"SELECT {select} FROM t").collect() - - def _test_udf(self, udf, args, return_type, array, expected): - ctx = datafusion.ExecutionContext() - - # write to disk - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), array) - ctx.register_parquet("t", path) - - ctx.register_udf("udf", udf, args, return_type) - - batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() - - result = batches[0].column(0) - - self.assertEqual(expected, result) - - def test_udf_identity(self): - self._test_udf( - lambda x: x, - [pyarrow.float64()], - pyarrow.float64(), - pyarrow.array([-1.2, None, 1.2]), - pyarrow.array([-1.2, None, 1.2]), - ) - - def test_udf(self): - self._test_udf( - lambda x: x.is_null(), - [pyarrow.float64()], - pyarrow.bool_(), - pyarrow.array([-1.2, None, 1.2]), - pyarrow.array([False, True, False]), - ) - - -class TestIO(unittest.TestCase): - def setUp(self): - # Create a temporary directory - self.test_dir = tempfile.mkdtemp() - - def tearDown(self): - # Remove the directory after the test - shutil.rmtree(self.test_dir) - - def _test_data(self, data): - ctx = datafusion.ExecutionContext() - - # write to disk - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data) - ctx.register_parquet("t", path) - - batches = ctx.sql("SELECT a AS tt FROM t").collect() - - result = batches[0].column(0) - - numpy.testing.assert_equal(data, result) - - def test_nans(self): - self._test_data(data_with_nans()) - - def test_utf8(self): - array = pyarrow.array( - ["a", "b", "c"], pyarrow.utf8(), numpy.array([False, True, False]) - ) - self._test_data(array) - - def test_large_utf8(self): - array = pyarrow.array( - ["a", "b", "c"], pyarrow.large_utf8(), numpy.array([False, True, False]) - ) - self._test_data(array) - - # Error from Arrow - @unittest.expectedFailure - def test_datetime_s(self): - self._test_data(data_datetime("s")) - - # C data interface missing - @unittest.expectedFailure - def test_datetime_ms(self): - self._test_data(data_datetime("ms")) - - # C data interface missing - @unittest.expectedFailure - def test_datetime_us(self): - self._test_data(data_datetime("us")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_datetime_ns(self): - self._test_data(data_datetime("ns")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_s(self): - self._test_data(data_timedelta("s")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_ms(self): - self._test_data(data_timedelta("ms")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_us(self): - self._test_data(data_timedelta("us")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_ns(self): - self._test_data(data_timedelta("ns")) - - def test_date32(self): - array = pyarrow.array( - [ - datetime.date(2000, 1, 1), - datetime.date(1980, 1, 1), - datetime.date(2030, 1, 1), - ], - pyarrow.date32(), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_binary_variable(self): - array = pyarrow.array( - [b"1", b"2", b"3"], pyarrow.binary(), numpy.array([False, True, False]) - ) - self._test_data(array) - - # C data interface missing - @unittest.expectedFailure - def test_binary_fixed(self): - array = pyarrow.array( - [b"1111", b"2222", b"3333"], - pyarrow.binary(4), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_large_binary(self): - array = pyarrow.array( - [b"1111", b"2222", b"3333"], - pyarrow.large_binary(), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_binary_other(self): - self._test_data(data_binary_other()) - - def test_bool(self): - array = pyarrow.array( - [False, True, True], None, numpy.array([False, True, False]) - ) - self._test_data(array) - - def test_u32(self): - array = pyarrow.array([0, 1, 2], None, numpy.array([False, True, False])) - self._test_data(array) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py deleted file mode 100644 index ffd235e285f80..0000000000000 --- a/python/tests/test_udaf.py +++ /dev/null @@ -1,91 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest - -import pyarrow -import pyarrow.compute -import datafusion - -f = datafusion.functions - - -class Accumulator: - """ - Interface of a user-defined accumulation. - """ - - def __init__(self): - self._sum = pyarrow.scalar(0.0) - - def to_scalars(self) -> [pyarrow.Scalar]: - return [self._sum] - - def update(self, values: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar( - self._sum.as_py() + pyarrow.compute.sum(values).as_py() - ) - - def merge(self, states: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar( - self._sum.as_py() + pyarrow.compute.sum(states).as_py() - ) - - def evaluate(self) -> pyarrow.Scalar: - return self._sum - - -class TestCase(unittest.TestCase): - def _prepare(self): - ctx = datafusion.ExecutionContext() - - # create a RecordBatch and a new DataFrame from it - batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2, 3]), pyarrow.array([4, 4, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) - - def test_aggregate(self): - df = self._prepare() - - udaf = f.udaf( - Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] - ) - - df = df.aggregate([], [udaf(f.col("a"))]) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - self.assertEqual(result.column(0), pyarrow.array([1.0 + 2.0 + 3.0])) - - def test_group_by(self): - df = self._prepare() - - udaf = f.udaf( - Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] - ) - - df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) - - # execute and collect the first (and only) batch - result = df.collect()[0] - - self.assertEqual(result.column(1), pyarrow.array([1.0 + 2.0, 3.0]))