Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 209 additions & 26 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,169 @@ use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::prelude::JoinType;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use datafusion::test_util::parquet_test_data;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion_common::ScalarValue;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::Wildcard;
use datafusion_expr::{
avg, col, count, exists, expr, in_subquery, lit, max, scalar_subquery, sum,
AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunction,
};

#[tokio::test]
async fn count_wildcard() -> Result<()> {
let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();
async fn test_count_wildcard_on_sort() -> Result<()> {
let ctx = create_join_context()?;

ctx.register_parquet(
"alltypes_tiny_pages",
&format!("{testdata}/alltypes_tiny_pages.parquet"),
ParquetReadOptions::default(),
)
.await?;
let sql_results = ctx
.sql("select b,count(*) from t1 group by b order by count(*)")
.await?
.explain(false, false)?
.collect()
.await?;

let df_results = ctx
.table("t1")
.await?
.aggregate(vec![col("b")], vec![count(Wildcard)])?
.sort(vec![count(Wildcard).sort(true, false)])?
.explain(false, false)?
.collect()
.await?;
//make sure sql plan same with df plan
assert_eq!(
pretty_format_batches(&sql_results)?.to_string(),
pretty_format_batches(&df_results)?.to_string()
);
Ok(())
}

#[tokio::test]
async fn test_count_wildcard_on_where_in() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
.sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
.await?
.explain(false, false)?
.collect()
.await?;

// In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
// https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
// for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
let ctx = create_join_context()?;
let df_results = ctx
.table("t1")
.await?
.filter(in_subquery(
col("a"),
Arc::new(
ctx.table("t2")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.select(vec![count(Expr::Wildcard)])?
.into_unoptimized_plan(),
// Usually, into_optimized_plan() should be used here, but due to
// https://github.com/apache/arrow-datafusion/issues/5771,
// subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here.
),
))?
.select(vec![col("a"), col("b")])?
.explain(false, false)?
.collect()
.await?;

// make sure sql plan same with df plan
assert_eq!(
pretty_format_batches(&sql_results)?.to_string(),
pretty_format_batches(&df_results)?.to_string()
);

Ok(())
}

#[tokio::test]
async fn test_count_wildcard_on_where_exist() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
.sql("select count(*) from alltypes_tiny_pages")
.sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
.await?
.explain(false, false)?
.collect()
.await?;
let df_results = ctx
.table("t1")
.await?
.filter(exists(Arc::new(
ctx.table("t2")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.select(vec![count(Expr::Wildcard)])?
.into_unoptimized_plan(),
// Usually, into_optimized_plan() should be used here, but due to
// https://github.com/apache/arrow-datafusion/issues/5771,
// subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here.
)))?
.select(vec![col("a"), col("b")])?
.explain(false, false)?
.collect()
.await?;

//make sure sql plan same with df plan
assert_eq!(
pretty_format_batches(&sql_results)?.to_string(),
pretty_format_batches(&df_results)?.to_string()
);

Ok(())
}

#[tokio::test]
async fn test_count_wildcard_on_window() -> Result<()> {
let ctx = create_join_context()?;

let sql_results = ctx
.sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1")
.await?
.explain(false, false)?
.collect()
.await?;
let df_results = ctx
.table("t1")
.await?
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Count),
vec![Expr::Wildcard],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
WindowFrame {
units: WindowFrameUnits::Range,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
},
))])?
.explain(false, false)?
.collect()
.await?;

//make sure sql plan same with df plan
assert_eq!(
pretty_format_batches(&df_results)?.to_string(),
pretty_format_batches(&sql_results)?.to_string()
);

Ok(())
}

#[tokio::test]
async fn test_count_wildcard_on_aggregate() -> Result<()> {
let ctx = create_join_context()?;
register_alltypes_tiny_pages_parquet(&ctx).await?;

let sql_results = ctx
.sql("select count(*) from t1")
.await?
.select(vec![count(Expr::Wildcard)])?
.explain(false, false)?
Expand All @@ -58,7 +203,7 @@ async fn count_wildcard() -> Result<()> {

// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
let df_results = ctx
.table("alltypes_tiny_pages")
.table("t1")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.select(vec![count(Expr::Wildcard)])?
Expand All @@ -72,24 +217,51 @@ async fn count_wildcard() -> Result<()> {
pretty_format_batches(&df_results)?.to_string()
);

let results = ctx
.table("alltypes_tiny_pages")
Ok(())
}
#[tokio::test]
async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
let ctx = create_join_context()?;

let sql_results = ctx
.sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.explain(false, false)?
.collect()
.await?;

let expected = vec![
"+-----------------+",
"| COUNT(UInt8(1)) |",
"+-----------------+",
"| 7300 |",
"+-----------------+",
];
assert_batches_sorted_eq!(expected, &results);
// In the same SessionContext, AliasGenerator will increase subquery_alias id by 1
// https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
// for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here
let ctx = create_join_context()?;
let df_results = ctx
.table("t1")
.await?
.filter(
scalar_subquery(Arc::new(
ctx.table("t2")
.await?
.filter(col("t1.a").eq(col("t2.a")))?
.aggregate(vec![], vec![count(lit(COUNT_STAR_EXPANSION))])?
.select(vec![count(lit(COUNT_STAR_EXPANSION))])?
.into_unoptimized_plan(),
))
.gt(lit(ScalarValue::UInt8(Some(0)))),
)?
.select(vec![col("t1.a"), col("t1.b")])?
.explain(false, false)?
.collect()
.await?;

//make sure sql plan same with df plan
assert_eq!(
pretty_format_batches(&sql_results)?.to_string(),
pretty_format_batches(&df_results)?.to_string()
);

Ok(())
}

#[tokio::test]
async fn describe() -> Result<()> {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -229,7 +401,7 @@ async fn sort_on_unprojected_columns() -> Result<()> {
let results = df.collect().await.unwrap();

#[rustfmt::skip]
let expected = vec![
let expected = vec![
"+-----+",
"| a |",
"+-----+",
Expand Down Expand Up @@ -275,7 +447,7 @@ async fn sort_on_distinct_columns() -> Result<()> {
let results = df.collect().await.unwrap();

#[rustfmt::skip]
let expected = vec![
let expected = vec![
"+-----+",
"| a |",
"+-----+",
Expand Down Expand Up @@ -417,7 +589,7 @@ async fn filter_with_alias_overwrite() -> Result<()> {
let results = df.collect().await.unwrap();

#[rustfmt::skip]
let expected = vec![
let expected = vec![
"+------+",
"| a |",
"+------+",
Expand Down Expand Up @@ -1047,3 +1219,14 @@ async fn table_with_nested_types(n: usize) -> Result<DataFrame> {
ctx.register_batch("shapes", batch)?;
ctx.table("shapes").await
}

pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> {
let testdata = parquet_test_data();
ctx.register_parquet(
"alltypes_tiny_pages",
&format!("{testdata}/alltypes_tiny_pages.parquet"),
ParquetReadOptions::default(),
)
.await?;
Ok(())
}
Loading