diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 0eb970a69e834..9040b6f807f93 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -18,6 +18,7 @@ import pyarrow as pa import pytest +from datafusion import functions as f from datafusion import DataFrame, ExecutionContext, column, literal, udf @@ -117,6 +118,23 @@ def test_join(): assert table.to_pydict() == expected +def test_window_lead(df): + df = df.select( + column("a"), + f.alias( + f.window( + "lead", [column("b")], order_by=[f.order_by(column("b"))] + ), + "a_next", + ), + ) + + table = pa.Table.from_batches(df.collect()) + + expected = {"a": [1, 2, 3], "a_next": [5, 6, None]} + assert table.to_pydict() == expected + + def test_get_dataframe(tmp_path): ctx = ExecutionContext() diff --git a/python/src/functions.rs b/python/src/functions.rs index a2862202602f1..c0b4e5989012e 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -23,6 +23,7 @@ use datafusion::physical_plan::{ aggregates::AggregateFunction, functions::BuiltinScalarFunction, }; +use crate::errors; use crate::expression::PyExpr; #[pyfunction] @@ -85,6 +86,63 @@ fn concat_ws(sep: String, args: Vec) -> PyResult { Ok(logical_plan::concat_ws(sep, &args).into()) } +/// Creates a new Sort expression +#[pyfunction] +fn order_by( + expr: PyExpr, + asc: Option, + nulls_first: Option, +) -> PyResult { + Ok(PyExpr { + expr: datafusion::logical_plan::Expr::Sort { + expr: Box::new(expr.expr), + asc: asc.unwrap_or(true), + nulls_first: nulls_first.unwrap_or(true), + }, + }) +} + +/// Creates a new Alias expression +#[pyfunction] +fn alias(expr: PyExpr, name: &str) -> PyResult { + Ok(PyExpr { + expr: datafusion::logical_plan::Expr::Alias( + Box::new(expr.expr), + String::from(name), + ), + }) +} + +/// Creates a new Window function expression +#[pyfunction] +fn window( + name: &str, + args: Vec, + partition_by: Option>, + order_by: Option>, +) -> PyResult { + use std::str::FromStr; + let fun = datafusion::physical_plan::window_functions::WindowFunction::from_str(name) + .map_err(|e| -> errors::DataFusionError { e.into() })?; + Ok(PyExpr { + expr: datafusion::logical_plan::Expr::WindowFunction { + fun, + args: args.into_iter().map(|x| x.expr).collect::>(), + partition_by: partition_by + .unwrap_or(vec![]) + .into_iter() + .map(|x| x.expr) + .collect::>(), + order_by: order_by + .unwrap_or(vec![]) + .into_iter() + .map(|x| x.expr) + .collect::>(), + window_frame: None, + }, + }) +} + macro_rules! scalar_function { ($NAME: ident, $FUNC: ident) => { scalar_function!($NAME, $FUNC, stringify!($NAME)); @@ -218,6 +276,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; m.add_wrapped(wrap_pyfunction!(approx_distinct))?; + m.add_wrapped(wrap_pyfunction!(alias))?; m.add_wrapped(wrap_pyfunction!(array))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; @@ -249,6 +308,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(min))?; m.add_wrapped(wrap_pyfunction!(now))?; m.add_wrapped(wrap_pyfunction!(octet_length))?; + m.add_wrapped(wrap_pyfunction!(order_by))?; m.add_wrapped(wrap_pyfunction!(random))?; m.add_wrapped(wrap_pyfunction!(regexp_match))?; m.add_wrapped(wrap_pyfunction!(regexp_replace))?; @@ -278,5 +338,6 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(upper))?; + m.add_wrapped(wrap_pyfunction!(window))?; Ok(()) }