Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 104 additions & 22 deletions datafusion/physical-expr/src/equivalence/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::ops::Deref;
use std::sync::Arc;

use crate::expressions::Column;
use crate::expressions::{CastColumnExpr, Column};
use crate::PhysicalExpr;

use arrow::datatypes::SchemaRef;
Expand Down Expand Up @@ -96,28 +96,56 @@ impl ProjectionMapping {
let mut map = IndexMap::<_, ProjectionTargets>::new();
for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
Some(col) => {
// Sometimes, an expression and its name in the input_schema
// doesn't match. This can cause problems, so we make sure
// that the expression name matches with the name in `input_schema`.
// Conceptually, `source_expr` and `expression` should be the same.
let idx = col.index();
let matching_field = input_schema.field(idx);
let matching_name = matching_field.name();
if col.name() != matching_name {
return internal_err!(
"Input field name {} does not match with the projection expression {}",
matching_name,
col.name()
let source_expr = expr
.transform_down(|e| {
if let Some(col) = e.as_any().downcast_ref::<Column>() {
// Sometimes, an expression and its name in the input_schema
// doesn't match. This can cause problems, so we make sure
// that the expression name matches with the name in `input_schema`.
// Conceptually, `source_expr` and `expression` should be the same.
let idx = col.index();
let matching_field = input_schema.field(idx);
let matching_name = matching_field.name();
if col.name() != matching_name {
return internal_err!(
"Input field name {} does not match with the projection expression {}",
matching_name,
col.name()
);
}
let matching_column = Column::new(matching_name, idx);
Ok(Transformed::yes(Arc::new(matching_column)))
} else if let Some(cast_column) =
e.as_any().downcast_ref::<CastColumnExpr>()
{
let new_input_field = cast_column
.expr()
.as_any()
.downcast_ref::<Column>()
.and_then(|col| {
input_schema
.fields()
.get(col.index())
.map(Arc::clone)
})
.unwrap_or_else(|| Arc::clone(cast_column.input_field()));

if new_input_field.as_ref() == cast_column.input_field().as_ref() {
return Ok(Transformed::no(e));
}

let new_expr = CastColumnExpr::new(
Arc::clone(cast_column.expr()),
new_input_field,
Arc::clone(cast_column.target_field()),
Some(cast_column.cast_options().clone()),
);
Ok(Transformed::yes(Arc::new(new_expr)))
} else {
Ok(Transformed::no(e))
}
let matching_column = Column::new(matching_name, idx);
Ok(Transformed::yes(Arc::new(matching_column)))
}
None => Ok(Transformed::no(e)),
})
.data()?;
})
.data()?;
map.entry(source_expr)
.or_default()
.push((target_expr, expr_idx));
Expand Down Expand Up @@ -253,7 +281,7 @@ mod tests {
use super::*;
use crate::equivalence::tests::output_schema;
use crate::equivalence::{convert_to_orderings, EquivalenceProperties};
use crate::expressions::{col, BinaryExpr};
use crate::expressions::{col, BinaryExpr, CastColumnExpr};
use crate::utils::tests::TestScalarUDF;
use crate::{PhysicalExprRef, ScalarFunctionExpr};

Expand All @@ -278,6 +306,12 @@ mod tests {
let col_d = &col("d", &schema)?;
let col_e = &col("e", &schema)?;
let col_ts = &col("ts", &schema)?;
let cast_column_expr = Arc::new(CastColumnExpr::new(
Arc::clone(col_a),
Arc::new(schema.field(0).clone()),
Arc::new(Field::new("a_cast", DataType::Int64, true)),
None,
)) as Arc<dyn PhysicalExpr>;
let a_plus_b = Arc::new(BinaryExpr::new(
Arc::clone(col_a),
Operator::Plus,
Expand Down Expand Up @@ -713,6 +747,26 @@ mod tests {
vec![("c_new", option_asc), ("b_new", option_desc)],
],
),
// ---------- TEST CASE 6 ------------
(
// orderings
vec![
// [a ASC]
vec![(col_a, option_asc)],
],
// projection exprs
vec![
(col_a, "a_new".to_string()),
(&cast_column_expr, "a_cast".to_string()),
],
// expected
vec![
// [a_new ASC]
vec![("a_new", option_asc)],
// [a_cast ASC]
vec![("a_cast", option_asc)],
],
),
];

for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
Expand Down Expand Up @@ -757,6 +811,34 @@ mod tests {
Ok(())
}

#[test]
fn projection_mapping_normalizes_cast_column_input_field() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
let column = col("a", &schema)?;
let mismatched_input = Arc::new(Field::new("alias_a", DataType::Int32, true));
let target_field = Arc::new(Field::new("a", DataType::Int32, true));
let cast_column = Arc::new(CastColumnExpr::new(
Arc::clone(&column),
mismatched_input,
target_field,
None,
)) as Arc<dyn PhysicalExpr>;

let mapping = ProjectionMapping::try_new(
vec![(cast_column, "a_alias".to_string())],
&schema,
)?;

let (source_expr, _) = mapping.iter().next().unwrap();
let cast_column = source_expr
.as_any()
.downcast_ref::<CastColumnExpr>()
.unwrap();
assert_eq!(cast_column.input_field().name(), "a");
assert_eq!(cast_column.input_field().data_type(), &DataType::Int32);
Ok(())
}

#[test]
fn project_orderings2() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down
13 changes: 12 additions & 1 deletion datafusion/physical-expr/src/equivalence/properties/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use self::dependency::{
use crate::equivalence::{
AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping,
};
use crate::expressions::{with_new_schema, CastExpr, Column, Literal};
use crate::expressions::{with_new_schema, CastColumnExpr, CastExpr, Column, Literal};
use crate::{
ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr,
PhysicalSortRequirement,
Expand Down Expand Up @@ -850,6 +850,17 @@ impl EquivalenceProperties {
sort_expr.options,
));
}
} else if let Some(cast_column) =
r_expr.as_any().downcast_ref::<CastColumnExpr>()
{
if cast_column.expr().eq(&sort_expr.expr)
&& cast_column.is_widening_cast(&expr_type)
{
result.push(PhysicalSortExpr::new(
r_expr,
sort_expr.options,
));
}
}
}
result.push(sort_expr);
Expand Down
31 changes: 31 additions & 0 deletions datafusion/physical-expr/src/expressions/cast_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,37 @@ impl CastColumnExpr {
pub fn target_field(&self) -> &FieldRef {
&self.target_field
}

/// Options forwarded to [`cast_column`].
pub fn cast_options(&self) -> &CastOptions<'static> {
&self.cast_options
}

/// Returns `true` if this cast widens the source type and therefore
/// preserves ordering semantics.
pub fn is_widening_cast(&self, src: &DataType) -> bool {
if self.target_field.data_type() == src {
return true;
}

use DataType::*;

matches!(
(src, self.target_field.data_type()),
(Int8, Int16 | Int32 | Int64)
| (Int16, Int32 | Int64)
| (Int32, Int64)
| (UInt8, UInt16 | UInt32 | UInt64)
| (UInt16, UInt32 | UInt64)
| (UInt32, UInt64)
| (
Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
Float32 | Float64
)
| (Int64 | UInt64, Float64)
| (Utf8, LargeUtf8)
)
}
}

impl Display for CastColumnExpr {
Expand Down
28 changes: 27 additions & 1 deletion datafusion/physical-expr/src/intervals/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use std::sync::Arc;

use crate::{
expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr},
expressions::{BinaryExpr, CastColumnExpr, CastExpr, Column, Literal, NegativeExpr},
PhysicalExpr,
};

Expand Down Expand Up @@ -55,6 +55,8 @@ pub fn check_support(expr: &Arc<dyn PhysicalExpr>, schema: &SchemaRef) -> bool {
}
} else if let Some(cast) = expr_any.downcast_ref::<CastExpr>() {
check_support(cast.expr(), schema)
} else if let Some(cast_column) = expr_any.downcast_ref::<CastColumnExpr>() {
check_support(cast_column.expr(), schema)
} else if let Some(negative) = expr_any.downcast_ref::<NegativeExpr>() {
check_support(negative.arg(), schema)
} else {
Expand Down Expand Up @@ -191,3 +193,27 @@ fn interval_dt_to_duration_ms(dt: &IntervalDayTime) -> Result<i64> {
)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{col, CastColumnExpr};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;

#[test]
fn supports_cast_column_expr() {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
let column = col("a", &schema).unwrap();
let input_field = Arc::new(schema.field(0).clone());
let target_field = Arc::new(Field::new("a_cast", DataType::Int64, true));
let expr = Arc::new(CastColumnExpr::new(
Arc::clone(&column),
input_field,
target_field,
None,
)) as Arc<dyn PhysicalExpr>;

assert!(check_support(&expr, &schema));
}
}