Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 159 additions & 20 deletions rust/datafusion/src/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,40 @@ impl fmt::Display for BinaryExpr {
}
}

// the type that both lhs and rhs can be casted to for the purpose of a string computation
/// Coercion rules for dictionary values (aka the type of the dictionary itself)
fn dictionary_value_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
numerical_coercion(lhs_type, rhs_type).or_else(|| string_coercion(lhs_type, rhs_type))
}

/// Coercion rules for Dictionaries: the type that both lhs and rhs
/// can be casted to for the purpose of a computation.
///
/// It would likely be preferable to cast primitive values to
/// dictionaries, and thus avoid unpacking dictionary as well as doing
/// faster comparisons. However, the arrow compute kernels (e.g. eq)
/// don't have DictionaryArray support yet, so fall back to unpacking
/// the dictionaries
fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
match (lhs_type, rhs_type) {
(
DataType::Dictionary(_lhs_index_type, lhs_value_type),
DataType::Dictionary(_rhs_index_type, rhs_value_type),
) => dictionary_value_coercion(lhs_value_type, rhs_value_type),
(DataType::Dictionary(_index_type, value_type), _) => {
dictionary_value_coercion(value_type, rhs_type)
}
(_, DataType::Dictionary(_index_type, value_type)) => {
dictionary_value_coercion(lhs_type, value_type)
}
_ => None,
}
}

/// Coercion rules for Strings: the type that both lhs and rhs can be
/// casted to for the purpose of a string computation
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
Expand All @@ -1092,7 +1125,9 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
}
}

/// coercion rule for numerical types
/// Coercion rule for numerical types: The type that both lhs and rhs
/// can be casted to for numerical calculation, while maintaining
/// maximum precision
pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;

Expand Down Expand Up @@ -1150,6 +1185,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
return Some(lhs_type.clone());
}
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
}

// coercion rules that assume an ordered set, such as "less than".
Expand All @@ -1160,16 +1196,13 @@ fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
return Some(lhs_type.clone());
}

match numerical_coercion(lhs_type, rhs_type) {
None => {
// strings are naturally ordered, and thus ordering can be applied to them.
string_coercion(lhs_type, rhs_type)
}
t => t,
}
numerical_coercion(lhs_type, rhs_type)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just a refactor, it is not meant to change the semantics

.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
}

/// coercion rules for all binary operators
/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
fn common_binary_type(
lhs_type: &DataType,
op: &Operator,
Expand Down Expand Up @@ -1526,8 +1559,8 @@ impl PhysicalExpr for CastExpr {
}
}

/// Returns a physical cast operation that casts `expr` to `cast_type`
/// if casting is needed.
/// Return a PhysicalExpression representing `expr` casted to
/// `cast_type`, if any casting is needed.
///
/// Note that such casts may lose type information
pub fn cast(
Expand Down Expand Up @@ -1665,11 +1698,14 @@ impl PhysicalSortExpr {
mod tests {
use super::*;
use crate::error::Result;
use arrow::array::{
LargeStringArray, PrimitiveArray, PrimitiveArrayOps, StringArray,
Time64NanosecondArray,
};
use arrow::datatypes::*;
use arrow::{
array::{
LargeStringArray, PrimitiveArray, PrimitiveArrayOps, PrimitiveBuilder,
StringArray, StringDictionaryBuilder, Time64NanosecondArray,
},
util::display::array_value_to_string,
};

// Create a binary expression without coercion. Used here when we do not want to coerce the expressions
// to valid types. Usage can result in an execution (after plan) error.
Expand Down Expand Up @@ -1772,11 +1808,13 @@ mod tests {

// runs an end-to-end test of physical type coercion:
// 1. construct a record batch with two columns of type A and B
// (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
// 2. construct a physical expression of A OP B
// 3. evaluate the expression
// 4. verify that the resulting expression is of type C
// 5. verify that the results of evaluation are $VEC
macro_rules! test_coercion {
($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the parameter names here to make this macro more consistent with its documentation

($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr) => {{
let schema = Schema::new(vec![
Field::new("a", $A_TYPE, false),
Field::new("b", $B_TYPE, false),
Expand All @@ -1792,18 +1830,18 @@ mod tests {
let expression = binary(col("a"), $OP, col("b"), &schema)?;

// verify that the expression's type is correct
assert_eq!(expression.data_type(&schema)?, $TYPE);
assert_eq!(expression.data_type(&schema)?, $C_TYPE);

// compute
let result = expression.evaluate(&batch)?;

// verify that the array's data_type is correct
assert_eq!(*result.data_type(), $TYPE);
assert_eq!(*result.data_type(), $C_TYPE);

// verify that the data itself is downcastable
let result = result
.as_any()
.downcast_ref::<$TYPEARRAY>()
.downcast_ref::<$C_ARRAY>()
.expect("failed to downcast");
// verify that the result itself is correct
for (i, x) in $VEC.iter().enumerate() {
Expand Down Expand Up @@ -1877,6 +1915,107 @@ mod tests {
Ok(())
}

#[test]
fn test_dictionary_type_coersion() -> Result<()> {
use DataType::*;

// TODO: In the future, this would ideally return Dictionary types and avoid unpacking
let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32));

let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None);

let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Utf8;
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));

let lhs_type = Utf8;
let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));

Ok(())
}

// Note it would be nice to use the same test_coercion macro as
// above, but sadly the type of the values of the dictionary are
// not encoded in the rust type of the DictionaryArray. Thus there
// is no way at the time of this writing to create a dictionary
// array using the `From` trait
#[test]
fn test_dictionary_type_to_array_coersion() -> Result<()> {
// Test string a string dictionary
let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let string_type = DataType::Utf8;

// build dictionary
let keys_builder = PrimitiveBuilder::<Int32Type>::new(10);
let values_builder = StringBuilder::new(10);
let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder);

dict_builder.append("one")?;
dict_builder.append_null()?;
dict_builder.append("three")?;
dict_builder.append("four")?;
let dict_array = dict_builder.finish();

let str_array =
StringArray::from(vec![Some("not one"), Some("two"), None, Some("four")]);

let schema = Arc::new(Schema::new(vec![
Field::new("dict", dict_type.clone(), true),
Field::new("str", string_type.clone(), true),
]));

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(dict_array), Arc::new(str_array)],
)?;

let expected = "false\n\n\ntrue";

// Test 1: dict = str

// verify that we can construct the expression
let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?;
assert_eq!(expression.data_type(&schema)?, DataType::Boolean);

// evaluate and verify the result type matched
let result = expression.evaluate(&batch)?;
assert_eq!(result.data_type(), &DataType::Boolean);

// verify that the result itself is correct
assert_eq!(expected, array_to_string(&result)?);

// Test 2: now test the other direction
// str = dict

// verify that we can construct the expression
let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?;
assert_eq!(expression.data_type(&schema)?, DataType::Boolean);

// evaluate and verify the result type matched
let result = expression.evaluate(&batch)?;
assert_eq!(result.data_type(), &DataType::Boolean);

// verify that the result itself is correct
assert_eq!(expected, array_to_string(&result)?);

Ok(())
}

// Convert the array to a newline delimited string of pretty printed values
fn array_to_string(array: &ArrayRef) -> Result<String> {
let s = (0..array.len())
.map(|i| array_value_to_string(array, i))
.collect::<std::result::Result<Vec<_>, arrow::error::ArrowError>>()?
.join("\n");
Ok(s)
}

#[test]
fn test_coersion_error() -> Result<()> {
let expr =
Expand Down
72 changes: 67 additions & 5 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use std::sync::Arc;
extern crate arrow;
extern crate datafusion;

use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::TimeUnit};
use arrow::{datatypes::Int32Type, record_batch::RecordBatch};
use arrow::{
datatypes::{DataType, Field, Schema, SchemaRef},
util::display::array_value_to_string,
Expand Down Expand Up @@ -930,14 +930,20 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) {
/// Execute query and return result set as 2-d table of Vecs
/// `result[row][column]`
async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec<Vec<String>> {
let plan = ctx.create_logical_plan(&sql).unwrap();
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx.create_logical_plan(&sql).expect(&msg);
let logical_schema = plan.schema();
let plan = ctx.optimize(&plan).unwrap();

let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan);
let plan = ctx.optimize(&plan).expect(&msg);
let optimized_logical_schema = plan.schema();
let plan = ctx.create_physical_plan(&plan).unwrap();

let msg = format!("Creating physical plan for '{}': {:?}", sql, plan);
let plan = ctx.create_physical_plan(&plan).expect(&msg);
let physical_schema = plan.schema();

let results = ctx.collect(plan).await.unwrap();
let msg = format!("Executing physical plan for '{}': {:?}", sql, plan);
let results = ctx.collect(plan).await.expect(&msg);

assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref());
assert_eq!(logical_schema.as_ref(), physical_schema.as_ref());
Expand Down Expand Up @@ -1238,3 +1244,59 @@ async fn query_count_distinct() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test demonstrates DictionaryArrays being used in DataFusion

#[tokio::test]
async fn query_on_string_dictionary() -> Result<()> {
// Test to ensure DataFusion can operate on dictionary types
// Use StringDictionary (32 bit indexes = keys)
let field_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)]));

let keys_builder = PrimitiveBuilder::<Int32Type>::new(10);
let values_builder = StringBuilder::new(10);
let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder);

builder.append("one")?;
builder.append_null()?;
builder.append("three")?;
let array = Arc::new(builder.finish());

let data = RecordBatch::try_new(schema.clone(), vec![array])?;

let table = MemTable::new(schema, vec![vec![data]])?;
let mut ctx = ExecutionContext::new();
ctx.register_table("test", Box::new(table));

// Basic SELECT
let sql = "SELECT * FROM test";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["one"], vec!["NULL"], vec!["three"]];
assert_eq!(expected, actual);

// basic filtering
let sql = "SELECT * FROM test WHERE d1 IS NOT NULL";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["one"], vec!["three"]];
assert_eq!(expected, actual);

// filtering with constant
let sql = "SELECT * FROM test WHERE d1 = 'three'";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["three"]];
assert_eq!(expected, actual);

// Expression evaluation
let sql = "SELECT concat(d1, '-foo') FROM test";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]];
assert_eq!(expected, actual);

// aggregation
let sql = "SELECT COUNT(d1) FROM test";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["2"]];
assert_eq!(expected, actual);

Ok(())
}