Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 60 additions & 1 deletion datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use arrow::{
},
record_batch::RecordBatch,
};
use arrow_array::TimestampNanosecondArray;
use arrow_schema::TimeUnit;
use datafusion::from_slice::FromSlice;
use std::sync::Arc;

Expand Down Expand Up @@ -153,7 +155,27 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> {

#[tokio::test]
async fn test_count_wildcard_on_window() -> Result<()> {
let ctx = create_join_context()?;
let ctx = SessionContext::new();
let t1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false),
]));

// define data.
let batch1 = RecordBatch::try_new(
t1,
vec![
Arc::new(UInt32Array::from_slice([1, 10, 11, 100])),
Arc::new(TimestampNanosecondArray::from_slice([
1664264591000000000,
1664264592000000000,
1664264592000000000,
1664264593000000000,
])),
],
)?;

ctx.register_batch("t1", batch1)?;

let sql_results = ctx
.sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1")
Expand Down Expand Up @@ -185,6 +207,43 @@ async fn test_count_wildcard_on_window() -> Result<()> {
pretty_format_batches(&sql_results)?.to_string()
);

#[allow(unused_doc_comments)]
///special timestamp scenarios, DataFrame cannot achieve the same logical plan as SQL, it needs to be tested separately.
///```sql
/// COUNT(*) OVER (ORDER BY ts DESC RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING)
/// ```
/// logic_plan
/// ```text
/// COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 18446744073709551616 PRECEDING AND 36893488147419103232 FOLLOWING
/// AS
/// COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING
/// ```
let sql_results = ctx
.sql("select \
COUNT(*) OVER (ORDER BY ts DESC RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) as cnt2 \
from t1")
.await?
.explain(false, false)?
.collect()
.await?;

#[rustfmt::skip]
let expected = vec![
"+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
"| plan_type | plan |",
"+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
"| logical_plan | Projection: COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING AS cnt2 |",
"| | WindowAggr: windowExpr=[[COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 18446744073709551616 PRECEDING AND 36893488147419103232 FOLLOWING AS COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING]] |",
"| | TableScan: t1 projection=[a, ts] |",
"| physical_plan | ProjectionExec: expr=[COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING@2 as cnt2] |",
"| | BoundedWindowAggExec: wdw=[COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING: Ok(Field { name: \"COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(IntervalMonthDayNano(\"18446744073709551616\")), end_bound: Following(IntervalMonthDayNano(\"36893488147419103232\")) }] |",
"| | SortExec: expr=[ts@1 DESC] |",
"| | MemoryExec: partitions=1, partition_sizes=[1] |",
"| | |",
"+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &sql_results);

Ok(())
}

Expand Down
9 changes: 0 additions & 9 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1131,15 +1131,6 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB
/// Execute query and return results as a Vec of RecordBatches
async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
let df = ctx.sql(sql).await.unwrap();

// We are not really interested in the direct output of optimized_logical_plan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

// since the physical plan construction already optimizes the given logical plan
// and we want to avoid double-optimization as a consequence. So we just construct
// it here to make sure that it doesn't fail at this step and get the optimized
// schema (to assert later that the logical and optimized schemas are the same).
let optimized = df.clone().into_optimized_plan().unwrap();
assert_eq!(df.logical_plan().schema(), optimized.schema());

df.collect().await.unwrap()
}

Expand Down
26 changes: 17 additions & 9 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::{Exists, InSubquery, ScalarSubquery};
use datafusion_expr::Expr::{Alias, Exists, InSubquery, ScalarSubquery};
use datafusion_expr::{
aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter,
LogicalPlan, Projection, Sort, Subquery, Window,
Expand Down Expand Up @@ -230,6 +230,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
negated,
}
}
Alias(expr, name) => Alias(expr, replace_count_star(name)),
_ => old_expr,
};
Ok(new_expr)
Expand All @@ -240,29 +241,35 @@ fn rewrite_schema(schema: &DFSchema) -> DFSchemaRef {
.fields()
.iter()
.map(|field| {
let mut name = field.field().name().clone();
if name.contains(COUNT_STAR) {
name = name.replace(
COUNT_STAR,
count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
);
}
DFField::new(
field.qualifier().cloned(),
&name,
&replace_count_star(field.field().name().clone()),
field.data_type().clone(),
field.is_nullable(),
)
})
.collect::<Vec<DFField>>();

DFSchemaRef::new(
DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(),
)
}

fn replace_count_star(name: String) -> String {
if name.contains(COUNT_STAR) {
name.replace(
COUNT_STAR,
count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
)
} else {
name
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::analyzer::Analyzer;
use crate::test::*;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::expr::Sort;
Expand Down Expand Up @@ -376,6 +383,7 @@ mod tests {
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(&plan, expected)
}

#[test]
fn test_count_wildcard_on_window() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ arrow = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { path = "../common", version = "22.0.0" }
datafusion-expr = { path = "../expr", version = "22.0.0" }
datafusion-optimizer = { path = "../optimizer", version = "22.0.0" }
log = "^0.4"
sqlparser = "0.33"

Expand Down
8 changes: 1 addition & 7 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{DFSchema, DataFusionError, Result};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::window_frame::regularize;
use datafusion_expr::{
expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame,
Expand Down Expand Up @@ -215,12 +214,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// Special case rewrite COUNT(*) to COUNT(constant)
AggregateFunction::Count => args
.into_iter()
.map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
}
_ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context),
})
.map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context))
.collect::<Result<Vec<Expr>>>()?,
_ => self.function_args_to_expr(args, schema, planner_context)?,
};
Expand Down
13 changes: 11 additions & 2 deletions datafusion/sql/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ use datafusion_sql::{
planner::{ContextProvider, ParserOptions, SqlToRel},
};

use datafusion_optimizer::analyzer::Analyzer;
use datafusion_optimizer::{OptimizerConfig, OptimizerContext};
use rstest::rstest;

#[cfg(test)]
Expand Down Expand Up @@ -865,7 +867,7 @@ fn select_aggregate_with_having_referencing_column_not_in_select() {
assert_eq!(
"Plan(\"HAVING clause references non-aggregate values: \
Expression person.first_name could not be resolved from available columns: \
COUNT(UInt8(1))\")",
COUNT(*)\")",
format!("{err:?}")
);
}
Expand Down Expand Up @@ -2530,7 +2532,14 @@ fn logical_plan_with_dialect_and_options(
let planner = SqlToRel::new_with_options(&context, options);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
let mut ast = result?;
planner.statement_to_plan(ast.pop_front().unwrap())
match planner.statement_to_plan(ast.pop_front().unwrap()) {
Ok(plan) => Analyzer::new().execute_and_check(
&plan,
OptimizerContext::new().options(),
|_, _| {},
),
Err(err) => Err(err),
}
}

/// Create logical plan, write with formatter, compare to expected output
Expand Down