Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions .github/workflows/python_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,21 @@ jobs:
- uses: actions/setup-python@v2
with:
python-version: "3.9"
- name: Install Python dependencies
run: python -m pip install --upgrade pip setuptools wheel
- name: Run tests
- name: Create Virtualenv
run: |
cd python/

python -m venv venv
source venv/bin/activate

pip install -r requirements.txt
pip install -r python/requirements.txt
- name: Run Linters
run: |
source venv/bin/activate
flake8 python
black --line-length 79 --check python
- name: Run tests
run: |
source venv/bin/activate
cd python
maturin develop

pytest -v .
env:
CARGO_HOME: "/home/runner/.cargo"
Expand Down
10 changes: 6 additions & 4 deletions python/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
maturin
toml
pyarrow
pytest

black
flake8
isort
maturin
mypy
numpy
pandas
pyarrow
pytest
toml
285 changes: 148 additions & 137 deletions python/requirements.txt

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions python/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
// specific language governing permissions and limitations
// under the License.

use std::path::PathBuf;
use std::{collections::HashSet, sync::Arc};

use rand::distributions::Alphanumeric;
use rand::Rng;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use datafusion::execution::context::ExecutionContext as _ExecutionContext;
use datafusion::prelude::CsvReadOptions;

use crate::dataframe;
use crate::errors;
Expand Down Expand Up @@ -97,6 +100,48 @@ impl ExecutionContext {
Ok(())
}

#[args(
schema = "None",
has_header = "true",
delimiter = "\",\"",
schema_infer_max_records = "1000",
file_extension = "\".csv\""
)]
fn register_csv(
&mut self,
name: &str,
path: PathBuf,
schema: Option<&PyAny>,
has_header: bool,
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
) -> PyResult<()> {
let path = path
.to_str()
.ok_or(PyValueError::new_err("Unable to convert path to a string"))?;
let schema = match schema {
Some(s) => Some(to_rust::to_rust_schema(s)?),
None => None,
};
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
"Delimiter must be a single character",
));
}

let mut options = CsvReadOptions::new()
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension);
options.schema = schema.as_ref();

errors::wrap(self.ctx.register_csv(name, path, options))?;
Ok(())
}

fn register_udf(
&mut self,
name: &str,
Expand Down
9 changes: 9 additions & 0 deletions python/src/to_rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::convert::TryFrom;
use std::sync::Arc;

use datafusion::arrow::{
Expand Down Expand Up @@ -111,3 +112,11 @@ pub fn to_rust_scalar(ob: &PyAny) -> PyResult<ScalarValue> {
}
})
}

pub fn to_rust_schema(ob: &PyAny) -> PyResult<Schema> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from https://github.com/apache/arrow-rs/blob/master/arrow-pyarrow-integration-testing/src/lib.rs#L136

Eventually we could add an optional module to arrow-rs where we implement the PyO3 conversion traits for arrow-rs <-> pyarrow interoperability for easier downstream integration.

let c_schema = ffi::FFI_ArrowSchema::empty();
let c_schema_ptr = &c_schema as *const ffi::FFI_ArrowSchema;
ob.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
let schema = Schema::try_from(&c_schema).map_err(errors::DataFusionError::from)?;
Ok(schema)
}
13 changes: 10 additions & 3 deletions python/tests/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import pyarrow as pa
import pyarrow.csv
import pyarrow.parquet as pq

# used to write parquet files
Expand Down Expand Up @@ -49,7 +50,9 @@ def data_datetime(f):
datetime.datetime.now() - datetime.timedelta(days=1),
datetime.datetime.now() + datetime.timedelta(days=1),
]
return pa.array(data, type=pa.timestamp(f), mask=np.array([False, True, False]))
return pa.array(
data, type=pa.timestamp(f), mask=np.array([False, True, False])
)


def data_date32():
Expand All @@ -58,7 +61,9 @@ def data_date32():
datetime.date(1980, 1, 1),
datetime.date(2030, 1, 1),
]
return pa.array(data, type=pa.date32(), mask=np.array([False, True, False]))
return pa.array(
data, type=pa.date32(), mask=np.array([False, True, False])
)


def data_timedelta(f):
Expand All @@ -67,7 +72,9 @@ def data_timedelta(f):
datetime.timedelta(days=1),
datetime.timedelta(seconds=1),
]
return pa.array(data, type=pa.duration(f), mask=np.array([False, True, False]))
return pa.array(
data, type=pa.duration(f), mask=np.array([False, True, False])
)


def data_binary_other():
Expand Down
16 changes: 12 additions & 4 deletions python/tests/test_math_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
def df():
ctx = ExecutionContext()
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays([pa.array([0.1, -0.7, 0.55])], names=["value"])
batch = pa.RecordBatch.from_arrays(
[pa.array([0.1, -0.7, 0.55])], names=["value"]
)
return ctx.create_dataframe([[batch]])


Expand Down Expand Up @@ -56,7 +58,13 @@ def test_math_functions(df):
np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values))
np.testing.assert_array_almost_equal(result.column(5), np.arccos(values))
np.testing.assert_array_almost_equal(result.column(6), np.exp(values))
np.testing.assert_array_almost_equal(result.column(7), np.log(values + 1.0))
np.testing.assert_array_almost_equal(result.column(8), np.log2(values + 1.0))
np.testing.assert_array_almost_equal(result.column(9), np.log10(values + 1.0))
np.testing.assert_array_almost_equal(
result.column(7), np.log(values + 1.0)
)
np.testing.assert_array_almost_equal(
result.column(8), np.log2(values + 1.0)
)
np.testing.assert_array_almost_equal(
result.column(9), np.log10(values + 1.0)
)
np.testing.assert_array_less(result.column(10), np.ones_like(values))
5 changes: 2 additions & 3 deletions python/tests/test_pa_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@


def test_type_ids():
"""having this fixed is very important because internally we rely on this id to parse from
python"""
# Having this fixed is very important because internally we rely on this id
# to parse from python
for idx, arrow_type in [
(0, pa.null()),
(1, pa.bool_()),
Expand All @@ -47,5 +47,4 @@ def test_type_ids():
(34, pa.large_utf8()),
(35, pa.large_binary()),
]:

assert idx == arrow_type.id
65 changes: 60 additions & 5 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np
import pyarrow as pa
import pytest
from datafusion import ExecutionContext

from datafusion import ExecutionContext
from . import generic as helpers


Expand All @@ -33,12 +33,63 @@ def test_no_table(ctx):
ctx.sql("SELECT a FROM b").collect()


def test_register(ctx, tmp_path):
def test_register_csv(ctx, tmp_path):
path = tmp_path / "test.csv"

table = pa.Table.from_arrays(
[
[1, 2, 3, 4],
["a", "b", "c", "d"],
[1.1, 2.2, 3.3, 4.4],
],
names=["int", "str", "float"],
)
pa.csv.write_csv(table, path)

ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
"csv2",
path,
has_header=True,
delimiter=",",
schema_infer_max_records=10,
)
alternative_schema = pa.schema(
[
("some_int", pa.int16()),
("some_bytes", pa.string()),
("some_floats", pa.float32()),
]
)
ctx.register_csv("csv3", path, schema=alternative_schema)

assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"}

for table in ["csv", "csv1", "csv2"]:
result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(int)": [4]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


result = ctx.sql("SELECT * FROM csv3").collect()
result = pa.Table.from_batches(result)
assert result.schema == alternative_schema

with pytest.raises(
ValueError, match="Delimiter must be a single character"
):
ctx.register_csv("csv4", path, delimiter="wrong")


def test_register_parquet(ctx, tmp_path):
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
ctx.register_parquet("t", path)

assert ctx.tables() == {"t"}

result = ctx.sql("SELECT COUNT(a) FROM t").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(a)": [100]}


def test_execute(ctx, tmp_path):
data = [1, 1, 2, 2, 3, 11, 12]
Expand Down Expand Up @@ -112,7 +163,9 @@ def test_cast(ctx, tmp_path):
"float",
]

select = ", ".join([f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)])
select = ", ".join(
[f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)]
)

# can execute, which implies that we can cast
ctx.sql(f"SELECT {select} FROM t").collect()
Expand Down Expand Up @@ -141,7 +194,9 @@ def test_udf(
ctx, tmp_path, fn, input_types, output_type, input_values, expected_values
):
# write to disk
path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(input_values))
path = helpers.write_parquet(
tmp_path / "a.parquet", pa.array(input_values)
)
ctx.register_parquet("t", path)
ctx.register_udf("udf", fn, input_types, output_type)

Expand Down