diff --git a/python/Cargo.toml b/python/Cargo.toml index 8dba538ae0c79..8c81a531c20ff 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [package] -name = "datafusion" +name = "datafusion-python" version = "0.3.0" homepage = "https://github.com/apache/arrow" repository = "https://github.com/apache/arrow" @@ -31,7 +31,11 @@ libc = "0.2" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" pyo3 = { version = "0.14.1", features = ["extension-module", "abi3", "abi3-py36"] } -datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" } +datafusion = { path = "../datafusion", version = "5.1.0" } +# workaround for a bug introduced in +# https://github.com/dtolnay/proc-macro2/pull/286 +# TODO: remove this version pin after upstream releases a fix +proc-macro2 = { version = "=1.0.28" } [lib] name = "datafusion" diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 8e5657ba2f8a3..0885ae367a8e5 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -161,9 +161,13 @@ impl DataFrame { Ok(pretty::print_batches(&batches).unwrap()) } - /// Returns the join of two DataFrames `on`. - fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult { + fn join( + &self, + right: &DataFrame, + join_keys: (Vec<&str>, Vec<&str>), + how: &str, + ) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); let join_type = match how { @@ -182,7 +186,7 @@ impl DataFrame { } }; - let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?; + let builder = errors::wrap(builder.join(&right.plan, join_type, join_keys))?; let plan = errors::wrap(builder.build())?; diff --git a/python/tests/test_df.py b/python/tests/test_df.py index 5b6cbddbd74ba..14ab5ffb8b163 100644 --- a/python/tests/test_df.py +++ b/python/tests/test_df.py @@ -104,7 +104,7 @@ def test_join(): ) df1 = ctx.create_dataframe([[batch]]) - df = df.join(df1, on="a", how="inner") + df = df.join(df1, join_keys=(["a"], ["a"]), how="inner") df = df.sort([f.col("a").sort(ascending=True)]) table = pa.Table.from_batches(df.collect()) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 669f640529eb5..beac578e99922 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -69,7 +69,7 @@ def test_register_csv(ctx, tmp_path): for table in ["csv", "csv1", "csv2"]: result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect() result = pa.Table.from_batches(result) - assert result.to_pydict() == {"COUNT(int)": [4]} + assert result.to_pydict() == {f"COUNT({table}.int)": [4]} result = ctx.sql("SELECT * FROM csv3").collect() result = pa.Table.from_batches(result) @@ -88,7 +88,7 @@ def test_register_parquet(ctx, tmp_path): result = ctx.sql("SELECT COUNT(a) FROM t").collect() result = pa.Table.from_batches(result) - assert result.to_pydict() == {"COUNT(a)": [100]} + assert result.to_pydict() == {"COUNT(t.a)": [100]} def test_execute(ctx, tmp_path): @@ -123,8 +123,8 @@ def test_execute(ctx, tmp_path): result_values = [] for result in results: pydict = result.to_pydict() - result_keys.extend(pydict["CAST(a AS Int32)"]) - result_values.extend(pydict["COUNT(a)"]) + result_keys.extend(pydict["CAST(t.a AS Int32)"]) + result_values.extend(pydict["COUNT(t.a)"]) result_keys, result_values = ( list(t) for t in zip(*sorted(zip(result_keys, result_values))) @@ -141,7 +141,7 @@ def test_execute(ctx, tmp_path): expected_cast = pa.array([50, 50], pa.int32()) expected = [ pa.RecordBatch.from_arrays( - [expected_a, expected_cast], ["a", "CAST(a AS Int32)"] + [expected_a, expected_cast], ["a", "CAST(t.a AS Int32)"] ) ] np.testing.assert_equal(expected[0].column(1), expected[0].column(1))